Skip to content

Commit

Permalink
Refactor dispatcher and native to use Signature structure. (pytorch#4…
Browse files Browse the repository at this point in the history
…5990)

Summary:
Pull Request resolved: pytorch#45990

In pytorch#45890 we introduced the concept of a CppSignature, which bundled
up all of the information necessary to declare a C++ signature for
the cpp API.  This PR introduces analogous concepts for dispatcher
and native: DispatcherSignature and NativeSignature.

The three interfaces are not particularly well coupled right now,
but they do have some duck typing coincidences:

- defn() which renders the C++ definition "bool f(int x)"
- decl() which renders the C++ declaration "bool f(int x = 2)"
- type() which renders the C++ function type "bool(int)"

Maybe at some point we'll introduce a Protocol, or a supertype.
Many other methods (like arguments()) have varying types.  These
signatures also have some helper methods that forward back to real
implementations in the api modules.  Something to think about is
whether or not we should attempt to reduce boilerplate here or
not; I'm not too sure about it yet.

The net effect is we get to reduce the number of variables we
have to explicitly write out in the codegen, since now these are all
bundled together into a signature.  Something extra special happens
in BackendSelect, where we now dynamically select between dispatcher_sig
and native_sig as "how" the backend select is implemented.

A little bit of extra cleanup:
- Some places where we previously advertised Sequence, we now advertise
  a more informative Tuple.
- defn() may take an optional positional parameter overriding the entire
  name, or a kwarg-only prefix parameter to just add a prefix to the
  name.

Signed-off-by: Edward Z. Yang <[email protected]>

Test Plan: Imported from OSS

Reviewed By: smessmer

Differential Revision: D24223100

Pulled By: ezyang

fbshipit-source-id: f985eced08af4a60ba9641d125d0f260f8cda9eb
  • Loading branch information
ezyang authored and facebook-github-bot committed Oct 13, 2020
1 parent f086032 commit d705083
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 39 deletions.
10 changes: 5 additions & 5 deletions tools/codegen/api/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tools.codegen.local as local

import itertools
from typing import Sequence, Optional
from typing import Sequence, Optional, Tuple

# This file describes the translation of JIT schema to the dispatcher
# API, the *unboxed* calling convention by which invocations through
Expand Down Expand Up @@ -65,14 +65,14 @@ def argument(a: Argument) -> DispatcherArgument:
def name(func: FunctionSchema) -> str:
return cpp.name(func)

def arguments(func: FunctionSchema) -> Sequence[DispatcherArgument]:
def arguments(func: FunctionSchema) -> Tuple[DispatcherArgument, ...]:
if local.use_c10_dispatcher().dispatcher_uses_new_style():
return list(map(argument, itertools.chain(func.out_arguments, func.arguments, func.kwarg_only_arguments)))
return tuple(map(argument, itertools.chain(func.out_arguments, func.arguments, func.kwarg_only_arguments)))
else:
return [
return tuple(
DispatcherArgument(type=la.type, name=la.name, argument=la.argument)
for la in native.arguments(func)
]
)

# Given a set of CppArguments in scope, return a sequence of dispatcher
# expressions that translate the cpp API into dispatcher API
Expand Down
6 changes: 3 additions & 3 deletions tools/codegen/api/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tools.codegen.api.types import TensorOptionsArguments, NativeArgument, ThisArgument
import tools.codegen.api.cpp as cpp

from typing import Union, Sequence
from typing import Union, Sequence, Tuple

# This file describes the translation of JIT schema to the native functions API.
# This looks a lot like the C++ API (which makes historical sense, because the
Expand Down Expand Up @@ -74,5 +74,5 @@ def argument(a: Union[Argument, ThisArgument, TensorOptionsArguments]) -> Native
else:
assert_never(a)

def arguments(func: FunctionSchema) -> Sequence[NativeArgument]:
return list(map(argument, cpp.group_arguments(func, method=False)))
def arguments(func: FunctionSchema) -> Tuple[NativeArgument, ...]:
return tuple(map(argument, cpp.group_arguments(func, method=False)))
80 changes: 77 additions & 3 deletions tools/codegen/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,11 @@ def decl(self) -> str:

# Render the C++ definition for this signature, not including
# the body (with curly braces)
def defn(self, prefix: str = "") -> str:
def defn(self, name: Optional[str] = None, *, prefix: str = "") -> str:
cpp_args_str = ', '.join(a.str_no_default() for a in self.arguments())
return f"{self._returns_type} {prefix}{cpp.name(self.func)}({cpp_args_str})"
if name is None:
name = prefix + cpp.name(self.func)
return f"{self._returns_type} {name}({cpp_args_str})"

