Skip to content

Commit

Permalink
Output model.explain.onnx with mask and encoding output. Not validate…
Browse files Browse the repository at this point in the history
…d and probably not correct but a start to plumb through to the UI
  • Loading branch information
rcurrie committed Dec 10, 2024
1 parent dba08d9 commit a4e895d
Showing 1 changed file with 28 additions and 36 deletions.
64 changes: 28 additions & 36 deletions scripts/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
parser.add_argument("sample", type=str, help="Path to the sample for validation")
args = parser.parse_args()

model_name = args.onnx.split("/")[-1].split(".")[0]
dest = os.path.dirname(args.onnx)

# Load the checkpoint
sims = SIMS(weights_path=args.checkpoint, map_location=torch.device("cpu"), weights_only=True)
sims.model.eval() # Turns off training mode
Expand All @@ -34,12 +37,12 @@
sims_logits = sims.model(x)[0][0].detach().numpy()
np.testing.assert_array_almost_equal(onnx_logits, sims_logits, decimal=3)


embedded_x = sims.model.network.embedder(x)
np.count_nonzero(embedded_x)


# Load the current production model
model = onnx.load(args.onnx)
g = model.graph

# Expose the last concat output of 3 x 32, encodings?
# candidate = "/network/tabnet/Concat_output_0" # 3 x 32
Expand All @@ -49,61 +52,50 @@
if node.name == candidate:
print(idx, node)
break
model.graph.output.extend([node])
so.list_outputs(model.graph)
g = so.rename_output(model.graph, candidate, "encoding")
g.output.extend([node])
so.list_outputs(g)
g = so.rename_output(g, candidate, "encoding")
so.list_outputs(g)

onnx.checker.check_model(model)
so.graph_to_file(model.graph, "data/temp.onnx")
model = onnx.load("data/temp.onnx")
result = so.run(model.graph,
inputs={"input": x.detach().numpy()},
outputs=["encoding"])



# Expose the masks
candidate = "/network/tabnet/encoder/att_transformers.0/selector/Clip_output_0"
model = onnx.load(args.onnx)
shape_info = onnx.shape_inference.infer_shapes(model)
for idx, node in enumerate(shape_info.graph.value_info):
# if re.search("Clip", node.name, re.IGNORECASE):
if node.name == candidate:
# print(idx, node)
print(node)
break
model.graph.output.extend([node])
so.list_outputs(model.graph)
onnx.checker.check_model(model)
so.graph_to_file(model.graph, "data/temp.onnx")
model = onnx.load("data/temp.onnx")
result = so.run(model.graph,
inputs={"input": x.detach().numpy()},
outputs=["logits", candidate])
np.count_nonzero(result[1][0])
np.nonzero(result[1][0])

g = so.rename_output(g, candidate, "mask")
so.list_outputs(g)

so.graph_to_file(g, f"{dest}/{model_name}.explain.onnx")

# Get forward mask from tabnet
explain, masks = sims.model.network.forward_masks(x)
np.count_nonzero(explain.detach().numpy())

model = onnx.load(f"{dest}/{model_name}.explain.onnx")
onnx.checker.check_model(model)
so.list_outputs(model.graph)
result = so.run(model.graph,
inputs={"input": x.detach().numpy()},
outputs=["encoding", "mask"])

np.count_nonzero(result[1][0])
np.nonzero(result[1][0])



# # Get forward mask from tabnet
# explain, masks = sims.model.network.forward_masks(x)
# np.count_nonzero(explain.detach().numpy())

# Predict the cell types and get an explanation matrix
cell_predictions = sims.predict(args.sample)
explainability_matrix = sims.explain(args.sample)

# # Predict the cell types and get an explanation matrix
# cell_predictions = sims.predict(args.sample)
# explainability_matrix = sims.explain(args.sample)


a = explainability_matrix[0][0]
b = explain.detach().numpy()
np.max(np.setdiff1d(a, b))
# a = explainability_matrix[0][0]
# b = explain.detach().numpy()
# np.max(np.setdiff1d(a, b))



Expand Down

0 comments on commit a4e895d

Please sign in to comment.