当前项目中主要使用,PyTorch Lightning 框架,在此进行记录
==暂未进行章节区分==
在继承 pl.LightningModule
构建自己的 net 时,需要记住框架在真正运行时,会运行的那些步骤,及其顺序,并针对性地进行覆写
以下示例中,看函数名字,就能知道运行顺序,详细情况还是查官方文档Callback — PyTorch Lightning
"""运行顺序
for epoch_i in range(num_epoch):
on_train_epoch_start()
for batch in trainning_dataloader:
on_train_batch_start()
trainning_step()
on_train_batch_end()
on_train_epoch_end()
on_validation_start()
for batch in validation_dataloader:
on_validation_batch_start()
validation_step()
on_validation_batch_end
on_validation_end()
"""
class CustomNet(pl.LightningModule):
def __init__(self, options: dict):
super().__init__()
self.save_hyperparameters()
self._options = VqVaePartttenNetHyperParameters.parse_obj(options)
self._lr = self._options.learning_rate
...
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self._lr)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[30, 60], gamma=0.1
),
"frequency": 1,
"interval": "epoch",
"strict": True,
"name": "learning_rate",
},
}
def forward(
self,
seq_daliy_industry_lv1_mean: torch.Tensor,
seq_intraday: Optional[torch.Tensor] = None,
) -> torch.Tensor:
...
return nearest_neighbor, preds
def training_step(self, batch, batch_idx):
...
return loss
def validation_step(self, batch, batch_idx):
...
def on_train_epoch_start(self) -> None:
...
def on_train_epoch_end(self) -> None:
...
def on_validation_epoch_start(self) -> None:
...
def on_validation_epoch_end(self) -> None:
...
在正式的训练之前(指 training_step
validation_step
) 运行前,会进行检查,self.trainer.sanity_checking
判断是否处于此状态,以对一些 hooks 进行屏蔽,以下为示例:
def on_train_epoch_end(self) -> None:
result = self._signal_analysis_cache_train.compute()
if self.trainer.sanity_checking:
entropy = torch.tensor(0.0)
else:
counts = torch.bincount(self.zq_indices_train, minlength=32)
probs = counts.float() / self.zq_indices_train.shape[0]
entropy = -(probs * torch.log2(probs + 1e-9)).sum()
给模型增加一个缓存区,可以存带梯度的参数,用于更新,也可以做一个缓存,不作为模型参数, persistent
控制了是否存为模型参数
self.register_buffer(
"zq_indices_train",
torch.full((train_dataset_length,), -1),
persistent=False,
)