Skip to content

Commit

Permalink
Merge pull request #32 from tplr-ai/feat/vali_fix
Browse files Browse the repository at this point in the history
Feat/vali fix
  • Loading branch information
distributedstatemachine authored Jan 20, 2025
2 parents 0c53fce + 036ea74 commit cc99d2a
Show file tree
Hide file tree
Showing 7 changed files with 498 additions and 351 deletions.
6 changes: 3 additions & 3 deletions hparams.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"spec_version": 5,
"project": "dough",
"sequence_length": 2048,
"pages_per_window": 5,
"pages_per_window": 10,
"batch_size": 6,
"learning_rate": 4e-4,
"blocks_per_window": 4,
"blocks_per_window": 6,
"windows_per_sync": 100,
"windows_per_weights": 100,
"momentum_decay": 0.999,
Expand All @@ -24,7 +24,7 @@
"warmup_steps": 250,
"alpha_f": 0.1,
"t_max": 20000,
"validator_offset": 2,
"validator_offset": 4,
"checkpoint_frequency": 50,
"topk_peers": 20,
"minimum_peers": 5,
Expand Down
118 changes: 46 additions & 72 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,50 +168,6 @@ def __init__(self):

# Main training loop.
async def run(self):
# Try to load latest checkpoint
result = await self.comms.get_latest_checkpoint()
if result:
checkpoint_data, window = result
try:
# Load state dicts from checkpoint data
self.model.load_state_dict({k: v.to(self.config.device) for k,v in checkpoint_data['model_state_dict'].items()})
self.model.to(self.config.device)

# Load optimizer state
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(self.config.device)
self.optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])

# Load scheduler state
self.scheduler.load_state_dict(checkpoint_data['scheduler_state_dict'])

# Load momentum and global_step
self.momentum = checkpoint_data['momentum']
self.global_step = checkpoint_data['global_step']

# Adjust scheduler to catch up with current window
checkpoint_window = checkpoint_data.get('checkpoint_window', None)
if checkpoint_window is not None:
window_difference = self.current_window - checkpoint_window
if window_difference > 0:
for _ in range(window_difference):
self.scheduler.step()
tplr.logger.info(f"Stepped scheduler {window_difference} times to catch up with current window {self.current_window}")
else:
tplr.logger.warning("Checkpoint does not contain 'checkpoint_window'; cannot adjust scheduler")

tplr.logger.info(f"Loaded checkpoint from window {window}, global_step={self.global_step}")
except KeyError as e:
tplr.logger.error(f"Invalid checkpoint format: missing key {e}")
except Exception as e:
tplr.logger.error(f"Failed to load checkpoint: {e}")
else:
tplr.logger.info("No valid checkpoints found, starting from scratch")
self.global_step = 0
self.model.to(self.config.device)

# Load Peers
if not self.config.peers:
self.peers = self.comms.peers
Expand All @@ -221,6 +177,31 @@ async def run(self):
if self.uid not in self.peers:
self.peers.append(self.uid)

self.comms.commitments = self.comms.get_commitments_sync()
self.comms.update_peers_with_buckets()
tplr.logger.info(f"Loaded commitments: {self.comms.commitments.keys()}")

success, loaded_momentum, loaded_global_step = await self.comms.load_checkpoint(
model=self.model,
optimizer=self.optimizer,
scheduler=self.scheduler,
transformer=self.transformer,
compressor=self.compressor,
current_window=self.current_window,
device=self.config.device,
peers=self.peers,
uid=self.uid
)
if success:
self.momentum = loaded_momentum
self.global_step = loaded_global_step
tplr.logger.info(f"Loaded checkpoint with global_step={self.global_step}")
else:
tplr.logger.info("Starting from scratch")
self.global_step = 0
self.momentum = {n: torch.zeros_like(p) for n, p in self.model.named_parameters()}
self.model.to(self.config.device)

