Skip to content

Commit

Permalink
Wrap Minibatch Operation in OpFromGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 7, 2024
1 parent e19cd39 commit 623ca42
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 180 deletions.
83 changes: 48 additions & 35 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@
import pytensor.tensor as pt
import xarray as xr

from pytensor.compile.builders import OpFromGraph
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Variable
from pytensor.raise_op import Assert
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.basic import IntegersRV
from pytensor.tensor.subtensor import AdvancedSubtensor
from pytensor.tensor.type import TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable

import pymc as pm

from pymc.pytensorf import convert_data, smarttypeX
from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX
from pymc.vartypes import isgenerator

__all__ = [
Expand Down Expand Up @@ -129,46 +130,47 @@ def __hash__(self):
class MinibatchIndexRV(IntegersRV):
_print_name = ("minibatch_index", r"\operatorname{minibatch\_index}")

# Work-around for https://github.com/pymc-devs/pytensor/issues/97
def make_node(self, rng, *args, **kwargs):
if rng is None:
rng = pytensor.shared(np.random.default_rng())
return super().make_node(rng, *args, **kwargs)


minibatch_index = MinibatchIndexRV()


def is_minibatch(v: TensorVariable) -> bool:
return (
isinstance(v.owner.op, AdvancedSubtensor)
and isinstance(v.owner.inputs[1].owner.op, MinibatchIndexRV)
and valid_for_minibatch(v.owner.inputs[0])
)
class MinibatchOp(OpFromGraph):
"""Encapsulate Minibatch random draws in an opaque OFG"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, inline=True)

def __str__(self):
return "Minibatch"


def valid_for_minibatch(v: TensorVariable) -> bool:
def is_valid_observed(v) -> bool:
if not isinstance(v, Variable):
# Non-symbolic constant
return True

if v.owner is None:
# Symbolic root variable (constant or not)
return True

return (
v.owner is None
# The only PyTensor operation we allow on observed data is type casting
# Although we could allow for any graph that does not depend on other RVs
or (
(
isinstance(v.owner.op, Elemwise)
and v.owner.inputs[0].owner is None
and isinstance(v.owner.op.scalar_op, Cast)
and is_valid_observed(v.owner.inputs[0])
)
# Or Minibatch
or (
isinstance(v.owner.op, MinibatchOp)
and all(is_valid_observed(inp) for inp in v.owner.inputs)
)
# Or Generator
or isinstance(v.owner.op, GeneratorOp)
)


def assert_all_scalars_equal(scalar, *scalars):
if len(scalars) == 0:
return scalar
else:
return Assert(
"All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code"
)(scalar, pt.all([pt.eq(scalar, s) for s in scalars]))


def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: int):
"""Get random slices from variables from the leading dimension.
Expand All @@ -188,18 +190,29 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")

tensor, *tensors = tuple(map(pt.as_tensor, (variable, *variables)))
upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)])
slc = minibatch_index(0, upper, size=batch_size)
for i, v in enumerate((tensor, *tensors)):
if not valid_for_minibatch(v):
tensors = tuple(map(pt.as_tensor, (variable, *variables)))
for i, v in enumerate(tensors):
if not is_valid_observed(v):
raise ValueError(
f"{i}: {v} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
)
result = tuple([v[slc] for v in (tensor, *tensors)])
for i, r in enumerate(result):

upper = tensors[0].shape[0]
if len(tensors) > 1:
upper = Assert(
"All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code"
)(upper, pt.all([pt.eq(upper, other_tensor.shape[0]) for other_tensor in tensors[1:]]))

rng = pytensor.shared(np.random.default_rng())
rng_update, mb_indices = minibatch_index(0, upper, size=batch_size, rng=rng).owner.outputs
mb_tensors = [tensor[mb_indices] for tensor in tensors]

# Wrap graph in OFG so it's easily identifiable and not rewritten accidentally
*mb_tensors, _ = MinibatchOp([*tensors, rng], [*mb_tensors, rng_update])(*tensors, rng)
for i, r in enumerate(mb_tensors[:-1]):
r.name = f"minibatch.{i}"
return result if tensors else result[0]

return mb_tensors if len(variables) else mb_tensors[0]


