-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
78 lines (65 loc) · 3.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import pandas as pd
import argparse
from src.models.model_utils import ModelConfig, setup_model_and_tokenizer
from src.training.trainer import LlamaTrainer
from src.data.data_preprocessing import create_dataloaders
from src.inference.ensemble_inference import EnsembleInference
from src.inference.uncertainty_metrics import calculate_uncertainty_metrics
import os
def train(config_path, seed):
print(f"Training with seed: {seed}")
config = ModelConfig.from_yaml(config_path)
print("INFO: Config loaded")
model, tokenizer = setup_model_and_tokenizer(config, seed)
print("INFO: Model and tokenizer loaded")
train_loader, val_loader = create_dataloaders(tokenizer, config.batch_size, config.max_length, config.num_workers)
print("INFO: Training data loaded")
trainer = LlamaTrainer(model, train_loader, val_loader, config, seed)
print("INFO: Trainer loaded")
trainer.train()
print("INFO: Training finished")
def inference(config_path):
print("Running inference")
config = ModelConfig.from_yaml(config_path)
print("INFO: Config loaded")
# Check if all models are trained
for seed in config.seeds:
model_path = os.path.join(config.output_dir, f"seed_{seed}")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model for seed {seed} not found. Please train the model before running inference.")
# Create tokenizer
_, tokenizer = setup_model_and_tokenizer(config)
# Create data loaders
_, val_loader = create_dataloaders(tokenizer, config.batch_size, config.max_length, config.num_workers)
print("INFO: Validation data loaded")
# Load models based on seeds specified in the configuration
ensemble = EnsembleInference(config)
print("INFO: Ensemble with models loaded based on seeds: ", config.seeds)
# Get ensemble predictions
results = ensemble.get_ensemble_predictions(val_loader)
print("INFO: Ensemble predictions finished")
# Calculate uncertainty metrics
metrics = calculate_uncertainty_metrics(results)
print("INFO: Uncertainty metrics finished")
metrics_df = pd.DataFrame([metrics])
print("INFO: Uncertainty metrics Datafrfame Created")
metrics_df.to_csv(os.path.join(config.output_dir, "uncertainty_metrics.csv"), index=False)
print("INFO: Uncertainty metrics saved to file")
print(metrics)
# Save results
results_df = pd.DataFrame([results])
results_df.to_csv(os.path.join(config.output_dir, "ensemble_results.csv"), index=False)
print("INFO: Ensemble results saved to file")
print(results)
print("INFO: Inference finished")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("mode", choices=["train", "inference"], help="Mode: train or inference")
parser.add_argument("--config", type=str, required=True, help="Path to config file")
parser.add_argument("--seed", type=int, help="Seed for training")
print("INFO: Arguments parsed")
args = parser.parse_args()
if args.mode == "train":
train(args.config, args.seed)
elif args.mode == "inference":
inference(args.config)