Skip to content

Commit

Permalink
use lazy import
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Mar 16, 2023
1 parent 47e6ca8 commit 78c6968
Show file tree
Hide file tree
Showing 5 changed files with 4,738 additions and 23 deletions.
360 changes: 360 additions & 0 deletions lazy_importer/lazy_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
# -*- coding: utf-8 -*-
import ast
import copy
import importlib
import inspect
import logging
import sys
from textwrap import dedent
from types import FrameType
from typing import Any, List, Optional, Set, Union

import pyccolo as pyc
from pyccolo._fast.misc_ast_utils import subscript_to_slice

logger = logging.getLogger(__name__)


_unresolved = object()

imports_file = set()
import_uses_file = set()
nn_modules_file = set()
functions_file = set()
classes_file = set()


class _LazySymbol:
non_modules: Set[str] = set()
blocklist_packages: Set[str] = set()

def __init__(self, spec: Union[ast.Import, ast.ImportFrom]):
self.spec = spec
self.value = _unresolved

@property
def qualified_module(self) -> str:
node = self.spec
name = node.names[0].name
if isinstance(node, ast.Import):
return name
else:
return f"{node.module}.{name}"

@staticmethod
def top_level_package(module: str) -> str:
return module.split(".", 1)[0]

@classmethod
def _unwrap_module(cls, module: str) -> Any:
if module in sys.modules:
return sys.modules[module]
exc = None
if module not in cls.non_modules:
try:
with pyc.allow_reentrant_event_handling():
return importlib.import_module(module)
except ImportError as e:
cls.non_modules.add(module)
exc = e
except Exception:
logger.error("fatal error trying to import module %s", module)
raise
module_symbol = module.rsplit(".", 1)
if len(module_symbol) != 2:
raise exc
else:
module, symbol = module_symbol
ret = getattr(cls._unwrap_module(module), symbol)
if isinstance(ret, _LazySymbol):
ret = ret.unwrap()
handle_unwrapped_ret(ret)
return ret

def _unwrap_helper(self) -> Any:
return self._unwrap_module(self.qualified_module)

def unwrap(self) -> Any:
if self.value is not _unresolved:
return self.value
ret = self._unwrap_helper()
handle_unwrapped_ret(ret)
self.value = ret
return ret

def __call__(self, *args, **kwargs):
raise TypeError("cant call _LazyName for spec %s" % ast.unparse(self.spec))

def __getattr__(self, item):
raise TypeError(
"cant __getattr__ on _LazyName for spec %s" % ast.unparse(self.spec)
)


class _GetLazyNames(ast.NodeVisitor):
def __init__(self) -> None:
self.lazy_names: Optional[Set[str]] = set()

def visit_Import(self, node: ast.Import) -> None:
if self.lazy_names is None:
return
for alias in node.names:
if alias.asname is None:
return
for alias in node.names:
self.lazy_names.add(alias.asname)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self.lazy_names is None:
return
for alias in node.names:
if alias.name == "*":
self.lazy_names = None
return
for alias in node.names:
self.lazy_names.add(alias.asname or alias.name)

@classmethod
def compute(cls, node: ast.Module) -> Set[str]:
inst = cls()
inst.visit(node)
return inst.lazy_names


def _make_attr_guard_helper(node: ast.Attribute) -> Optional[str]:
if isinstance(node.value, ast.Name):
return f"{node.value.id}_O_{node.attr}"
elif isinstance(node.value, ast.Attribute):
prefix = _make_attr_guard_helper(node.value)
if prefix is None:
return None
else:
return f"{prefix}_O_{node.attr}"
else:
return None


def _make_attr_guard(node: ast.Attribute) -> Optional[str]:
suffix = _make_attr_guard_helper(node)
if suffix is None:
return None
else:
return f"_Xix_{suffix}"


def _make_subscript_guard_helper(node: ast.Subscript) -> Optional[str]:
slice_val = subscript_to_slice(node)
if isinstance(slice_val, (ast.Constant, ast.Str, ast.Num, ast.Name)):
if isinstance(slice_val, ast.Name):
subscript = slice_val.id
elif hasattr(slice_val, "s"):
subscript = f"_{slice_val.s}_" # type: ignore
elif hasattr(slice_val, "n"):
subscript = f"_{slice_val.n}_" # type: ignore
else:
return None
else:
return None
if isinstance(node.value, ast.Name):
return f"{node.value.id}_S_{subscript}"
elif isinstance(node.value, ast.Subscript):
prefix = _make_subscript_guard_helper(node.value)
if prefix is None:
return None
else:
return f"{prefix}_S_{subscript}"
else:
return None


def _make_subscript_guard(node: ast.Subscript) -> Optional[str]:
suffix = _make_subscript_guard_helper(node)
if suffix is None:
return None
else:
return f"_Xix_{suffix}"


MODULE_TARGET = "diffusers"


def handle_unwrapped_ret(ret):
import torch

