Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 committed Oct 8, 2023
1 parent 5fc1ee6 commit 4648ed8
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 25 deletions.
6 changes: 2 additions & 4 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,8 @@ def infer_meta_fn(layer, *metas, **kwmetas):
return metas

def compute_fn(layer, inputs, outputs, stacks):
inputs = (layer.get_symbol(), *inputs)
inputs = inputs[1:]
self.sir_ctx.call_LAYER(
layer.value.__class__.__name__,
layer.value,
inputs=inputs,
outputs=outputs,
stacks=stacks,
Expand All @@ -444,7 +442,7 @@ def message_handler(*args, **kwargs):
return f"Call paddle layer error: {layer}, may be not a valid paddle layer ?"

return inner_error_default_handler(self.symbolic_call, message_handler)(
infer_meta_fn, compute_fn, layer, *[layer, *args], **kwargs
infer_meta_fn, compute_fn, layer, *args, **kwargs
)

@event_register("symbolic_call", event_level=2)
Expand Down
16 changes: 7 additions & 9 deletions sot/symbolic/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,26 +120,24 @@ def _set(v, s):
return replace_symbol(SIR.outputs, state)

def call(self, stmt: Statement, inputs):
SIR = self.get_sir(stmt.name)
SIR = self.get_sir(stmt.sir_name)
state = prepare_state(SIR, inputs)
return self.run_sir(stmt.name, state)
return self.run_sir(stmt.sir_name, state)

def api(self, stmt, inputs):
args, kwargs = inputs
return stmt.name(*args, **kwargs)
return stmt.api(*args, **kwargs)

def method(self, stmt, inputs):
args, kwargs = inputs
var = args[0]
return getattr(var, stmt.name)(*args[1:], **kwargs)
return getattr(var, stmt.method)(*args[1:], **kwargs)

def layer(self, stmt, inputs):
args, kwargs = inputs
layer, args = args[0], args[1:]
return layer(*args, **kwargs)

def delete(self, stmt, inputs):
pass
layer = stmt.layer()
assert layer is not None, "SIR bounded layer"
return stmt.layer(*args, **kwargs)


def compile_sir(context: SymbolicTraceContext, name: str):
Expand Down
63 changes: 57 additions & 6 deletions sot/symbolic/statement_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
"""
from __future__ import annotations

import weakref
from typing import Callable

import paddle
from paddle.utils import is_sequence, map_structure

from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend
Expand Down Expand Up @@ -69,22 +73,69 @@ def to_string(inps):
inps = (x.__str__() for x in inps)
return ", ".join(inps)

name = (
self.name
if isinstance(self.name, str)
else "paddle." + self.name.__name__
)
return "{} || {} = {} ({}) ".format(
self.type + " " * (10 - len(self.type)),
to_string(self.outputs),
name,
self.name,
to_string(self.inputs),
)

def __repr__(self):
return self.__str__()


class CallStatement(Statement):
def __init__(
self,
name: str,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__("call", name, inputs, outputs, stacks)
self.sir_name = name


class ApiStatement(Statement):
def __init__(
self,
api: Callable,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__(
"api", "paddle." + api.__name__, inputs, outputs, stacks
)
self.api = api


class MethodStatement(Statement):
def __init__(
self,
name: str,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__("method", name, inputs, outputs, stacks)
self.method = name


class LayerStatement(Statement):
def __init__(
self,
layer: paddle.nn.Layer,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__(
"layer", layer.__class__.__name__, inputs, outputs, stacks
)
self.layer = weakref.ref(layer)


class StatementIR:
"""
StatementIR is the carrier that records the code for building the neural network model.It is
Expand Down
20 changes: 14 additions & 6 deletions sot/symbolic/symbolic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

from ..utils import event_register, log
from .compile_cache import CompileSIRCache
from .statement_ir import Statement, StatementIR, StatementIRFactory, Symbol
from .statement_ir import (
ApiStatement,
CallStatement,
LayerStatement,
MethodStatement,
StatementIR,
StatementIRFactory,
Symbol,
)


class SymbolicTraceContext:
Expand Down Expand Up @@ -41,7 +49,7 @@ def call_SIR(self, sirname, inputs, outputs, stacks):
Call a SIR, which is a subgraph.
"""

stmt = Statement("call", sirname, inputs, outputs, stacks)
stmt = CallStatement(sirname, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_API", event_level=2)
Expand All @@ -51,7 +59,7 @@ def call_API(self, api, inputs, outputs, stacks):
"""

assert callable(api), "call_API must receive a paddle api."
stmt = Statement("api", api, inputs, outputs, stacks)
stmt = ApiStatement(api, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_METHOD", event_level=2)
Expand All @@ -65,15 +73,15 @@ def call_METHOD(self, method_name, inputs, outputs, stacks):
assert isinstance(
inputs[0][0], Symbol
), "call_METHOD must first augument must be Symbol Variable."
stmt = Statement("method", method_name, inputs, outputs, stacks)
stmt = MethodStatement(method_name, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_LAYER", event_level=2)
def call_LAYER(self, layer_name, inputs, outputs, stacks):
def call_LAYER(self, layer, inputs, outputs, stacks):
"""
Call a layer of a api.
"""
stmt = Statement("layer", layer_name, inputs, outputs, stacks)
stmt = LayerStatement(layer, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

def get_sir(self, name: str):
Expand Down

0 comments on commit 4648ed8

Please sign in to comment.