# NB: This constructor knows how to disambiguate defaults when
# faithful is True. Ideally this would live as an external process
Expand Down Expand Up @@ -280,6 +282,47 @@ class DispatcherArgument:
def __str__(self) -> str:
return f"{self.type} {self.name}"

@dataclass(frozen=True)
class DispatcherSignature:
# The schema this signature is derived from
func: FunctionSchema

# Note to self: if we ever need to reassemble tensor options, we may need to
# also preserve grouping with DispatcherTensorOptionsArguments. This should
# be an unlikely situation, however, since the general direction we are
# headed is to make native:: take everything in expanded form, so you
# shouldn't need to reassemble
_arguments: Tuple[DispatcherArgument, ...]
_returns_type: str

def arguments(self) -> Tuple[DispatcherArgument, ...]:
return self._arguments

def defn(self, name: Optional[str] = None) -> str:
args_str = ', '.join(map(str, self.arguments()))
if name is None:
name = native.name(self.func)
return f"{self._returns_type} {name}({args_str})"

def exprs(self) -> Sequence[DispatcherExpr]:
return dispatcher.exprs(self.arguments())

# Return the C++ function type, e.g., something like int(bool)
def type(self) -> str:
dispatcher_args_types_str = ', '.join(map(lambda a: a.type, self._arguments))
return f'{self._returns_type} ({dispatcher_args_types_str})'

@staticmethod
def from_schema(func: FunctionSchema) -> 'DispatcherSignature':
arguments = dispatcher.arguments(func)
returns_type = dispatcher.returns_type(func.returns)

return DispatcherSignature(
func=func,
_arguments=arguments,
_returns_type=returns_type,
)

# ------------------------------------------------------------------- #

# native types (NativeFunctions.h)
Expand Down Expand Up @@ -320,5 +363,36 @@ def str_with_default(self) -> str:
mb_default = f"={self.default}"
return f"{self.type} {self.name}{mb_default}"

@dataclass(frozen=True)
class NativeSignature:
# The schema this signature is derived from
func: FunctionSchema

_arguments: Tuple[NativeArgument, ...]
_returns_type: str

def defn(self, name: Optional[str] = None) -> str:
args_str = ', '.join(map(str, self.arguments()))
if name is None:
name = dispatcher.name(self.func)
return f"{self._returns_type} {name}({args_str})"

def arguments(self) -> Tuple[NativeArgument, ...]:
return self._arguments

def dispatcher_exprs(self) -> Sequence['DispatcherExpr']:
return dispatcher.nativearguments_exprs(self.arguments())

@staticmethod
def from_schema(func: FunctionSchema) -> 'NativeSignature':
arguments = native.arguments(func)
returns_type = native.returns_type(func.returns)

return NativeSignature(
func=func,
_arguments=arguments,
_returns_type=returns_type,
)

