Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impossible to create a module with a parameter that lacks a const shape? #839

Closed
emchristiansen opened this issue Aug 2, 2023 · 1 comment · Fixed by #854
Closed

Impossible to create a module with a parameter that lacks a const shape? #839

emchristiansen opened this issue Aug 2, 2023 · 1 comment · Fixed by #854

Comments

@emchristiansen
Copy link

I'd like to create a module from a struct like this:

#[derive(Debug, Clone)]
pub struct MyNet<E, D>
where
  E: Dtype,
  D: Device<E>,
{
  pub metadata: MyMetadata,
  pub logits: Tensor<(usize,), E, D, NoneTape>,
}

However, I run into two issues when impl'ing TensorCollection:

  1. Minor issue: TensorCollection assumes the module can be constructed from purely the Tensor values and nothing else, meaning I actually need to refactor the struct to remove metadata, which is inconvenient.
  2. Major issue: To register logits when calling visitor.visit_fields I need to call Self::tensor, which expects a TensorOptions. But, TensorOptions can only be constructed if S: ConstShape as that is assumed by all the construction methods and TensorOptions is marked non-exhaustive. But even if this were relaxed I wouldn't be able to construct logits without knowing its (runtime) shape, which I'm unable to determine given the limited context allowed by the signature.

Is this a correct diagnosis of the limitations of the module API?
If so, do you suggest a workaround?

@coreylowman
Copy link
Owner

This will be addressed with the nn rewrite

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants