diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a37bf0047..4273254b0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: runs-on: "${{ matrix.os }}" strategy: matrix: - python-version: ['3.9', '3.10', '3.11'] + python-version: ['3.10', '3.11'] os: [ubuntu-latest] steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 209ab4905..ad45e111f 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -14,7 +14,7 @@ jobs: runs-on: "${{ matrix.os }}" strategy: matrix: - python-version: ['3.9', '3.10', '3.11'] + python-version: ['3.10', '3.11'] os: [ubuntu-latest] steps: - uses: actions/checkout@v2 diff --git a/examples/imagenet/dataset.py b/examples/imagenet/dataset.py index 67e824bdf..0207e608a 100644 --- a/examples/imagenet/dataset.py +++ b/examples/imagenet/dataset.py @@ -18,7 +18,6 @@ import enum import itertools as it import types -from typing import Optional import jax import jax.numpy as jnp @@ -294,7 +293,7 @@ def _decode_and_random_crop(image_bytes: tf.Tensor) -> tf.Tensor: def _decode_and_center_crop( image_bytes: tf.Tensor, - jpeg_shape: Optional[tf.Tensor] = None, + jpeg_shape: tf.Tensor | None = None, ) -> tf.Tensor: """Crops to center of image with padding then scales.""" if jpeg_shape is None: diff --git a/examples/impala/agent.py b/examples/impala/agent.py index 59bafda6d..98e04d58f 100644 --- a/examples/impala/agent.py +++ b/examples/impala/agent.py @@ -14,8 +14,9 @@ # ============================================================================== """A stateless agent interface.""" import collections +from collections.abc import Callable import functools -from typing import Any, Callable, Optional +from typing import Any import dm_env import haiku as hk @@ -69,7 +70,7 @@ def initial_params(self, rng_key): return self._init_fn(rng_key, dummy_inputs, self.initial_state(1)) @functools.partial(jax.jit, static_argnums=(0, 1)) - def initial_state(self, batch_size: Optional[int]): + def initial_state(self, batch_size: int | None): """Returns agent initial state.""" # We expect that generating the initial_state does not require parameters. return self._initial_state_apply_fn(None, batch_size) diff --git a/examples/impala_lite.py b/examples/impala_lite.py index 2d59fd6f0..50a14ad22 100644 --- a/examples/impala_lite.py +++ b/examples/impala_lite.py @@ -20,10 +20,11 @@ See: https://arxiv.org/abs/1802.01561 """ +from collections.abc import Callable import functools import queue import threading -from typing import Any, Callable, NamedTuple +from typing import Any, NamedTuple from absl import app from absl import logging diff --git a/examples/mnist_pruning.py b/examples/mnist_pruning.py index e07ae469f..adab47e01 100644 --- a/examples/mnist_pruning.py +++ b/examples/mnist_pruning.py @@ -14,9 +14,8 @@ # ============================================================================== """MNIST classifier with pruning as in https://arxiv.org/abs/1710.01878 .""" -from collections.abc import Iterator, Mapping, Sequence +from collections.abc import Callable, Iterator, Mapping, Sequence import functools -from typing import Callable from absl import app import haiku as hk diff --git a/examples/transformer/model.py b/examples/transformer/model.py index 3f09c51b1..b8c3eea54 100644 --- a/examples/transformer/model.py +++ b/examples/transformer/model.py @@ -23,7 +23,6 @@ """ import dataclasses -from typing import Optional import haiku as hk import jax @@ -45,7 +44,7 @@ class Transformer(hk.Module): attn_size: int # Size of the attention (key, query, value) vectors. dropout_rate: float # Probability with which to apply dropout. widening_factor: int = 4 # Factor by which the MLP hidden layer widens. - name: Optional[str] = None # Optional identifier for the module. + name: str | None = None # Optional identifier for the module. def __call__( self, @@ -98,7 +97,7 @@ class LanguageModel(hk.Module): model_size: int # Embedding size. vocab_size: int # Size of the vocabulary. pad_token: int # Identity of the padding token (used for masking inputs). - name: Optional[str] = None # Optional identifier for the module. + name: str | None = None # Optional identifier for the module. def __call__( self, diff --git a/examples/transformer/train.py b/examples/transformer/train.py index 6950c062b..51c246549 100644 --- a/examples/transformer/train.py +++ b/examples/transformer/train.py @@ -30,7 +30,7 @@ from collections.abc import MutableMapping import time -from typing import Any, NamedTuple, Union +from typing import Any, NamedTuple from absl import app from absl import flags @@ -75,7 +75,7 @@ class TrainingState(NamedTuple): step: jax.Array # Tracks the number of training steps. -def forward_pass(tokens: Union[np.ndarray, jax.Array]) -> jax.Array: +def forward_pass(tokens: np.ndarray | jax.Array) -> jax.Array: """Defines the forward pass of the language model.""" lm = model.LanguageModel( model_size=MODEL_SIZE, diff --git a/haiku/_src/attention.py b/haiku/_src/attention.py index a03949261..bf54dffbd 100644 --- a/haiku/_src/attention.py +++ b/haiku/_src/attention.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== """(Multi-Head) Attention module for use in Transformer architectures.""" - -from typing import Optional import warnings from haiku._src import basic @@ -60,14 +58,14 @@ def __init__( num_heads: int, key_size: int, # TODO(b/240019186): Remove `w_init_scale`. - w_init_scale: Optional[float] = None, + w_init_scale: float | None = None, *, - w_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, with_bias: bool = True, - b_init: Optional[hk.initializers.Initializer] = None, - value_size: Optional[int] = None, - model_size: Optional[int] = None, - name: Optional[str] = None, + b_init: hk.initializers.Initializer | None = None, + value_size: int | None = None, + model_size: int | None = None, + name: str | None = None, ): """Initialises the module. @@ -115,7 +113,7 @@ def __call__( query: jax.Array, key: jax.Array, value: jax.Array, - mask: Optional[jax.Array] = None, + mask: jax.Array | None = None, ) -> jax.Array: """Computes (optionally masked) MHA with queries, keys & values. @@ -168,7 +166,7 @@ def _linear_projection( self, x: jax.Array, head_size: int, - name: Optional[str] = None, + name: str | None = None, ) -> jax.Array: y = hk.Linear(self.num_heads * head_size, w_init=self.w_init, with_bias=self.with_bias, b_init=self.b_init, name=name)(x) diff --git a/haiku/_src/base.py b/haiku/_src/base.py index 953d1775c..79d5a35c6 100644 --- a/haiku/_src/base.py +++ b/haiku/_src/base.py @@ -15,11 +15,11 @@ """Base Haiku module.""" import collections -from collections.abc import Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence import contextlib import functools import itertools as it -from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union +from typing import Any, NamedTuple, Optional, TypeVar import warnings from haiku._src import config @@ -88,7 +88,7 @@ class Frame(NamedTuple): # JAX values. params: MutableParams - state: Optional[MutableState] + state: MutableState | None rng_stack: Stack[Optional["PRNGSequence"]] # Pure python values. @@ -217,9 +217,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): def new_context( *, - params: Optional[Params] = None, - state: Optional[State] = None, - rng: Optional[Union[PRNGKey, int]] = None, + params: Params | None = None, + state: State | None = None, + rng: PRNGKey | int | None = None, ) -> HaikuContext: """Collects the results of hk.{get,set}_{parameter,state} calls. @@ -286,7 +286,7 @@ def safe_get_module_name(module: Module) -> str: return module.module_name -def current_module_state() -> Optional[ModuleState]: +def current_module_state() -> ModuleState | None: frame = current_frame() if frame.module_stack: return frame.module_stack.peek() @@ -295,7 +295,7 @@ def current_module_state() -> Optional[ModuleState]: @contextlib.contextmanager -def maybe_push_module_state(module_state: Optional[ModuleState]): +def maybe_push_module_state(module_state: ModuleState | None): if module_state is not None: frame = current_frame() with frame.module_stack(module_state): @@ -304,7 +304,7 @@ def maybe_push_module_state(module_state: Optional[ModuleState]): yield -def current_module() -> Optional[Module]: +def current_module() -> Module | None: module_state = current_module_state() return module_state.module if module_state is not None else None @@ -346,7 +346,7 @@ def current_name() -> str: return "~" -def get_lift_prefix() -> Optional[str]: +def get_lift_prefix() -> str | None: """Get full lifted prefix from frame_stack if in lifted init.""" if params_frozen(): return None @@ -589,7 +589,7 @@ def dtype(self): DO_NOT_STORE = DoNotStore() -def check_not_none(value: Optional[T], msg: str) -> T: +def check_not_none(value: T | None, msg: str) -> T: if value is None: raise ValueError(msg) return value @@ -604,7 +604,7 @@ def get_parameter( name: str, shape: Sequence[int], dtype: Any = jnp.float32, - init: Optional[Initializer] = None, + init: Initializer | None = None, ) -> jax.Array: """Creates or reuses a parameter for the given transformed function. @@ -721,11 +721,11 @@ class GetterContext(NamedTuple): name: The name of this parameter. """ full_name: str - module: Optional[Module] + module: Module | None original_dtype: Any original_shape: Sequence[int] - original_init: Optional[Initializer] - lifted_prefix_name: Optional[str] + original_init: Initializer | None + lifted_prefix_name: str | None @property def module_name(self): @@ -748,7 +748,7 @@ def run_creators( context: GetterContext, shape: Sequence[int], dtype: Any = jnp.float32, - init: Optional[Initializer] = None, + init: Initializer | None = None, ) -> jax.Array: """See :func:`custom_creator` for usage.""" assert stack @@ -926,10 +926,10 @@ class SetterContext(NamedTuple): name: The name of this state. """ full_name: str - module: Optional[Module] + module: Module | None original_dtype: Any original_shape: Sequence[int] - lifted_prefix_name: Optional[str] + lifted_prefix_name: str | None @property def module_name(self): @@ -1029,7 +1029,7 @@ class PRNGSequence(Iterator[PRNGKey]): """ __slots__ = ("_key", "_subkeys") - def __init__(self, key_or_seed: Union[PRNGKey, int, PRNGSequenceState]): + def __init__(self, key_or_seed: PRNGKey | int | PRNGSequenceState): """Creates a new :class:`PRNGSequence`.""" if isinstance(key_or_seed, tuple): key, subkeys = key_or_seed @@ -1210,7 +1210,7 @@ def next_rng_keys(num: int) -> jax.Array: return jnp.stack(rng_seq.take(num)) -def maybe_next_rng_key() -> Optional[PRNGKey]: +def maybe_next_rng_key() -> PRNGKey | None: """:func:`next_rng_key` if random numbers are available, else ``None``.""" assert_context("maybe_next_rng_key") rng_seq = current_frame().rng_stack.peek() @@ -1221,7 +1221,7 @@ def maybe_next_rng_key() -> Optional[PRNGKey]: return None -def maybe_get_rng_sequence_state() -> Optional[PRNGSequenceState]: +def maybe_get_rng_sequence_state() -> PRNGSequenceState | None: """Returns the internal state of the PRNG sequence. Returns: @@ -1265,9 +1265,9 @@ def extract_state(state: State, *, initial) -> MutableState: def get_state( name: str, - shape: Optional[Sequence[int]] = None, + shape: Sequence[int] | None = None, dtype: Any = jnp.float32, - init: Optional[Initializer] = None, + init: Initializer | None = None, ) -> jax.Array: """Gets the current value for state with an optional initializer. diff --git a/haiku/_src/basic.py b/haiku/_src/basic.py index 1573da0b1..76179f83c 100644 --- a/haiku/_src/basic.py +++ b/haiku/_src/basic.py @@ -14,10 +14,9 @@ # ============================================================================== """Basic Haiku modules and functions.""" -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import functools -from typing import Any, Callable, Optional - +from typing import Any from haiku._src import base from haiku._src import initializers from haiku._src import module @@ -113,7 +112,7 @@ class Sequential(hk.Module): def __init__( self, layers: Iterable[Callable[..., Any]], - name: Optional[str] = None, + name: str | None = None, ): super().__init__(name=name) self.layers = tuple(layers) @@ -136,9 +135,9 @@ def __init__( self, output_size: int, with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, - name: Optional[str] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, + name: str | None = None, ): """Constructs the Linear module. @@ -162,7 +161,7 @@ def __call__( self, inputs: jax.Array, *, - precision: Optional[lax.Precision] = None, + precision: lax.Precision | None = None, ) -> jax.Array: """Computes a linear transform of the input.""" if not inputs.shape: diff --git a/haiku/_src/batch_norm.py b/haiku/_src/batch_norm.py index dd89462f9..4a92fc8e5 100644 --- a/haiku/_src/batch_norm.py +++ b/haiku/_src/batch_norm.py @@ -15,7 +15,6 @@ """Batch Norm.""" from collections.abc import Sequence -from typing import Optional, Union from haiku._src import base from haiku._src import initializers @@ -68,13 +67,13 @@ def __init__( create_offset: bool, decay_rate: float, eps: float = 1e-5, - scale_init: Optional[hk.initializers.Initializer] = None, - offset_init: Optional[hk.initializers.Initializer] = None, - axis: Optional[Sequence[int]] = None, - cross_replica_axis: Optional[Union[str, Sequence[str]]] = None, - cross_replica_axis_index_groups: Optional[Sequence[Sequence[int]]] = None, + scale_init: hk.initializers.Initializer | None = None, + offset_init: hk.initializers.Initializer | None = None, + axis: Sequence[int] | None = None, + cross_replica_axis: str | Sequence[str] | None = None, + cross_replica_axis_index_groups: Sequence[Sequence[int]] | None = None, data_format: str = "channels_last", - name: Optional[str] = None, + name: str | None = None, ): """Constructs a BatchNorm module. @@ -130,8 +129,8 @@ def __call__( inputs: jax.Array, is_training: bool, test_local_stats: bool = False, - scale: Optional[jax.Array] = None, - offset: Optional[jax.Array] = None, + scale: jax.Array | None = None, + offset: jax.Array | None = None, ) -> jax.Array: """Computes the normalized version of the input. diff --git a/haiku/_src/bias.py b/haiku/_src/bias.py index 30de84a51..28a922a18 100644 --- a/haiku/_src/bias.py +++ b/haiku/_src/bias.py @@ -15,7 +15,6 @@ """Bias module.""" from collections.abc import Sequence -from typing import Optional, Union from haiku._src import base from haiku._src import initializers @@ -76,10 +75,10 @@ class Bias(hk.Module): def __init__( self, - output_size: Optional[Sequence[int]] = None, - bias_dims: Optional[Sequence[int]] = None, - b_init: Optional[hk.initializers.Initializer] = None, - name: Optional[str] = None, + output_size: Sequence[int] | None = None, + bias_dims: Sequence[int] | None = None, + b_init: hk.initializers.Initializer | None = None, + name: str | None = None, ): """Constructs a ``Bias`` module that supports broadcasting. @@ -103,7 +102,7 @@ def __init__( def __call__( self, inputs: jax.Array, - multiplier: Optional[Union[float, jax.Array]] = None, + multiplier: float | jax.Array | None = None, ) -> jax.Array: """Adds bias to ``inputs`` and optionally multiplies by ``multiplier``. diff --git a/haiku/_src/config.py b/haiku/_src/config.py index 5fee80d9b..a62fba332 100644 --- a/haiku/_src/config.py +++ b/haiku/_src/config.py @@ -17,7 +17,6 @@ import contextlib import dataclasses import threading -from typing import Optional @dataclasses.dataclass @@ -48,10 +47,10 @@ def write(config, **overrides): # pylint: disable=redefined-outer-name,unused-argument def context( *, - check_jax_usage: Optional[bool] = None, - module_auto_repr: Optional[bool] = None, - restore_flatmap: Optional[bool] = None, - rng_reserve_size: Optional[int] = None, + check_jax_usage: bool | None = None, + module_auto_repr: bool | None = None, + restore_flatmap: bool | None = None, + rng_reserve_size: int | None = None, ): """Context manager for setting config options. @@ -86,10 +85,10 @@ def context( # pylint: disable=redefined-outer-name,unused-argument,redefined-builtin def set( *, - check_jax_usage: Optional[bool] = None, - module_auto_repr: Optional[bool] = None, - restore_flatmap: Optional[bool] = None, - rng_reserve_size: Optional[int] = None, + check_jax_usage: bool | None = None, + module_auto_repr: bool | None = None, + restore_flatmap: bool | None = None, + rng_reserve_size: int | None = None, ): """Sets the given config option(s). diff --git a/haiku/_src/conv.py b/haiku/_src/conv.py index effd59785..f3e2e8379 100644 --- a/haiku/_src/conv.py +++ b/haiku/_src/conv.py @@ -15,7 +15,6 @@ """Convolutional Haiku modules.""" from collections.abc import Sequence -from typing import Optional, Union from haiku._src import base from haiku._src import initializers @@ -71,19 +70,22 @@ def __init__( self, num_spatial_dims: int, output_channels: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - rate: Union[int, Sequence[int]] = 1, - padding: Union[ - str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn] - ] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + rate: int | Sequence[int] = 1, + padding: ( + str + | Sequence[tuple[int, int]] + | hk.pad.PadFn + | Sequence[hk.pad.PadFn] + ) = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "channels_last", - mask: Optional[jax.Array] = None, + mask: jax.Array | None = None, feature_group_count: int = 1, - name: Optional[str] = None, + name: str | None = None, ): """Initializes the module. @@ -160,7 +162,7 @@ def __call__( self, inputs: jax.Array, *, - precision: Optional[lax.Precision] = None, + precision: lax.Precision | None = None, ) -> jax.Array: """Connects ``ConvND`` layer. @@ -237,19 +239,22 @@ class Conv1D(ConvND): def __init__( self, output_channels: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - rate: Union[int, Sequence[int]] = 1, - padding: Union[ - str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn] - ] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + rate: int | Sequence[int] = 1, + padding: ( + str + | Sequence[tuple[int, int]] + | hk.pad.PadFn + | Sequence[hk.pad.PadFn] + ) = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "NWC", - mask: Optional[jax.Array] = None, + mask: jax.Array | None = None, feature_group_count: int = 1, - name: Optional[str] = None, + name: str | None = None, ): """Initializes the module. @@ -305,19 +310,22 @@ class Conv2D(ConvND): def __init__( self, output_channels: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - rate: Union[int, Sequence[int]] = 1, - padding: Union[ - str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn] - ] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + rate: int | Sequence[int] = 1, + padding: ( + str + | Sequence[tuple[int, int]] + | hk.pad.PadFn + | Sequence[hk.pad.PadFn] + ) = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "NHWC", - mask: Optional[jax.Array] = None, + mask: jax.Array | None = None, feature_group_count: int = 1, - name: Optional[str] = None, + name: str | None = None, ): """Initializes the module. @@ -373,19 +381,22 @@ class Conv3D(ConvND): def __init__( self, output_channels: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - rate: Union[int, Sequence[int]] = 1, - padding: Union[ - str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn] - ] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + rate: int | Sequence[int] = 1, + padding: ( + str + | Sequence[tuple[int, int]] + | hk.pad.PadFn + | Sequence[hk.pad.PadFn] + ) = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "NDHWC", - mask: Optional[jax.Array] = None, + mask: jax.Array | None = None, feature_group_count: int = 1, - name: Optional[str] = None, + name: str | None = None, ): """Initializes the module. @@ -478,16 +489,16 @@ def __init__( self, num_spatial_dims: int, output_channels: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - output_shape: Optional[Union[int, Sequence[int]]] = None, - padding: Union[str, Sequence[tuple[int, int]]] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + output_shape: int | Sequence[int] | None = None, + padding: str | Sequence[tuple[int, int]] = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "channels_last", - mask: Optional[jax.Array] = None, - name: Optional[str] = None, + mask: jax.Array | None = None, + name: str | None = None, ): """Initializes the module. @@ -548,7 +559,7 @@ def __call__( self, inputs: jax.Array, *, - precision: Optional[lax.Precision] = None, + precision: lax.Precision | None = None, ) -> jax.Array: """Computes the transposed convolution of the input. @@ -625,16 +636,16 @@ class Conv1DTranspose(ConvNDTranspose): def __init__( self, output_channels: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - output_shape: Optional[Union[int, Sequence[int]]] = None, - padding: Union[str, Sequence[tuple[int, int]]] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + output_shape: int | Sequence[int] | None = None, + padding: str | Sequence[tuple[int, int]] = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "NWC", - mask: Optional[jax.Array] = None, - name: Optional[str] = None, + mask: jax.Array | None = None, + name: str | None = None, ): """Initializes the module. @@ -679,16 +690,16 @@ class Conv2DTranspose(ConvNDTranspose): def __init__( self, output_channels: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - output_shape: Optional[Union[int, Sequence[int]]] = None, - padding: Union[str, Sequence[tuple[int, int]]] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + output_shape: int | Sequence[int] | None = None, + padding: str | Sequence[tuple[int, int]] = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "NHWC", - mask: Optional[jax.Array] = None, - name: Optional[str] = None, + mask: jax.Array | None = None, + name: str | None = None, ): """Initializes the module. @@ -733,16 +744,16 @@ class Conv3DTranspose(ConvNDTranspose): def __init__( self, output_channels: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - output_shape: Optional[Union[int, Sequence[int]]] = None, - padding: Union[str, Sequence[tuple[int, int]]] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + output_shape: int | Sequence[int] | None = None, + padding: str | Sequence[tuple[int, int]] = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "NDHWC", - mask: Optional[jax.Array] = None, - name: Optional[str] = None, + mask: jax.Array | None = None, + name: str | None = None, ): """Initializes the module. diff --git a/haiku/_src/data_structures.py b/haiku/_src/data_structures.py index 1c3db70bc..191a5ca2e 100644 --- a/haiku/_src/data_structures.py +++ b/haiku/_src/data_structures.py @@ -20,12 +20,11 @@ # for users. import collections -from collections.abc import Iterator, Mapping, MutableMapping, Sequence +from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence import contextlib import pprint import threading -from typing import Any, Callable, Deque, Generic, NamedTuple, Optional, TypeVar, Union - +from typing import Any, Deque, Generic, NamedTuple, TypeVar from haiku._src import config from haiku._src import utils import jax @@ -332,7 +331,7 @@ def __getattr__(self, key): raise AttributeError( f"x.{key} is not supported on frozendict, use x['{key}'] instead.") - def get(self, key: K, default: Optional[T] = None) -> Union[V, Optional[T]]: + def get(self, key: K, default: T | None = None) -> V | T | None: return self._storage.get(key, default) def __getitem__(self, key: K) -> V: diff --git a/haiku/_src/deferred.py b/haiku/_src/deferred.py index 6ea452f01..fdc19e06b 100644 --- a/haiku/_src/deferred.py +++ b/haiku/_src/deferred.py @@ -14,9 +14,8 @@ # ============================================================================== """Enables module construction to be deferred.""" -from collections.abc import Sequence -from typing import Callable, Generic, TypeVar - +from collections.abc import Callable, Sequence +from typing import Generic, TypeVar from haiku._src import base from haiku._src import module diff --git a/haiku/_src/depthwise_conv.py b/haiku/_src/depthwise_conv.py index 0bd9bf1b4..a6e05b060 100644 --- a/haiku/_src/depthwise_conv.py +++ b/haiku/_src/depthwise_conv.py @@ -15,7 +15,6 @@ """Depthwise Convolutional Haiku module.""" from collections.abc import Sequence -from typing import Optional, Union from haiku._src import base from haiku._src import initializers @@ -56,16 +55,16 @@ class DepthwiseConvND(hk.Module): def __init__( self, channel_multiplier: int, - kernel_shape: Union[int, Sequence[int]], + kernel_shape: int | Sequence[int], num_spatial_dims: int, data_format: str, - stride: Union[int, Sequence[int]] = 1, - rate: Union[int, Sequence[int]] = 1, - padding: Union[str, Sequence[tuple[int, int]]] = "SAME", + stride: int | Sequence[int] = 1, + rate: int | Sequence[int] = 1, + padding: str | Sequence[tuple[int, int]] = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, - name: Optional[str] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, + name: str | None = None, ): """Construct an ND Depthwise Convolution. @@ -150,14 +149,14 @@ class SeparableDepthwiseConv2D(hk.Module): def __init__( self, channel_multiplier: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - padding: Union[str, Sequence[tuple[int, int]]] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + padding: str | Sequence[tuple[int, int]] = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "NHWC", - name: Optional[str] = None, + name: str | None = None, ): """Construct a Separable 2D Depthwise Convolution module. @@ -211,15 +210,15 @@ class DepthwiseConv1D(DepthwiseConvND): def __init__( self, channel_multiplier: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - rate: Union[int, Sequence[int]] = 1, - padding: Union[str, Sequence[tuple[int, int]]] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + rate: int | Sequence[int] = 1, + padding: str | Sequence[tuple[int, int]] = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "NWC", - name: Optional[str] = None, + name: str | None = None, ): """Construct a 1D Depthwise Convolution. @@ -264,15 +263,15 @@ class DepthwiseConv2D(DepthwiseConvND): def __init__( self, channel_multiplier: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - rate: Union[int, Sequence[int]] = 1, - padding: Union[str, Sequence[tuple[int, int]]] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + rate: int | Sequence[int] = 1, + padding: str | Sequence[tuple[int, int]] = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "NHWC", - name: Optional[str] = None, + name: str | None = None, ): """Construct a 2D Depthwise Convolution. @@ -317,15 +316,15 @@ class DepthwiseConv3D(DepthwiseConvND): def __init__( self, channel_multiplier: int, - kernel_shape: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - rate: Union[int, Sequence[int]] = 1, - padding: Union[str, Sequence[tuple[int, int]]] = "SAME", + kernel_shape: int | Sequence[int], + stride: int | Sequence[int] = 1, + rate: int | Sequence[int] = 1, + padding: str | Sequence[tuple[int, int]] = "SAME", with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, data_format: str = "NDHWC", - name: Optional[str] = None, + name: str | None = None, ): """Construct a 3D Depthwise Convolution. diff --git a/haiku/_src/dot.py b/haiku/_src/dot.py index 6e03cd757..8c80e6357 100644 --- a/haiku/_src/dot.py +++ b/haiku/_src/dot.py @@ -15,10 +15,11 @@ """Converts Haiku functions to dot.""" import collections +from collections.abc import Callable import contextlib import functools import html -from typing import Any, Callable, NamedTuple, Optional +from typing import Any, NamedTuple from haiku._src import data_structures from haiku._src import module @@ -51,7 +52,7 @@ class Graph(NamedTuple): subgraphs: list['Graph'] @classmethod - def create(cls, title: Optional[str] = None): + def create(cls, title: str | None = None): return Graph(title=title, nodes=[], edges=[], subgraphs=[]) def evolve(self, **kwargs) -> 'Graph': @@ -295,7 +296,7 @@ def format_path(path): outids = {id(v) for v in jax.tree.leaves(outputs)} outname = {id(v): format_path(p) for p, v in tree.flatten_with_path(outputs)} - def render_graph(g: Graph, parent: Optional[Graph] = None, depth: int = 0): + def render_graph(g: Graph, parent: Graph | None = None, depth: int = 0): """Renders a given graph by appending 'dot' format lines.""" if parent: diff --git a/haiku/_src/embed.py b/haiku/_src/embed.py index 1296b2bb8..52c87ca91 100644 --- a/haiku/_src/embed.py +++ b/haiku/_src/embed.py @@ -16,7 +16,6 @@ from collections.abc import Sequence import enum -from typing import Optional, Union from haiku._src import base from haiku._src import initializers @@ -62,12 +61,12 @@ class Embed(hk.Module): def __init__( self, - vocab_size: Optional[int] = None, - embed_dim: Optional[int] = None, - embedding_matrix: Optional[Union[np.ndarray, jax.Array]] = None, - w_init: Optional[hk.initializers.Initializer] = None, - lookup_style: Union[str, EmbedLookupStyle] = "ARRAY_INDEX", - name: Optional[str] = None, + vocab_size: int | None = None, + embed_dim: int | None = None, + embedding_matrix: np.ndarray | jax.Array | None = None, + w_init: hk.initializers.Initializer | None = None, + lookup_style: str | EmbedLookupStyle = "ARRAY_INDEX", + name: str | None = None, precision: jax.lax.Precision = jax.lax.Precision.HIGHEST, ): """Constructs an Embed module. @@ -141,9 +140,9 @@ def embeddings(self): def __call__( self, - ids: Union[jax.Array, Sequence[int]], - lookup_style: Optional[Union[str, hk.EmbedLookupStyle]] = None, - precision: Optional[jax.lax.Precision] = None, + ids: jax.Array | Sequence[int], + lookup_style: str | hk.EmbedLookupStyle | None = None, + precision: jax.lax.Precision | None = None, ) -> jax.Array: r"""Lookup embeddings. diff --git a/haiku/_src/filtering.py b/haiku/_src/filtering.py index 0e0c26ab5..8cb4542a8 100644 --- a/haiku/_src/filtering.py +++ b/haiku/_src/filtering.py @@ -15,8 +15,8 @@ """Functions for filtering parameters and state in Haiku.""" import collections -from collections.abc import Generator, Mapping, MutableMapping -from typing import (Any, Callable, TypeVar) +from collections.abc import Callable, Generator, Mapping, MutableMapping +from typing import Any, TypeVar from haiku._src import data_structures from haiku._src import utils diff --git a/haiku/_src/filtering_test.py b/haiku/_src/filtering_test.py index e4d86f3c3..72d34e027 100644 --- a/haiku/_src/filtering_test.py +++ b/haiku/_src/filtering_test.py @@ -15,11 +15,11 @@ """Tests for haiku._src.filtering.""" import collections -from collections.abc import Sequence +from collections.abc import Callable, Sequence import itertools import re import types -from typing import Any, Callable +from typing import Any from absl.testing import absltest from absl.testing import parameterized diff --git a/haiku/_src/flax/flax_module.py b/haiku/_src/flax/flax_module.py index 31e62a7a4..5f3b2ec30 100644 --- a/haiku/_src/flax/flax_module.py +++ b/haiku/_src/flax/flax_module.py @@ -15,7 +15,7 @@ """Utilities for converting Haiku modules to Flax modules.""" -from typing import TypeVar, Union +from typing import TypeVar import flax.core import flax.linen as nn @@ -108,7 +108,7 @@ class Module(nn.Module): >>> out = mod.apply(variables, x) """ - transformed: Union[hk.Transformed, hk.TransformedWithState] + transformed: hk.Transformed | hk.TransformedWithState def __post_init__(self): super().__post_init__() diff --git a/haiku/_src/flax/transform_flax.py b/haiku/_src/flax/transform_flax.py index 4b46a5346..ebf1e1521 100644 --- a/haiku/_src/flax/transform_flax.py +++ b/haiku/_src/flax/transform_flax.py @@ -15,8 +15,8 @@ """Utilities for converting Flax modules to use with Haiku.""" -from collections.abc import Mapping -from typing import Any, Callable +from collections.abc import Callable, Mapping +from typing import Any import flax.errors import flax.linen as nn diff --git a/haiku/_src/group_norm.py b/haiku/_src/group_norm.py index 032c1a676..6089156d4 100644 --- a/haiku/_src/group_norm.py +++ b/haiku/_src/group_norm.py @@ -16,7 +16,6 @@ import collections from collections.abc import Sequence -from typing import Optional, Union from haiku._src import base from haiku._src import initializers @@ -73,14 +72,14 @@ class GroupNorm(hk.Module): def __init__( self, groups: int, - axis: Union[int, slice, Sequence[int]] = slice(1, None), + axis: int | slice | Sequence[int] = slice(1, None), create_scale: bool = True, create_offset: bool = True, eps: float = 1e-5, - scale_init: Optional[hk.initializers.Initializer] = None, - offset_init: Optional[hk.initializers.Initializer] = None, + scale_init: hk.initializers.Initializer | None = None, + offset_init: hk.initializers.Initializer | None = None, data_format: str = "channels_last", - name: Optional[str] = None, + name: str | None = None, ): """Constructs a ``GroupNorm`` module. @@ -144,8 +143,8 @@ def __init__( def __call__( self, x: jax.Array, - scale: Optional[jax.Array] = None, - offset: Optional[jax.Array] = None, + scale: jax.Array | None = None, + offset: jax.Array | None = None, ) -> jax.Array: """Returns normalized inputs. diff --git a/haiku/_src/initializers.py b/haiku/_src/initializers.py index cb9b8d9b1..43ca5912f 100644 --- a/haiku/_src/initializers.py +++ b/haiku/_src/initializers.py @@ -15,7 +15,7 @@ """Haiku initializers.""" from collections.abc import Sequence -from typing import Any, Union +from typing import Any from haiku._src import base from haiku._src.typing import Initializer @@ -62,7 +62,7 @@ class Constant(hk.initializers.Initializer): """Initializes with a constant.""" def __init__( - self, constant: Union[float, int, complex, np.ndarray, jax.Array] + self, constant: float | int | complex | np.ndarray | jax.Array ): """Constructs a Constant initializer. @@ -98,10 +98,10 @@ class TruncatedNormal(hk.initializers.Initializer): """Initializes by sampling from a truncated normal distribution.""" def __init__(self, - stddev: Union[float, jax.Array] = 1., - mean: Union[float, complex, jax.Array] = 0.0, - lower: Union[float, jax.Array] = -2.0, - upper: Union[float, jax.Array] = 2.0, + stddev: float | jax.Array = 1., + mean: float | complex | jax.Array = 0.0, + lower: float | jax.Array = -2.0, + upper: float | jax.Array = 2.0, ): """Constructs a :class:`TruncatedNormal` initializer. @@ -305,7 +305,7 @@ class Identity(hk.initializers.Initializer): Constructs a 2D identity matrix or batches of these. """ - def __init__(self, gain: Union[float, np.ndarray, jax.Array] = 1.0): + def __init__(self, gain: float | np.ndarray | jax.Array = 1.0): """Constructs an :class:`Identity` initializer. Args: diff --git a/haiku/_src/integration/descriptors.py b/haiku/_src/integration/descriptors.py index 6b0b4ff31..1071a9f16 100644 --- a/haiku/_src/integration/descriptors.py +++ b/haiku/_src/integration/descriptors.py @@ -14,8 +14,8 @@ # ============================================================================== """Module descriptors programatically describe how to use modules.""" -from collections.abc import Sequence -from typing import Any, Callable, NamedTuple +from collections.abc import Callable, Sequence +from typing import Any, NamedTuple import haiku as hk import jax diff --git a/haiku/_src/jaxpr_info.py b/haiku/_src/jaxpr_info.py index 20dc6f701..e1a2e2d0d 100644 --- a/haiku/_src/jaxpr_info.py +++ b/haiku/_src/jaxpr_info.py @@ -27,13 +27,13 @@ print(jaxpr_info.format_module(mod)) """ -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence import dataclasses import itertools import logging import os import sys -from typing import Any, Callable, NamedTuple, Optional +from typing import Any, NamedTuple from haiku._src import summarise import jax @@ -49,7 +49,7 @@ class Module: # How many flops it takes to compute this module, including all operations # contained by sub-modules. # Only populated if `compute_flops` was passed to `make_model_info`. - flops: Optional[int] = None + flops: int | None = None # Expressions that are directly part of this module, e.g. in the __call__ # function. @@ -78,7 +78,7 @@ class Expression: # Estimated number of flops required to compute this expression. # Only populated if `compute_flops` was passed to `make_model_info`. - flops: Optional[int] = None + flops: int | None = None # Additional details, e.g. input/output shapes. details: str = '' @@ -89,7 +89,7 @@ class Expression: # Some expressions, such as named_call, contain a whole subtree of modules # and expressions. - submodule: Optional[Module] = None + submodule: Module | None = None # For internal use only, the first variable in outvars. first_outvar: str = '' @@ -103,10 +103,10 @@ class Expression: def make_model_info( f: Callable[..., Any], - name: Optional[str] = None, + name: str | None = None, include_module_info: bool = True, - compute_flops: Optional[ComputeFlopsFn] = None, - axis_env: Optional[Sequence[tuple[Any, int]]] = None, + compute_flops: ComputeFlopsFn | None = None, + axis_env: Sequence[tuple[Any, int]] | None = None, ) -> Callable[..., Module]: """Creates a function that computes flop, param and state information. @@ -270,16 +270,16 @@ def _process_eqn( eqn: jax.core.JaxprEqn, seen: set[str], eqns_by_output: Mapping[str, jax.core.JaxprEqn], - compute_flops: Optional[ComputeFlopsFn], + compute_flops: ComputeFlopsFn | None, scope: _ModuleScope, module: Module, binder_idx: dict[jax.core.Var, int], -) -> Optional[int]: +) -> int | None: """Recursive walks the JaxprEqn to compute the flops it takes.""" for out_var in eqn.outvars: _mark_seen(binder_idx, seen, out_var, scope) - outvars = sorted([_var_to_str(binder_idx, e) for e in eqn.outvars], + outvars = sorted((_var_to_str(binder_idx, e) for e in eqn.outvars), key=_var_sort_key) name_stack = str(eqn.source_info.name_stack) expression = Expression( @@ -392,11 +392,11 @@ def _process_eqn( def _process_jaxpr( jaxpr: jax.core.Jaxpr, - compute_flops: Optional[ComputeFlopsFn], + compute_flops: ComputeFlopsFn | None, scope: _ModuleScope, seen: set[str], module: Module, -) -> Optional[int]: +) -> int | None: """Computes the flops used for a JAX expression, tracking module scope.""" if isinstance(jaxpr, jax.core.ClosedJaxpr): return _process_jaxpr(jaxpr.jaxpr, compute_flops, scope, seen, module) diff --git a/haiku/_src/jaxpr_info_test.py b/haiku/_src/jaxpr_info_test.py index 4eee01b1b..049b599fe 100644 --- a/haiku/_src/jaxpr_info_test.py +++ b/haiku/_src/jaxpr_info_test.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """Tests for jaxpr_info.""" -from typing import Optional from absl import logging from absl.testing import absltest @@ -28,7 +27,7 @@ class MyModel(module.Module): - def __init__(self, name: Optional[str] = None): + def __init__(self, name: str | None = None): super().__init__(name=name) def __call__(self, x: jax.Array): diff --git a/haiku/_src/layer_norm.py b/haiku/_src/layer_norm.py index 6bc6e9252..1c2311ed7 100644 --- a/haiku/_src/layer_norm.py +++ b/haiku/_src/layer_norm.py @@ -16,7 +16,6 @@ import collections.abc from collections.abc import Sequence -from typing import Optional, Union from haiku._src import base from haiku._src import initializers @@ -37,8 +36,8 @@ class hk: # pylint: enable=invalid-name del base, module, initializers, utils -AxisOrAxes = Union[int, Sequence[int], slice] -AxesOrSlice = Union[tuple[int, ...], slice] +AxisOrAxes = int | Sequence[int] | slice +AxesOrSlice = tuple[int, ...] | slice # TODO(tomhennigan): Update users to `param_axis=-1` and flip + remove this. ERROR_IF_PARAM_AXIS_NOT_EXPLICIT = False @@ -82,12 +81,12 @@ def __init__( create_scale: bool, create_offset: bool, eps: float = 1e-5, - scale_init: Optional[hk.initializers.Initializer] = None, - offset_init: Optional[hk.initializers.Initializer] = None, + scale_init: hk.initializers.Initializer | None = None, + offset_init: hk.initializers.Initializer | None = None, use_fast_variance: bool = False, - name: Optional[str] = None, + name: str | None = None, *, - param_axis: Optional[AxisOrAxes] = None, + param_axis: AxisOrAxes | None = None, ): """Constructs a LayerNorm module. @@ -131,8 +130,8 @@ def __init__( def __call__( self, inputs: jax.Array, - scale: Optional[jax.Array] = None, - offset: Optional[jax.Array] = None, + scale: jax.Array | None = None, + offset: jax.Array | None = None, ) -> jax.Array: """Connects the layer norm. @@ -219,10 +218,10 @@ def __init__( create_scale: bool, create_offset: bool, eps: float = 1e-5, - scale_init: Optional[hk.initializers.Initializer] = None, - offset_init: Optional[hk.initializers.Initializer] = None, + scale_init: hk.initializers.Initializer | None = None, + offset_init: hk.initializers.Initializer | None = None, data_format: str = "channels_last", - name: Optional[str] = None, + name: str | None = None, ): """Constructs an :class:`InstanceNorm` module. diff --git a/haiku/_src/layer_stack.py b/haiku/_src/layer_stack.py index 14b22b88a..a2c3175ba 100644 --- a/haiku/_src/layer_stack.py +++ b/haiku/_src/layer_stack.py @@ -15,9 +15,10 @@ """Function to stack repeats of a layer function without shared parameters.""" import collections +from collections.abc import Callable import functools import inspect -from typing import Any, Callable, Optional, Protocol, Union +from typing import Any, Protocol, Union from haiku._src import base from haiku._src import lift @@ -47,7 +48,7 @@ def _check_no_varargs(f): "arguments") -def _get_rng_stack(count: int) -> Optional[jax.Array]: +def _get_rng_stack(count: int) -> jax.Array | None: rng = base.maybe_next_rng_key() if rng is None: return None @@ -63,7 +64,7 @@ def stacked_to_flat(self, stacked_module_name: str, scan_idx: int) -> str: def flat_to_stacked( self, unstacked_module_name: str - ) -> Optional[tuple[str, int]]: + ) -> tuple[str, int] | None: """Creates stacked module name and scan index from flat name. Returns None when the module is not a part of layer_stack. This happens @@ -136,7 +137,7 @@ def __init__( count: int, unroll: int, pass_reverse_to_layer_fn: bool = False, - transparency_map: Optional[LayerStackTransparencyMapping] = None, + transparency_map: LayerStackTransparencyMapping | None = None, name: str = "", ): """Iterate f count times, with non-shared parameters.""" @@ -218,7 +219,7 @@ def _call_wrapped( self, x: jax.Array, *args, - ) -> tuple[jax.Array, Optional[jax.Array]]: + ) -> tuple[jax.Array, jax.Array | None]: raise NotImplementedError() @@ -231,7 +232,7 @@ def __init__( count: int, unroll: int, pass_reverse_to_layer_fn: bool = False, - transparency_map: Optional[LayerStackTransparencyMapping] = None, + transparency_map: LayerStackTransparencyMapping | None = None, name: str = "", ): super().__init__( @@ -263,7 +264,7 @@ def __init__( count: int, unroll: int, pass_reverse_to_layer_fn: bool = False, - transparency_map: Optional[LayerStackTransparencyMapping] = None, + transparency_map: LayerStackTransparencyMapping | None = None, name: str = "", ): super().__init__( @@ -286,8 +287,8 @@ def layer_stack( unroll: int = 1, pass_reverse_to_layer_fn: bool = False, transparent: bool = False, - transparency_map: Optional[LayerStackTransparencyMapping] = None, - name: Optional[str] = None, + transparency_map: LayerStackTransparencyMapping | None = None, + name: str | None = None, ): """Utility to wrap a Haiku function and recursively apply it to an input. diff --git a/haiku/_src/layer_stack_test.py b/haiku/_src/layer_stack_test.py index 21f123e9d..e3da3cce0 100644 --- a/haiku/_src/layer_stack_test.py +++ b/haiku/_src/layer_stack_test.py @@ -16,7 +16,7 @@ import functools import re -from typing import Optional + from absl.testing import absltest from absl.testing import parameterized from haiku._src import base @@ -565,7 +565,7 @@ def stacked_to_flat(self, stacked_module_name: str, scan_idx: int) -> str: def flat_to_stacked( self, unstacked_module_name: str - ) -> Optional[tuple[str, int]]: + ) -> tuple[str, int] | None: idx = int(re.findall(r"\d+", unstacked_module_name)[0]) return unstacked_module_name.replace(str(idx), "0"), idx @@ -611,7 +611,7 @@ def stacked_to_flat(self, stacked_module_name: str, scan_idx: int) -> str: def flat_to_stacked( self, unstacked_module_name: str - ) -> Optional[tuple[str, int]]: + ) -> tuple[str, int] | None: idx = int(re.findall(r"\d+", unstacked_module_name)[0]) return unstacked_module_name.replace(str(idx), "0"), idx diff --git a/haiku/_src/lift.py b/haiku/_src/lift.py index 7f97e3f83..3966b5e88 100644 --- a/haiku/_src/lift.py +++ b/haiku/_src/lift.py @@ -14,9 +14,9 @@ # ============================================================================== """Lifting parameters in Haiku.""" -from collections.abc import Mapping, MutableMapping +from collections.abc import Callable, Mapping, MutableMapping import functools -from typing import Any, Callable, Optional, TypeVar +from typing import Any, TypeVar from haiku._src import base from haiku._src import data_structures @@ -325,7 +325,7 @@ class LiftWithStateUpdater: __slots__ = ("_used", "_name", "_context_id") - def __init__(self, name: Optional[str]): + def __init__(self, name: str | None): self._used = False self._name = name ctx = base.current_context() diff --git a/haiku/_src/mixed_precision.py b/haiku/_src/mixed_precision.py index ee9198b71..b2038f211 100644 --- a/haiku/_src/mixed_precision.py +++ b/haiku/_src/mixed_precision.py @@ -17,7 +17,7 @@ import collections import contextlib import threading -from typing import TypeVar, Optional, Union +from typing import TypeVar, Union from haiku._src import base from haiku._src import data_structures @@ -89,7 +89,7 @@ def set_policy(self, cls: type[hk.Module], policy: jmp.Policy): key = key_for_module(cls) self._cls_policy[key] = policy - def get_policy(self, cls: type[hk.Module]) -> Optional[jmp.Policy]: + def get_policy(self, cls: type[hk.Module]) -> jmp.Policy | None: key = key_for_module(cls) return self._cls_policy.get(key) @@ -102,7 +102,7 @@ def reset_thread_local_state_for_test(): _thread_local_state = _ThreadState() -def current_policy() -> Optional[jmp.Policy]: +def current_policy() -> jmp.Policy | None: """Retrieves the currently active policy in the current context. Returns: @@ -118,7 +118,7 @@ def current_policy() -> Optional[jmp.Policy]: return tls.current_policy if tls.has_current_policy else None -def get_policy(cls: type[hk.Module]) -> Optional[jmp.Policy]: +def get_policy(cls: type[hk.Module]) -> jmp.Policy | None: """Retrieves the currently active policy for the given class. Note that policies applied explicitly to a top level class (e.g. ``ResNet``) diff --git a/haiku/_src/mixed_precision_test.py b/haiku/_src/mixed_precision_test.py index 4196247e4..116f1629b 100644 --- a/haiku/_src/mixed_precision_test.py +++ b/haiku/_src/mixed_precision_test.py @@ -15,7 +15,6 @@ """Tests for haiku._src.mixed_precision.""" import importlib -from typing import Optional from absl.testing import absltest from haiku._src import base @@ -29,7 +28,7 @@ import jmp -def with_policy(cls: type[module.Module], policy: Optional[jmp.Policy]): +def with_policy(cls: type[module.Module], policy: jmp.Policy | None): def decorator(f): def wrapper(*args, **kwargs): with mixed_precision.push_policy(cls, policy): diff --git a/haiku/_src/module.py b/haiku/_src/module.py index 25c279e11..028711fbd 100644 --- a/haiku/_src/module.py +++ b/haiku/_src/module.py @@ -14,12 +14,12 @@ # ============================================================================== """Base Haiku module.""" -from collections.abc import Mapping +from collections.abc import Callable, Mapping import contextlib import functools import inspect import re -from typing import Any, Callable, ContextManager, NamedTuple, Optional, Protocol, TypeVar +from typing import Any, ContextManager, NamedTuple, Protocol, TypeVar from haiku._src import base from haiku._src import config @@ -484,7 +484,7 @@ def wrapped(self, *args, **kwargs): valid_identifier = lambda name: bool(_VALID_IDENTIFIER_R.match(name)) -def name_and_number(name: str) -> tuple[str, Optional[int]]: +def name_and_number(name: str) -> tuple[str, int | None]: splits = re.split(r"_(0|[1-9]\d*)$", name, 3) if len(splits) > 1: return splits[0], int(splits[1]) @@ -635,7 +635,7 @@ class Module(metaclass=ModuleMetaclass): 2.0 """ - def __init__(self, name: Optional[str] = None): + def __init__(self, name: str | None = None): """Initializes the current module with the given name. Subclasses should call this constructor before creating other modules or diff --git a/haiku/_src/module_test.py b/haiku/_src/module_test.py index 2f23108fc..255a7adc6 100644 --- a/haiku/_src/module_test.py +++ b/haiku/_src/module_test.py @@ -15,11 +15,11 @@ """Tests for haiku._src.module.""" import abc -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import dataclasses import inspect -from typing import Callable, Optional, Protocol, TypeVar, runtime_checkable +from typing import Optional, Protocol, TypeVar, runtime_checkable from absl.testing import absltest from absl.testing import parameterized @@ -923,7 +923,7 @@ def __call__(self): class DataLinear(module.Module): output_size: int - name: Optional[str] = None + name: str | None = None def __call__(self, x): j, k = x.shape[-1], self.output_size @@ -937,7 +937,7 @@ class DataMLP(module.Module): output_sizes: Sequence[int] activation: Callable[[jax.Array], jax.Array] = jax.nn.relu - name: Optional[str] = None + name: str | None = None def __call__(self, x): for i, output_size in enumerate(self.output_sizes): diff --git a/haiku/_src/moving_averages.py b/haiku/_src/moving_averages.py index 7c2dcb2fc..15ee2022c 100644 --- a/haiku/_src/moving_averages.py +++ b/haiku/_src/moving_averages.py @@ -15,7 +15,6 @@ """Moving averages.""" import re -from typing import Optional, Union import warnings from haiku._src import base @@ -50,7 +49,7 @@ def __init__( decay, zero_debias: bool = True, warmup_length: int = 0, - name: Optional[str] = None, + name: str | None = None, ): """Initializes an ExponentialMovingAverage module. @@ -93,7 +92,7 @@ def initialize(self, shape, dtype=jnp.float32): def __call__( self, - value: Union[float, jax.Array], + value: float | jax.Array, update_stats: bool = True, ) -> jax.Array: """Updates the EMA and returns the new value. @@ -173,7 +172,7 @@ def __init__( zero_debias: bool = True, warmup_length: int = 0, ignore_regex: str = "", - name: Optional[str] = None, + name: str | None = None, ): """Initializes an EMAParamsTree module. diff --git a/haiku/_src/multi_transform.py b/haiku/_src/multi_transform.py index 7aa446cc1..98c47ddff 100644 --- a/haiku/_src/multi_transform.py +++ b/haiku/_src/multi_transform.py @@ -16,10 +16,11 @@ # pylint: disable=unnecessary-lambda +from collections.abc import Callable import dataclasses import functools import inspect -from typing import Any, Callable, NamedTuple, Optional, TypeVar +from typing import Any, NamedTuple, Optional, TypeVar from haiku._src import analytics from haiku._src import transform diff --git a/haiku/_src/multi_transform_test.py b/haiku/_src/multi_transform_test.py index 9bfe9af91..fe34f38ac 100644 --- a/haiku/_src/multi_transform_test.py +++ b/haiku/_src/multi_transform_test.py @@ -14,7 +14,7 @@ # ============================================================================== """Tests for haiku._src.multi_transform.""" import inspect -from typing import Optional, Union + from absl.testing import absltest from absl.testing import parameterized from haiku._src import base @@ -136,13 +136,13 @@ def f(pos, key=37) -> int: return 2 def expected_f_init( - rng: Optional[Union[PRNGKey, int]], pos, key=37 + rng: PRNGKey | int | None, pos, key=37 ) -> tuple[Params, State]: del rng, pos, key raise NotImplementedError def expected_f_apply( - params: Optional[Params], state: Optional[State], pos, key=37 + params: Params | None, state: State | None, pos, key=37 ) -> tuple[int, State]: del params, state, pos, key raise NotImplementedError @@ -158,12 +158,12 @@ def test_signature_without_apply_rng_transform(self): def f(pos, *, key: int = 37) -> int: del pos, key return 2 - def expected_f_init(rng: Optional[Union[PRNGKey, int]], + def expected_f_init(rng: PRNGKey | int | None, pos, *, key: int = 37) -> Params: del rng, pos, key raise NotImplementedError def expected_f_apply( - params: Optional[Params], pos, *, key: int = 37) -> int: + params: Params | None, pos, *, key: int = 37) -> int: del params, pos, key raise NotImplementedError self.assertEqual( diff --git a/haiku/_src/nets/mlp.py b/haiku/_src/nets/mlp.py index d419b65e4..fde1d6316 100644 --- a/haiku/_src/nets/mlp.py +++ b/haiku/_src/nets/mlp.py @@ -14,8 +14,7 @@ # ============================================================================== """A minimal interface mlp module.""" -from collections.abc import Iterable -from typing import Callable, Optional +from collections.abc import Callable, Iterable from haiku._src import base from haiku._src import basic @@ -43,12 +42,12 @@ class MLP(hk.Module): def __init__( self, output_sizes: Iterable[int], - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, + w_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, with_bias: bool = True, activation: Callable[[jax.Array], jax.Array] = jax.nn.relu, activate_final: bool = False, - name: Optional[str] = None, + name: str | None = None, ): """Constructs an MLP. @@ -89,7 +88,7 @@ def __init__( def __call__( self, inputs: jax.Array, - dropout_rate: Optional[float] = None, + dropout_rate: float | None = None, rng=None, ) -> jax.Array: """Connects the module to some inputs. @@ -123,8 +122,8 @@ def __call__( def reverse( self, - activate_final: Optional[bool] = None, - name: Optional[str] = None, + activate_final: bool | None = None, + name: str | None = None, ) -> "MLP": """Returns a new MLP which is the layer-wise reverse of this MLP. diff --git a/haiku/_src/nets/mobilenetv1.py b/haiku/_src/nets/mobilenetv1.py index 5f49e6830..6d46cf867 100644 --- a/haiku/_src/nets/mobilenetv1.py +++ b/haiku/_src/nets/mobilenetv1.py @@ -25,7 +25,6 @@ """ from collections.abc import Sequence -from typing import Optional from haiku._src import basic from haiku._src import batch_norm @@ -58,7 +57,7 @@ def __init__( channels: int, stride: int, use_bn: bool = True, - name: Optional[str] = None, + name: str | None = None, ): super().__init__(name=name) self.channels = channels @@ -109,7 +108,7 @@ def __init__( 512, 512, 512, 512, 1024, 1024), num_classes: int = 1000, use_bn: bool = True, - name: Optional[str] = None, + name: str | None = None, ): """Constructs a MobileNetV1 model. diff --git a/haiku/_src/nets/resnet.py b/haiku/_src/nets/resnet.py index 42bce54c6..7932d9f8d 100644 --- a/haiku/_src/nets/resnet.py +++ b/haiku/_src/nets/resnet.py @@ -15,7 +15,7 @@ """Resnet.""" from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Union from haiku._src import basic from haiku._src import batch_norm @@ -46,11 +46,11 @@ class BlockV1(hk.Module): def __init__( self, channels: int, - stride: Union[int, Sequence[int]], + stride: int | Sequence[int], use_projection: bool, bn_config: Mapping[str, FloatStrOrBool], bottleneck: bool, - name: Optional[str] = None, + name: str | None = None, ): super().__init__(name=name) self.use_projection = use_projection @@ -128,11 +128,11 @@ class BlockV2(hk.Module): def __init__( self, channels: int, - stride: Union[int, Sequence[int]], + stride: int | Sequence[int], use_projection: bool, bn_config: Mapping[str, FloatStrOrBool], bottleneck: bool, - name: Optional[str] = None, + name: str | None = None, ): super().__init__(name=name) self.use_projection = use_projection @@ -208,12 +208,12 @@ def __init__( self, channels: int, num_blocks: int, - stride: Union[int, Sequence[int]], + stride: int | Sequence[int], bn_config: Mapping[str, FloatStrOrBool], resnet_v2: bool, bottleneck: bool, use_projection: bool, - name: Optional[str] = None, + name: str | None = None, ): super().__init__(name=name) @@ -291,14 +291,14 @@ def __init__( self, blocks_per_group: Sequence[int], num_classes: int, - bn_config: Optional[Mapping[str, FloatStrOrBool]] = None, + bn_config: Mapping[str, FloatStrOrBool] | None = None, resnet_v2: bool = False, bottleneck: bool = True, channels_per_group: Sequence[int] = (256, 512, 1024, 2048), use_projection: Sequence[bool] = (True, True, True, True), - logits_config: Optional[Mapping[str, Any]] = None, - name: Optional[str] = None, - initial_conv_config: Optional[Mapping[str, FloatStrOrBool]] = None, + logits_config: Mapping[str, Any] | None = None, + name: str | None = None, + initial_conv_config: Mapping[str, FloatStrOrBool] | None = None, strides: Sequence[int] = (1, 2, 2, 2), ): """Constructs a ResNet model. @@ -402,11 +402,11 @@ class ResNet18(ResNet): def __init__( self, num_classes: int, - bn_config: Optional[Mapping[str, FloatStrOrBool]] = None, + bn_config: Mapping[str, FloatStrOrBool] | None = None, resnet_v2: bool = False, - logits_config: Optional[Mapping[str, Any]] = None, - name: Optional[str] = None, - initial_conv_config: Optional[Mapping[str, FloatStrOrBool]] = None, + logits_config: Mapping[str, Any] | None = None, + name: str | None = None, + initial_conv_config: Mapping[str, FloatStrOrBool] | None = None, strides: Sequence[int] = (1, 2, 2, 2), ): """Constructs a ResNet model. @@ -440,11 +440,11 @@ class ResNet34(ResNet): def __init__( self, num_classes: int, - bn_config: Optional[Mapping[str, FloatStrOrBool]] = None, + bn_config: Mapping[str, FloatStrOrBool] | None = None, resnet_v2: bool = False, - logits_config: Optional[Mapping[str, Any]] = None, - name: Optional[str] = None, - initial_conv_config: Optional[Mapping[str, FloatStrOrBool]] = None, + logits_config: Mapping[str, Any] | None = None, + name: str | None = None, + initial_conv_config: Mapping[str, FloatStrOrBool] | None = None, strides: Sequence[int] = (1, 2, 2, 2), ): """Constructs a ResNet model. @@ -478,11 +478,11 @@ class ResNet50(ResNet): def __init__( self, num_classes: int, - bn_config: Optional[Mapping[str, FloatStrOrBool]] = None, + bn_config: Mapping[str, FloatStrOrBool] | None = None, resnet_v2: bool = False, - logits_config: Optional[Mapping[str, Any]] = None, - name: Optional[str] = None, - initial_conv_config: Optional[Mapping[str, FloatStrOrBool]] = None, + logits_config: Mapping[str, Any] | None = None, + name: str | None = None, + initial_conv_config: Mapping[str, FloatStrOrBool] | None = None, strides: Sequence[int] = (1, 2, 2, 2), ): """Constructs a ResNet model. @@ -516,11 +516,11 @@ class ResNet101(ResNet): def __init__( self, num_classes: int, - bn_config: Optional[Mapping[str, FloatStrOrBool]] = None, + bn_config: Mapping[str, FloatStrOrBool] | None = None, resnet_v2: bool = False, - logits_config: Optional[Mapping[str, Any]] = None, - name: Optional[str] = None, - initial_conv_config: Optional[Mapping[str, FloatStrOrBool]] = None, + logits_config: Mapping[str, Any] | None = None, + name: str | None = None, + initial_conv_config: Mapping[str, FloatStrOrBool] | None = None, strides: Sequence[int] = (1, 2, 2, 2), ): """Constructs a ResNet model. @@ -554,11 +554,11 @@ class ResNet152(ResNet): def __init__( self, num_classes: int, - bn_config: Optional[Mapping[str, FloatStrOrBool]] = None, + bn_config: Mapping[str, FloatStrOrBool] | None = None, resnet_v2: bool = False, - logits_config: Optional[Mapping[str, Any]] = None, - name: Optional[str] = None, - initial_conv_config: Optional[Mapping[str, FloatStrOrBool]] = None, + logits_config: Mapping[str, Any] | None = None, + name: str | None = None, + initial_conv_config: Mapping[str, FloatStrOrBool] | None = None, strides: Sequence[int] = (1, 2, 2, 2), ): """Constructs a ResNet model. @@ -592,11 +592,11 @@ class ResNet200(ResNet): def __init__( self, num_classes: int, - bn_config: Optional[Mapping[str, FloatStrOrBool]] = None, + bn_config: Mapping[str, FloatStrOrBool] | None = None, resnet_v2: bool = False, - logits_config: Optional[Mapping[str, Any]] = None, - name: Optional[str] = None, - initial_conv_config: Optional[Mapping[str, FloatStrOrBool]] = None, + logits_config: Mapping[str, Any] | None = None, + name: str | None = None, + initial_conv_config: Mapping[str, FloatStrOrBool] | None = None, strides: Sequence[int] = (1, 2, 2, 2), ): """Constructs a ResNet model. diff --git a/haiku/_src/nets/vqvae.py b/haiku/_src/nets/vqvae.py index 63e1b8209..540cc07e6 100644 --- a/haiku/_src/nets/vqvae.py +++ b/haiku/_src/nets/vqvae.py @@ -14,13 +14,12 @@ # ============================================================================== """Haiku implementation of VQ-VAE https://arxiv.org/abs/1711.00937.""" -from typing import Any, Optional +from typing import Any from haiku._src import base from haiku._src import initializers from haiku._src import module from haiku._src import moving_averages - import jax import jax.numpy as jnp @@ -69,8 +68,8 @@ def __init__( num_embeddings: int, commitment_cost: float, dtype: Any = jnp.float32, - name: Optional[str] = None, - cross_replica_axis: Optional[str] = None, + name: str | None = None, + cross_replica_axis: str | None = None, ): """Initializes a VQ-VAE module. @@ -216,8 +215,8 @@ def __init__( decay, epsilon: float = 1e-5, dtype: Any = jnp.float32, - cross_replica_axis: Optional[str] = None, - name: Optional[str] = None, + cross_replica_axis: str | None = None, + name: str | None = None, ): """Initializes a VQ-VAE EMA module. diff --git a/haiku/_src/pad.py b/haiku/_src/pad.py index 2379e78d4..d4e5946b9 100644 --- a/haiku/_src/pad.py +++ b/haiku/_src/pad.py @@ -15,9 +15,9 @@ """Padding module for Haiku.""" from collections import abc -from collections.abc import Sequence +from collections.abc import Callable, Sequence import typing -from typing import Any, Callable, Union +from typing import Any from haiku._src import utils @@ -59,9 +59,9 @@ def reverse_causal(effective_kernel_size: int) -> tuple[int, int]: def create_from_padfn( - padding: Union[hk.pad.PadFn, Sequence[hk.pad.PadFn]], - kernel: Union[int, Sequence[int]], - rate: Union[int, Sequence[int]], + padding: hk.pad.PadFn | Sequence[hk.pad.PadFn], # pylint: disable=g-bare-generic + kernel: int | Sequence[int], + rate: int | Sequence[int], n: int, ) -> Sequence[tuple[int, int]]: """Generates the padding required for a given padding algorithm. @@ -97,7 +97,7 @@ def create_from_padfn( def create_from_tuple( - padding: Union[tuple[int, int], Sequence[tuple[int, int]]], + padding: tuple[int, int] | Sequence[tuple[int, int]], n: int, ) -> Sequence[tuple[int, int]]: """Create a padding tuple using partially specified padding tuple.""" @@ -114,7 +114,7 @@ def create_from_tuple( return padding -def is_padfn(padding: Union[hk.pad.PadFn, Sequence[hk.pad.PadFn], Any]) -> bool: +def is_padfn(padding: hk.pad.PadFn | Sequence[hk.pad.PadFn] | Any) -> bool: # pylint: disable=g-bare-generic """Tests whether the given argument is a single or sequence of PadFns.""" if isinstance(padding, abc.Sequence): padding = padding[0] diff --git a/haiku/_src/pool.py b/haiku/_src/pool.py index bae641761..588edc6eb 100644 --- a/haiku/_src/pool.py +++ b/haiku/_src/pool.py @@ -15,7 +15,6 @@ """Pooling Haiku modules.""" from collections.abc import Sequence -from typing import Optional, Union import warnings from haiku._src import module @@ -35,8 +34,8 @@ class hk: def _infer_shape( x: jax.Array, - size: Union[int, Sequence[int]], - channel_axis: Optional[int] = -1, + size: int | Sequence[int], + channel_axis: int | None = -1, ) -> tuple[int, ...]: """Infer shape for pooling window or strides.""" if isinstance(size, int): @@ -75,10 +74,10 @@ def _warn_if_unsafe(window_shape, strides): def max_pool( value: jax.Array, - window_shape: Union[int, Sequence[int]], - strides: Union[int, Sequence[int]], + window_shape: int | Sequence[int], + strides: int | Sequence[int], padding: str, - channel_axis: Optional[int] = -1, + channel_axis: int | None = -1, ) -> jax.Array: """Max pool. @@ -105,10 +104,10 @@ def max_pool( def avg_pool( value: jax.Array, - window_shape: Union[int, Sequence[int]], - strides: Union[int, Sequence[int]], + window_shape: int | Sequence[int], + strides: int | Sequence[int], padding: str, - channel_axis: Optional[int] = -1, + channel_axis: int | None = -1, ) -> jax.Array: """Average pool. @@ -155,11 +154,11 @@ class MaxPool(hk.Module): def __init__( self, - window_shape: Union[int, Sequence[int]], - strides: Union[int, Sequence[int]], + window_shape: int | Sequence[int], + strides: int | Sequence[int], padding: str, - channel_axis: Optional[int] = -1, - name: Optional[str] = None, + channel_axis: int | None = -1, + name: str | None = None, ): """Max pool. @@ -189,11 +188,11 @@ class AvgPool(hk.Module): def __init__( self, - window_shape: Union[int, Sequence[int]], - strides: Union[int, Sequence[int]], + window_shape: int | Sequence[int], + strides: int | Sequence[int], padding: str, - channel_axis: Optional[int] = -1, - name: Optional[str] = None, + channel_axis: int | None = -1, + name: str | None = None, ): """Average pool. diff --git a/haiku/_src/recurrent.py b/haiku/_src/recurrent.py index 27f17b889..4eae00f19 100644 --- a/haiku/_src/recurrent.py +++ b/haiku/_src/recurrent.py @@ -16,7 +16,7 @@ import abc from collections.abc import Sequence -from typing import Any, NamedTuple, Optional, Union +from typing import Any, NamedTuple from haiku._src import base from haiku._src import basic @@ -71,7 +71,7 @@ def __call__(self, inputs, prev_state) -> tuple[Any, Any]: """ @abc.abstractmethod - def initial_state(self, batch_size: Optional[int]): + def initial_state(self, batch_size: int | None): """Constructs an initial state for this core. Args: @@ -218,7 +218,7 @@ def scan_f(prev_state, inputs): return output_sequence, last_state -def add_batch(nest, batch_size: Optional[int]): +def add_batch(nest, batch_size: int | None): """Adds a batch dimension at axis 0 to the leaves of a nested structure.""" broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape) return jax.tree.map(broadcast, nest) @@ -241,7 +241,7 @@ def __init__( self, hidden_size: int, double_bias: bool = True, - name: Optional[str] = None + name: str | None = None ): """Constructs a vanilla RNN core. @@ -263,7 +263,7 @@ def __call__(self, inputs, prev_state): out = jax.nn.relu(input_to_hidden(inputs) + hidden_to_hidden(prev_state)) return out, out - def initial_state(self, batch_size: Optional[int]): + def initial_state(self, batch_size: int | None): state = jnp.zeros([self.hidden_size]) if batch_size is not None: state = add_batch(state, batch_size) @@ -311,7 +311,7 @@ class LSTM(RNNCore): the beginning of the training. """ - def __init__(self, hidden_size: int, name: Optional[str] = None): + def __init__(self, hidden_size: int, name: str | None = None): """Constructs an LSTM. Args: @@ -338,7 +338,7 @@ def __call__( h = jax.nn.sigmoid(o) * jnp.tanh(c) return h, LSTMState(h, c) - def initial_state(self, batch_size: Optional[int]) -> LSTMState: + def initial_state(self, batch_size: int | None) -> LSTMState: state = LSTMState(hidden=jnp.zeros([self.hidden_size]), cell=jnp.zeros([self.hidden_size])) if batch_size is not None: @@ -382,8 +382,8 @@ def __init__( num_spatial_dims: int, input_shape: Sequence[int], output_channels: int, - kernel_shape: Union[int, Sequence[int]], - name: Optional[str] = None, + kernel_shape: int | Sequence[int], + name: str | None = None, ): """Constructs a convolutional LSTM. @@ -427,7 +427,7 @@ def __call__( h = jax.nn.sigmoid(o) * jnp.tanh(c) return h, LSTMState(h, c) - def initial_state(self, batch_size: Optional[int]) -> LSTMState: + def initial_state(self, batch_size: int | None) -> LSTMState: shape = self.input_shape + (self.output_channels,) state = LSTMState(jnp.zeros(shape), jnp.zeros(shape)) if batch_size is not None: @@ -442,8 +442,8 @@ def __init__( self, input_shape: Sequence[int], output_channels: int, - kernel_shape: Union[int, Sequence[int]], - name: Optional[str] = None, + kernel_shape: int | Sequence[int], + name: str | None = None, ): """Constructs a 1-D convolutional LSTM. @@ -470,8 +470,8 @@ def __init__( self, input_shape: Sequence[int], output_channels: int, - kernel_shape: Union[int, Sequence[int]], - name: Optional[str] = None, + kernel_shape: int | Sequence[int], + name: str | None = None, ): """Constructs a 2-D convolutional LSTM. @@ -498,8 +498,8 @@ def __init__( self, input_shape: Sequence[int], output_channels: int, - kernel_shape: Union[int, Sequence[int]], - name: Optional[str] = None, + kernel_shape: int | Sequence[int], + name: str | None = None, ): """Constructs a 3-D convolutional LSTM. @@ -544,10 +544,10 @@ class GRU(RNNCore): def __init__( self, hidden_size: int, - w_i_init: Optional[hk.initializers.Initializer] = None, - w_h_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, - name: Optional[str] = None, + w_i_init: hk.initializers.Initializer | None = None, + w_h_init: hk.initializers.Initializer | None = None, + b_init: hk.initializers.Initializer | None = None, + name: str | None = None, ): super().__init__(name=name) self.hidden_size = hidden_size @@ -582,7 +582,7 @@ def __call__(self, inputs, state): next_state = (1 - z) * state + z * a return next_state, next_state - def initial_state(self, batch_size: Optional[int]): + def initial_state(self, batch_size: int | None): state = jnp.zeros([self.hidden_size]) if batch_size is not None: state = add_batch(state, batch_size) @@ -599,7 +599,7 @@ class IdentityCore(RNNCore): def __call__(self, inputs, state): return inputs, state - def initial_state(self, batch_size: Optional[int]): + def initial_state(self, batch_size: int | None): return () @@ -626,7 +626,7 @@ class ResetCore(RNNCore): ``should_reset`` nest compatible with the state structure. """ - def __init__(self, core: RNNCore, name: Optional[str] = None): + def __init__(self, core: RNNCore, name: str | None = None): super().__init__(name=name) self.core = core @@ -709,7 +709,7 @@ def __call__(self, inputs, state): state = jax.tree.map(jnp.where, should_reset, initial_state, state) return self.core(inputs, state) - def initial_state(self, batch_size: Optional[int]): + def initial_state(self, batch_size: int | None): return self.core.initial_state(batch_size) def _is_batched(self, state): @@ -727,7 +727,7 @@ def __init__( self, layers: Sequence[Any], skip_connections: bool, - name: Optional[str] = None + name: str | None = None ): super().__init__(name=name) self.layers = layers @@ -770,7 +770,7 @@ def __call__(self, inputs, state): return out, tuple(next_states) - def initial_state(self, batch_size: Optional[int]): + def initial_state(self, batch_size: int | None): return tuple( layer.initial_state(batch_size) for layer in self.layers @@ -791,12 +791,12 @@ class DeepRNN(_DeepRNN): tuple. """ - def __init__(self, layers: Sequence[Any], name: Optional[str] = None): + def __init__(self, layers: Sequence[Any], name: str | None = None): super().__init__(layers, skip_connections=False, name=name) def deep_rnn_with_skip_connections(layers: Sequence[RNNCore], - name: Optional[str] = None) -> RNNCore: + name: str | None = None) -> RNNCore: r"""Constructs a :class:`DeepRNN` with skip connections. Skip connections alter the dependency structure within a :class:`DeepRNN`. diff --git a/haiku/_src/reshape.py b/haiku/_src/reshape.py index 72adb0666..733a480e0 100644 --- a/haiku/_src/reshape.py +++ b/haiku/_src/reshape.py @@ -15,7 +15,6 @@ """Reshaping Haiku modules.""" from collections.abc import Sequence -from typing import Optional from haiku._src import module import jax.numpy as jnp @@ -95,7 +94,7 @@ def __init__( self, output_shape: Sequence[int], preserve_dims: int = 1, - name: Optional[str] = None, + name: str | None = None, ): """Constructs a :class:`Reshape` module. @@ -171,7 +170,7 @@ class Flatten(Reshape): def __init__( self, preserve_dims: int = 1, - name: Optional[str] = None, + name: str | None = None, ): super().__init__( output_shape=(-1,), diff --git a/haiku/_src/rms_norm.py b/haiku/_src/rms_norm.py index 0afecde10..72cdeb5f1 100644 --- a/haiku/_src/rms_norm.py +++ b/haiku/_src/rms_norm.py @@ -19,7 +19,7 @@ from collections import abc from collections.abc import Sequence -from typing import Optional, Union +from typing import Union from haiku._src import base from haiku._src import initializers @@ -55,11 +55,11 @@ def __init__( self, axis: AxisOrAxes, eps: float = 1e-5, - scale_init: Optional[hk.initializers.Initializer] = None, - name: Optional[str] = None, + scale_init: hk.initializers.Initializer | None = None, + name: str | None = None, create_scale: bool = True, *, - param_axis: Optional[AxisOrAxes] = None, + param_axis: AxisOrAxes | None = None, ): """Constructs a RMSNorm module. diff --git a/haiku/_src/spectral_norm.py b/haiku/_src/spectral_norm.py index adab35543..649c094f3 100644 --- a/haiku/_src/spectral_norm.py +++ b/haiku/_src/spectral_norm.py @@ -21,7 +21,6 @@ """ import re -from typing import Optional from haiku._src import base from haiku._src import data_structures @@ -74,7 +73,7 @@ def __init__( self, eps: float = 1e-4, n_steps: int = 1, - name: Optional[str] = None, + name: str | None = None, ): """Initializes an SpectralNorm module. @@ -170,7 +169,7 @@ def __init__( eps: float = 1e-4, n_steps: int = 1, ignore_regex: str = "", - name: Optional[str] = None, + name: str | None = None, ): """Initializes an SNParamsTree module. diff --git a/haiku/_src/stateful.py b/haiku/_src/stateful.py index f85c7ab20..1d8aa253f 100644 --- a/haiku/_src/stateful.py +++ b/haiku/_src/stateful.py @@ -16,10 +16,10 @@ import collections import collections.abc -from collections.abc import Mapping, MutableMapping +from collections.abc import Callable, Mapping, MutableMapping import functools import inspect -from typing import Any, Callable, Optional, TypeVar +from typing import Any, TypeVar from haiku._src import base import jax @@ -328,8 +328,8 @@ def is_new_state(a: base.StatePair, b: base.StatePair): ) # rng - def is_new_rng(a: Optional[base.PRNGSequenceState], - b: Optional[base.PRNGSequenceState]): + def is_new_rng(a: base.PRNGSequenceState | None, + b: base.PRNGSequenceState | None): if a is None: return True assert len(a) == 2 and len(b) == 2 @@ -698,7 +698,7 @@ def pure_body_fun(i, val): return val -def maybe_get_axis(axis: Optional[int], arrays: Any) -> Optional[int]: +def maybe_get_axis(axis: int | None, arrays: Any) -> int | None: """Returns `array.shape[axis]` for one of the arrays in the input.""" if axis is None: return None shapes = [a.shape for a in jax.tree.leaves(arrays)] @@ -758,8 +758,8 @@ def vmap( fun: Callable[..., Any], in_axes=0, out_axes=0, - axis_name: Optional[str] = None, - axis_size: Optional[int] = None, + axis_name: str | None = None, + axis_size: int | None = None, *, split_rng: bool, ) -> Callable[..., Any]: diff --git a/haiku/_src/summarise.py b/haiku/_src/summarise.py index 266985a47..ee0236dcf 100644 --- a/haiku/_src/summarise.py +++ b/haiku/_src/summarise.py @@ -14,11 +14,11 @@ # ============================================================================== """Summarises Haiku modules.""" -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence import dataclasses import functools import pprint -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeVar from haiku._src import base from haiku._src import data_structures @@ -165,7 +165,7 @@ def make_hk_transform_ignore_jax_transforms(f): def eval_summary( - f: Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState], + f: Callable[..., Any] | hk.Transformed | hk.TransformedWithState, ) -> Callable[..., Sequence[MethodInvocation]]: """Records module method calls performed by ``f``. @@ -314,10 +314,10 @@ def format_entry(state: ModuleDetails) -> str: def tabulate( - f: Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState], + f: Callable[..., Any] | hk.Transformed | hk.TransformedWithState, *, - columns: Optional[Sequence[str]] = DEFAULT_COLUMNS, - filters: Optional[Sequence[str]] = DEFAULT_FILTERS, + columns: Sequence[str] | None = DEFAULT_COLUMNS, + filters: Sequence[str] | None = DEFAULT_FILTERS, tabulate_kwargs={"tablefmt": "grid"}, ) -> Callable[..., str]: # pylint: disable=line-too-long diff --git a/haiku/_src/test_utils.py b/haiku/_src/test_utils.py index 57e23b1ea..60389f006 100644 --- a/haiku/_src/test_utils.py +++ b/haiku/_src/test_utils.py @@ -14,13 +14,13 @@ # ============================================================================== """Testing utilities for Haiku.""" -from collections.abc import Generator, Sequence +from collections.abc import Callable, Generator, Sequence import functools import inspect import itertools import os import types -from typing import Any, Callable, Optional, TypeVar +from typing import Any, TypeVar from absl.testing import parameterized from haiku._src import config @@ -33,12 +33,12 @@ def transform_and_run( - f: Optional[Fn] = None, - seed: Optional[int] = 42, + f: Fn | None = None, + seed: int | None = 42, run_apply: bool = True, - jax_transform: Optional[Callable[[Fn], Fn]] = None, + jax_transform: Callable[[Fn], Fn] | None = None, *, - map_rng: Optional[Callable[[Key], Key]] = None, + map_rng: Callable[[Key], Key] | None = None, ) -> T: r"""Transforms the given function and runs init then (optionally) apply. @@ -221,7 +221,7 @@ def named_range(name, stop: int) -> Sequence[tuple[str, int]]: return tuple((f"{name}_{i}", i) for i in range(stop)) -def with_environ(key: str, value: Optional[str]): +def with_environ(key: str, value: str | None): """Runs the given test with envrionment variables set.""" def set_env(new_value): if new_value is None: diff --git a/haiku/_src/transform.py b/haiku/_src/transform.py index 3c48da325..5feaa885e 100644 --- a/haiku/_src/transform.py +++ b/haiku/_src/transform.py @@ -14,9 +14,9 @@ # ============================================================================== """Base Haiku module.""" -from collections.abc import Mapping +from collections.abc import Callable, Mapping import inspect -from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union +from typing import Any, NamedTuple, Optional, TypeVar, Union from haiku._src import analytics from haiku._src import base @@ -120,7 +120,7 @@ class TransformedWithState(NamedTuple): apply: Callable[..., tuple[Any, hk.MutableState]] -def to_prng_sequence(rng, err_msg) -> Optional[hk.PRNGSequence]: +def to_prng_sequence(rng, err_msg) -> hk.PRNGSequence | None: if rng is not None: try: rng = hk.PRNGSequence(rng) @@ -236,7 +236,7 @@ def init_fn(*args, **kwargs) -> tuple[hk.MutableParams, hk.MutableState]: init_fn.__signature__ = sig_add_state(inspect.signature(f.init)) def apply_fn( - params: hk.Params, state: Optional[hk.State], *args, **kwargs + params: hk.Params, state: hk.State | None, *args, **kwargs ) -> tuple[Any, hk.MutableState]: del state out = f.apply(params, *args, **kwargs) @@ -411,7 +411,7 @@ def transform_with_state(f) -> TransformedWithState: f_sig = inspect.signature(f) def init_fn( - rng: Optional[Union[PRNGKey, int]], + rng: PRNGKey | int | None, *args, **kwargs, ) -> tuple[hk.MutableParams, hk.MutableState]: @@ -438,9 +438,9 @@ def init_fn( ) def apply_fn( - params: Optional[hk.Params], - state: Optional[hk.State], - rng: Optional[Union[PRNGKey, int]], + params: hk.Params | None, + state: hk.State | None, + rng: PRNGKey | int | None, *args, **kwargs, ) -> tuple[Any, hk.MutableState]: @@ -484,14 +484,13 @@ def tie_in_original_fn(f, init_fn, apply_fn): apply_fn._original_fn = f # pylint: disable=protected-access -def get_original_fn(f: Union[Transformed, TransformedWithState, Callable[..., - Any]]): +def get_original_fn(f: Transformed | TransformedWithState | Callable[..., Any]): if isinstance(f, (Transformed, TransformedWithState)): f = f.init return getattr(f, "_original_fn") -def check_mapping(name: str, mapping: Optional[T]) -> T: +def check_mapping(name: str, mapping: T | None) -> T: """Cleans inputs to apply_fn, providing better errors.""" if mapping is None: # Convert None to empty dict. diff --git a/haiku/_src/transform_test.py b/haiku/_src/transform_test.py index 6e1dff92f..ff536485f 100644 --- a/haiku/_src/transform_test.py +++ b/haiku/_src/transform_test.py @@ -16,7 +16,6 @@ from collections.abc import Mapping import inspect -from typing import Optional, Union from absl.testing import absltest from absl.testing import parameterized @@ -592,15 +591,15 @@ def f(pos, key=37) -> int: return 2 def expected_f_init( - rng: Optional[Union[PRNGKey, int]], pos, key=37 + rng: PRNGKey | int | None, pos, key=37 ) -> tuple[Params, State]: del rng, pos, key raise NotImplementedError def expected_f_apply( - params: Optional[Params], - state: Optional[State], - rng: Optional[Union[PRNGKey, int]], + params: Params | None, + state: State | None, + rng: PRNGKey | int | None, pos, key=37, ) -> tuple[int, State]: @@ -617,12 +616,12 @@ def test_signature_transform(self): def f(pos, *, key: int = 37) -> int: del pos, key return 2 - def expected_f_init(rng: Optional[Union[PRNGKey, int]], + def expected_f_init(rng: PRNGKey | int | None, pos, *, key: int = 37) -> Params: del rng, pos, key raise NotImplementedError def expected_f_apply( - params: Optional[Params], rng: Optional[Union[PRNGKey, int]], + params: Params | None, rng: PRNGKey | int | None, pos, *, key: int = 37) -> int: del params, rng, pos, key raise NotImplementedError diff --git a/haiku/_src/typing.py b/haiku/_src/typing.py index 4532baf8b..5efc8fa9c 100644 --- a/haiku/_src/typing.py +++ b/haiku/_src/typing.py @@ -15,9 +15,9 @@ """Haiku types.""" import abc -from collections.abc import Mapping, MutableMapping, Sequence +from collections.abc import Callable, Mapping, MutableMapping, Sequence import typing -from typing import Any, Callable, Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable import jax diff --git a/haiku/_src/utils.py b/haiku/_src/utils.py index 6fbf9ad92..2bd86ae5a 100644 --- a/haiku/_src/utils.py +++ b/haiku/_src/utils.py @@ -20,7 +20,7 @@ import inspect import pprint import re -from typing import Any, TypeVar, Union +from typing import Any, TypeVar import jax @@ -114,7 +114,7 @@ def indent(amount: int, s: str) -> str: def replicate( - element: Union[T, Sequence[T]], + element: T | Sequence[T], num_times: int, name: str, ) -> tuple[T, ...]: diff --git a/setup.py b/setup.py index f98cb65e0..5c347024a 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ def _parse_requirements(requirements_txt_path): 'flax': _parse_requirements('requirements-flax.txt'), }, tests_require=_parse_requirements('requirements-test.txt'), - requires_python='>=3.9', + requires_python='>=3.10', include_package_data=True, zip_safe=False, # PyPI package information. @@ -64,7 +64,6 @@ def _parse_requirements(requirements_txt_path): 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Topic :: Scientific/Engineering :: Mathematics',