Quadratic forms are not scalars? Question about scalar-valued functions #637
Replies: 1 comment 4 replies
-
Hi @doctor-phil thanks for reaching out. When dprinting, you may want to pass the flag import pytensor
import pytensor.tensor as pt
A = pt.dmatrix('A')
x = pt.col('x')
f = -x.T @ A @ x
pytensor.dprint(f, print_type=True) This will show that
Grad explicitly requires you to have a scalar input, so you can call pytensor.dprint(f.squeeze(), print_type=True)
And now you should be able to call grad just fine. Usually we always do On your second question. The goal of compiling a function is not to simplify an expression, but to optimize it for speed/memory consumption. In this case we have replaced the Dot by an efficient GemV Blas operation. It looks complicated because the function has a very specfic signature (5 inputs in this case). If you are only interested in simplifying an expression, you don't even need to compile which not only does rewrites that increase "complexity" but also wastes time doing C compilaiton. Instead you can use from pytensor.graph import rewrite_graph
grad_f_wrt_x = pt.grad(f.squeeze(), x)
simpler_grad_f_wrt_x = rewrite_graph(grad_f_wrt_x, include=("canonicalize", "specialize"))
pytensor.dprint(simpler_grad_f_wrt_x)
A bit better, although there are some stupid dot(x, 1) that should be rewritten away. However, my hand-picked arguments to Anyway let me know if this helps. |
Beta Was this translation helpful? Give feedback.
-
This is maybe a conceptual issue or I am missing something so I was hoping someone could point me in the right direction.
Consider the following example:
Clearly in this example there is only one possible shape for f wherever it is defined. However, returning
f
shows it is aBlockwise{dot, (m,k),(k,n)->(m,n)}.0
, and trying to find the gradient withpt.grad(f,x)
throws aTypeError: Cost must be a scalar.
What am I doing wrong here?
Edit: So, I got it working by wrapping the expression in
tn.trace()
. (So thatf=tn.trace(-x.T @ A @ x)
). This seems like a hacky workaround but whatever. Anyway, printing the graph for f gives:Meanwhile, if I compile f into a function (as
ff = pt.function([x,A], f)
I am now left withI was under the impression that compiling with
pt.function
should simplify the expression, not make it more verbose.For my actual application, I am really just looking for a way to simplify a certain graph and retrieve it as a symbolic expression.
Beta Was this translation helpful? Give feedback.
All reactions