Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Datasize Tracking, persistent training loss json and share model … #4

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 88 additions & 41 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
from datetime import datetime
from pathlib import Path

import json
import torch
import torch.nn as nn
import torch.optim as optim
Expand Down Expand Up @@ -69,7 +69,7 @@ def init_shared_dirs(client: Client, proj_folder: Path) -> None:

# Give public write permission to the round_weights and agg_weights folder
add_public_write_permission(client, agg_weights_folder)

# Create a state folder to track progress of the project
# and give public read permission to the state folder for the aggregator
create_project_state(client, proj_folder)
Expand All @@ -86,60 +86,68 @@ def load_model_class(model_path: Path) -> type:
return model_class


def train_model(proj_folder: Path, round_num: int, dataset_path_files: Path) -> None:
"""
Trains the model for the given round number
"""

def train_model(proj_folder: Path, round_num: int, dataset_path_files: list[Path]) -> None:
round_weights_folder = proj_folder / "round_weights"
agg_weights_folder = proj_folder / "agg_weights"

fl_config_path = proj_folder / "fl_config.json"
fl_config = read_json(fl_config_path)

# Load the Model from the model_arch filename
# Load model and aggregator weights
model_class = load_model_class(proj_folder / fl_config["model_arch"])
model: nn.Module = model_class()

# Load the aggregated weights from the previous round
agg_weights_file = agg_weights_folder / f"agg_model_round_{round_num - 1}.pt"
model.load_state_dict(torch.load(agg_weights_file, weights_only=True))

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=fl_config["learning_rate"])

# Load datasets
all_datasets = []
for dataset_path_file in dataset_path_files:
# load the saved mnist subset
images, labels = torch.load(str(dataset_path_file), weights_only=True)

# create a tensordataset
dataset = TensorDataset(images, labels)

all_datasets.append(dataset)

combined_dataset = ConcatDataset(all_datasets)

# create a dataloader for the dataset
# Save dataset size to JSON
dataset_size = len(combined_dataset)
dataset_size_file = proj_folder / "dataset_size.json"
with open(dataset_size_file, "w") as f:
json.dump({"dataset_size": dataset_size}, f, indent=4)

# Create DataLoader
train_loader = DataLoader(combined_dataset, batch_size=64, shuffle=True)

# Open log file for writing
# Logging
logs_folder_path = proj_folder / "logs"
logs_folder_path.mkdir(parents=True, exist_ok=True)

# 1) Text log file
output_logs_path = logs_folder_path / f"training_logs_round_{round_num}.txt"
log_file = open(str(output_logs_path), "w")

# Log training start
# 2) JSON loss log file
training_loss_file = logs_folder_path / f"training_loss_round_{round_num}.json"
if training_loss_file.exists():
with open(training_loss_file, "r") as f:
training_loss_data = json.load(f)
else:
training_loss_data = []

start_msg = f"[{datetime.now().isoformat()}] Starting training...\n"
log_file.write(start_msg)
log_file.flush()

update_project_state(
proj_folder,
ProjectStateCols.MODEL_TRAIN_PROGRESS,
f"Training Started for Round {round_num}",
)

# training loop
# Training loop
for epoch in range(fl_config["epoch"]):
running_loss = 0
for images, labels in train_loader:
Expand All @@ -148,36 +156,48 @@ def train_model(proj_folder: Path, round_num: int, dataset_path_files: Path) ->
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# accumulate loss
running_loss += loss.item()

# Calculate average loss for the epoch
avg_loss = running_loss / len(train_loader)
log_msg = f"[{datetime.now().isoformat()}] Epoch {epoch + 1:04d}: Loss = {avg_loss:.6f}\n"

# Write to text file
log_msg = f"[{datetime.now().isoformat()}] Epoch {epoch + 1:02d}: Loss = {avg_loss:.6f}\n"
log_file.write(log_msg)
log_file.flush() # Force write to disk
log_file.flush()

# Also track loss in JSON
training_loss_data.append({
"epoch": epoch + 1,
"loss": avg_loss
})

update_project_state(
proj_folder,
ProjectStateCols.MODEL_TRAIN_PROGRESS,
f"Training InProgress for Round {round_num} (Curr Epoc: {epoch}/{fl_config['epoch']})",
f"Training InProgress for Round {round_num} (Curr Epoch: {epoch + 1}/{fl_config['epoch']})",
)

