-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_lora_unstructured.py
41 lines (33 loc) · 1.16 KB
/
train_lora_unstructured.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Example script to train a LoRA model on unstructured data."""
import logging
import time
from pathlib import Path
from gerd.config import PROJECT_DIR
from gerd.training.unstructured import train_lora
_LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.WARNING)
logging.getLogger("gerd").setLevel(logging.DEBUG)
if Path().cwd() != PROJECT_DIR:
msg = "This example must be run from the project root."
raise AssertionError(msg)
trainer = train_lora("lora_unstructured_example")
try:
while trainer.thread.is_alive():
time.sleep(0.5)
except KeyboardInterrupt:
trainer.interrupt()
_LOGGER.info(
"Interrupting, please wait... "
"*(Run will stop after the current training step completes.)*"
)
trainer.thread.join()
if not trainer.tracked.did_save:
trainer.save()
if trainer.tracked.interrupted:
_LOGGER.info("Interrupted. Incomplete LoRA saved to %s.", trainer.config.output_dir)
else:
_LOGGER.info(
"Done! LoRA saved to %s.\n\nBefore testing your new LoRA, "
"make sure to first reload the model, as it is currently dirty from training.",
trainer.config.output_dir,
)