7
7
8
8
import pytensor .scalar as ps
9
9
from pytensor .compile .function import function
10
- from pytensor .gradient import grad , hessian , jacobian
10
+ from pytensor .gradient import grad , jacobian
11
11
from pytensor .graph .basic import Apply , Constant
12
12
from pytensor .graph .fg import FunctionGraph
13
13
from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
@@ -484,6 +484,7 @@ def __init__(
484
484
jac : bool = True ,
485
485
hess : bool = False ,
486
486
hessp : bool = False ,
487
+ use_vectorized_jac : bool = False ,
487
488
optimizer_kwargs : dict | None = None ,
488
489
):
489
490
if not cast (TensorVariable , objective ).ndim == 0 :
@@ -496,6 +497,7 @@ def __init__(
496
497
)
497
498
498
499
self .fgraph = FunctionGraph ([x , * args ], [objective ])
500
+ self .use_vectorized_jac = use_vectorized_jac
499
501
500
502
if jac :
501
503
grad_wrt_x = cast (
@@ -505,7 +507,12 @@ def __init__(
505
507
506
508
if hess :
507
509
hess_wrt_x = cast (
508
- Variable , hessian (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
510
+ Variable ,
511
+ jacobian (
512
+ self .fgraph .outputs [- 1 ],
513
+ self .fgraph .inputs [0 ],
514
+ vectorize = use_vectorized_jac ,
515
+ ),
509
516
)
510
517
self .fgraph .add_output (hess_wrt_x )
511
518
@@ -561,7 +568,10 @@ def L_op(self, inputs, outputs, output_grads):
561
568
implicit_f = grad (inner_fx , inner_x )
562
569
563
570
df_dx , * df_dtheta_columns = jacobian (
564
- implicit_f , [inner_x , * inner_args ], disconnected_inputs = "ignore"
571
+ implicit_f ,
572
+ [inner_x , * inner_args ],
573
+ disconnected_inputs = "ignore" ,
574
+ vectorize = self .use_vectorized_jac ,
565
575
)
566
576
grad_wrt_args = implict_optimization_grads (
567
577
df_dx = df_dx ,
@@ -581,6 +591,7 @@ def minimize(
581
591
method : str = "BFGS" ,
582
592
jac : bool = True ,
583
593
hess : bool = False ,
594
+ use_vectorized_jac : bool = False ,
584
595
optimizer_kwargs : dict | None = None ,
585
596
) -> tuple [TensorVariable , TensorVariable ]:
586
597
"""
@@ -590,18 +601,21 @@ def minimize(
590
601
----------
591
602
objective : TensorVariable
592
603
The objective function to minimize. This should be a pytensor variable representing a scalar value.
593
-
594
- x : TensorVariable
604
+ x: TensorVariable
595
605
The variable with respect to which the objective function is minimized. It must be an input to the
596
606
computational graph of `objective`.
597
-
598
- method : str, optional
607
+ method: str, optional
599
608
The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
600
-
601
- jac : bool, optional
602
- Whether to compute and use the gradient of teh objective function with respect to x for optimization.
609
+ jac: bool, optional
610
+ Whether to compute and use the gradient of the objective function with respect to x for optimization.
603
611
Default is True.
604
-
612
+ hess: bool, optional
613
+ Whether to compute and use the Hessian of the objective function with respect to x for optimization.
614
+ Default is False. Note that some methods require this, while others do not support it.
615
+ use_vectorized_jac: bool, optional
616
+ Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a
617
+ scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster,
618
+ but use more memory. Default is False.
605
619
optimizer_kwargs
606
620
Additional keyword arguments to pass to scipy.optimize.minimize
607
621
@@ -624,6 +638,7 @@ def minimize(
624
638
method = method ,
625
639
jac = jac ,
626
640
hess = hess ,
641
+ use_vectorized_jac = use_vectorized_jac ,
627
642
optimizer_kwargs = optimizer_kwargs ,
628
643
)
629
644
@@ -804,6 +819,7 @@ def __init__(
804
819
method : str = "hybr" ,
805
820
jac : bool = True ,
806
821
optimizer_kwargs : dict | None = None ,
822
+ use_vectorized_jac : bool = False ,
807
823
):
808
824
if cast (TensorVariable , variables ).ndim != cast (TensorVariable , equations ).ndim :
809
825
raise ValueError (
@@ -817,7 +833,11 @@ def __init__(
817
833
self .fgraph = FunctionGraph ([variables , * args ], [equations ])
818
834
819
835
if jac :
820
- jac_wrt_x = jacobian (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
836
+ jac_wrt_x = jacobian (
837
+ self .fgraph .outputs [0 ],
838
+ self .fgraph .inputs [0 ],
839
+ vectorize = use_vectorized_jac ,
840
+ )
821
841
self .fgraph .add_output (atleast_2d (jac_wrt_x ))
822
842
823
843
self .jac = jac
@@ -897,8 +917,14 @@ def L_op(
897
917
inner_x , * inner_args = self .fgraph .inputs
898
918
inner_fx = self .fgraph .outputs [0 ]
899
919
900
- df_dx = jacobian (inner_fx , inner_x ) if not self .jac else self .fgraph .outputs [1 ]
901
- df_dtheta_columns = jacobian (inner_fx , inner_args , disconnected_inputs = "ignore" )
920
+ df_dx = (
921
+ jacobian (inner_fx , inner_x , vectorize = True )
922
+ if not self .jac
923
+ else self .fgraph .outputs [1 ]
924
+ )
925
+ df_dtheta_columns = jacobian (
926
+ inner_fx , inner_args , disconnected_inputs = "ignore" , vectorize = True
927
+ )
902
928
903
929
grad_wrt_args = implict_optimization_grads (
904
930
df_dx = df_dx ,
@@ -917,6 +943,7 @@ def root(
917
943
variables : TensorVariable ,
918
944
method : str = "hybr" ,
919
945
jac : bool = True ,
946
+ use_vectorized_jac : bool = False ,
920
947
optimizer_kwargs : dict | None = None ,
921
948
) -> tuple [TensorVariable , TensorVariable ]:
922
949
"""
@@ -935,6 +962,10 @@ def root(
935
962
jac : bool, optional
936
963
Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
937
964
Default is True. Most methods require this.
965
+ use_vectorized_jac: bool, optional
966
+ Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead.
967
+ This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory.
968
+ Default is False.
938
969
optimizer_kwargs : dict, optional
939
970
Additional keyword arguments to pass to `scipy.optimize.root`.
940
971
@@ -958,6 +989,7 @@ def root(
958
989
method = method ,
959
990
jac = jac ,
960
991
optimizer_kwargs = optimizer_kwargs ,
992
+ use_vectorized_jac = use_vectorized_jac ,
961
993
)
962
994
963
995
solution , success = cast (
0 commit comments