Skip to content

Commit

Permalink
adjusted to use on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
haddadanas committed Jan 29, 2025
1 parent 3600aef commit 0f15a1a
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion ml_network/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def NLL_Focal_Loss(
# Calculate the NLL loss
if targets.dim() == 2:
targets = targets.squeeze(1)
nll_loss = F.nll_loss(inputs.log(), targets, reduction="none")
nll_loss = F.nll_loss(inputs.log(), targets.long(), reduction="none")

# Calculate the Focal Loss
p_t = inputs[:, 1] * targets + inputs[:, 0] * (1 - targets)
Expand Down
6 changes: 3 additions & 3 deletions ml_network/models/ml_model_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"embedding_fields": ["channel_id"] + [f"l{n}.tauVS{var}" for n in [1, 2] for var in ["jet", "e", "mu"]],
"optimizer": partial(torch.optim.Adam, lr=0.001, eps=1e-06),
"loss": partial(NLL_Focal_Loss, reduction="mean"),
"epochs": 30,
"batch_size": 265,
"epochs": 20,
"batch_size": 256,
}


Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(self, name, input_features, save_path="./models"):
def forward(self, X_embed, X_num):
x = [X_num]
for layer, data in zip(self.embed, X_embed):
x.append(layer(data))
x.append(layer(data.clamp(0, 9)))
x = torch.cat(x, dim=1)
x = x.float()

Expand Down
6 changes: 3 additions & 3 deletions ml_network/src/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch.nn import CrossEntropyLoss, Module
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter # type: ignore

import matplotlib.pyplot as plt

Expand Down Expand Up @@ -201,11 +200,12 @@ def weighted_loss(y_pred, y_true, weight):
def _split_data_tuple(self, use_weights: bool) -> Callable:
if use_weights:
return lambda data: (
data[0][0], data[0][1].requires_grad_(True), data[1].requires_grad_(True), data[2].requires_grad_(True)
data[0][0], data[0][1].requires_grad_(True), data[1], data[2].requires_grad_(True)
)
return lambda data: (data[0][0], data[0][1].requires_grad_(True), data[1].requires_grad_(True), None)
return lambda data: (data[0][0], data[0][1].requires_grad_(True), data[1], None)

def _tensorboard_setup(self) -> None:
from torch.utils.tensorboard import SummaryWriter # type: ignore
self.writer = SummaryWriter(self.path)
print(f"Tensorboard logs are saved to {self.path}")

Expand Down
2 changes: 1 addition & 1 deletion ml_network/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(self, device, inp_embed, inp_num, target=None, weight=None):
else to_tensor_i32(inp_embed).to(device)
)
size = self.num_data.size(0)
self.target = torch.Tensor(size) if target is None else to_tensor_f32(target).to(device)
self.target = torch.Tensor(size) if target is None else to_tensor_i32(target).to(device)

self.weight = None if weight is None else to_tensor_f32(weight).to(device)
# Dynamically set the __getitem__ method
Expand Down

0 comments on commit 0f15a1a

Please sign in to comment.