Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start of a prototype for string-based, very strict, type dispatching #2

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions prototype_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import textwrap

# This should be fetched via an entry-point, but lets do it here for now!

# This type of dictionary could be all there is to a backend for now.
# Of course... the backend will need machinery to conveniently build it.
# The backend-info dict, after build needs to live in a minimal/cheap to
# import module.
backend_info = {
"name": "my_backend",
"types": ["numpy:matrix"],
"symbol_mapping": {},
}


from spatch import WillNotHandle


def implements(func):
"""Helper decorator. Since/if we name our modules identically to the
main module, we can just do a simple replace the module and be done.
"""
mod = func.__module__
# TODO: May want to make sure to replace the start exactly, but OK...
orig_mod = mod.replace("prototype_backend", "prototype_module")
name = func.__qualname__

backend_info["symbol_mapping"][f"{orig_mod}:{name}"] = {
"impl_symbol": f"{mod}:{name}",
"doc_blurp": textwrap.dedent(func.__doc__).strip("\n"),
}

# We don't actually change the function, just keep track of it.
return func


# TODO/NOTE: This function would of course be in a different module!
@implements
def func1(arg, optional=None, parameter="param"):
"""
This text is added by `my_backend`!
"""
if parameter != "param":
return WillNotHandle("Don't know how to do param.")
return "my_backend called"
14 changes: 14 additions & 0 deletions prototype_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np

import prototype_module


print(prototype_module.func1.__doc__)

print("\n\nTrivial call examples:\n")

print(prototype_module.func1([1, 2, 3]))
print(prototype_module.func1(np.matrix([1, 2, 3])))

print("\nNobody can handle this example:")
prototype_module.func1(np.matrix([1, 2, 3]), parameter="uhoh!")
22 changes: 22 additions & 0 deletions prototype_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# This is a silly module that might be used by a user.
from spatch import BackendSystem, WillNotHandle

_backend_sys = BackendSystem("prototype_modules_backends")


@_backend_sys.dispatchable("arg", "optional")
def func1(arg, /, optional=None, parameter="param"):
"""
This is my function

Parameters
----------
...

"""
if parameter != "param":
# I suppose even the fallback/default can refuse to handle it?
return WillNotHandle("Don't know how to do param.")

return "default implementation called!"

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ authors = [
]
version = "0.0.0"
description = "Coming soon"


[project.entry-points.prototype_modules_backends]
my_backend = 'prototype_backend:backend_info'
232 changes: 232 additions & 0 deletions spatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import inspect
import functools
import importlib
import importlib_metadata
import textwrap
import warnings


def get_identifier(obj):
"""Helper to get any objects identifier. Is there an exiting short-hand?
"""
return f"{obj.__module__}:{obj.__qualname__}"


def from_identifier(ident):
module, qualname = ident.split(":")
obj = importlib.import_module(module)
for name in qualname.split("."):
obj = getattr(obj, name)

return obj


class WillNotHandle:
"""Class to return when an implementation does not want to handle
args/kwargs.
"""
def __init__(self, info="<unknown reason>"):
self.info = info


class Backend:
@classmethod
def from_info_dict(cls, info):
return cls.from_mapping_and_types(info["name"], info["types"], info["symbol_mapping"])

@classmethod
def from_mapping_and_types(cls, name, types, symbol_mapping):
"""
Create a new backend.
"""
self = cls()
self.name = name
self.type_names = types
self.symbol_mapping = symbol_mapping
return self

def match_types(self, types):
"""See if this backend matches the types, we do this by name.

Of course, we could use more complicated ways in the future.
E.g. one thing is that we can have to kind of types:
1. Types that must match (at least once).
2. Types that are understood (we do not break for them).

Returns
-------
matches : boolean
Whether or not the types matches.
unknown_types : sequence of types
A sequence of types the backend did not match/know.
This may be a way to e.g. deal with scalars, that we assume
all backends can convert, but creating an extensive list may
not be desireable?
"""
matches = False
unknown_types = []
for t in types:
ident = get_identifier(t)

if ident in self.type_names:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this uses the string types, and @eriknw correctly pointed out (independent of this PR), that projects may sometimes be sloppy about which __module__ they define types.
I am not sure this is a problem, but we could work around it, with something like:

module, name = self.type_names[0]  # all of them of course
if module in sys.modules:
    t = getattr(module, name, None)  # ignore error?
    ...

I.e. I think rather than matching only by strings, we could check whether the type may be defined, by testing if the module exists.

