-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiff_args.py
49 lines (45 loc) · 1.03 KB
/
diff_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class DiffusionTrainingArguments:
"""
For Training
"""
num_diffusion_steps: Optional[int] = field(
default=1000,
)
num_steps_for_loss: Optional[int] = field(
default=2,
)
mask_prob_end: Optional[float] = field(
default=0.01,
)
other_prob_end: Optional[float] = field(
default=0.1,
)
other_prob_scheduler_type: Optional[str] = field(
default="linear",
)
loss_type: Optional[str] = field(
default="ctc",
)
exclude_blank_from_masking: Optional[bool] = field(
default=True,
)
task_name: Optional[str] = field(
default="Squad",
)
freq_noise_drawing: Optional[str] = field(
default="every",
)
@dataclass
class DiffusionInferenceArguments:
"""
For inference
"""
finetuned_model_path: Optional[str] = field(
default=None,
)
num_inference_steps: Optional[int] = field(
default=20,
)