Skip to content

Commit

Permalink
add stuff, dunno what this was
Browse files Browse the repository at this point in the history
  • Loading branch information
nelhage committed Dec 17, 2023
1 parent 876541f commit cf116da
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ wandb
tqdm
shed
zstandard==0.19.0
plotly
pandas
jupyterlab==3.5.*
jupyter==1.0.*
1 change: 1 addition & 0 deletions python/scripts/train_4x4.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from tak.alphazero import cli, hooks, trainer, schedule
import os.path
import yaml
import shlex


Expand Down
2 changes: 2 additions & 0 deletions python/tak/alphazero/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def dedup_batch(batch):
out[k][idx] += batch[k][i]

for k in keys:
if not out[k].dtype.is_floating_point:
out[k] = out[k].float()
out[k] /= counts.reshape((-1,) + (1,) * (len(out[k].shape) - 1))
return {k: v[:next] for (k, v) in out.items()}

Expand Down
2 changes: 2 additions & 0 deletions python/tak/self_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,13 @@ def encode_games(logs: list[Transcript]):
all_values = [v for tr in logs for v in tr.values]
all_move_probs = torch.cat([tr.logits for tr in logs])
all_results = [r for tr in logs for r in tr.results]
all_plies = [p.ply for tr in logs for p in tr.positions]
encoded, mask = encoding.encode_batch(all_positions)
return dict(
positions=encoded,
mask=mask,
moves=all_move_probs,
values=torch.tensor(all_values),
results=torch.tensor(all_results, dtype=torch.float32),
plies=torch.tensor(all_plies, dtype=torch.int),
)

0 comments on commit cf116da

Please sign in to comment.