Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

at + regex based matching for sharding #100

Open
ASEM000 opened this issue Jan 30, 2024 · 0 comments
Open

at + regex based matching for sharding #100

ASEM000 opened this issue Jan 30, 2024 · 0 comments
Assignees
Labels
documentation Improvements or additions to documentation

Comments

@ASEM000
Copy link
Owner

ASEM000 commented Jan 30, 2024

at can match paths using regex pattern, use this feature and write a sample example on how set sharding in a similar sense how the new keras API implements it 1.

Something along the following lines,

Note that this should work with arbitrary pytrees (e.g. flax params dict)

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import jax
from jax.sharding import Mesh, NamedSharding as N, PartitionSpec as P
import numpy as np
import serket as sk
import re

class FeedForward(sk.TreeClass):
    def __init__(self, d_model: int, *, key: jax.Array):
        k1, k2 = jax.random.split(key)
        self.linear1 = sk.nn.Linear(d_model, d_model * 4, key=k1)
        self.linear2 = sk.nn.Linear(4 * d_model, d_model, key=k2)

    def __call__(self, input: jax.Array) -> jax.Array:
        return self.linear2(jax.nn.relu(self.linear1(input)))


ff = FeedForward(d_model=128, key=jax.random.PRNGKey(0))
ff = sk.tree_mask(ff)  # hide non in-exact types
mesh = Mesh(np.array(jax.devices()).reshape(2, 4), axis_names=["data", "model"])
# select layers start with `linear`
sharding = sk.at(ff)[re.compile("linear.*")]["weight"].set(N(mesh, P("model", None)))
sharding = sk.at(sharding)[re.compile("linear.*")]["bias"].set(N(mesh, P("model")))
ff = jax.device_put(ff, sharding)


def vis_sharding(path, leaf):
    print(jax.tree_util.keystr(path))
    jax.debug.visualize_array_sharding(leaf)

jax.tree_util.tree_map_with_path(vis_sharding, ff)

image

@ASEM000 ASEM000 added the documentation Improvements or additions to documentation label Jan 30, 2024
@ASEM000 ASEM000 self-assigned this Jan 30, 2024
@ASEM000 ASEM000 changed the title at + regex based filtering for sharding at + regex based matching for sharding Jan 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

1 participant