def determine_coords(
Expand Down
6 changes: 2 additions & 4 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
from pytensor import config
from pytensor.compile.mode import optdb
from pytensor.graph.basic import (
Constant,
Variable,
ancestors,
io_toposort,
truncated_graph_inputs,
)
Expand Down Expand Up @@ -400,8 +398,8 @@ def construct_ir_fgraph(
# the old nodes to the new ones; otherwise, we won't be able to use
# `rv_values`.
# We start the `dict` with mappings from the value variables to themselves,
# to prevent them from being cloned. This also includes ancestors
memo = {v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)}
# to prevent them from being cloned.
memo = {v: v for v in rv_values.values()}

# We add `ShapeFeature` because it will get rid of references to the old
# `RandomVariable`s that have been lifted; otherwise, it will be difficult
Expand Down
17 changes: 2 additions & 15 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,13 @@
from pytensor.compile import DeepCopyOp, Function, get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.variable import TensorConstant, TensorVariable
from typing_extensions import Self

from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.data import GenTensorVariable, is_minibatch
from pymc.data import is_valid_observed
from pymc.exceptions import (
BlockModelAccessError,
ImputationWarning,
Expand Down Expand Up @@ -1294,18 +1292,7 @@ def register_rv(
self.add_named_variable(rv_var, dims)
self.set_initval(rv_var, initval)
else:
if (
isinstance(observed, Variable)
and not isinstance(observed, GenTensorVariable)
and observed.owner is not None
# The only PyTensor operation we allow on observed data is type casting
# Although we could allow for any graph that does not depend on other RVs
and not (
isinstance(observed.owner.op, Elemwise)
and isinstance(observed.owner.op.scalar_op, Cast)
)
and not is_minibatch(observed)
):
if not is_valid_observed(observed):
raise TypeError(
"Variables that depend on other nodes cannot be used for observed data."
f"The data variable was: {observed}"
Expand Down
24 changes: 15 additions & 9 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,25 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
TypeError
"""
# TODO: These data functions should be in data.py or model/core.py
from pymc.data import MinibatchOp

if isinstance(x, Constant):
return x.data
if isinstance(x, SharedVariable):
return x.get_value()
if x.owner and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast):
array_data = extract_obs_data(x.owner.inputs[0])
return array_data.astype(x.type.dtype)
if x.owner and isinstance(x.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1):
array_data = extract_obs_data(x.owner.inputs[0])
mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:])
mask = np.zeros_like(array_data)
mask[mask_idx] = 1
return np.ma.MaskedArray(array_data, mask)
if x.owner is not None:
if isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast):
array_data = extract_obs_data(x.owner.inputs[0])
return array_data.astype(x.type.dtype)
if isinstance(x.owner.op, MinibatchOp):
return extract_obs_data(x.owner.inputs[x.owner.outputs.index(x)])
if isinstance(x.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1):
array_data = extract_obs_data(x.owner.inputs[0])
mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:])
mask = np.zeros_like(array_data)
mask[mask_idx] = 1
return np.ma.MaskedArray(array_data, mask)

raise TypeError(f"Data cannot be extracted from {x}")

Expand Down
4 changes: 2 additions & 2 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ class GroupError(VariationalInferenceError, TypeError):

def _known_scan_ignored_inputs(terms):
# TODO: remove when scan issue with grads is fixed
from pymc.data import MinibatchIndexRV
from pymc.data import MinibatchOp
from pymc.distributions.simulator import SimulatorRV

return [
n.owner.inputs[0]
for n in pytensor.graph.ancestors(terms)
if n.owner is not None and isinstance(n.owner.op, MinibatchIndexRV | SimulatorRV)
if n.owner is not None and isinstance(n.owner.op, MinibatchOp | SimulatorRV)
]


Expand Down
35 changes: 12 additions & 23 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import io
import itertools as it
import re

from os import path

Expand All @@ -29,7 +28,7 @@

import pymc as pm

from pymc.data import is_minibatch
from pymc.data import MinibatchOp
from pymc.pytensorf import GeneratorOp, floatX


Expand Down Expand Up @@ -593,44 +592,34 @@ class TestMinibatch:

def test_1d(self):
mb = pm.Minibatch(self.data, batch_size=20)
assert is_minibatch(mb)
assert mb.eval().shape == (20, 10)
assert isinstance(mb.owner.op, MinibatchOp)
draw1, draw2 = pm.draw(mb, draws=2)
assert draw1.shape == (20, 10)
assert draw2.shape == (20, 10)
assert not np.all(draw1 == draw2)

def test_allowed(self):
mb = pm.Minibatch(pt.as_tensor(self.data).astype(int), batch_size=20)
assert is_minibatch(mb)
assert isinstance(mb.owner.op, MinibatchOp)

def test_not_allowed(self):
with pytest.raises(ValueError, match="not valid for Minibatch"):
mb = pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)

def test_not_allowed2(self):
with pytest.raises(ValueError, match="not valid for Minibatch"):
mb = pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20)
pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20)

def test_assert(self):
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
with pytest.raises(
AssertionError, match=r"All variables shape\[0\] in Minibatch should be equal"
):
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
d1.eval()

def test_multiple_vars(self):
A = np.arange(1000)
B = np.arange(1000)
B = -np.arange(1000)
mA, mB = pm.Minibatch(A, B, batch_size=10)

[draw_mA, draw_mB] = pm.draw([mA, mB])
assert draw_mA.shape == (10,)
np.testing.assert_allclose(draw_mA, draw_mB)

# Check invalid dims
A = np.arange(1000)
C = np.arange(999)
mA, mC = pm.Minibatch(A, C, batch_size=10)

with pytest.raises(
AssertionError,
match=re.escape("All variables shape[0] in Minibatch should be equal"),
):
pm.draw([mA, mC])
np.testing.assert_allclose(draw_mA, -draw_mB)
Loading

0 comments on commit 623ca42

Please sign in to comment.