# Serialize the model
output_model_path = round_weights_folder / f"trained_model_round_{round_num}.pt"
torch.save(model.state_dict(), str(output_model_path))

# Log completion
# Final log
final_msg = f"[{datetime.now().isoformat()}] Training completed. Final loss: {avg_loss:.6f}\n"
log_file.write(final_msg)
log_file.flush()
log_file.close()

update_project_state(
proj_folder,
ProjectStateCols.MODEL_TRAIN_PROGRESS,
f"Training Completed for Round {round_num}",
)

# Save the updated JSON loss file
with open(training_loss_file, "w") as f:
json.dump(training_loss_data, f, indent=4)




def shift_project_to_done_folder(
client: Client, proj_folder: Path, total_rounds: int
Expand Down Expand Up @@ -237,48 +257,54 @@ def perform_model_training(
proj_folder: Path,
dataset_files: list[Path],
) -> None:
"""
Step 2: Has the aggregate sent the weights for the current round x (in the agg_weights folder)
b. The client trains the model on the given round and places the trained model in the round_weights folder
c. It sends the trained model to the aggregator.
d. repeat a until all round completes
"""
round_weights_folder = proj_folder / "round_weights"
agg_weights_folder = proj_folder / "agg_weights"

fl_config_path = proj_folder / "fl_config.json"
fl_config = read_json(fl_config_path)

total_rounds = fl_config["rounds"]
current_round = len(list(round_weights_folder.iterdir())) + 1

# Exit if the project has completed all the rounds.
# Check if project completed...
if has_project_completed(client, proj_folder, total_rounds):
return

# Check if the aggregate has sent the weights for the previous round
# We always use the previous round weights to train the model
# from the agg_weights folder to train for the current round
# Check aggregator weights...
agg_weights_file = agg_weights_folder / f"agg_model_round_{current_round - 1}.pt"
if not agg_weights_file.is_file():
raise StateNotReady(
f"Aggregator has not sent the weights for the round {current_round}"
)

# Train the model for the given FL round
# Train the model
train_model(proj_folder, current_round, dataset_files)

# Share the trained model to the aggregator
trained_model_file = (
round_weights_folder / f"trained_model_round_{current_round}.pt"
)
# Share the trained model
trained_model_file = round_weights_folder / f"trained_model_round_{current_round}.pt"
share_model_to_aggregator(
client,
fl_config["aggregator"],
proj_folder,
trained_model_file,
)

# Share dataset size info
share_dataset_info_to_aggregator(
client,
fl_config["aggregator"],
proj_folder,
)
# Copy the training_loss JSON to the public folder
training_loss_file = proj_folder / "logs" / f"training_loss_round_{current_round}.json"
if training_loss_file.exists():
print("Yes it does exists")
# Public folder for this client under the project name
public_folder = client.my_datasite / "public" / "fl" / proj_folder.name
public_folder.mkdir(parents=True, exist_ok=True)

# Copy the file to public so it’s always accessible
shutil.copy(training_loss_file, public_folder)


def share_model_to_aggregator(
client: Client,
Expand All @@ -298,6 +324,27 @@ def share_model_to_aggregator(
# Copy the trained model to the aggregator's client folder
shutil.copy(model_file, fl_aggregator_client_path)

def share_dataset_info_to_aggregator(
client: Client,
aggregator_email: str,
proj_folder: Path,
) -> None:
"""Shares the dataset size info to the aggregator."""
fl_aggregator_app_path = (
client.datasites / f"{aggregator_email}/api_data/fl_aggregator"
)
fl_aggregator_running_folder = fl_aggregator_app_path / "running" / proj_folder.name
fl_aggregator_client_path = (
fl_aggregator_running_folder / "fl_clients" / client.email
)

dataset_size_file = proj_folder / "dataset_size.json"
if dataset_size_file.exists():
shutil.copy(dataset_size_file, fl_aggregator_client_path)
else:
raise ValueError("dataset_size.json not found on the client side.")



def _advance_fl_project(client: Client, proj_folder: Path) -> None:
"""
Expand Down