From 48946cff04b50f591c6e6fe65a175b92df46246d Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Sun, 8 Oct 2023 09:42:00 +0000 Subject: [PATCH] update --- sot/opcode_translator/executor/function_graph.py | 3 +-- sot/opcode_translator/executor/variables/callable.py | 8 -------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index 349c5fed..0ecc57c5 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -301,7 +301,7 @@ def start_compile(self, *ret_vars: VariableBase): found = False for variable in self.input_variables: if ( - isinstance(variable, (TensorVariable, PaddleLayerVariable)) + isinstance(variable, TensorVariable) and variable.get_symbol().name == name ): variable.tracker.gen_instructions(self.pycode_gen) @@ -426,7 +426,6 @@ def call_layer( """ def infer_meta_fn(layer, *metas, **kwmetas): - metas = metas[1:] metas = LayerInferMetaCache()(layer.value, *metas, **kwmetas) return metas diff --git a/sot/opcode_translator/executor/variables/callable.py b/sot/opcode_translator/executor/variables/callable.py index 59dc8c98..3c9406c8 100644 --- a/sot/opcode_translator/executor/variables/callable.py +++ b/sot/opcode_translator/executor/variables/callable.py @@ -9,10 +9,8 @@ import paddle from .... import psdb -from ....symbolic.statement_ir import Symbol from ....utils import ( EventGuard, - NameGenerator, is_break_graph_api, is_break_graph_tensor_methods, is_builtin_fn, @@ -503,16 +501,10 @@ class PaddleLayerVariable(LayerVariable): tracker(Tracker): The Tracker object that tracks the information of this variable. """ - layer_name_generator = NameGenerator("layer_") - def __init__( self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker ): super().__init__(layer, graph, tracker) - self.name = self.layer_name_generator.next() - - def get_symbol(self) -> Symbol: - return Symbol(self.name) def call_function(self, /, *args, **kwargs): return self.graph.call_layer(self, *args, **kwargs)