File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -33,7 +33,7 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
3333 import matplotlib .pyplot as plt
3434 from matplotlib import cm
3535
36- # TODO: use jax to find the gradient.
36+ # TODO: use torch to find the gradient.
3737
3838 nx , ny = (1001 , 1001 )
3939 x = th .linspace (- 3 , 3 , nx )
@@ -57,7 +57,7 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
5757 step_total = 100
5858
5959 pos_list = [start_pos ]
60- velocity_vec = np . array ((0.0 , 0.0 ))
60+ velocity_vec = th . tensor ((0.0 , 0.0 ))
6161 # TODO: Implement gradient descent with momentum.
6262
6363 for pos in pos_list :
@@ -69,5 +69,5 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
6969 np .array (my ),
7070 np .array (mz ),
7171 pos_list ,
72- "writer_grad_bumpy_plot_jax " ,
72+ "writer_grad_bumpy_plot_torch " ,
7373 )
You can’t perform that action at this time.
0 commit comments