diff --git a/tools/codegen/api/dispatcher.py b/tools/codegen/api/dispatcher.py index f802576e13a35..2c3b3adb9115c 100644 --- a/tools/codegen/api/dispatcher.py +++ b/tools/codegen/api/dispatcher.py @@ -6,7 +6,7 @@ import tools.codegen.local as local import itertools -from typing import Sequence, Optional +from typing import Sequence, Optional, Tuple # This file describes the translation of JIT schema to the dispatcher # API, the *unboxed* calling convention by which invocations through @@ -65,14 +65,14 @@ def argument(a: Argument) -> DispatcherArgument: def name(func: FunctionSchema) -> str: return cpp.name(func) -def arguments(func: FunctionSchema) -> Sequence[DispatcherArgument]: +def arguments(func: FunctionSchema) -> Tuple[DispatcherArgument, ...]: if local.use_c10_dispatcher().dispatcher_uses_new_style(): - return list(map(argument, itertools.chain(func.out_arguments, func.arguments, func.kwarg_only_arguments))) + return tuple(map(argument, itertools.chain(func.out_arguments, func.arguments, func.kwarg_only_arguments))) else: - return [ + return tuple( DispatcherArgument(type=la.type, name=la.name, argument=la.argument) for la in native.arguments(func) - ] + ) # Given a set of CppArguments in scope, return a sequence of dispatcher # expressions that translate the cpp API into dispatcher API diff --git a/tools/codegen/api/native.py b/tools/codegen/api/native.py index ef8ebab287643..aa40538736da3 100644 --- a/tools/codegen/api/native.py +++ b/tools/codegen/api/native.py @@ -3,7 +3,7 @@ from tools.codegen.api.types import TensorOptionsArguments, NativeArgument, ThisArgument import tools.codegen.api.cpp as cpp -from typing import Union, Sequence +from typing import Union, Sequence, Tuple # This file describes the translation of JIT schema to the native functions API. # This looks a lot like the C++ API (which makes historical sense, because the @@ -74,5 +74,5 @@ def argument(a: Union[Argument, ThisArgument, TensorOptionsArguments]) -> Native else: assert_never(a) -def arguments(func: FunctionSchema) -> Sequence[NativeArgument]: - return list(map(argument, cpp.group_arguments(func, method=False))) +def arguments(func: FunctionSchema) -> Tuple[NativeArgument, ...]: + return tuple(map(argument, cpp.group_arguments(func, method=False))) diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py index 7952ced42f441..268592590ff50 100644 --- a/tools/codegen/api/types.py +++ b/tools/codegen/api/types.py @@ -198,9 +198,11 @@ def decl(self) -> str: # Render the C++ definition for this signature, not including # the body (with curly braces) - def defn(self, prefix: str = "") -> str: + def defn(self, name: Optional[str] = None, *, prefix: str = "") -> str: cpp_args_str = ', '.join(a.str_no_default() for a in self.arguments()) - return f"{self._returns_type} {prefix}{cpp.name(self.func)}({cpp_args_str})" + if name is None: + name = prefix + cpp.name(self.func) + return f"{self._returns_type} {name}({cpp_args_str})" # NB: This constructor knows how to disambiguate defaults when # faithful is True. Ideally this would live as an external process @@ -280,6 +282,47 @@ class DispatcherArgument: def __str__(self) -> str: return f"{self.type} {self.name}" +@dataclass(frozen=True) +class DispatcherSignature: + # The schema this signature is derived from + func: FunctionSchema + + # Note to self: if we ever need to reassemble tensor options, we may need to + # also preserve grouping with DispatcherTensorOptionsArguments. This should + # be an unlikely situation, however, since the general direction we are + # headed is to make native:: take everything in expanded form, so you + # shouldn't need to reassemble + _arguments: Tuple[DispatcherArgument, ...] + _returns_type: str + + def arguments(self) -> Tuple[DispatcherArgument, ...]: + return self._arguments + + def defn(self, name: Optional[str] = None) -> str: + args_str = ', '.join(map(str, self.arguments())) + if name is None: + name = native.name(self.func) + return f"{self._returns_type} {name}({args_str})" + + def exprs(self) -> Sequence[DispatcherExpr]: + return dispatcher.exprs(self.arguments()) + + # Return the C++ function type, e.g., something like int(bool) + def type(self) -> str: + dispatcher_args_types_str = ', '.join(map(lambda a: a.type, self._arguments)) + return f'{self._returns_type} ({dispatcher_args_types_str})' + + @staticmethod + def from_schema(func: FunctionSchema) -> 'DispatcherSignature': + arguments = dispatcher.arguments(func) + returns_type = dispatcher.returns_type(func.returns) + + return DispatcherSignature( + func=func, + _arguments=arguments, + _returns_type=returns_type, + ) + # ------------------------------------------------------------------- # # native types (NativeFunctions.h) @@ -320,5 +363,36 @@ def str_with_default(self) -> str: mb_default = f"={self.default}" return f"{self.type} {self.name}{mb_default}" +@dataclass(frozen=True) +class NativeSignature: + # The schema this signature is derived from + func: FunctionSchema + + _arguments: Tuple[NativeArgument, ...] + _returns_type: str + + def defn(self, name: Optional[str] = None) -> str: + args_str = ', '.join(map(str, self.arguments())) + if name is None: + name = dispatcher.name(self.func) + return f"{self._returns_type} {name}({args_str})" + + def arguments(self) -> Tuple[NativeArgument, ...]: + return self._arguments + + def dispatcher_exprs(self) -> Sequence['DispatcherExpr']: + return dispatcher.nativearguments_exprs(self.arguments()) + + @staticmethod + def from_schema(func: FunctionSchema) -> 'NativeSignature': + arguments = native.arguments(func) + returns_type = native.returns_type(func.returns) + + return NativeSignature( + func=func, + _arguments=arguments, + _returns_type=returns_type, + ) + # Functions only, no types -import tools.codegen.api.cpp as cpp +from tools.codegen.api import cpp, dispatcher, native diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index cd8d3937ea70f..e4a13ccebeadc 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -269,9 +269,8 @@ def func(f: NativeFunction) -> Optional[str]: """ elif target is Target.REGISTRATION: - assert returns_type == dispatcher.returns_type(f.func.returns) - dispatcher_args = dispatcher.arguments(f.func) - dispatcher_args_types_str = ', '.join(map(lambda a: a.type, dispatcher_args)) + dispatcher_sig = DispatcherSignature.from_schema(f.func) + if dispatch is None or dispatch == 'Math' or dispatch == 'DefaultBackend': type_name = f'TypeDefault::{name}' else: @@ -289,7 +288,8 @@ def func(f: NativeFunction) -> Optional[str]: payload = f"TORCH_FN({type_name})" elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures: payload = "c10::impl::hacky_wrapper_for_legacy_signatures<" \ - f"{returns_type} ({dispatcher_args_types_str})>(TORCH_FN({type_name}))" + f"{dispatcher_sig.type()}>(TORCH_FN({type_name}))" + else: assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper payload = f"torch::CppFunction::makeUnboxedOnly(&{type_name})" @@ -338,9 +338,9 @@ def go(f: NativeFunction) -> Optional[str]: assert target is Target.DEFINITION def generate_defn(sig: CppSignature) -> str: + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher.cpparguments_exprs(sig.argument_packs()) - dispatcher_returns_type = dispatcher.returns_type(f.func.returns) - dispatcher_types_str = ', '.join(map(lambda a: a.type, dispatcher_exprs)) dispatcher_exprs_str = ', '.join(map(lambda a: a.expr, dispatcher_exprs)) return f""" @@ -348,7 +348,7 @@ def generate_defn(sig: CppSignature) -> str: {sig.defn()} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") - .typed<{dispatcher_returns_type} ({dispatcher_types_str})>(); + .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ @@ -388,17 +388,17 @@ def go(f: NativeFunction) -> Optional[str]: assert target is Target.DEFINITION def generate_defn(sig: CppSignature) -> str: + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher.cpparguments_exprs(sig.argument_packs()) - dispatcher_returns_type = dispatcher.returns_type(f.func.returns) - dispatcher_types_str = ', '.join(map(lambda a: a.type, dispatcher_exprs)) dispatcher_exprs_str = ', '.join(map(lambda a: a.expr, dispatcher_exprs)) return f""" // aten::{f.func} -{sig.defn("Tensor::")} const {{ +{sig.defn(prefix="Tensor::")} const {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") - .typed<{dispatcher_returns_type} ({dispatcher_types_str})>(); + .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ @@ -455,30 +455,26 @@ def go(f: NativeFunction) -> Optional[str]: return None name = native.name(f.func) - native_returns_type = native.returns_type(f.func.returns) - native_args = native.arguments(f.func) + native_sig = NativeSignature.from_schema(f.func) - if not any(isinstance(a.argument, TensorOptionsArguments) for a in native_args): + if not any(isinstance(a.argument, TensorOptionsArguments) for a in native_sig.arguments()): return None native_tensor_args = [ - a for a in native_args + a for a in native_sig.arguments() if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like() ] - dispatcher_returns_type = dispatcher.returns_type(f.func.returns) - dispatcher_args = dispatcher.arguments(f.func) + dispatcher_sig = DispatcherSignature.from_schema(f.func) - args: Union[Sequence[DispatcherArgument], Sequence[NativeArgument]] + sig: Union[NativeSignature, DispatcherSignature] if local.use_c10_dispatcher().dispatcher_uses_new_style(): - returns_type = dispatcher_returns_type - args = dispatcher_args - exprs = dispatcher.exprs(dispatcher_args) + sig = dispatcher_sig + dispatcher_exprs = dispatcher_sig.exprs() dispatch_key = "c10::computeDispatchKey(dtype, layout, device)" else: - returns_type = native_returns_type - args = native_args - exprs = dispatcher.nativearguments_exprs(native_args) + sig = native_sig + dispatcher_exprs = native_sig.dispatcher_exprs() dispatch_key = "options.computeDispatchKey()" if target is Target.DEFINITION: @@ -496,16 +492,16 @@ def go(f: NativeFunction) -> Optional[str]: compute_dk = f"DispatchKey _dk = {dispatch_key};" return f"""\ // aten::{f.func} -{returns_type} {name}({', '.join(str(a) for a in args)}) {{ +{sig.defn(name)} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") - .typed<{dispatcher_returns_type} ({', '.join(a.type for a in dispatcher_args)})>(); + .typed<{dispatcher_sig.type()}>(); {compute_dk} DispatchKey _autograd_dk = c10::getAutogradKeyFromBackend(_dk); // This trick allows calling Autograd backend kernel first and then backend kernel, // without adding another AutogradBackendSelect dispatch key. DispatchKey _current_dk = at::impl::variable_excluded_from_dispatch() ? _dk : _autograd_dk; - return op.callWithDispatchKey(_current_dk, {', '.join(a.expr for a in exprs)}); + return op.callWithDispatchKey(_current_dk, {', '.join(a.expr for a in dispatcher_exprs)}); }} """ elif target is Target.REGISTRATION: @@ -513,7 +509,7 @@ def go(f: NativeFunction) -> Optional[str]: return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));""" elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures: return f"""m.impl("aten::{f.func.name}", - c10::impl::hacky_wrapper_for_legacy_signatures<{dispatcher_returns_type} ({', '.join(a.type for a in dispatcher_args)})>( + c10::impl::hacky_wrapper_for_legacy_signatures<{dispatcher_sig.type()}>( TORCH_FN({name})));""" else: assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper