Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: forward() missing 1 required positional argument: 'events' #172

Open
KristofPusztai opened this issue Oct 4, 2023 · 0 comments

Comments

@KristofPusztai
Copy link

KristofPusztai commented Oct 4, 2023

I appreciate this library but the documentation for models is quite lackluster and the jupyter notebook is not enough to re-use this for other, more complex, use cases.

Pretty frustrating that this error is very unclear and no documentation on how exactly inputs should be formatted or how to fix/what is going on... my input data is shaped as follows:

x_train.shape
>>> torch.Size([5720633, 75])

y_train.shape
>>>torch.Size([5720633, 2])

Getting this error when running the lr_finder method:
lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=10)

Full Error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_517/3901793619.py in <module>
----> 1 lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=10)
      2 _ = lr_finder.plot()

~/.local/lib/python3.7/site-packages/torchtuples/base.py in lr_finder(self, input, target, batch_size, lr_min, lr_max, lr_range, n_steps, tolerance, callbacks, verbose, num_workers, shuffle, **kwargs)
    346                 num_workers,
    347                 shuffle,
--> 348                 **kwargs,
    349             )
    350         return lr_finder

~/.local/lib/python3.7/site-packages/pycox/models/cox.py in fit(self, input, target, batch_size, epochs, callbacks, verbose, num_workers, shuffle, metrics, val_data, val_batch_size, **kwargs)
     51         return super().fit(input, target, batch_size, epochs, callbacks, verbose,
     52                            num_workers, shuffle, metrics, val_data, val_batch_size,
---> 53                            **kwargs)
     54 
     55     def _compute_baseline_hazards(self, input, df, max_duration, batch_size, eval_=True, num_workers=0):

~/.local/lib/python3.7/site-packages/torchtuples/base.py in fit(self, input, target, batch_size, epochs, callbacks, verbose, num_workers, shuffle, metrics, val_data, val_batch_size, **kwargs)
    292                 val_data, val_batch_size, shuffle=False, num_workers=num_workers, **kwargs
    293             )
--> 294         log = self.fit_dataloader(dataloader, epochs, callbacks, verbose, metrics, val_dataloader)
    295         return log
    296 

~/.local/lib/python3.7/site-packages/torchtuples/base.py in fit_dataloader(self, dataloader, epochs, callbacks, verbose, metrics, val_dataloader)
    234                     break
    235                 self.optimizer.zero_grad()
--> 236                 self.batch_metrics = self.compute_metrics(data, self.metrics)
    237                 self.batch_loss = self.batch_metrics["loss"]
    238                 self.batch_loss.backward()

~/.local/lib/python3.7/site-packages/torchtuples/base.py in compute_metrics(self, data, metrics)
    180         out = self.net(*input)
    181         out = tuplefy(out)
--> 182         return {name: metric(*out, *target) for name, metric in metrics.items()}
    183 
    184     def _setup_metrics(self, metrics=None):

~/.local/lib/python3.7/site-packages/torchtuples/base.py in <dictcomp>(.0)
    180         out = self.net(*input)
    181         out = tuplefy(out)
--> 182         return {name: metric(*out, *target) for name, metric in metrics.items()}
    183 
    184     def _setup_metrics(self, metrics=None):

~/.local/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() missing 1 required positional argument: 'events'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant