Skip to content

Commit

Permalink
Merge pull request #22 from normal-computing/remove-item
Browse files Browse the repository at this point in the history
Remove .item in favour of .detach
  • Loading branch information
SamDuffield authored Feb 20, 2024
2 parents 7080e3f + 8cbc3ee commit 63ea636
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
_archive/
.vscode/
*.pkl
*.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 1 addition & 1 deletion uqlib/ekf/diag_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def update(
grad,
inplace=inplace,
)
return EKFDiagState(update_mean, update_sd_diag, log_liks.mean().item(), aux)
return EKFDiagState(update_mean, update_sd_diag, log_liks.mean().detach(), aux)


def build(
Expand Down
2 changes: 1 addition & 1 deletion uqlib/ekf/diag_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def update(
grad,
inplace=inplace,
)
return EKFDiagState(update_mean, update_sd_diag, log_lik.item(), aux)
return EKFDiagState(update_mean, update_sd_diag, log_lik.detach(), aux)


def build(
Expand Down
4 changes: 2 additions & 2 deletions uqlib/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class SGHMCState(NamedTuple):

params: TensorTree
momenta: TensorTree
log_posterior: float = 0.0
log_posterior: torch.tensor = torch.tensor(0.0)
aux: Any = None


Expand Down Expand Up @@ -90,7 +90,7 @@ def transform_momenta(m, g):
)
momenta = flexi_tree_map(transform_momenta, state.momenta, grads, inplace=inplace)

return SGHMCState(params, momenta, log_post.item(), aux)
return SGHMCState(params, momenta, log_post.detach(), aux)


def build(
Expand Down
2 changes: 1 addition & 1 deletion uqlib/vi/diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def nelbo_log_sd(m, lsd):
mean, log_sd_diag = torchopt.apply_updates(
(state.mean, state.log_sd_diag), updates, inplace=inplace
)
return VIDiagState(mean, log_sd_diag, optimizer_state, nelbo_val.item(), aux)
return VIDiagState(mean, log_sd_diag, optimizer_state, nelbo_val.detach(), aux)


def build(
Expand Down

0 comments on commit 63ea636

Please sign in to comment.