-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Start of a prototype for string-based, very strict, type dispatching #2
Draft
seberg
wants to merge
6
commits into
scientific-python:main
Choose a base branch
from
seberg:prototype
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
b8ff81c
Start of a prototype for string-based, very strict, type dispatching
seberg f1a8580
should use get_identifier function if I have it
seberg b4dd3d9
Use entr-points (that was simple! :))
seberg 7ae0355
Allow overriding module and some smaller things
seberg 0d8a973
oops, forgot return
seberg 6cb54d8
Have to overwrite original funcs module (or reorganize a bit)
seberg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import textwrap | ||
|
||
# This should be fetched via an entry-point, but lets do it here for now! | ||
|
||
# This type of dictionary could be all there is to a backend for now. | ||
# Of course... the backend will need machinery to conveniently build it. | ||
# The backend-info dict, after build needs to live in a minimal/cheap to | ||
# import module. | ||
backend_info = { | ||
"name": "my_backend", | ||
"types": ["numpy:matrix"], | ||
"symbol_mapping": {}, | ||
} | ||
|
||
|
||
from spatch import WillNotHandle | ||
|
||
|
||
def implements(func): | ||
"""Helper decorator. Since/if we name our modules identically to the | ||
main module, we can just do a simple replace the module and be done. | ||
""" | ||
mod = func.__module__ | ||
# TODO: May want to make sure to replace the start exactly, but OK... | ||
orig_mod = mod.replace("prototype_backend", "prototype_module") | ||
name = func.__qualname__ | ||
|
||
backend_info["symbol_mapping"][f"{orig_mod}:{name}"] = { | ||
"impl_symbol": f"{mod}:{name}", | ||
"doc_blurp": textwrap.dedent(func.__doc__).strip("\n"), | ||
} | ||
|
||
# We don't actually change the function, just keep track of it. | ||
return func | ||
|
||
|
||
# TODO/NOTE: This function would of course be in a different module! | ||
@implements | ||
def func1(arg, optional=None, parameter="param"): | ||
""" | ||
This text is added by `my_backend`! | ||
""" | ||
if parameter != "param": | ||
return WillNotHandle("Don't know how to do param.") | ||
return "my_backend called" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import numpy as np | ||
|
||
import prototype_module | ||
|
||
|
||
print(prototype_module.func1.__doc__) | ||
|
||
print("\n\nTrivial call examples:\n") | ||
|
||
print(prototype_module.func1([1, 2, 3])) | ||
print(prototype_module.func1(np.matrix([1, 2, 3]))) | ||
|
||
print("\nNobody can handle this example:") | ||
prototype_module.func1(np.matrix([1, 2, 3]), parameter="uhoh!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# This is a silly module that might be used by a user. | ||
from spatch import BackendSystem, WillNotHandle | ||
|
||
_backend_sys = BackendSystem("prototype_modules_backends") | ||
|
||
|
||
@_backend_sys.dispatchable("arg", "optional") | ||
def func1(arg, /, optional=None, parameter="param"): | ||
""" | ||
This is my function | ||
|
||
Parameters | ||
---------- | ||
... | ||
|
||
""" | ||
if parameter != "param": | ||
# I suppose even the fallback/default can refuse to handle it? | ||
return WillNotHandle("Don't know how to do param.") | ||
|
||
return "default implementation called!" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
import inspect | ||
import functools | ||
import importlib | ||
import importlib_metadata | ||
import textwrap | ||
import warnings | ||
|
||
|
||
def get_identifier(obj): | ||
"""Helper to get any objects identifier. Is there an exiting short-hand? | ||
""" | ||
return f"{obj.__module__}:{obj.__qualname__}" | ||
|
||
|
||
def from_identifier(ident): | ||
module, qualname = ident.split(":") | ||
obj = importlib.import_module(module) | ||
for name in qualname.split("."): | ||
obj = getattr(obj, name) | ||
|
||
return obj | ||
|
||
|
||
class WillNotHandle: | ||
"""Class to return when an implementation does not want to handle | ||
args/kwargs. | ||
""" | ||
def __init__(self, info="<unknown reason>"): | ||
self.info = info | ||
|
||
|
||
class Backend: | ||
@classmethod | ||
def from_info_dict(cls, info): | ||
return cls.from_mapping_and_types(info["name"], info["types"], info["symbol_mapping"]) | ||
|
||
@classmethod | ||
def from_mapping_and_types(cls, name, types, symbol_mapping): | ||
""" | ||
Create a new backend. | ||
""" | ||
self = cls() | ||
self.name = name | ||
self.type_names = types | ||
self.symbol_mapping = symbol_mapping | ||
return self | ||
|
||
def match_types(self, types): | ||
"""See if this backend matches the types, we do this by name. | ||
|
||
Of course, we could use more complicated ways in the future. | ||
E.g. one thing is that we can have to kind of types: | ||
1. Types that must match (at least once). | ||
2. Types that are understood (we do not break for them). | ||
|
||
Returns | ||
------- | ||
matches : boolean | ||
Whether or not the types matches. | ||
unknown_types : sequence of types | ||
A sequence of types the backend did not match/know. | ||
This may be a way to e.g. deal with scalars, that we assume | ||
all backends can convert, but creating an extensive list may | ||
not be desireable? | ||
""" | ||
matches = False | ||
unknown_types = [] | ||
for t in types: | ||
ident = get_identifier(t) | ||
|
||
if ident in self.type_names: | ||
matches = True | ||
unknown_types.append(t) | ||
|
||
return matches, unknown_types | ||
|
||
|
||
class BackendSystem: | ||
def __init__(self, group): | ||
# TODO: Should we use group and name, or is group enough? | ||
# TODO: We could define types of the fallback here, or known "scalar" | ||
# (i.e. unimportant types). | ||
# In a sense, the fallback should maybe itself just be a normal | ||
# backend, except we always try it if all else fails... | ||
self.backends = {} | ||
|
||
eps = importlib_metadata.entry_points(group=group) | ||
for ep in eps: | ||
self.backend_from_dict(ep.load()) | ||
|
||
print(self.backends) | ||
|
||
def backend_from_dict(self, info_dict): | ||
new_backend = Backend.from_info_dict(info_dict) | ||
if new_backend.name in self.backends: | ||
warnings.warn( | ||
UserWarning, | ||
f"Backend of name '{new_backend.name}' already exists. Ignoring second!") | ||
return | ||
self.backends[new_backend.name] = new_backend | ||
|
||
def dispatchable(self, *relevant_args, module=None): | ||
""" | ||
Decorate a Python function with information on how to extract | ||
the "relevant" arguments, i.e. arguments we wish to dispatch for. | ||
|
||
Parameters | ||
---------- | ||
*relevant_args : The names of parameters to extract (we use inspect to | ||
map these correctly). | ||
""" | ||
def wrap_callable(func): | ||
# Overwrite original module (we use it later, could also pass it) | ||
if module is not None: | ||
func.__module__ = module | ||
|
||
disp = Dispatchable(self, func, relevant_args) | ||
|
||
return disp | ||
|
||
return wrap_callable | ||
|
||
class Dispatchable: | ||
"""Dispatchable function object | ||
|
||
""" | ||
def __init__(self, backend_system, func, relevant_args, ident=None): | ||
self._backend_system = backend_system | ||
self._sig = inspect.signature(func) | ||
self._relevant_args = relevant_args | ||
self._default_impl = func | ||
# Keep a list of implementations for this backend | ||
self._implementations = [] | ||
|
||
ident = get_identifier(func) | ||
|
||
functools.update_wrapper(self, func) | ||
|
||
new_doc = [] | ||
for backend in backend_system.backends.values(): | ||
info = backend.symbol_mapping.get(ident, None) | ||
print(backend.symbol_mapping, ident) | ||
if info is None: | ||
continue # not implemented by backend (apparently) | ||
|
||
self._implementations.append((backend, info["impl_symbol"])) | ||
|
||
impl_symbol = info["impl_symbol"] | ||
doc_blurp = info.get("doc_blurp", "No backend documentation available.") | ||
new_doc.append(f"backend.name :\n" + textwrap.indent(doc_blurp, " ")) | ||
|
||
if not new_doc: | ||
new_doc = ["No backends found for this function."] | ||
|
||
new_doc = "\n\n".join(new_doc) | ||
new_doc = "\n\nBackends\n--------\n" + new_doc | ||
|
||
# Just dedent, so it makes sense to append (should be fine): | ||
self.__doc__ = textwrap.dedent(self.__doc__) + new_doc | ||
|
||
def __get__(self, ): | ||
raise NotImplementedError( | ||
"Need to implement this eventually to act like functions.") | ||
|
||
@property | ||
def _backends(self): | ||
# Extract the backends: | ||
return [impl[0] for impl in self._implementations] | ||
|
||
def _find_matching_backends(self, relevant_types): | ||
"""Find all matching backends. | ||
""" | ||
matching = [] | ||
unknown_types = relevant_types | ||
for backend, impl in self._implementations: | ||
matches, unknown_types_backend = backend.match_types(relevant_types) | ||
unknown_types = unknown_types.union(unknown_types_backend) | ||
|
||
if matches: | ||
matching.append((backend, impl, unknown_types)) | ||
|
||
match_with_unknown = [] | ||
for backend, impl, unknown_types_backend in matching: | ||
# If the types the backend didn't know are also not known by | ||
# any other backend, we just ignore them | ||
if unknown_types_backend.issubset(unknown_types): | ||
match_with_unknown.append((backend, impl)) | ||
|
||
return match_with_unknown | ||
|
||
def __call__(self, *args, **kwargs): | ||
partial = self._sig.bind_partial(*args, **kwargs) | ||
|
||
relevant_types = set() | ||
for arg in self._relevant_args: | ||
val = partial.arguments.get(arg, None) | ||
if val is not None: | ||
relevant_types.add(type(val)) | ||
|
||
matching_impls = self._find_matching_backends(relevant_types) | ||
|
||
# TODO: When more than one backend matches, we could: | ||
# 1. Ensure e.g. an alphabetic order early on during registration. | ||
# 2. Inspect types, to see if one backend is clearly more specific | ||
# than another one. | ||
reasons = [] | ||
for backend, impl in matching_impls + [(None, self._default_impl)]: | ||
# Call the backend: | ||
if isinstance(impl, str): | ||
# TODO: Should update the impl we store, to do this only once! | ||
impl = from_identifier(impl) | ||
|
||
result = impl(*args, **kwargs) | ||
if isinstance(result, WillNotHandle): | ||
# The backend indicated it cannot/does not want to handle | ||
# this. | ||
reasons.append((backend, result)) | ||
else: | ||
return result | ||
|
||
if len(reasons) == 1: | ||
backends = self._backends | ||
msg = (f"No backend matched out of {backends} and the default " | ||
f"did not work because of: {reasons[0][1].info}") | ||
else: | ||
msg = f"Of the available backends, the following were tried but failed:" | ||
for backend, reason in reasons: | ||
name = "default" if backend is None else backend.name | ||
msg += f"\n - {name}: {reason}" | ||
|
||
raise TypeError(msg) | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this uses the string types, and @eriknw correctly pointed out (independent of this PR), that projects may sometimes be sloppy about which
__module__
they define types.I am not sure this is a problem, but we could work around it, with something like:
I.e. I think rather than matching only by strings, we could check whether the type may be defined, by testing if the module exists.