You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
and the introduction of diag matrix creates large memory overhead, especially in the case when number of training data points is large.
A possible fix is as follows
if train_dataloader is not None:
return (running_xinv_XTX_XT * running_Q.to(self.device).unsqueeze(0)).T
return (running_xinv_XTX_XT * self.Q.to(self.device).unsqueeze(0)).T
and I have verified the equivalency using two different setups.
The text was updated successfully, but these errors were encountered:
Current Implementation of TRAK calculates final result using
and the introduction of diag matrix creates large memory overhead, especially in the case when number of training data points is large.
A possible fix is as follows
and I have verified the equivalency using two different setups.
The text was updated successfully, but these errors were encountered: