-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Semantics for batched Diagrams? #138
Comments
Hmm, I guess the more correct way to do this is to consider a batched diagram like a list of diagrams. You could then apply cat or chat or tcat along the batched axes. You would then need to do: concat(draw(np.arange(10)), axis=0) + hcat(draw(np.arange(5)), axis=0) Maybe internally though this could still store it as 2 arrays instead of flattening to 15. Need to think if this works for the other cases. |
I ended up going with the following design. @jax.vmap
def outer(j):
@jax.vmap
def inner(i):
return (circle(0.3 * i / 6).fill_color(np.ones(3) * i / 6) +
square(0.1).fill_color("white")).scale(1)
inside = inner(np.arange(2, 5))
return vcat(inside).scale(1)
out = outer(np.arange(1, 6))
print("My Size", out.size())
d = hcat(out) I think this solves my issue without adding too much complexity to the system. The main difference is that now diagrams can have a There are a couple core change to the system to make this work. One is a new node (I'll add to the PR when stable). |
Okay, this is indeed interesting! Let me see if I understood this right:
The original
I'm not completely sure I understand what this change implies. But it will surely get clear once I see the code 🙂 |
Yup, that's a good summary. Still working out the details, but I think the semantics make sense.
Yeah! I think the trick here is kind of cool. If you render batched primitives, you have a List[Prim] but they may no longer be in order (elements of the prim 2 may need to be in front of prim 1). But you can have an array
Right good point. This was exactly the hard part since having functions on the trees is not allowed in Jax. So instead I was thinking an If you want to override the envelope on a @dataclass
class EnvDistance(Monoid):
d: Scalars
def __add__(self, other: Self) -> Self:
return EnvDistance(tx.X.np.maximum(self.d, other.d))
@staticmethod
def empty() -> EnvDistance:
return EnvDistance(tx.X.np.asarray(-1e5))
def reduce(self, axis=0):
return EnvDistance(tx.X.np.max(self.d, axis=axis))
class Envelope(Transformable, Monoid):
diagram: Diagram
affine: Affine
def __call__(self, direction: V2_t) -> Scalars:
self.diagram.accept(ApplyEnvelope(), EnvDistance.empty())
...
class ApplyEnvelope(DiagramVisitor[EnvDistance, V2_t]):
A_type = EnvDistance
def visit_primitive(self, diagram: Primitive, t: V2_t) -> EnvDistance: |
This is actually a pretty interesting question that I'm stuck on. In Jax, I'm thinking of a diagram as a https://jax.readthedocs.io/en/latest/pytrees.html . A pytree is basically a tree of arrays. When you call vmap, and return a diagram, Jax returns a diagram where there is an extra dimension on all the arrays.
This object is a
Primitive
where the transform/style has a batch of 10 in front of it. By default I am interpreting this as a concat of 10 primitives. However one might also interpret it as an animation with 10 frames.The question is what happens if you try to compose this with another object? Think in this case you just have 15 composed elements.
But is that the same as this case? Here you have a Compose node that also has a Batch dimension on it that applies to both its children.
You also have the case where there are multiple vmaps, in this case I think the concat should just
flatten
them and draw them in order.However I think it would be nice if whatever we do here works both for animation and drawing. Like there is some notion of a composable sequence of diagrams either in z-space or time that corresponds to this tree idea.
The text was updated successfully, but these errors were encountered: