-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtent_come.py
121 lines (109 loc) · 4.59 KB
/
tent_come.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Copyright to COME Authors, ICLR 2025
built upon on Tent code.
# https://github.com/mr-eggplant/SAR/blob/main/tent.py
"""
from copy import deepcopy
import torch
import torch.nn as nn
import torch.jit
import torch.nn.functional as F
class Tent_COME(nn.Module):
"""Tent_COME adapts a model by entropy minimization during testing.
"""
def __init__(self, model, optimizer,args, steps=1, episodic=False):
super().__init__()
self.model = model
self.optimizer = optimizer
self.steps = steps
assert steps > 0, "tent_come requires >= 1 step(s) to forward and update"
self.episodic = episodic
self.args = args
self.model_state, self.optimizer_state = \
copy_model_and_optimizer(self.model, self.optimizer)
def forward(self, x):
if self.episodic:
self.reset()
for _ in range(self.steps):
outputs = forward_and_adapt(x, self.model, self.optimizer,self.args)
return outputs
def reset(self):
if self.model_state is None or self.optimizer_state is None:
raise Exception("cannot reset without saved model/optimizer state")
load_model_and_optimizer(self.model, self.optimizer,
self.model_state, self.optimizer_state)
@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
"""Entropy of softmax distribution from logits."""
return -(x.softmax(1) * x.log_softmax(1)).sum(1)
@torch.jit.script
def dirichlet_entropy(x: torch.Tensor):#key component of COME
x = x / torch.norm(x, p=2, dim=-1, keepdim=True) * torch.norm(x, p=2, dim=-1, keepdim=True).detach()
brief = torch.exp(x)/(torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000)
uncertainty = 1000 / (torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000)
probability = torch.cat([brief, uncertainty], dim=1) + 1e-7
entropy = -(probability * torch.log(probability)).sum(1)
return entropy
@torch.enable_grad()
def forward_and_adapt(x, model, optimizer,args):
"""Forward and adapt model on batch of data.
Measure entropy of the model prediction, take gradients, and update params.
"""
outputs = model(x)
# COME: replace softmax_entropy with dirichlet_entropy
entropy = dirichlet_entropy(outputs)
loss = entropy
loss = loss.mean(0)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return outputs
def collect_params(model):
"""Collect the affine scale + shift parameters from batch norms.
Walk the model's modules and collect all batch normalization parameters.
Return the parameters and their names.
"""
params = []
names = []
for nm, m in model.named_modules():
if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
for np, p in m.named_parameters():
if np in ['weight', 'bias']:
params.append(p)
names.append(f"{nm}.{np}")
return params, names
def copy_model_and_optimizer(model, optimizer):
"""Copy the model and optimizer states for resetting after adaptation."""
model_state = deepcopy(model.state_dict())
optimizer_state = deepcopy(optimizer.state_dict())
return model_state, optimizer_state
def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
"""Restore the model and optimizer states from copies."""
model.load_state_dict(model_state, strict=True)
optimizer.load_state_dict(optimizer_state)
def configure_model(model):
"""Configure model for use with tent_come."""
model.train()
model.requires_grad_(False)
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.requires_grad_(True)
m.track_running_stats = False
m.running_mean = None
m.running_var = None
if isinstance(m, (nn.GroupNorm, nn.LayerNorm)):
m.requires_grad_(True)
return model
def check_model(model):
"""Check model for compatability with tent_come."""
is_training = model.training
assert is_training, "tent_come needs train mode: call model.train()"
param_grads = [p.requires_grad for p in model.parameters()]
has_any_params = any(param_grads)
has_all_params = all(param_grads)
assert has_any_params, "tent_come needs params to update: " \
"check which require grad"
assert not has_all_params, "tent_come should not update all params: " \
"check which require grad"
has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
assert has_bn, "tent_come needs normalization for its optimization"