Skip to content

Commit

Permalink
chore: minor cleanup in layer dependency solver
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Jan 3, 2024
1 parent 13ded09 commit cc943b6
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/HGQ/proxy/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
def get_all_nodes(model: keras.Model) -> set[Node]:
"""Get all nodes in the model as a set."""
nodes = set()
layers = set(model.layers)
for layer in model.layers:
for node in layer._inbound_nodes:
nodes.add(node)
if node.layer in layers:
nodes.add(node)
for node in layer._outbound_nodes:
nodes.add(node)
if node.layer in layers:
nodes.add(node)
return nodes


Expand All @@ -40,9 +43,7 @@ def solve_dependencies(model: keras.Model):
for node in nodes:
if node.is_input:
continue
layer = node.layer
if layer not in model.layers:
continue
layer = node.layer # layer that is called on the node
requires = list(node.parent_nodes)
provides = node
dependencies_list.append((layer, requires, provides))
Expand Down Expand Up @@ -261,7 +262,6 @@ def to_proxy_model(model: keras.Model, aggressive: bool = True, accum_fp_max_off
if accum_fp_max_offset is not None and accum_fp_max_offset < 0:
warn('You are using a negative value for bias_accum_bits. Please make sure you know what you are doing.')

nof_output = len(output_nodes)
inputs = [keras.layers.Input(shape=node.input_shapes[0][1:]) for node in input_nodes]
satisfied = {node: tensor for node, tensor in zip(input_nodes, inputs)}
outputs = []
Expand All @@ -271,7 +271,7 @@ def to_proxy_model(model: keras.Model, aggressive: bool = True, accum_fp_max_off
SAT = 'WRAP'
else:
SAT = 'SAT'
while dependencies_list and not len(outputs) == nof_output:
while dependencies_list:
layer, requires, provides = dependencies_list.pop(0)
if set(requires).issubset(satisfied):
inps = [satisfied[node] for node in requires]
Expand Down

0 comments on commit cc943b6

Please sign in to comment.