Skip to content

Commit

Permalink
Require Python 3.10+ for Haiku (same as JAX) and modernize code to ma…
Browse files Browse the repository at this point in the history
…tch.

    $ pyupgrade --py310-plus **/*.py
    $ pyfactor fix_imports --remove_unused
    $ # some manual edits

PiperOrigin-RevId: 686854449
  • Loading branch information
tomhennigan authored and copybara-github committed Oct 17, 2024
1 parent d915ddf commit c252c9c
Show file tree
Hide file tree
Showing 58 changed files with 446 additions and 457 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: "${{ matrix.os }}"
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
python-version: ['3.10', '3.11']
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: "${{ matrix.os }}"
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
python-version: ['3.10', '3.11']
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v2
Expand Down
3 changes: 1 addition & 2 deletions examples/imagenet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import enum
import itertools as it
import types
from typing import Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -294,7 +293,7 @@ def _decode_and_random_crop(image_bytes: tf.Tensor) -> tf.Tensor:

def _decode_and_center_crop(
image_bytes: tf.Tensor,
jpeg_shape: Optional[tf.Tensor] = None,
jpeg_shape: tf.Tensor | None = None,
) -> tf.Tensor:
"""Crops to center of image with padding then scales."""
if jpeg_shape is None:
Expand Down
5 changes: 3 additions & 2 deletions examples/impala/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# ==============================================================================
"""A stateless agent interface."""
import collections
from collections.abc import Callable
import functools
from typing import Any, Callable, Optional
from typing import Any

import dm_env
import haiku as hk
Expand Down Expand Up @@ -69,7 +70,7 @@ def initial_params(self, rng_key):
return self._init_fn(rng_key, dummy_inputs, self.initial_state(1))

@functools.partial(jax.jit, static_argnums=(0, 1))
def initial_state(self, batch_size: Optional[int]):
def initial_state(self, batch_size: int | None):
"""Returns agent initial state."""
# We expect that generating the initial_state does not require parameters.
return self._initial_state_apply_fn(None, batch_size)
Expand Down
3 changes: 2 additions & 1 deletion examples/impala_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
See: https://arxiv.org/abs/1802.01561
"""

from collections.abc import Callable
import functools
import queue
import threading
from typing import Any, Callable, NamedTuple
from typing import Any, NamedTuple

from absl import app
from absl import logging
Expand Down
3 changes: 1 addition & 2 deletions examples/mnist_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
# ==============================================================================
"""MNIST classifier with pruning as in https://arxiv.org/abs/1710.01878 ."""

from collections.abc import Iterator, Mapping, Sequence
from collections.abc import Callable, Iterator, Mapping, Sequence
import functools
from typing import Callable

from absl import app
import haiku as hk
Expand Down
5 changes: 2 additions & 3 deletions examples/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"""

import dataclasses
from typing import Optional

import haiku as hk
import jax
Expand All @@ -45,7 +44,7 @@ class Transformer(hk.Module):
attn_size: int # Size of the attention (key, query, value) vectors.
dropout_rate: float # Probability with which to apply dropout.
widening_factor: int = 4 # Factor by which the MLP hidden layer widens.
name: Optional[str] = None # Optional identifier for the module.
name: str | None = None # Optional identifier for the module.

def __call__(
self,
Expand Down Expand Up @@ -98,7 +97,7 @@ class LanguageModel(hk.Module):
model_size: int # Embedding size.
vocab_size: int # Size of the vocabulary.
pad_token: int # Identity of the padding token (used for masking inputs).
name: Optional[str] = None # Optional identifier for the module.
name: str | None = None # Optional identifier for the module.

def __call__(
self,
Expand Down
4 changes: 2 additions & 2 deletions examples/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from collections.abc import MutableMapping
import time
from typing import Any, NamedTuple, Union
from typing import Any, NamedTuple

from absl import app
from absl import flags
Expand Down Expand Up @@ -75,7 +75,7 @@ class TrainingState(NamedTuple):
step: jax.Array # Tracks the number of training steps.


def forward_pass(tokens: Union[np.ndarray, jax.Array]) -> jax.Array:
def forward_pass(tokens: np.ndarray | jax.Array) -> jax.Array:
"""Defines the forward pass of the language model."""
lm = model.LanguageModel(
model_size=MODEL_SIZE,
Expand Down
18 changes: 8 additions & 10 deletions haiku/_src/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# limitations under the License.
# ==============================================================================
"""(Multi-Head) Attention module for use in Transformer architectures."""

from typing import Optional
import warnings

from haiku._src import basic
Expand Down Expand Up @@ -60,14 +58,14 @@ def __init__(
num_heads: int,
key_size: int,
# TODO(b/240019186): Remove `w_init_scale`.
w_init_scale: Optional[float] = None,
w_init_scale: float | None = None,
*,
w_init: Optional[hk.initializers.Initializer] = None,
w_init: hk.initializers.Initializer | None = None,
with_bias: bool = True,
b_init: Optional[hk.initializers.Initializer] = None,
value_size: Optional[int] = None,
model_size: Optional[int] = None,
name: Optional[str] = None,
b_init: hk.initializers.Initializer | None = None,
value_size: int | None = None,
model_size: int | None = None,
name: str | None = None,
):
"""Initialises the module.
Expand Down Expand Up @@ -115,7 +113,7 @@ def __call__(
query: jax.Array,
key: jax.Array,
value: jax.Array,
mask: Optional[jax.Array] = None,
mask: jax.Array | None = None,
) -> jax.Array:
"""Computes (optionally masked) MHA with queries, keys & values.
Expand Down Expand Up @@ -168,7 +166,7 @@ def _linear_projection(
self,
x: jax.Array,
head_size: int,
name: Optional[str] = None,
name: str | None = None,
) -> jax.Array:
y = hk.Linear(self.num_heads * head_size, w_init=self.w_init,
with_bias=self.with_bias, b_init=self.b_init, name=name)(x)
Expand Down
46 changes: 23 additions & 23 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
"""Base Haiku module."""

import collections
from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
import contextlib
import functools
import itertools as it
from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union
from typing import Any, NamedTuple, Optional, TypeVar
import warnings

from haiku._src import config
Expand Down Expand Up @@ -88,7 +88,7 @@ class Frame(NamedTuple):

# JAX values.
params: MutableParams
state: Optional[MutableState]
state: MutableState | None
rng_stack: Stack[Optional["PRNGSequence"]]

# Pure python values.
Expand Down Expand Up @@ -217,9 +217,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def new_context(
*,
params: Optional[Params] = None,
state: Optional[State] = None,
rng: Optional[Union[PRNGKey, int]] = None,
params: Params | None = None,
state: State | None = None,
rng: PRNGKey | int | None = None,
) -> HaikuContext:
"""Collects the results of hk.{get,set}_{parameter,state} calls.
Expand Down Expand Up @@ -286,7 +286,7 @@ def safe_get_module_name(module: Module) -> str:
return module.module_name


def current_module_state() -> Optional[ModuleState]:
def current_module_state() -> ModuleState | None:
frame = current_frame()
if frame.module_stack:
return frame.module_stack.peek()
Expand All @@ -295,7 +295,7 @@ def current_module_state() -> Optional[ModuleState]:


@contextlib.contextmanager
def maybe_push_module_state(module_state: Optional[ModuleState]):
def maybe_push_module_state(module_state: ModuleState | None):
if module_state is not None:
frame = current_frame()
with frame.module_stack(module_state):
Expand All @@ -304,7 +304,7 @@ def maybe_push_module_state(module_state: Optional[ModuleState]):
yield


def current_module() -> Optional[Module]:
def current_module() -> Module | None:
module_state = current_module_state()
return module_state.module if module_state is not None else None

Expand Down Expand Up @@ -346,7 +346,7 @@ def current_name() -> str:
return "~"


def get_lift_prefix() -> Optional[str]:
def get_lift_prefix() -> str | None:
"""Get full lifted prefix from frame_stack if in lifted init."""
if params_frozen():
return None
Expand Down Expand Up @@ -589,7 +589,7 @@ def dtype(self):
DO_NOT_STORE = DoNotStore()


def check_not_none(value: Optional[T], msg: str) -> T:
def check_not_none(value: T | None, msg: str) -> T:
if value is None:
raise ValueError(msg)
return value
Expand All @@ -604,7 +604,7 @@ def get_parameter(
name: str,
shape: Sequence[int],
dtype: Any = jnp.float32,
init: Optional[Initializer] = None,
init: Initializer | None = None,
) -> jax.Array:
"""Creates or reuses a parameter for the given transformed function.
Expand Down Expand Up @@ -721,11 +721,11 @@ class GetterContext(NamedTuple):
name: The name of this parameter.
"""
full_name: str
module: Optional[Module]
module: Module | None
original_dtype: Any
original_shape: Sequence[int]
original_init: Optional[Initializer]
lifted_prefix_name: Optional[str]
original_init: Initializer | None
lifted_prefix_name: str | None

@property
def module_name(self):
Expand All @@ -748,7 +748,7 @@ def run_creators(
context: GetterContext,
shape: Sequence[int],
dtype: Any = jnp.float32,
init: Optional[Initializer] = None,
init: Initializer | None = None,
) -> jax.Array:
"""See :func:`custom_creator` for usage."""
assert stack
Expand Down Expand Up @@ -926,10 +926,10 @@ class SetterContext(NamedTuple):
name: The name of this state.
"""
full_name: str
module: Optional[Module]
module: Module | None
original_dtype: Any
original_shape: Sequence[int]
lifted_prefix_name: Optional[str]
lifted_prefix_name: str | None

@property
def module_name(self):
Expand Down Expand Up @@ -1029,7 +1029,7 @@ class PRNGSequence(Iterator[PRNGKey]):
"""
__slots__ = ("_key", "_subkeys")

def __init__(self, key_or_seed: Union[PRNGKey, int, PRNGSequenceState]):
def __init__(self, key_or_seed: PRNGKey | int | PRNGSequenceState):
"""Creates a new :class:`PRNGSequence`."""
if isinstance(key_or_seed, tuple):
key, subkeys = key_or_seed
Expand Down Expand Up @@ -1210,7 +1210,7 @@ def next_rng_keys(num: int) -> jax.Array:
return jnp.stack(rng_seq.take(num))


def maybe_next_rng_key() -> Optional[PRNGKey]:
def maybe_next_rng_key() -> PRNGKey | None:
""":func:`next_rng_key` if random numbers are available, else ``None``."""
assert_context("maybe_next_rng_key")
rng_seq = current_frame().rng_stack.peek()
Expand All @@ -1221,7 +1221,7 @@ def maybe_next_rng_key() -> Optional[PRNGKey]:
return None


def maybe_get_rng_sequence_state() -> Optional[PRNGSequenceState]:
def maybe_get_rng_sequence_state() -> PRNGSequenceState | None:
"""Returns the internal state of the PRNG sequence.
Returns:
Expand Down Expand Up @@ -1265,9 +1265,9 @@ def extract_state(state: State, *, initial) -> MutableState:

def get_state(
name: str,
shape: Optional[Sequence[int]] = None,
shape: Sequence[int] | None = None,
dtype: Any = jnp.float32,
init: Optional[Initializer] = None,
init: Initializer | None = None,
) -> jax.Array:
"""Gets the current value for state with an optional initializer.
Expand Down
15 changes: 7 additions & 8 deletions haiku/_src/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# ==============================================================================
"""Basic Haiku modules and functions."""

from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
import functools
from typing import Any, Callable, Optional

from typing import Any
from haiku._src import base
from haiku._src import initializers
from haiku._src import module
Expand Down Expand Up @@ -113,7 +112,7 @@ class Sequential(hk.Module):
def __init__(
self,
layers: Iterable[Callable[..., Any]],
name: Optional[str] = None,
name: str | None = None,
):
super().__init__(name=name)
self.layers = tuple(layers)
Expand All @@ -136,9 +135,9 @@ def __init__(
self,
output_size: int,
with_bias: bool = True,
w_init: Optional[hk.initializers.Initializer] = None,
b_init: Optional[hk.initializers.Initializer] = None,
name: Optional[str] = None,
w_init: hk.initializers.Initializer | None = None,
b_init: hk.initializers.Initializer | None = None,
name: str | None = None,
):
"""Constructs the Linear module.
Expand All @@ -162,7 +161,7 @@ def __call__(
self,
inputs: jax.Array,
*,
precision: Optional[lax.Precision] = None,
precision: lax.Precision | None = None,
) -> jax.Array:
"""Computes a linear transform of the input."""
if not inputs.shape:
Expand Down
Loading

0 comments on commit c252c9c

Please sign in to comment.