Skip to content

Conversation

willGraham01
Copy link
Collaborator

@willGraham01 willGraham01 commented Mar 21, 2025

Creates the Translator class, which expands on the BackendAgnostic ABC. A Translator is aware of the backend_obj (backend object), and also aware of the signature of the required _frontend_methods.

On instantiation, one can pass a sequence of Translation objects - essentially dataclass wrappers for the arguments to convert_signature - to a Translator. For each _frontend_method, the Translator stores the map between the arguments the frontend method takes, and the arguments the corresponding backend method takes. This means that the Translator can be used by the frontend, and will have the expected syntax, but also takes care of the necessary mapping of the arguments provided by the frontend to those that the backend takes. There is a simple example of this in action in the tests for the Translator class, however the tests for different backends demonstrate how this functionality is envisioned to be used.

The Distribution class now inherits from Translator. The NativeDistribution class has also been introduced, for us to use as the "base" for any standard distributions that we want to provide via our jax-default backend. This class simply defaults all the "translations" to the identity map (IE, the arguments and method names of the backend are already what we expect of the frontend).

Users should use the Distribution class when they want to use a backend that is different to our "standard" backend (jax). If they envision using multiple distributions from the same backend, they can create a derived class to reduce the number of times they need to provide the mapping information, for example

from numpyro.distributions.continuous import MultivariateNormal

from causalprog.distributions.base import Distribution
from causalprog.backend.translation import Translation

NUMPYRO_TRANSLATIONS = (
    Translation(backend_method="sample", frontend_method="sample", param_map = {"seed": "rng_key"},
    ...
)

class NumpyroDistribution(Distribution):

    def __init__(self, *, backend, label: str) -> None:
        super().__init__(*NUMPYRO_TRANSLATIONS, backend=backend, label=label)      

# Specialized classes are further possible,
# if the user so desires.

class NumpyroNormal(NumpyroDistribution):

    def __init__(self, mean, cov, *, label: str) -> None:
        super().__init__(backend=MultivariateNormal(mean, cov), label=label)

import jax.numpy as jnp

mean, cov = jnp.array(...), jnp.array(...)
normal = NumpyroNormal(mean, cov, label="Numpyro Normal")
normal.sample(...) # Works with frontend syntax, but calls `numpyro` functionality.

Other Changes

  • convert_signature now returns the function that maps the frontend arguments to their backend counterparts, rather than the function that does this and then evaluates the backend function. This is so that we can recycle the static identity map method of Translator.

@willGraham01 willGraham01 force-pushed the wgraham/signature-converting branch from 5b28001 to 775489f Compare March 21, 2025 14:49
Copy link
Collaborator Author

@willGraham01 willGraham01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self-review with some typos

@willGraham01 willGraham01 marked this pull request as ready for review March 24, 2025 10:05
@willGraham01 willGraham01 requested review from mscroggs and removed request for mscroggs March 24, 2025 10:05
@willGraham01 willGraham01 marked this pull request as draft March 24, 2025 10:07
@willGraham01 willGraham01 marked this pull request as ready for review March 24, 2025 10:20
Base automatically changed from wgraham/signature-converting to main March 24, 2025 14:50
@willGraham01
Copy link
Collaborator Author

Closing this as it's outdated, especially in light of our decision - for now - to commit to numpyro for our distribution needs.

@willGraham01 willGraham01 deleted the wgraham/distributions-are-backend-agnostic branch September 1, 2025 13:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant