Skip to content

Commit

Permalink
Maxcut fixes (sony#1312)
Browse files Browse the repository at this point in the history
1. Fix A* estimate value.
2. Fix cuts to include last op input tensor.
  • Loading branch information
elad-c authored Jan 6, 2025
1 parent 5ed46c7 commit ce318c0
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 186 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def compute_graph_max_cut(memory_graph: MemoryGraph,
it = 0
while it < n_iter:
estimate = (u_bound + l_bound) / 2
schedule, max_cut_size, cuts = max_cut_astar.solve(estimate_factor=estimate, iter_limit=astar_n_iter)
schedule, max_cut_size, cuts = max_cut_astar.solve(estimate=estimate, iter_limit=astar_n_iter)
if schedule is None:
l_bound = estimate
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def __eq__(self, other) -> bool:
"""
if isinstance(other, Cut):
return self.mem_elements == other.mem_elements
return False
return False # pragma: no cover

def __hash__(self):
return hash((frozenset(self.op_order), frozenset(self.op_record), self.mem_elements))
return hash((frozenset(self.op_order), frozenset(self.op_record), self.mem_elements))

def __repr__(self):
return f"<Cut: Nodes={[e.node_name for e in self.mem_elements.elements]}, size={self.memory_size()}>" # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, memory_graph: MemoryGraph):
edges_src_ab = [(src_dummy_a, src_dummy_b)]
edges_src_ba = [(src_dummy_b, src_a) for src_a in memory_graph.sources_a]

# Target Cut
# Target Cut (Adding 2 consecutive dummy nodes so the final cut will include only dummy tensors).
target_dummy_a = next(gen_a)
target_dummy_a2 = next(gen_a)
target_dummy_b = next(gen_b)
Expand All @@ -122,13 +122,13 @@ def __init__(self, memory_graph: MemoryGraph):
self.target_cut = Cut([], set(), MemoryElements(elements={target_dummy_b, target_dummy_b2},
total_size=0))

def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[BaseNode], float, List[Cut]]:
def solve(self, estimate: float, iter_limit: int = 500) -> Tuple[List[BaseNode], float, List[Cut]]:
"""
The AStar solver function. This method runs an AStar-like search on the memory graph,
using the given estimate_factor as a heuristic gap for solutions to consider.
using the given estimate as a heuristic gap for solutions to consider.
Args:
estimate_factor: A multiplication factor which allows the search to consider larger size of nodes in each
estimate: Cut size estimation to consider larger size of nodes in each
expansion step, in order to fasten the algorithm divergence towards a solution.
iter_limit: An upper limit for the number of expansion steps that the algorithm preforms.
Expand All @@ -148,17 +148,14 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas

while expansion_count < iter_limit and len(open_list) > 0:
# Choose next node to expand
next_cut = self._get_cut_to_expand(open_list, costs, routes, estimate_factor)
next_cut = self._get_cut_to_expand(open_list, costs, routes, estimate)

cut_cost = costs[next_cut]
cut_route = routes[next_cut]

if next_cut == self.target_cut:
# TODO maxcut: Why do we filter the cuts (cut_route) but not the max cut size (cut_sost).
# This is a mismatch between max_cut and max(cuts).
# Also, unfiltered cut_route seems perfect, including input and output tensor sizes of current op.
return self._remove_dummys_from_path(cut_route[0].op_order), cut_cost,\
list(set([self._remove_dummys_from_cut(self.clean_memory_for_next_step(c)) for c in cut_route]))
return self._remove_dummy_nodes_from_path(cut_route[0].op_order), cut_cost,\
list(set([self._remove_dummy_tensors_from_cut(c) for c in cut_route]))

if self.is_pivot(next_cut):
# Can clear all search history
Expand All @@ -176,7 +173,7 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas
expansion_count += 1

# Only consider nodes that where not already visited
expanded_cuts = list(filter(lambda _c: _c not in closed_list, expanded_cuts))
expanded_cuts = [_c for _c in expanded_cuts if _c not in closed_list]
for c in expanded_cuts:
cost = self.accumulate(cut_cost, c.memory_size())
if c not in open_list:
Expand All @@ -191,7 +188,7 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas
self._update_expanded_node(c, cost, cut_route, open_list, costs, routes)

# Halt or No Solution
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
# TODO maxcut: this isn't covered in the coverage test. Add test and remove no cover
return None, 0, None # pragma: no cover

@staticmethod
Expand All @@ -214,21 +211,23 @@ def _update_expanded_node(cut: Cut, cost: float, route: List[Cut], open_list: Li
routes.update({cut: [cut] + route})

def _get_cut_to_expand(self, open_list: List[Cut], costs: Dict[Cut, float], routes: Dict[Cut, List[Cut]],
estimate_factor: float) -> Cut:
estimate: float) -> Cut:
"""
An auxiliary method for finding a cut for expanding the search out of a set of potential cuts for expansion.
Args:
open_list: The search open list.
costs: The search utility mapping between cuts and their cost.
routes: The search utility mapping between cuts and their routes.
estimate_factor: A multiplication factor to set extended boundaries on the potential cuts to exapand.
estimate: Cut size estimation to set extended boundaries on the potential cuts to expand.
Returns: A sorted list of potential cuts for expansion (ordered by lowest cost first).
"""
max_cut_len = max([len(routes[c]) for c in open_list])
ordered_cuts_list = sorted(open_list,
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate_factor)), -len(routes[c])))
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate)),
max_cut_len - len(routes[c])))

