-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer_clip.py
97 lines (77 loc) · 2.58 KB
/
trainer_clip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
@author: Adityam Ghosh
Date: 10-29-2023
"""
from typing import Dict, List, Any, Tuple, Optional
import torch
import torch.nn as nn
import pytorch_lightning as pl
torch.set_float32_matmul_precision("medium")
class LitCLIP(pl.LightningModule):
def __init__(self, model: nn.Module, lr: float = 1e-4, min_lr: float = 1e-8):
super().__init__()
self.model = model
self.lr = lr
self.min_lr = min_lr
def forward(
self,
image: torch.Tensor,
text: torch.Tensor,
) -> Tuple:
model_out = self.model(
input_ids=text["input_ids"].squeeze(),
attention_mask=text["attention_mask"].squeeze(),
pixel_values=image["pixel_values"].squeeze(),
return_loss=True,
return_dict=True,
)
return model_out
def _common_steps(
self, batch: torch.Tensor, batch_idx: torch.Tensor
) -> torch.Tensor:
img, txt = batch["img"], batch["txt"]
out = self(img, txt)
return out["loss"]
def training_step(self, batch: torch.Tensor, batch_idx: torch.Tensor) -> Dict:
loss = self._common_steps(batch, batch_idx)
self.log(
"train_loss",
loss.item(),
prog_bar=True,
on_step=True,
on_epoch=True,
rank_zero_only=True,
logger=True,
sync_dist=True,
)
return {"loss": loss}
def validation_step(self, batch: torch.Tensor, batch_idx: torch.Tensor) -> Dict:
loss = self._common_steps(batch, batch_idx)
self.log(
"val_loss",
loss.item(),
prog_bar=True,
on_step=False,
on_epoch=True,
rank_zero_only=True,
logger=True,
sync_dist=True,
)
return {"val_loss": loss}
def configure_optimizers(self) -> Any:
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
1000,
eta_min=self.min_lr,
verbose=True,
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
def encode_image(self, img_tensor: torch.Tensor) -> torch.Tensor:
return self.model.get_image_features(pixel_values=img_tensor)
def encode_text(
self, text_tensor: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
return self.model.get_text_features(
input_ids=text_tensor, attention_mask=attn_mask
)