matches = True
unknown_types.append(t)

return matches, unknown_types


class BackendSystem:
def __init__(self, group):
# TODO: Should we use group and name, or is group enough?
# TODO: We could define types of the fallback here, or known "scalar"
# (i.e. unimportant types).
# In a sense, the fallback should maybe itself just be a normal
# backend, except we always try it if all else fails...
self.backends = {}

eps = importlib_metadata.entry_points(group=group)
for ep in eps:
self.backend_from_dict(ep.load())

print(self.backends)

def backend_from_dict(self, info_dict):
new_backend = Backend.from_info_dict(info_dict)
if new_backend.name in self.backends:
warnings.warn(
UserWarning,
f"Backend of name '{new_backend.name}' already exists. Ignoring second!")
return
self.backends[new_backend.name] = new_backend

def dispatchable(self, *relevant_args, module=None):
"""
Decorate a Python function with information on how to extract
the "relevant" arguments, i.e. arguments we wish to dispatch for.

Parameters
----------
*relevant_args : The names of parameters to extract (we use inspect to
map these correctly).
"""
def wrap_callable(func):
# Overwrite original module (we use it later, could also pass it)
if module is not None:
func.__module__ = module

disp = Dispatchable(self, func, relevant_args)

return disp

return wrap_callable

class Dispatchable:
"""Dispatchable function object

"""
def __init__(self, backend_system, func, relevant_args, ident=None):
self._backend_system = backend_system
self._sig = inspect.signature(func)
self._relevant_args = relevant_args
self._default_impl = func
# Keep a list of implementations for this backend
self._implementations = []

ident = get_identifier(func)

functools.update_wrapper(self, func)

new_doc = []
for backend in backend_system.backends.values():
info = backend.symbol_mapping.get(ident, None)
print(backend.symbol_mapping, ident)
if info is None:
continue # not implemented by backend (apparently)

self._implementations.append((backend, info["impl_symbol"]))

impl_symbol = info["impl_symbol"]
doc_blurp = info.get("doc_blurp", "No backend documentation available.")
new_doc.append(f"backend.name :\n" + textwrap.indent(doc_blurp, " "))

if not new_doc:
new_doc = ["No backends found for this function."]

new_doc = "\n\n".join(new_doc)
new_doc = "\n\nBackends\n--------\n" + new_doc

# Just dedent, so it makes sense to append (should be fine):
self.__doc__ = textwrap.dedent(self.__doc__) + new_doc

def __get__(self, ):
raise NotImplementedError(
"Need to implement this eventually to act like functions.")

@property
def _backends(self):
# Extract the backends:
return [impl[0] for impl in self._implementations]

def _find_matching_backends(self, relevant_types):
"""Find all matching backends.
"""
matching = []
unknown_types = relevant_types
for backend, impl in self._implementations:
matches, unknown_types_backend = backend.match_types(relevant_types)
unknown_types = unknown_types.union(unknown_types_backend)

if matches:
matching.append((backend, impl, unknown_types))

match_with_unknown = []
for backend, impl, unknown_types_backend in matching:
# If the types the backend didn't know are also not known by
# any other backend, we just ignore them
if unknown_types_backend.issubset(unknown_types):
match_with_unknown.append((backend, impl))

return match_with_unknown

def __call__(self, *args, **kwargs):
partial = self._sig.bind_partial(*args, **kwargs)

relevant_types = set()
for arg in self._relevant_args:
val = partial.arguments.get(arg, None)
if val is not None:
relevant_types.add(type(val))

matching_impls = self._find_matching_backends(relevant_types)

# TODO: When more than one backend matches, we could:
# 1. Ensure e.g. an alphabetic order early on during registration.
# 2. Inspect types, to see if one backend is clearly more specific
# than another one.
reasons = []
for backend, impl in matching_impls + [(None, self._default_impl)]:
# Call the backend:
if isinstance(impl, str):
# TODO: Should update the impl we store, to do this only once!
impl = from_identifier(impl)

result = impl(*args, **kwargs)
if isinstance(result, WillNotHandle):
# The backend indicated it cannot/does not want to handle
# this.
reasons.append((backend, result))
else:
return result

if len(reasons) == 1:
backends = self._backends
msg = (f"No backend matched out of {backends} and the default "
f"did not work because of: {reasons[0][1].info}")
else:
msg = f"Of the available backends, the following were tried but failed:"
for backend, reason in reasons:
name = "default" if backend is None else backend.name
msg += f"\n - {name}: {reason}"

raise TypeError(msg)