# Start background block listener
self.loop = asyncio.get_running_loop()
self.listener = threading.Thread(
Expand All @@ -233,13 +214,13 @@ async def run(self):
self.comms.start_background_tasks()

while True:
# 1. Initialize window and update peers
step_window = self.current_window
tplr.logger.info(f"\n{'-' * 40} Window: {step_window} {'-' * 40}")
# self.comms.update_peers_with_buckets()
# Update local references
self.comms.update_peers_with_buckets()
self.peers = self.comms.peers

# Get the pages for this window.
# 2. Load training data for this window
pages = await tplr.dataset.DatasetLoader.next_pages(
offset = step_window,
n_pages = self.hparams.pages_per_window,
Expand All @@ -253,7 +234,7 @@ async def run(self):
)
tplr.logger.info(f"Pages: {[p[1] for p in pages]} for Window: {step_window}")

# Accumulate gradient
# 3. Accumulate gradients over batches
start_time = time.time()
tplr.logger.info("Start accumulating...")
self.optimizer.zero_grad()
Expand All @@ -272,26 +253,27 @@ async def run(self):
total_loss += outputs.loss.item()
outputs.loss.backward()

# Track tokens
batch_tokens += (labels != -100).sum().item()

# TODO: INCREASE LENGHT OF THE WINDOW
tplr.logger.info(f'loss: {outputs.loss.item()}')
if self.current_window != step_window:
tplr.logger.info('<Exhausted window>')
break

# 4. Wait for next window
tplr.logger.info("Wait for next window...")
while self.current_window == step_window:
await asyncio.sleep(0.1)
tplr.logger.info(f"Stopped accumulating: {i+1} batches with {(i+1) * self.hparams.batch_size * self.hparams.sequence_length} tokens")

# Calculate processing metrics
# 5. Calculate and log metrics
duration = time.time() - start_time
self.batch_times.append(duration)
self.total_tokens_processed += batch_tokens

# Log gradient metrics
grad_norms = [p.grad.norm().item() for p in self.model.parameters() if p.grad is not None]
weight_norms = [p.norm().item() for p in self.model.parameters()]
momentum_norms = [m.norm().item() for m in self.momentum.values()]

# Enhanced wandb logging with all metrics
self.wandb.log({
# Training metrics
"miner/loss": total_loss/(i+1),
Expand Down Expand Up @@ -321,7 +303,7 @@ async def run(self):
"miner/mean_momentum_norm": sum(momentum_norms) / len(momentum_norms),
}, step=self.global_step)

# Reduce gradient using DeMo.
# 6. Prepare gradients for sharing using DeMo compression
gradient = {}
xshapes = {}
totalks = {}
Expand Down Expand Up @@ -351,35 +333,35 @@ async def run(self):
xshapes[n] = xshape
totalks[n] = totalk

# Gather gradients from peers
# 7. Gather and process peer gradients
tplr.logger.info(f"Start gather: {self.peers}")
gather_result = await self.comms.gather(
state_dict=gradient,
my_uid=self.uid,
uids=self.peers,
window=step_window,
key='gradient',
timeout=5,
timeout=30,
device=self.config.device,
local=False,
stale_retention=10,
stale_retention=100,
global_step=self.global_step,
)

if gather_result is None:
tplr.logger.error("Failed to gather gradients from peers. Waiting for next window.")
# Wait for next window
while self.current_window == step_window:
await asyncio.sleep(0.1)
continue # Proceed to the next window
continue

# Update self.global_step based on the maximum global_step received
# 8. Update global step based on peer information
max_global_step = max(gather_result.global_steps + [self.global_step])
tplr.logger.info(f"Gather global steps : {gather_result.global_steps}")
if max_global_step > self.global_step:
tplr.logger.info(f"Updating global_step from {self.global_step} to {max_global_step}")
self.global_step = max_global_step

# Decompress state and apply to grad.
# 9. Apply gathered gradients
for n, p in self.model.named_parameters():
idxs_key = n + 'idxs'
vals_key = n + 'vals'
Expand All @@ -406,27 +388,19 @@ async def run(self):
p.grad = new_grad
else:
p.grad.copy_(new_grad)
# Sign-SGD
# Sign-SGD
p.grad.sign_()
else:
tplr.logger.info(f"Gradient data missing for parameter {n}, skipping.")



# Apply optimizer step
# 10. Optimization step
tplr.logger.info("Finish and step.")
self.optimizer.step()
self.scheduler.step()
self.global_step += 1
self.window_step += 1
tplr.logger.info(f"Total optimization steps: {self.global_step}")

# Wait for next window
tplr.logger.info("Wait for next window...")
while self.current_window == step_window:
await asyncio.sleep(0.1)
self.window_step = 0

# Listens for new blocks and sets self.current_block and self.current_window
def block_listener(self, loop):
def handler(event, _u, _s):
Expand Down
Loading

0 comments on commit cc99d2a

Please sign in to comment.