# Functions only, no types
import tools.codegen.api.cpp as cpp
from tools.codegen.api import cpp, dispatcher, native
52 changes: 24 additions & 28 deletions tools/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,8 @@ def func(f: NativeFunction) -> Optional[str]:
"""

elif target is Target.REGISTRATION:
assert returns_type == dispatcher.returns_type(f.func.returns)
dispatcher_args = dispatcher.arguments(f.func)
dispatcher_args_types_str = ', '.join(map(lambda a: a.type, dispatcher_args))
dispatcher_sig = DispatcherSignature.from_schema(f.func)

if dispatch is None or dispatch == 'Math' or dispatch == 'DefaultBackend':
type_name = f'TypeDefault::{name}'
else:
Expand All @@ -289,7 +288,8 @@ def func(f: NativeFunction) -> Optional[str]:
payload = f"TORCH_FN({type_name})"
elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
payload = "c10::impl::hacky_wrapper_for_legacy_signatures<" \
f"{returns_type} ({dispatcher_args_types_str})>(TORCH_FN({type_name}))"
f"{dispatcher_sig.type()}>(TORCH_FN({type_name}))"

else:
assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper
payload = f"torch::CppFunction::makeUnboxedOnly(&{type_name})"
Expand Down Expand Up @@ -338,17 +338,17 @@ def go(f: NativeFunction) -> Optional[str]:
assert target is Target.DEFINITION

def generate_defn(sig: CppSignature) -> str:
dispatcher_sig = DispatcherSignature.from_schema(f.func)

dispatcher_exprs = dispatcher.cpparguments_exprs(sig.argument_packs())
dispatcher_returns_type = dispatcher.returns_type(f.func.returns)
dispatcher_types_str = ', '.join(map(lambda a: a.type, dispatcher_exprs))
dispatcher_exprs_str = ', '.join(map(lambda a: a.expr, dispatcher_exprs))

return f"""
// aten::{f.func}
{sig.defn()} {{
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
.typed<{dispatcher_returns_type} ({dispatcher_types_str})>();
.typed<{dispatcher_sig.type()}>();
return op.call({dispatcher_exprs_str});
}}
"""
Expand Down Expand Up @@ -388,17 +388,17 @@ def go(f: NativeFunction) -> Optional[str]:
assert target is Target.DEFINITION

def generate_defn(sig: CppSignature) -> str:
dispatcher_sig = DispatcherSignature.from_schema(f.func)

dispatcher_exprs = dispatcher.cpparguments_exprs(sig.argument_packs())
dispatcher_returns_type = dispatcher.returns_type(f.func.returns)
dispatcher_types_str = ', '.join(map(lambda a: a.type, dispatcher_exprs))
dispatcher_exprs_str = ', '.join(map(lambda a: a.expr, dispatcher_exprs))

return f"""
// aten::{f.func}
{sig.defn("Tensor::")} const {{
{sig.defn(prefix="Tensor::")} const {{
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
.typed<{dispatcher_returns_type} ({dispatcher_types_str})>();
.typed<{dispatcher_sig.type()}>();
return op.call({dispatcher_exprs_str});
}}
"""
Expand Down Expand Up @@ -455,30 +455,26 @@ def go(f: NativeFunction) -> Optional[str]:
return None

name = native.name(f.func)
native_returns_type = native.returns_type(f.func.returns)
native_args = native.arguments(f.func)
native_sig = NativeSignature.from_schema(f.func)

if not any(isinstance(a.argument, TensorOptionsArguments) for a in native_args):
if not any(isinstance(a.argument, TensorOptionsArguments) for a in native_sig.arguments()):
return None

native_tensor_args = [
a for a in native_args
a for a in native_sig.arguments()
if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
]

dispatcher_returns_type = dispatcher.returns_type(f.func.returns)
dispatcher_args = dispatcher.arguments(f.func)
dispatcher_sig = DispatcherSignature.from_schema(f.func)

args: Union[Sequence[DispatcherArgument], Sequence[NativeArgument]]
sig: Union[NativeSignature, DispatcherSignature]
if local.use_c10_dispatcher().dispatcher_uses_new_style():
returns_type = dispatcher_returns_type
args = dispatcher_args
exprs = dispatcher.exprs(dispatcher_args)
sig = dispatcher_sig
dispatcher_exprs = dispatcher_sig.exprs()
dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
else:
returns_type = native_returns_type
args = native_args
exprs = dispatcher.nativearguments_exprs(native_args)
sig = native_sig
dispatcher_exprs = native_sig.dispatcher_exprs()
dispatch_key = "options.computeDispatchKey()"

if target is Target.DEFINITION:
Expand All @@ -496,24 +492,24 @@ def go(f: NativeFunction) -> Optional[str]:
compute_dk = f"DispatchKey _dk = {dispatch_key};"
return f"""\
// aten::{f.func}
{returns_type} {name}({', '.join(str(a) for a in args)}) {{
{sig.defn(name)} {{
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
.typed<{dispatcher_returns_type} ({', '.join(a.type for a in dispatcher_args)})>();
.typed<{dispatcher_sig.type()}>();
{compute_dk}
DispatchKey _autograd_dk = c10::getAutogradKeyFromBackend(_dk);
// This trick allows calling Autograd backend kernel first and then backend kernel,
// without adding another AutogradBackendSelect dispatch key.
DispatchKey _current_dk = at::impl::variable_excluded_from_dispatch() ? _dk : _autograd_dk;
return op.callWithDispatchKey(_current_dk, {', '.join(a.expr for a in exprs)});
return op.callWithDispatchKey(_current_dk, {', '.join(a.expr for a in dispatcher_exprs)});
}}
"""
elif target is Target.REGISTRATION:
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
return f"""m.impl("aten::{f.func.name}",
c10::impl::hacky_wrapper_for_legacy_signatures<{dispatcher_returns_type} ({', '.join(a.type for a in dispatcher_args)})>(
c10::impl::hacky_wrapper_for_legacy_signatures<{dispatcher_sig.type()}>(
TORCH_FN({name})));"""
else:
assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper
Expand Down

0 comments on commit d705083

Please sign in to comment.