Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Revert "Fix infer layer" (#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Jul 4, 2023
1 parent 6af0a9e commit 019b846
Showing 1 changed file with 24 additions and 39 deletions.
63 changes: 24 additions & 39 deletions sot/infer_meta.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import contextlib

import paddle
from paddle.fluid.unique_name import UniqueNameGenerator
from paddle.fluid.unique_name import guard as UniqueNameGuard
Expand Down Expand Up @@ -86,19 +84,6 @@ def __init__(self):
self.startup_program = Program()
self.var_name_generator = UniqueNameGenerator("infer_meta_variable_")

def static_guard(self):
@contextlib.contextmanager
def _static_guard():
with paddle.fluid.framework._dygraph_guard(None), UniqueNameGuard(
self.var_name_generator
):
with paddle.static.program_guard(
self.main_program, self.startup_program
):
yield

return _static_guard()

def gen_name(self, meta):
name = f"{meta.dtype}_{meta.stop_gradient}"
for l in meta.shape:
Expand All @@ -123,21 +108,27 @@ def get_variable(self, meta):
return self.var_cache[var_feature_name]

def infer_meta(self, func, *args, **kwargs):
with self.static_guard():
args, kwargs = convert_meta_to_variable(
args
), convert_meta_to_variable(kwargs)
if isinstance(func, str):
# TODO(Aurelius84): Is length of args always greater than 0?
# Do we need add condition check here?
out = getattr(args[0], func)(*args[1:], **kwargs)
else:
out = func(*args, **kwargs)
with paddle.fluid.framework._dygraph_guard(None), UniqueNameGuard(
self.var_name_generator
):
args, kwargs = convert_to_variable(args), convert_to_variable(
kwargs
)

with paddle.static.program_guard(
self.main_program, self.startup_program
):
if isinstance(func, str):
# TODO(Aurelius84): Is length of args always greater than 0?
# Do we need add condition check here?
out = getattr(args[0], func)(*args[1:], **kwargs)
else:
out = func(*args, **kwargs)

return convert_variable_to_meta_info(out)
return variable_to_meta_info(out)


def convert_meta_to_variable(args):
def convert_to_variable(args):
return map_if(
args,
pred=lambda x: isinstance(x, MetaInfo),
Expand All @@ -146,7 +137,7 @@ def convert_meta_to_variable(args):
)


def convert_meta_to_input_spec(args):
def convert_to_input_spec(args):
return map_if(
args,
pred=lambda x: isinstance(x, MetaInfo),
Expand All @@ -155,7 +146,7 @@ def convert_meta_to_input_spec(args):
)


def convert_variable_to_meta_info(args):
def variable_to_meta_info(args):
return map_if(
args,
pred=lambda x: isinstance(x, paddle.static.Variable),
Expand All @@ -177,16 +168,10 @@ def infer_meta_for_layer(layer, *args, **kwargs):
), f"Expect a Layer, but got {layer}."
layer = paddle.jit.to_static(layer, enable_fallback=False)

args_, kwargs_ = convert_meta_to_input_spec(
args
), convert_meta_to_input_spec(kwargs)
(
concrete_program,
partial_program_layer,
) = layer.forward.get_concrete_program(*args_, **kwargs_)
out = partial_program_layer._restore_out(
convert_variable_to_meta_info(concrete_program.outputs)
)
args, kwargs = convert_to_input_spec(args), convert_to_input_spec(kwargs)
concrete_program = layer.forward.get_concrete_program(*args, **kwargs)[0]
out = concrete_program.outputs[0]
out = MetaInfo.from_tensor(out)
layer.forward.rollback()
return out

Expand Down

0 comments on commit 019b846

Please sign in to comment.