-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
36 lines (33 loc) · 1.07 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""
Config 클래스나 랜덤 시드를 고정하는 코드.
환경변수 값들도 포함합니다.
"""
import random
import numpy as np
import torch
GOOGLE_APPLICATION_CREDENTIAL = './credential.json' # GCP 서비스 접근을 위한 파일
MLFLOW_TRACKING_URI = 'http://localhost:80' # GCP 서버 주소
class Config:
def __init__(
self,
dropout1 : float,
dropout2 : float,
label_smoothing : float,
epochs: int,
embedding_dim: int,
hidden_size: int
) -> None:
self.dropout1 = dropout1
self.dropout2 = dropout2
self.label_smoothing = label_smoothing
self.epochs = epochs
self.embedding_dim = embedding_dim
self.hidden_size = hidden_size
def set_seed(random_seed) -> None:
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)