assert len(ordered_cuts_list) > 0
return ordered_cuts_list[0]
Expand Down Expand Up @@ -356,23 +355,24 @@ def ordering(cost_1, cost_2) -> bool:
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
return cost_1 < cost_2 # pragma: no cover

def estimate(self, cut: Cut, estimate_factor: float) -> float:
@staticmethod
def estimate(cut: Cut, estimate: float) -> float:
"""
A function that defines the estimation gap for the Astar search.
The estimation gap is used to sort the cuts that are considered for expanding the search in each iteration.
Args:
cut: A cut (not used in the default implementation, but can be used if overriding the method to consider
the actual cut in the estimation computation).
estimate_factor: The given estimate factor to the search.
estimate: The given estimate to the search.
Returns: An estimation value.
"""
return estimate_factor * self.memory_graph.memory_lbound_single_op
return estimate

@staticmethod
def get_init_estimate_factor(memory_graph: MemoryGraph) -> float:
def get_init_estimate(memory_graph: MemoryGraph) -> float: # pragma: no cover
"""
Returns an initial estimation value, which is based on the memory graph's upper and lower bounds.
Expand All @@ -383,12 +383,12 @@ def get_init_estimate_factor(memory_graph: MemoryGraph) -> float:
"""
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
l_bound = memory_graph.memory_lbound_single_op # pragma: no cover
u_bound = 2 * sum([t.total_size for t in memory_graph.b_nodes]) - l_bound # pragma: no cover
return (u_bound + l_bound) / 2 # pragma: no cover
l_bound = memory_graph.memory_lbound_single_op
u_bound = 2 * sum([t.total_size for t in memory_graph.b_nodes]) - l_bound
return (u_bound + l_bound) / 2

@staticmethod
def _remove_dummys_from_path(path: List[BaseNode]) -> List[BaseNode]:
def _remove_dummy_nodes_from_path(path: List[BaseNode]) -> List[BaseNode]:
"""
An auxiliary method which removes dummy nodes from a given list of nodes (a path in the graph).
Expand All @@ -401,7 +401,7 @@ def _remove_dummys_from_path(path: List[BaseNode]) -> List[BaseNode]:
return list(filter(lambda n: DUMMY_NODE not in n.name, path))

@staticmethod
def _remove_dummys_from_cut(cut: Cut) -> Cut:
def _remove_dummy_tensors_from_cut(cut: Cut) -> Cut:
"""
An auxiliary method which removes dummy nodes from a given cut.
Expand All @@ -411,7 +411,7 @@ def _remove_dummys_from_cut(cut: Cut) -> Cut:
Returns: The same cut without dummy nodes elements.
"""
filtered_memory_elements = set(filter(lambda elm: DUMMY_TENSOR not in elm.node_name, cut.mem_elements.elements))
filtered_memory_elements = set([elm for elm in cut.mem_elements.elements if DUMMY_TENSOR not in elm.node_name])

return Cut(cut.op_order, cut.op_record,
mem_elements=MemoryElements(filtered_memory_elements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,38 +49,6 @@
BATCH_INPUT_SHAPE = 'batch_input_shape'


def get_node_name_from_layer(layer: Layer) -> str:
"""
Get a node's name from the layer it was built from. For TensorFlowOpLayer
we remove the prefix "tf_op_layer".
Args:
layer: Keras Layer to get its corresponding node's name.
Returns:
Name of the node that was built from the passed layer.
"""

name = layer.name
if isinstance(layer, TensorFlowOpLayer): # remove TF op layer prefix
name = '_'.join(name.split('_')[3:])
return name


def is_layer_fake_quant(layer: Layer) -> bool:
"""
Check whether a Keras layer is a fake quantization layer or not.
Args:
layer: Layer to check if it's a fake quantization layer or not.
Returns:
Whether a Keras layer is a fake quantization layer or not.
"""
# in tf2.3 fake quant node is implemented as TensorFlowOpLayer, while in tf2.4 as TFOpLambda
return (isinstance(layer, TensorFlowOpLayer) and layer.node_def.op == FQ_NODE_OP_V2_3) or (
isinstance(layer, TFOpLambda) and layer.symbol == FQ_NODE_OP_V2_4)


class KerasModelBuilder(BaseModelBuilder):
"""
Builder for Keras models.
Expand Down Expand Up @@ -291,7 +259,7 @@ def _run_operation(self,
arg = n.weights.get(pos)
if arg is None:
if len(input_tensors) == 0:
Logger.critical(f"Couldn't find a weight or input tensor matching operator's "
Logger.critical(f"Couldn't find a weight or input tensor matching operator's " # pragma: no cover
f"argument name '{k}' in location {pos} for node {n.name}.")
arg = input_tensors.pop(0)
op_call_kwargs.update({k: arg})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def substitute(self,
strides = self._parse_tf_stride_dilation(conv_func_node, STRIDES)
if strides is None:
# Non-standard strides -> skip substitution.
return graph
return graph # pragma: no cover
conv_fw_attr[STRIDES] = strides

padding = conv_func_node.op_call_kwargs.get(PADDING) or 'VALID'
Expand All @@ -153,7 +153,7 @@ def substitute(self,
dilations = self._parse_tf_stride_dilation(conv_func_node, DILATIONS)
if dilations is None:
# Non-standard dilations -> skip substitution.
return graph
return graph # pragma: no cover
conv_fw_attr[DILATION_RATE] = dilations

if b is None:
Expand Down
Loading

0 comments on commit ce318c0

Please sign in to comment.