From 92725c52031fc5b19cc6923dc22ec98a0ffe27ae Mon Sep 17 00:00:00 2001 From: Mcc Lee Date: Mon, 25 Mar 2024 16:34:53 -0400 Subject: [PATCH] Fix the dataconfig changes after 194a712 --- zoology/experiments/examples/basic.py | 18 +++++++----------- zoology/experiments/examples/basic_sweep.py | 20 +++++++------------- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/zoology/experiments/examples/basic.py b/zoology/experiments/examples/basic.py index a22b6a0..9e331bd 100644 --- a/zoology/experiments/examples/basic.py +++ b/zoology/experiments/examples/basic.py @@ -1,19 +1,15 @@ -from zoology.config import TrainConfig, ModelConfig, DataConfig, FunctionConfig, ModuleConfig - +from zoology.config import TrainConfig, ModelConfig, DataConfig, ModuleConfig +from zoology.data.associative_recall import MQARConfig +factory_kwargs = { + "num_kv_pairs": 4, + } config = TrainConfig( data=DataConfig( # cache_dir="/path/to/cache/dir" TODO: add this - vocab_size=256, - input_seq_len=64, - num_train_examples=10_000, - num_test_examples=1_000, - builder=FunctionConfig( - name="zoology.data.associative_recall.multiquery_ar", - kwargs={"num_kv_pairs": 4} - ), - + train_configs=[MQARConfig(num_examples=10_000, vocab_size=256, input_seq_len=64, **factory_kwargs)], + test_configs=[MQARConfig(num_examples=1_000, vocab_size=256, input_seq_len=64, **factory_kwargs)], ), model=ModelConfig( vocab_size=256, diff --git a/zoology/experiments/examples/basic_sweep.py b/zoology/experiments/examples/basic_sweep.py index cf2d05a..ce55d1e 100644 --- a/zoology/experiments/examples/basic_sweep.py +++ b/zoology/experiments/examples/basic_sweep.py @@ -1,26 +1,20 @@ import numpy as np -from zoology.config import TrainConfig -from zoology.config import TrainConfig, ModelConfig, DataConfig, FunctionConfig, ModuleConfig - -from zoology.config import TrainConfig, ModelConfig, DataConfig, FunctionConfig, ModuleConfig +from zoology.config import TrainConfig, ModelConfig, DataConfig, ModuleConfig +from zoology.data.associative_recall import MQARConfig configs = [] +factory_kwargs = { + "num_kv_pairs": 4, + } for lr in np.logspace(-4, -2, 10): config = TrainConfig( data=DataConfig( # cache_dir="/path/to/cache/dir" TODO: add a directory where data will be cached - vocab_size=256, - input_seq_len=64, - num_train_examples=10_000, - num_test_examples=1_000, - builder=FunctionConfig( - name="zoology.data.associative_recall.multiquery_ar", - kwargs={"num_kv_pairs": 4} - ), - + train_configs=[MQARConfig(num_examples=10_000, vocab_size=256, input_seq_len=64, **factory_kwargs)], + test_configs=[MQARConfig(num_examples=1_000, vocab_size=256, input_seq_len=64, **factory_kwargs)], ), model=ModelConfig( vocab_size=256,