Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
saandeepa93 committed Jul 20, 2024
1 parent df1601a commit ce22d2b
Show file tree
Hide file tree
Showing 12 changed files with 22 additions and 18 deletions.
2 changes: 1 addition & 1 deletion configs/experiments/cifar10/cifar10_5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ LR:
TEST:
EMP_PARAMS: True
SCORE: True
MAGNITUDE: 0.0024
MAGNITUDE: 0.00
IN_FEATS: [64, 128, 128, 512]

COMMENTS:
Expand Down
Binary file removed figures/cifar10_3.jpg
Binary file not shown.
Binary file removed figures/cifar10_3_cifar100.jpg
Binary file not shown.
Binary file removed figures/cifar10_7.jpg
Binary file not shown.
Binary file removed figures/cifar10_7_cifar100.jpg
Binary file not shown.
Binary file removed figures/intuition_orig.png
Binary file not shown.
Binary file removed figures/new_arch.pdf
Binary file not shown.
Binary file removed figures/new_arch.png
Binary file not shown.
Binary file removed figures/new_arch2.png
Binary file not shown.
28 changes: 19 additions & 9 deletions losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def gaussian_log_p(x, mean, log_sd):


class FlowConLoss:
def __init__(self, cfg, device, p_y=None):
def __init__(self, cfg, device, test=False):
self.cfg = cfg
self.device=device
self.n_bins = cfg.FLOW.N_BINS
Expand All @@ -18,6 +18,7 @@ def __init__(self, cfg, device, p_y=None):

self.tau = cfg.LOSS.TAU
self.tau2 = cfg.LOSS.TAU2
self.test = test


# RAF12
Expand Down Expand Up @@ -49,15 +50,24 @@ def nllLoss(self, z, logdet, mu, log_sd):

logdet = logdet.mean()
loss = self.init_loss + logdet + log_p_nll

if self.test:
score = loss / (log(2) * self.n_pixel), # CONVERTING LOGe to LOG2 |
log_p = log_p_nll / (log(2) * self.n_pixel) # v
return (
score,
log_p,
(logdet / (log(2) * self.n_pixel)).mean(),
(log_p_all/ (log(2) * self.n_pixel))
)
else:
return (

return (
# (-loss / (log(2) * self.n_pixel)).mean(), # CONVERTING LOGe to LOG2 |
# (log_p_nll / (log(2) * self.n_pixel)).mean(), # v
(loss / (log(2) * self.n_pixel)), # CONVERTING LOGe to LOG2 |
(log_p_nll / (log(2) * self.n_pixel)), # v
(logdet / (log(2) * self.n_pixel)).mean(),
(log_p_all/ (log(2) * self.n_pixel))
)
(-loss / (log(2) * self.n_pixel)).mean(), # CONVERTING LOGe to LOG2 |
(log_p_nll / (log(2) * self.n_pixel)).mean(), # v
(logdet / (log(2) * self.n_pixel)).mean(),
(log_p_all/ (log(2) * self.n_pixel))
)


def conLoss(self, log_p_all, labels):
Expand Down
3 changes: 1 addition & 2 deletions tester/emperical_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def calc_emp_params(cfg, args, loader, pretrained, flow, dist_dir, device, label
z = torch.cat(z_all, dim=0)

plot_umap(cfg, z.cpu(), labels_all.cpu(), f"{args.config}", 2, "in_ood", labels_in_ood)
e()

mu_k, std_k , z_k= [], [], []
for cls in range(cfg.DATASET.N_CLASS):
Expand Down Expand Up @@ -276,7 +275,7 @@ def calc_score(cfg, args, loader, pretrained, flow, mu, log_sd, criterion, devic
# ood_datasets = ['cifar10']
result = {}

criterion = FlowConLoss(cfg, device)
criterion = FlowConLoss(cfg, device, test=True)


loss_criterion = nn.CrossEntropyLoss()
Expand Down
7 changes: 1 addition & 6 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,6 @@ def plot_umap(cfg, X_lst_un, y_lst, name, dim, mode, labels_in_ood=None):
df = pd.DataFrame(X_lst, columns=["x", "y"])
df_color = pd.DataFrame(y_lst_label, columns=["class"])
df['legend_group'] = df_color['class'].apply(lambda x: 'ID' if x in id_list else "OOD")
print(df_color.head())
# e()
df = df.join(df_color)
if dim == 3:
fig = px.scatter_3d(df, x='x', y='y', z='z',color='class', title=f"{name}", \
Expand All @@ -276,11 +274,8 @@ def plot_umap(cfg, X_lst_un, y_lst, name, dim, mode, labels_in_ood=None):
# color_discrete_sequence=colors_all
)

# fig.update_traces(marker=dict(size=8))
# fig.update_traces(marker=dict(size=12, line=dict(width=2)))
# fig.update_traces(marker=dict(size=[12 if c == '100' else 8 for c in df['class']]))
fig.for_each_trace(lambda t: t.update(showlegend=False) if t.name in id_list else t.update(name='OOD'))
fig.update_layout(title="CV vs Entropy", margin=dict(l=0, r=0, t=0, b=0), showlegend=True)
fig.update_layout(title=None, margin=dict(l=0, r=0, t=0, b=0), showlegend=True)
fig.update_layout(
legend=dict(
yanchor="top",
Expand Down

0 comments on commit ce22d2b

Please sign in to comment.