Skip to content

Commit ee107cb

Browse files
Michal-Novomestskypre-commit-ci[bot]jessegrabowski
authored
Use vectorized jacobian in Minimize Op (#1582)
* added identity as alias for tensor_copy and defined No-Op for TensorFromScalar * refactor: jacobian should use tensorize * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * removed redundant pprint * refactor: added vectorize=True to all jacobians * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add option to vectorize jacobian in minimize/root * pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: jessegrabowski <[email protected]>
1 parent c932ffb commit ee107cb

File tree

2 files changed

+53
-14
lines changed

2 files changed

+53
-14
lines changed

pytensor/tensor/basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,11 @@ def c_code_cache_version(self):
664664
tensor_from_scalar = TensorFromScalar()
665665

666666

667+
@_vectorize_node.register(TensorFromScalar)
668+
def vectorize_tensor_from_scalar(op, node, batch_x):
669+
return identity(batch_x).owner
670+
671+
667672
class ScalarFromTensor(COp):
668673
__props__ = ()
669674

@@ -2046,6 +2051,7 @@ def register_transfer(fn):
20462051
"""Create a duplicate of `a` (with duplicated storage)"""
20472052
tensor_copy = Elemwise(ps.identity)
20482053
pprint.assign(tensor_copy, printing.IgnorePrinter())
2054+
identity = tensor_copy
20492055

20502056

20512057
class Default(Op):
@@ -4603,6 +4609,7 @@ def ix_(*args):
46034609
"matrix_transpose",
46044610
"default",
46054611
"tensor_copy",
4612+
"identity",
46064613
"transfer",
46074614
"alloc",
46084615
"identity_like",

pytensor/tensor/optimize.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pytensor.scalar as ps
99
from pytensor.compile.function import function
10-
from pytensor.gradient import grad, hessian, jacobian
10+
from pytensor.gradient import grad, jacobian
1111
from pytensor.graph.basic import Apply, Constant
1212
from pytensor.graph.fg import FunctionGraph
1313
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
@@ -484,6 +484,7 @@ def __init__(
484484
jac: bool = True,
485485
hess: bool = False,
486486
hessp: bool = False,
487+
use_vectorized_jac: bool = False,
487488
optimizer_kwargs: dict | None = None,
488489
):
489490
if not cast(TensorVariable, objective).ndim == 0:
@@ -496,6 +497,7 @@ def __init__(
496497
)
497498

498499
self.fgraph = FunctionGraph([x, *args], [objective])
500+
self.use_vectorized_jac = use_vectorized_jac
499501

500502
if jac:
501503
grad_wrt_x = cast(
@@ -505,7 +507,12 @@ def __init__(
505507

506508
if hess:
507509
hess_wrt_x = cast(
508-
Variable, hessian(self.fgraph.outputs[0], self.fgraph.inputs[0])
510+
Variable,
511+
jacobian(
512+
self.fgraph.outputs[-1],
513+
self.fgraph.inputs[0],
514+
vectorize=use_vectorized_jac,
515+
),
509516
)
510517
self.fgraph.add_output(hess_wrt_x)
511518

@@ -561,7 +568,10 @@ def L_op(self, inputs, outputs, output_grads):
561568
implicit_f = grad(inner_fx, inner_x)
562569

563570
df_dx, *df_dtheta_columns = jacobian(
564-
implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore"
571+
implicit_f,
572+
[inner_x, *inner_args],
573+
disconnected_inputs="ignore",
574+
vectorize=self.use_vectorized_jac,
565575
)
566576
grad_wrt_args = implict_optimization_grads(
567577
df_dx=df_dx,
@@ -581,6 +591,7 @@ def minimize(
581591
method: str = "BFGS",
582592
jac: bool = True,
583593
hess: bool = False,
594+
use_vectorized_jac: bool = False,
584595
optimizer_kwargs: dict | None = None,
585596
) -> tuple[TensorVariable, TensorVariable]:
586597
"""
@@ -590,18 +601,21 @@ def minimize(
590601
----------
591602
objective : TensorVariable
592603
The objective function to minimize. This should be a pytensor variable representing a scalar value.
593-
594-
x : TensorVariable
604+
x: TensorVariable
595605
The variable with respect to which the objective function is minimized. It must be an input to the
596606
computational graph of `objective`.
597-
598-
method : str, optional
607+
method: str, optional
599608
The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
600-
601-
jac : bool, optional
602-
Whether to compute and use the gradient of teh objective function with respect to x for optimization.
609+
jac: bool, optional
610+
Whether to compute and use the gradient of the objective function with respect to x for optimization.
603611
Default is True.
604-
612+
hess: bool, optional
613+
Whether to compute and use the Hessian of the objective function with respect to x for optimization.
614+
Default is False. Note that some methods require this, while others do not support it.
615+
use_vectorized_jac: bool, optional
616+
Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a
617+
scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster,
618+
but use more memory. Default is False.
605619
optimizer_kwargs
606620
Additional keyword arguments to pass to scipy.optimize.minimize
607621
@@ -624,6 +638,7 @@ def minimize(
624638
method=method,
625639
jac=jac,
626640
hess=hess,
641+
use_vectorized_jac=use_vectorized_jac,
627642
optimizer_kwargs=optimizer_kwargs,
628643
)
629644

@@ -804,6 +819,7 @@ def __init__(
804819
method: str = "hybr",
805820
jac: bool = True,
806821
optimizer_kwargs: dict | None = None,
822+
use_vectorized_jac: bool = False,
807823
):
808824
if cast(TensorVariable, variables).ndim != cast(TensorVariable, equations).ndim:
809825
raise ValueError(
@@ -817,7 +833,11 @@ def __init__(
817833
self.fgraph = FunctionGraph([variables, *args], [equations])
818834

819835
if jac:
820-
jac_wrt_x = jacobian(self.fgraph.outputs[0], self.fgraph.inputs[0])
836+
jac_wrt_x = jacobian(
837+
self.fgraph.outputs[0],
838+
self.fgraph.inputs[0],
839+
vectorize=use_vectorized_jac,
840+
)
821841
self.fgraph.add_output(atleast_2d(jac_wrt_x))
822842

823843
self.jac = jac
@@ -897,8 +917,14 @@ def L_op(
897917
inner_x, *inner_args = self.fgraph.inputs
898918
inner_fx = self.fgraph.outputs[0]
899919

900-
df_dx = jacobian(inner_fx, inner_x) if not self.jac else self.fgraph.outputs[1]
901-
df_dtheta_columns = jacobian(inner_fx, inner_args, disconnected_inputs="ignore")
920+
df_dx = (
921+
jacobian(inner_fx, inner_x, vectorize=True)
922+
if not self.jac
923+
else self.fgraph.outputs[1]
924+
)
925+
df_dtheta_columns = jacobian(
926+
inner_fx, inner_args, disconnected_inputs="ignore", vectorize=True
927+
)
902928

903929
grad_wrt_args = implict_optimization_grads(
904930
df_dx=df_dx,
@@ -917,6 +943,7 @@ def root(
917943
variables: TensorVariable,
918944
method: str = "hybr",
919945
jac: bool = True,
946+
use_vectorized_jac: bool = False,
920947
optimizer_kwargs: dict | None = None,
921948
) -> tuple[TensorVariable, TensorVariable]:
922949
"""
@@ -935,6 +962,10 @@ def root(
935962
jac : bool, optional
936963
Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
937964
Default is True. Most methods require this.
965+
use_vectorized_jac: bool, optional
966+
Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead.
967+
This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory.
968+
Default is False.
938969
optimizer_kwargs : dict, optional
939970
Additional keyword arguments to pass to `scipy.optimize.root`.
940971
@@ -958,6 +989,7 @@ def root(
958989
method=method,
959990
jac=jac,
960991
optimizer_kwargs=optimizer_kwargs,
992+
use_vectorized_jac=use_vectorized_jac,
961993
)
962994

963995
solution, success = cast(

0 commit comments

Comments
 (0)