if hasattr(ret, "__module__") and ret.__module__.startswith(MODULE_TARGET):
if inspect.isclass(ret) and issubclass(ret, torch.nn.Module):
src = dedent(inspect.getsource(ret))
nn_modules_file.add(src)
# print("wtfbbq unwrapped", ret)
elif isinstance(ret, torch.nn.Module):
src = dedent(inspect.getsource(ret.__class__))
nn_modules_file.add(src)
# print("wtfbbq unwrapped", ret)
elif inspect.isclass(ret):
src = dedent(inspect.getsource(ret))
classes_file.add(src)
# print("wtfbbq unwrapped", ret)
elif inspect.isfunction(ret):
if ret.__name__ == "__init__":
# wrapped initializers (somehow not bound?)
return
src = dedent(inspect.getsource(ret))
functions_file.add(src)
else:
print("wtfbbq unwrapped", ret)


class LazyImportTracer(pyc.BaseTracer):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.cur_module_lazy_names: Set[str] = set()
self.saved_attributes: List[Any] = []
self.saved_subscripts: List[Any] = []
self.saved_slices: List[Any] = []

def _is_name_lazy_load(self, node: Union[ast.Attribute, ast.Name]) -> bool:
if self.cur_module_lazy_names is None:
return True
elif isinstance(node, ast.Name):
return node.id in self.cur_module_lazy_names
elif isinstance(node, (ast.Attribute, ast.Subscript)):
return self._is_name_lazy_load(node.value) # type: ignore
elif isinstance(node, ast.Call):
return self._is_name_lazy_load(node.func)
else:
return False

def static_init_module(self, node: ast.Module) -> None:
self.cur_module_lazy_names = _GetLazyNames.compute(node)

@staticmethod
def _convert_relative_to_absolute(
package: str, module: Optional[str], level: int
) -> str:
prefix = package.rsplit(".", level - 1)[0]
if not module:
return prefix
else:
return f"{prefix}.{module}"

@pyc.init_module
def init_module(
self, _ret: None, node: ast.Module, frame: FrameType, *_, **__
) -> None:
assert node is not None
for guard in self.local_guards_by_module_id.get(id(node), []):
frame.f_globals[guard] = False

@pyc.before_call()
def before_call(
self,
ret: None,
node: Union[ast.Import, ast.ImportFrom],
frame: FrameType,
*_,
**__,
) -> Any:
handle_unwrapped_ret(ret)
return ret

@pyc.before_stmt(
when=pyc.Predicate(
lambda node: isinstance(node, (ast.Import, ast.ImportFrom))
and pyc.is_outer_stmt(node),
static=True,
)
)
def before_stmt(
self,
_ret: None,
node: Union[ast.Import, ast.ImportFrom],
frame: FrameType,
*_,
**__,
) -> Any:
is_import = isinstance(node, ast.Import)
for alias in node.names:
if alias.name == "*":
return None
elif is_import and alias.asname is None:
return None
package = frame.f_globals["__package__"]
level = getattr(node, "level", 0)
if is_import:
module = None
else:
module = node.module # type: ignore
if level > 0:
module = self._convert_relative_to_absolute(package, module, level)
for alias in node.names:
node_cpy = copy.deepcopy(node)
node_cpy.names = [alias]
if module is not None:
node_cpy.module = module # type: ignore
node_cpy.level = 0 # type: ignore
lz = _LazySymbol(spec=node_cpy)
frame.f_globals[alias.asname or alias.name] = lz
return pyc.Pass

@pyc.before_attribute_load(
when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_attr_guard
)
def before_attr_load(self, ret: Any, *_, **__) -> Any:
self.saved_attributes.append(ret)
return ret

@pyc.after_attribute_load(
when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_attr_guard
)
def after_attr_load(
self, ret: Any, node: ast.Attribute, frame: FrameType, _evt, guard, *_, **__
) -> Any:
if guard is not None:
frame.f_globals[guard] = True
saved_attr_obj = self.saved_attributes.pop()
if isinstance(ret, _LazySymbol):
ret = ret.unwrap()
handle_unwrapped_ret(ret)
setattr(saved_attr_obj, node.attr, ret)
return pyc.Null if ret is None else ret

@pyc.before_subscript_load(
when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_subscript_guard
)
def before_subscript_load(self, ret: Any, *_, attr_or_subscript: Any, **__) -> Any:
self.saved_subscripts.append(ret)
self.saved_slices.append(attr_or_subscript)
return ret

@pyc.after_subscript_load(
when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_subscript_guard
)
def after_subscript_load(
self, ret: Any, _node, frame: FrameType, _evt, guard, *_, **__
) -> Any:
if guard is not None:
frame.f_globals[guard] = True
saved_subscript_obj = self.saved_subscripts.pop()
saved_slice_obj = self.saved_slices.pop()
if isinstance(ret, _LazySymbol):
ret = ret.unwrap()
handle_unwrapped_ret(ret)
saved_subscript_obj[saved_slice_obj] = ret
return pyc.Null if ret is None else ret

@pyc.load_name(
when=pyc.Predicate(_is_name_lazy_load, static=True),
guard=lambda node: f"_Xix_{node.id}_guard",
)
def load_name(
self, ret: Any, node: ast.Name, frame: FrameType, _evt, guard, *_, **__
) -> Any:
if guard is not None:
frame.f_globals[guard] = True
if isinstance(ret, _LazySymbol):
ret = ret.unwrap()
handle_unwrapped_ret(ret)
frame.f_globals[node.id] = ret
else:
print("notlazysymbol", ret)
return pyc.Null if ret is None else ret
Loading

0 comments on commit 78c6968

Please sign in to comment.