diff --git a/benchmarks/bleu.py b/benchmarks/bleu.py new file mode 100644 index 0000000000..fe4f207e2e --- /dev/null +++ b/benchmarks/bleu.py @@ -0,0 +1,56 @@ +""" +Implements the evaluation metrics based on BLEU score + +example: + import sacrebleu + + translated_sentences = ['The dog had bit the man.', "It wasn't surprising.", 'The man had bitten the dog.'] + target_sentences = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.'] + bleu_score = sacrebleu.corpus_bleu(translated_sentences, [target_sentences]).score + print(f'Test BLEU: {bleu_score}') + +""" + +import numpy as np +from typing import List + +import sacrebleu + +def corpus_bleu(sys_sents: List[str], + refs_sents: List[List[str]], + smooth_method: str = 'exp', + smooth_value: float = None, + force: bool = True, + lowercase: bool = False, + tokenizer: str = '13a', + use_effective_order: bool = False): + + return sacrebleu.corpus_bleu(sys_sents, refs_sents, smooth_method, smooth_value, force, + lowercase=False, tokenize='none', use_effective_order=use_effective_order).score + + +def sentence_bleu(sys_sent: str, + ref_sents: List[str], + smooth_method: str = 'floor', + smooth_value: float = None, + lowercase: bool = False, + tokenizer: str = '13a', + use_effective_order: bool = True): + + return corpus_bleu([sys_sent], [[ref] for ref in ref_sents], smooth_method, smooth_value, force=True, + lowercase=lowercase, tokenizer=tokenizer, use_effective_order=use_effective_order) + + +def corpus_averaged_sentence_bleu(sys_sents: List[str], + refs_sents: List[List[str]], + smooth_method: str = 'floor', + smooth_value: float = None, + lowercase: bool = False, + tokenizer: str = '13a', + use_effective_order: bool = True): + + scores = [] + for sys_sent, *ref_sents in zip(sys_sents, *refs_sents): + scores.append(sentence_bleu(sys_sent, ref_sents, smooth_method, smooth_value, + lowercase=lowercase, tokenizer=tokenizer, use_effective_order=use_effective_order)) + return np.mean(scores) diff --git a/data/commonvoice_ko/get_dataset.sh b/data/commonvoice_ko/get_dataset.sh index 6e02eb8709..0ace984ce5 100644 --- a/data/commonvoice_ko/get_dataset.sh +++ b/data/commonvoice_ko/get_dataset.sh @@ -1,7 +1,7 @@ # !/bin/bash -# Set strict error handling -set -euo pipefail +# Show lines before execution and exit on errors +set -xe # Install python dependencies for Hugging face pip install -U "huggingface_hub[cli]" @@ -10,7 +10,13 @@ pip install -U "huggingface_hub[cli]" # Replace with your hugging face tokens ##### You can find and create your own tokens here: https://huggingface.co/settings/tokens ###### ##### "Token Type" of "Read" is recommended. ######## -HF_TOKEN="" +if [[ -f ~/.cache/huggingface/token && -s ~/.cache/huggingface/token ]]; then + export HF_TOKEN=$(cat ~/.cache/huggingface/token) +else + echo "Consider running 'python3 ./utils/save_hf_token.py' to automate finding HF_TOKEN" + read -s -p "To continue, please enter your Hugging Face token: " HF_TOKEN + echo "" # Add a newline for better readability +fi # Authenticate with hugging face echo "Authenticating with Hugging Face..." @@ -28,12 +34,12 @@ fi # Download transcription files under "transcription" directory. pushd "${out_dir}" -wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "dev.tsv" "${url}/resolve/main/transcript/ko/dev.tsv?download=true" -wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "invalidated.tsv" "${url}/resolve/main/transcript/ko/validated.tsv?download=true" -wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "other.tsv" "${url}/resolve/main/transcript/ko/other.tsv?download=true" -wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "test.tsv" "${url}/resolve/main/transcript/ko/test.tsv?download=true" -wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "train.tsv" "${url}/resolve/main/transcript/ko/train.tsv?download=true" -wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "validated.tsv" "${url}/resolve/main/transcript/ko/validated.tsv?download=true" +wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "dev.tsv" "${url}/resolve/main/transcript/ko/dev.tsv?download=true" || true +wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "invalidated.tsv" "${url}/resolve/main/transcript/ko/validated.tsv?download=true" || true +wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "other.tsv" "${url}/resolve/main/transcript/ko/other.tsv?download=true" || true +wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "test.tsv" "${url}/resolve/main/transcript/ko/test.tsv?download=true" || true +wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "train.tsv" "${url}/resolve/main/transcript/ko/train.tsv?download=true" || true +wget --header="Authorization: Bearer ${HF_TOKEN}" -nc -O "validated.tsv" "${url}/resolve/main/transcript/ko/validated.tsv?download=true" || true echo "transcripts downloaded and saved to transcription." popd @@ -54,11 +60,11 @@ echo "All .tsv files have been processed." # Run program to convert sentences into IPA format. echo "Converting sentences to IPA..." -python3 ./utils/ko_en_to_ipa.py "$output_file" --input_json_key "sentence" --output_json_key "phonetic" +python3 ./utils/ko_en_to_ipa.py "$output_file" --input_json_key "sentence" --output_json_key "sentence_ipa" output_ipa="ko_ipa.txt" echo "export IPA to txt file" -python3 ./utils/extract_json_values.py "$output_file" "phonetic" "$output_ipa" +python3 ./utils/extract_json_values.py "$output_file" "sentence_ipa" "$output_ipa" echo "IPA conversion finished." diff --git a/data/template/parallel_embeddings/part_of_speech_zh.py b/data/template/parallel_embeddings/part_of_speech_zh.py new file mode 100644 index 0000000000..cda372051d --- /dev/null +++ b/data/template/parallel_embeddings/part_of_speech_zh.py @@ -0,0 +1,8 @@ +import jieba.posseg as pseg + +text = "他今天在北京大学的图书馆里看书,学习非常认真。这本书很有意思,内容包括历史、哲学和科学。" + +words = pseg.cut(text) + +for word, flag in words: + print(f"{word}: {flag}") diff --git a/data/template/phoneme_list.txt b/data/template/phoneme_list.txt index 107a397c1c..ab48b46961 100644 --- a/data/template/phoneme_list.txt +++ b/data/template/phoneme_list.txt @@ -1,87 +1,98 @@ -i: -I -iI -eI + +\n +\t +. +[ +] +_ a -A: -Q -0 -' -O: -U -u: -V -@ -eI -aI -OI -aU -oU -p +ä +æ b -t +c +ç d -k -g +e f -v -T -D -s -z -S -Z +g h +i +j +k +l m n -N -l -r -w -j -iu -i -e o -u -W -A -y -E -ME -O -oo -ou -ye - -\n -\r -: -, -F -C -Y -? -. -B -c -R -M -L -c -; -! -H -P +ø +p q - -G -- +r +s +t +u +v +w x -$ -& -3 -J -K -X -_ +y +z +ð +ħ +ŋ +œ +ɐ +ɑ +ɔ +ɕ +ɘ +ə +ɛ +ɡ +ɣ +ɤ +ɥ +ɦ +ɨ +ɪ +ɫ +ɯ +ɴ +ɵ +ɸ +ɻ +ɽ +ɾ +ʁ +ʂ +ʃ +ʈ +ʉ +ʊ +ʌ +ʏ +ʐ +ʑ +ʔ +ʕ +ʰ +ʲ +ʼ +ˈ +ˌ +ː +ˑ +ˤ +˥ +˦ +˧ +˨ +˩ +̂ +̃ +̆ +̌ +̚ +̥ +̬ +β +θ +χ diff --git a/data/template/tests.py b/data/template/tests.py index 5ecc970672..f3c40d9535 100644 --- a/data/template/tests.py +++ b/data/template/tests.py @@ -15,6 +15,7 @@ from rich.console import Console from rich.theme import Theme from rich.table import Table +from rich.text import Text console = Console(theme=Theme({ "pass": "bold green", @@ -23,7 +24,7 @@ "separator": "grey50", "input": "bold cyan", "output": "bold magenta", - "info": "bold blue" + "info": "bold blue", })) @@ -188,10 +189,10 @@ def test_custom_char_tokenizer_with_byte_fallback(self): args = Namespace(custom_chars_file="custom_chars.txt") # Create a custom characters file for testing with open(args.custom_chars_file, 'w', encoding='utf-8') as f: - f.write('a\nb\nc\n') + f.write('a\nb\nc\n\\n') tokenizer = CustomCharTokenizerWithByteFallback(args) - test_string = 'abc😊' + test_string = 'abc😊d\nefg' ids = tokenizer.tokenize(test_string) detokenized = tokenizer.detokenize(ids) @@ -201,6 +202,15 @@ def test_custom_char_tokenizer_with_byte_fallback(self): console.print("[output]Detokenized Output:[/output]") console.print(detokenized, style="output") + console.print("[info]Characters that used byte fallback:[/info]") + bft = [] # Byte Fallback Tokens + for char in detokenized: + if char not in tokenizer.custom_chars: + char = repr(char) + bft.append(char) + + console.print(", ".join(bft), style="info") + self.assertEqual(test_string, detokenized) print("CustomCharTokenizerWithByteFallback test passed.") diff --git a/data/template/utils/meta_util.py b/data/template/utils/meta_util.py index aea0cc1e0b..61f393fa5a 100644 --- a/data/template/utils/meta_util.py +++ b/data/template/utils/meta_util.py @@ -58,7 +58,21 @@ def create_meta_from_text(text_file, output_path, special_chars={"": 0}): pickle.dump(meta, f) print(f"Meta created from text and saved to {output_path}.") - +def export_tokens(meta_path, output_path): + meta = load_meta(meta_path) + with open(output_path, "w") as f: + for i in range(meta["vocab_size"]): + token = meta["itos"][i] + if token == "\n": + token = "\\n" + elif token == "\t": + token = "\\t" + elif token == "\r": + token = "\\r" + # Note: Add more special character handling here as needed + f.write(token + "\n") + print(f"Tokens exported to {output_path}") + def main(): parser = argparse.ArgumentParser(description="Utility for handling token metadata.") @@ -73,6 +87,12 @@ def main(): nargs=2, help="Path to the input text file and the output meta.pkl file for creation.", ) + + parser.add_argument( + "--export", + nargs=2, + help="Path to the meta.pkl file and the output text file for exporting tokens.", + ) args = parser.parse_args() @@ -82,7 +102,8 @@ def main(): merge_metas(args.merge[0], args.merge[1], "merged_meta.pkl") elif args.create: create_meta_from_text(args.create[0], args.create[1]) - + elif args.export: + export_tokens(args.export[0], args.export[1]) if __name__ == "__main__": main() diff --git a/explorations/hsnorm.json b/explorations/hsnorm.json new file mode 100644 index 0000000000..83656d74d6 --- /dev/null +++ b/explorations/hsnorm.json @@ -0,0 +1,35 @@ + +[ + { + "parameter_groups": [ + { + "norm_variant_attn" : ["rmsnorm"], + "norm_variant_output" : ["rmsnorm"], + "tensorboard_log_name": ["regular_rmsnorm"] + }, + { + "norm_variant_attn" : ["hyperspherenorm"], + "norm_variant_output" : ["hyperspherenorm"], + "hsnorm_radius": ["5", "10", "15", "20", "25"], + "hsnorm_radius_learning": [true, false], + "tensorboard_log_name": ["set_radius"] + }, + { + "norm_variant_attn" : ["hyperspherenorm"], + "norm_variant_output" : ["hyperspherenorm"], + "hsnorm_radius_learning": [true, false], + "tensorboard_log_name": ["root_embd_dim_radius"] + } + ], + "max_iters": ["3500"], + "n_layer": ["6"], + "n_kv_group": ["6"], + "n_head": ["6"], + "n_embd": ["384"], + "block_size":["256"], + "device": ["cuda"], + "dtype": ["float16", "bfloat16", "float32"], + "compile": [true] + } +] + diff --git a/gpt_conf.py b/gpt_conf.py index 6c459cad0a..7fbbdd48f9 100644 --- a/gpt_conf.py +++ b/gpt_conf.py @@ -200,6 +200,9 @@ class GPTConfig: krmsnorm_enable_gain: bool = True krmsnorm_selection_type: str = 'last' krmsnorm_recompute_percentage: float = 0.05 + hsnorm_gain: bool = False + hsnorm_radius: float = None + hsnorm_radius_learning: float = None # Activation Alternatives diff --git a/train.py b/train.py index c4e2147932..43ae7256f7 100644 --- a/train.py +++ b/train.py @@ -25,9 +25,12 @@ from rich.progress import Progress -import matplotlib.pyplot as plt +# GNS Related +import utils.gns_monitoring.gns_utils as gns_utils +from utils.gns_monitoring.hook import (add_hooks_to_model, add_sogns_hooks, + add_exact_hooks, gather_hook_results) + import numpy as np -import plotly.graph_objects as go import torch from torch.distributed import destroy_process_group, init_process_group from torch.nn.parallel import DistributedDataParallel as DDP @@ -50,7 +53,13 @@ def __init__(self, args, model_group, training_group, logging_group): self.training_group = training_group self.logging_group = logging_group - # typically make the decay iters equal to max_iters + # GNS and batch schedule + self.gns = None + self.tokens_trained = 0 + + # Learning Rate Settings + self.lr = self.args.learning_rate + ## Make the decay iters equal to max_iters if not specified if self.args.lr_decay_match_max_iters: self.args.lr_decay_iters = self.args.max_iters @@ -181,6 +190,16 @@ def setup(self): self.model.crop_block_size(self.args.block_size) self.model_args['block_size'] = self.args.block_size + # Add gradient monitoring + if self.args.gns_type is not None: + get_gns_fn = {'sogns': add_sogns_hooks, 'exact': add_exact_hooks} + add_hooks_to_model(self.model, get_gns_fn[self.args.gns_type]) + ema_beta = self.args.gns_ema_beta + self.gns_ema = gns_utils.EMA(beta=ema_beta) + + # Initialize GNS for later + self.gns = None + self.model.to(self.device) # Print the model summary @@ -491,6 +510,14 @@ def get_transitioned_probs(): dataset = self.args.dataset data = self.train_data if split == 'train' else self.val_data + # Adaptive GNS settings + if (self.gns is not None) and (self.args.gns_target is not None): + if self.gns < self.args.gns_target: + if self.args.batch_size < self.args.gns_max_batch: + self.args.batch_size = math.ceil(self.args.batch_size * (1.0 + self.args.gns_batch_pct)) + if self.gns > self.args.gns_target: + self.args.batch_size = math.ceil(self.args.batch_size * (1.0 - self.args.gns_batch_pct)) + # Generate random indices for the batch ix = torch.randint(len(data) - self.args.block_size, (self.args.batch_size,)) @@ -570,32 +597,36 @@ def get_lr(self, it): coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return self.args.min_lr + coeff * (self.args.learning_rate - self.args.min_lr) - def log_metrics(self, losses, lr, running_mfu, vram_allocated, iter_num, target_dataset=None): + def log_metrics(self, losses, running_mfu, target_dataset=None): if self.args.tensorboard_log: # Log metrics for each dataset separately if target_dataset: self.writer.add_scalars( "loss", {f"{target_dataset}/train": losses['train'].item(), - f"{target_dataset}/val": losses['val'].item()}, iter_num + f"{target_dataset}/val": losses['val'].item()}, self.iter_num ) else: self.writer.add_scalars( "loss", {"train": losses['train'].item(), "val": - losses['val'].item()}, iter_num + losses['val'].item()}, self.iter_num ) - self.writer.add_scalar("mfu_pct", running_mfu * 100, iter_num) - self.writer.add_scalar("lr", lr, iter_num) - self.writer.add_scalar("vram", vram_allocated, iter_num) + self.writer.add_scalar("mfu_pct", running_mfu * 100, self.iter_num) + self.writer.add_scalar("lr", self.lr, self.iter_num) + self.writer.add_scalar("vram", self.vram_allocated, self.iter_num) + self.writer.add_scalar("batch_size", self.args.batch_size, self.iter_num) + self.writer.add_scalar("tokens_trained", self.tokens_trained, self.iter_num) + if self.args.gns_type is not None: + self.writer.add_scalar("gns", self.gns, self.iter_num) if self.args.wandb_log and self.master_process: import wandb log_data = { - "iter": iter_num, - "lr": lr, + "iter": self.iter_num, + "lr": self.lr, "mfu": running_mfu * 100, - "vram": vram_allocated, + "vram": self.vram_allocated, } if target_dataset: log_data[f"{dataset}/train/loss"] = losses['train'] @@ -607,18 +638,21 @@ def log_metrics(self, losses, lr, running_mfu, vram_allocated, iter_num, target_ wandb.log(log_data) if self.args.csv_log: + # concise training metrics if target_dataset: self.write_to_csv(losses['train'].item(), losses['val'].item(), prefix=f"{target_dataset}_") else: self.write_to_csv(losses['train'].item(), losses['val'].item()) - # Other metrics - self.write_to_csv(iter_num, lr, running_mfu, vram_allocated, prefix="misc_") - - + # bulk metrics + if target_dataset: + self.write_to_csv(target_datset, losses['train'].item(), losses['val'].item(), running_mfu, prefix="bulk_") + else: + self.write_to_csv(self.args.dataset, losses['train'].item(), losses['val'].item(), running_mfu, prefix="bulk_") def write_to_csv(self, *args, prefix=""): + args = list(args) csv_full_dir = self.args.csv_dir if self.args.csv_ckpt_dir: csv_full_dir = f"{self.args.csv_dir}/{self.args.csv_ckpt_dir}" @@ -630,56 +664,62 @@ def write_to_csv(self, *args, prefix=""): with open(csv_path, 'a', newline='') as file: writer = csv.writer(file) # Write arguments as a new row in the CSV + args.insert(0, self.iter_num) + args.append(self.lr) + args.append(self.args.batch_size) + args.append(self.tokens_trained) + if self.args.gns_type is not None: + args.append(self.gns) writer.writerow(args) - def log_gamma_beta(self, gamma, beta, iter_num, layer_num, head_num=None): + def log_gamma_beta(self, gamma, beta, layer_num, head_num=None): if self.args.tensorboard_log: if head_num: self.writer.add_scalars( "gammas", - {"gamma_L" + str(layer_num) + "_H" + head_num: gamma}, - iter_num - ) + {"gamma_L" + str(layer_num) + "_H" + head_num: gamma}, self.iter_num) self.writer.add_scalars( "betas", - {"beta_L" + str(layer_num) + "_H" + head_num: beta}, - iter_num - ) + {"beta_L" + str(layer_num) + "_H" + head_num: beta}, self.iter_num) else: - self.writer.add_scalar( "gamma_L" + str(layer_num), gamma, iter_num) - self.writer.add_scalar( "beta_L" + str(layer_num), beta, iter_num) + self.writer.add_scalar( "gamma_L" + str(layer_num), gamma, self.iter_num) + self.writer.add_scalar( "beta_L" + str(layer_num), beta, self.iter_num) if self.args.wandb_log and self.master_process: import wandb wandb.log({ - "iter": iter_num, + "iter": self.iter_num, "train/loss": losses['train'], "val/loss": losses['val'], - "lr": lr, + "lr": self.lr, "mfu": running_mfu*100, }) - def log_metrics_non_validation(self, loss_training, running_mfu, vram_allocated, iter_num, target_dataset=None): + def log_metrics_non_validation(self, loss_training, running_mfu, target_dataset=None): if self.args.tensorboard_log: if target_dataset: self.writer.add_scalars( - "loss", {f"{target_dataset}/train": loss_training}, iter_num + "loss", {f"{target_dataset}/train": loss_training}, self.iter_num ) else: self.writer.add_scalars( - "loss", { "train": loss_training }, iter_num + "loss", { "train": loss_training }, self.iter_num ) - self.writer.add_scalar("mfu_pct", running_mfu * 100, iter_num) - self.writer.add_scalar("vram", vram_allocated, iter_num) - + self.writer.add_scalar("mfu_pct", running_mfu * 100, self.iter_num) + self.writer.add_scalar("lr", self.lr, self.iter_num) + self.writer.add_scalar("vram", self.vram_allocated, self.iter_num) + self.writer.add_scalar("batch_size", self.args.batch_size, self.iter_num) + self.writer.add_scalar("tokens_trained", self.tokens_trained, self.iter_num) + if self.args.gns_type is not None: + self.writer.add_scalar("gns", self.gns, self.iter_num) if self.args.wandb_log and self.master_process: import wandb wandb.log({ - "iter": iter_num, + "iter": self.iter_num, "train/loss": loss_training, "mfu": running_mfu*100, - "vram": vram_allocated, + "vram": self.vram_allocated, }) def save_checkpoint(self, filename): @@ -710,22 +750,26 @@ def train(self): with progress: task_id = progress.add_task("[green]Training...", total=(self.args.max_iters - self.iter_num)) while True: - lr = self.get_lr(self.iter_num) if self.args.decay_lr else self.args.learning_rate + if self.args.decay_lr: + self.lr = self.get_lr(self.iter_num) for param_group in self.optimizer.param_groups: - param_group['lr'] = lr + param_group['lr'] = self.lr if self.iter_num % self.args.eval_interval == 0 and self.master_process: losses = self.estimate_loss() - vram_allocated = get_gpu_memory_info(info_type='used') if self.args.device != "cpu" else 0 + if self.args.gns_type is not None: + self.gns = self.gns_ema.get_gns() + + self.vram_allocated = get_gpu_memory_info(info_type='used') if self.args.device != "cpu" else 0 if self.args.dataset_list is not None: # Print loss for each dataset if multiple datasets are used for dataset, dataset_losses in losses['datasets'].items(): - print(f"step {self.iter_num}: {dataset} train loss {dataset_losses['train']:.4f}, val loss {dataset_losses['val']:.4f}") - self.log_metrics(dataset_losses, lr, running_mfu, vram_allocated, self.iter_num, target_dataset=dataset) + print(f"step {self.iter_num}: {dataset} train loss {dataset_losses['train']:.4f}, val loss {dataset_losses['val']:.4f}, gns {self.gns:.2f}, batch_size {self.args.batch_size}, lr {self.lr}, tokens_trained {self.tokens_trained:e}") + self.log_metrics(dataset_losses, running_mfu, target_dataset=dataset) else: # Default behavior for a single dataset print(f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") - self.log_metrics(losses, lr, running_mfu, vram_allocated, self.iter_num) + self.log_metrics(losses, running_mfu) if math.isnan(losses["val"]): # If val loss is nan, then exit. @@ -801,6 +845,11 @@ def train(self): self.scaler.scale(loss).backward() + if self.args.gns_type is not None: + approx_gns_results = gather_hook_results(self.model) + self.gns_ema.update(*gns_utils.gnsify(approx_gns_results, self.args.batch_size, ddp=self.ddp)) + + if self.args.grad_clip != 0.0: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip) @@ -813,20 +862,29 @@ def train(self): t1 = time.time() dt = t1 - t0 t0 = t1 + + # Udpate tokens trained + self.tokens_trained += self.args.batch_size * self.args.block_size + if self.iter_num % self.args.log_interval == 0 and self.master_process: lossf = loss.item() * self.args.gradient_accumulation_steps if local_iter_num >= 5: mfu = self.raw_model.estimate_mfu(self.args.batch_size * self.args.gradient_accumulation_steps, dt) running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu - print(f"iter {self.iter_num}: loss {lossf:.4f}, time {dt*1000:.2f} ms, mfu {running_mfu*100:.2f}%") + if self.args.gns_type is not None: + self.gns = self.gns_ema.get_gns() + print(f"iter {self.iter_num}: loss {lossf:.4f}, time {dt*1000:.2f} ms, mfu {running_mfu*100:.2f}%, gns {self.gns:.2f}, batch_size {self.args.batch_size}, lr {self.lr}, tokens_trained {self.tokens_trained:e}") + else: + print(f"iter {self.iter_num}: loss {lossf:.4f}, time {dt*1000:.2f} ms, mfu {running_mfu*100:.2f}%") + if math.isnan(lossf): # If training loss is nan, then exit. with open(self.args.out_dir + "/nan_iter_num.txt", 'w') as file: file.write(str(self.iter_num)) sys.exit("Exiting training loss is NaN") - vram_allocated = get_gpu_memory_info(info_type='used') if self.args.device != "cpu" else 0 - self.log_metrics_non_validation(lossf, running_mfu, vram_allocated, self.iter_num) + self.vram_allocated = get_gpu_memory_info(info_type='used') if self.args.device != "cpu" else 0 + self.log_metrics_non_validation(lossf, running_mfu) if self.args.create_statistics and local_iter_num % self.args.softmax_io_log_interval == 0: create_statistics(self, graph_y_labels) diff --git a/train_args.py b/train_args.py index 592104eaf9..af91634e51 100644 --- a/train_args.py +++ b/train_args.py @@ -64,6 +64,13 @@ def parse_args(): training_group.add_argument('--dataset_sampling_probs_final', default=None, nargs='+', type=float, help="If, set final sampling probabilities for each dataset in dataset_list.") training_group.add_argument('--dataset_sampling_probs_transition_method', default=None, type=str, choices=["linear", "cosine", "exponential"]) + # Add GNS settings + training_group.add_argument('--gns_type', type=str, default=None, choices=['sogns', 'exact'], help='Type of gradient norm scaling to use (default: None)') + training_group.add_argument('--gns_ema_beta', type=float, default=0.9, choices=['sogns', 'exact'], help='Type of gradient norm scaling to use (default: None)') + training_group.add_argument('--gns_target', type=float, default=None) + training_group.add_argument('--gns_max_batch', type=int, default=100) + training_group.add_argument('--gns_batch_pct', type=float, default=0.2) + # Model args model_group.add_argument('--block_size', default=256, type=int) @@ -133,16 +140,35 @@ def parse_args(): model_group.add_argument('--shared_attn_sym', default=False, action=argparse.BooleanOptionalAction, help="symmetrical attention sharing") # NORM VARIATIONS - model_group.add_argument("--norm_variant_attn", type=str, default="rmsnorm", choices=["krmsnorm", "prmsnorm", "rmsnorm", "layernorm"]) - model_group.add_argument("--norm_variant_output", type=str, default="rmsnorm", choices=["krmsnorm", "prmsnorm", "rmsnorm", "layernorm"]) + norm_variations = [ + "krmsnorm", + "prmsnorm", + "rmsnorm", + "layernorm", + "hyperspherenorm", + ] + + model_group.add_argument("--norm_variant_attn", type=str, default="rmsnorm", choices=norm_variations) + model_group.add_argument("--norm_variant_output", type=str, default="rmsnorm", choices=norm_variations) + + ## Layernorm model_group.add_argument('--bias', default=False, action=argparse.BooleanOptionalAction, help="only used for layernorm variation option") + + ## PRMSNorm model_group.add_argument("--prmsnorm_pct", default=0.0625, type=float, help="percentage (1 being 100 percent) of first entries used for partial rms" ) + + ## KRMSNorm model_group.add_argument("--krmsnorm_num", default=10, type=int, help="max number of first entries for partial rms" ) model_group.add_argument("--krmsnorm_quantize_type", type=str, default="none", choices=["int8", "int16", "none"]) model_group.add_argument('--krmsnorm_enable_gain', default=True, action=argparse.BooleanOptionalAction, help="include gain in kRMSNorm") model_group.add_argument("--krmsnorm_selection_type", type=str, default="last", choices=["first", "last", "random"]) model_group.add_argument("--krmsnorm_recompute_percentage", type=float, default=None, help="percentage needed within the total RMS to not trigger recompute") + ## HyperSphereNorm + model_group.add_argument("--hsnorm_gain", default=False, action=argparse.BooleanOptionalAction) + model_group.add_argument("--hsnorm_radius", type=float, default=None) + model_group.add_argument("--hsnorm_radius_learning", default=False, action=argparse.BooleanOptionalAction) + activation_variations = [ "celu", "elu", diff --git a/utils/gns_monitoring/LICENSE b/utils/gns_monitoring/LICENSE new file mode 100644 index 0000000000..261eeb9e9f --- /dev/null +++ b/utils/gns_monitoring/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/utils/gns_monitoring/gns_utils.py b/utils/gns_monitoring/gns_utils.py new file mode 100644 index 0000000000..a77104bcc3 --- /dev/null +++ b/utils/gns_monitoring/gns_utils.py @@ -0,0 +1,152 @@ +# Copyright 2023 Cerebras Systems Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['GradNorm', 'mean_loss_scale', 'EMA'] + +import math +import torch +import numpy as np +from dataclasses import dataclass, asdict +from typing import Union +from collections.abc import Sequence + +@dataclass +class GradNorm: + """ + A GradNorm measurement annotated with loss_scale and batch_size, because + these are necessary to compute GNS using this GradNorm later. + """ + val: float + loss_scale: float + batch_size: int + + def __repr__(self): + return f"GradNorm(val={self.val}, loss_scale={self.loss_scale}, batch_size={self.batch_size})" + +def mean_loss_scale(microbatch_size, minibatch_size): + """Compute the appropriate loss scale when using a loss that has been + reduced by a mean over the batch dimension, for a given minibatch + and microbatch size.""" + return minibatch_size / microbatch_size + +### BEGIN MIT LICENSE ### +# Copyright (c) 2022 Katherine Crowson +# 2023 Gavia Gray +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +class EMA: + """Calculates the gradient noise scale (1 / SNR), or critical batch size, + from _An Empirical Model of Large-Batch Training_, + https://arxiv.org/abs/1812.06162). + + Args: + beta (float): The decay factor for the exponential moving averages used to + calculate the gradient noise scale. + Default: 0.9998 + eps (float): Added for numerical stability. + Default: 1e-8 + """ + def __init__(self, beta=0.9998, eps=1e-8): + self.beta = beta + self.eps = eps + self.ema_sq_norm = 0. + self.ema_var = 0. + self.beta_cumprod = 1. + self.gradient_noise_scale = float('nan') + + def state_dict(self): + """Returns the state of the object as a :class:`dict`.""" + return dict(self.__dict__.items()) + + def load_state_dict(self, state_dict): + """Loads the object's state. + Args: + state_dict (dict): object state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def update(self, norm_small_batch, norm_large_batch): + """Updates the state with a new batch's gradient statistics, and returns the + current gradient noise scale. + + Args: + norm_small_batch (GradNorm): The mean of the 2-norms of microbatch or + per sample gradients. + norm_large_batch (GradNorm): The 2-norm of the mean of the microbatch or + per sample gradients. + """ + sq_norm_small_batch = (norm_small_batch.val * norm_small_batch.loss_scale)**2 + sq_norm_large_batch = (norm_large_batch.val * norm_large_batch.loss_scale)**2 + m, n = norm_small_batch.batch_size, norm_large_batch.batch_size + est_sq_norm = (n * sq_norm_large_batch - m * sq_norm_small_batch) / (n - m) + est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / m - 1 / n) + self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm + self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var + self.beta_cumprod *= self.beta + self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps) + return self.gradient_noise_scale + + def get_gns(self): + """Returns the current gradient noise scale.""" + return self.gradient_noise_scale + + def get_stats(self): + """Returns the current (debiased) estimates of the squared mean gradient + and gradient variance.""" + return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod) +### END MIT LICENSE ### + +def gnsify(sogns_results, minibatch_size, ddp=False): + # dictionary of approximate per-example gradient norms + # convert to gns format + # accumulate small and large squared gradient norms + total_small = 0. + total_big = 0. + for _, v in sogns_results.items(): + total_small += v.peg_sqnorm + total_big += v.g_sqnorm + if ddp: + # all_reduce AVG + torch.distributed.all_reduce(total_small, op=torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(total_big, op=torch.distributed.ReduceOp.AVG) + small = GradNorm( + math.sqrt(total_small), + mean_loss_scale(1, minibatch_size), + 1 + ) + big = GradNorm( + math.sqrt(total_big), + 1., + minibatch_size + ) + return small, big + diff --git a/utils/gns_monitoring/hook.py b/utils/gns_monitoring/hook.py new file mode 100644 index 0000000000..2e532c5d1b --- /dev/null +++ b/utils/gns_monitoring/hook.py @@ -0,0 +1,179 @@ +# Copyright 2023 Cerebras Systems Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Hooks to implement approximate gradient noise scale without touching the actual +model code. +""" + +import math + +import torch + +from dataclasses import dataclass + +def is_any_nan_or_inf(tensor): + """ + Check if any element of a tensor is NaN. + """ + return torch.isnan(tensor).any() or torch.isinf(tensor).any() + +def add_sogns_hooks(module): + """ + Add forward and backward hooks necessary for computing scaled output + gradient noise scale. + """ + assert isinstance(module, torch.nn.Linear) + @torch.no_grad() + def forward_pre_hook(module, activations): + """ + Forward pre-hook to store statistics about the activations. + """ + a = activations[0] + if a.ndim == 2: + _, i = a.shape + l = 1 + a = a.unsqueeze(1) + elif a.ndim == 3: + _, l, i = a.shape + else: + raise ValueError(f'Unsupported activation shape: {a.shape}') + z = 1./(l * i) + a = z * a.float() # this is safer + module.a_sigma = torch.einsum('bli,bli->b', a, a).sqrt().unsqueeze(1) + # assert not is_any_nan_or_inf(module.a_sigma) + module.activation_dim = i + + class TensorHook: + def __init__(self, module): + self.module = module + @torch.no_grad() + def __call__(self, grad): + """ + Backward hook to compute the gradient noise scale. + """ + grad = grad.float() + # assert not is_any_nan_or_inf(grad) + if grad.ndim == 2: + grad = grad.unsqueeze(1) + # comput squared per-example batch gradient contribution + bias_s = 0. + bias_g_sqnorm = 0. + if self.module.bias is not None: + bias_s += (grad**2).sum(1).mean() # scalar + bias_g_sqnorm += (grad.sum(0)**2).sum() # scalar + i = self.module.activation_dim + w_tilde = math.sqrt(i) * self.module.a_sigma * grad.sum(1) + # assert not is_any_nan_or_inf(w_tilde) + self.module.peg_sqnorm = (torch.sum(w_tilde**2, 1).mean() + bias_s) + self.module.g_sqnorm = (torch.sum(w_tilde.sum(0)**2) + bias_g_sqnorm) + # delete a_sigma to make sure our garbage is collected + del self.module.a_sigma + + def forward_post_hook(module, activations, output): + """ + Forward post-hook to store the output tensor. + """ + # add tensor hook to the output tensor if it requires grad + if output.requires_grad: + output.register_hook(TensorHook(module)) + + # add hooks to this module + module.register_forward_pre_hook(forward_pre_hook) + module.register_forward_hook(forward_post_hook) + + +def add_exact_hooks(module): + """ + Add forward and backward hooks necessary for computing regular gradient + noise scale. (Much more expensive). + """ + assert isinstance(module, torch.nn.Linear) + @torch.no_grad() + def forward_pre_hook(module, activations): + """ + Forward pre-hook to store a reference to the input tensor :( + """ + module.input_activations = activations[0] + + class TensorHook: + def __init__(self, module): + self.module = module + @torch.no_grad() + def __call__(self, grad): + """ + Backward hook to compute the gradient noise scale. + """ + a = self.module.input_activations.float() + if a.ndim == 2: + a = a.unsqueeze(1) + if grad.ndim == 2: + g = grad.unsqueeze(1) + else: + g = grad + g = g.float() + # comput squared per-example batch gradient contribution + bias_s = 0. + bias_g_sqnorm = 0. + if self.module.bias is not None: + bias_s += (g**2).sum(1).mean() # scalar + bias_g_sqnorm += (g.sum(0)**2).sum() # scalar + s = torch.einsum('bmk,bnk,bml,bnl->b', a, a, g, g).mean() + module.peg_sqnorm = (s + bias_s) + g_big = torch.einsum('bmk,bml->kl', a, g) + self.module.g_sqnorm = (torch.sum(g_big**2) + + bias_g_sqnorm) + # delete a_sigma to make sure our garbage is collected + del self.module.input_activations + + def forward_post_hook(module, activations, output): + """ + Forward post-hook to store the output tensor. + """ + # add tensor hook to the output tensor if it requires grad + if output.requires_grad: + output.register_hook(TensorHook(module)) + + # add hooks to this module + module.register_forward_pre_hook(forward_pre_hook) + module.register_forward_hook(forward_post_hook) + + +def add_hooks_to_model(model, add_hooks): + """ + Add hooks to all modules in the model. + """ + for module in model.modules(): + if isinstance(module, torch.nn.Linear): + add_hooks(module) + + +@dataclass +class HookResult: + """ + Hook result dataclass. + """ + peg_sqnorm: float + g_sqnorm: float + + +def gather_hook_results(model): + """ + Gather the results from the hooks. + """ + results = {} + for name, module in model.named_modules(): + if hasattr(module, 'peg_sqnorm'): + results[name] = HookResult(module.peg_sqnorm, module.g_sqnorm) + return results diff --git a/variations/norm_variations.py b/variations/norm_variations.py index a59a10d56d..7db62a6eda 100644 --- a/variations/norm_variations.py +++ b/variations/norm_variations.py @@ -28,6 +28,35 @@ def forward(self, x): rms = x.norm(2, dim=-1, keepdim=True) / math.sqrt(x.size(-1)) return x / rms * self.gain +class HyperSphereNorm(nn.Module): + """Normalization to the surface of Hypersphere""" + + def __init__(self, config): + super().__init__() + + ndim = config.n_embd + if config.hsnorm_gain: + self.gain = nn.Parameter(torch.ones(ndim)) + else: + self.gain = 1.0 + + # Determine radius initialization value + radius_init = None + if config.hsnorm_radius is not None: + radius_init = config.hsnorm_radius + else: + radius_init = math.sqrt(ndim) + + # Set as constant or learned param + if config.hsnorm_radius_learning: + self.radius = nn.Parameter(torch.tensor([radius_init])) + else: + self.radius = radius_init + + def forward(self, x): + hypersphere_norm = x.norm(2, dim=-1, keepdim=True) + return x / hypersphere_norm * self.radius + class pRMSNorm(nn.Module): """Partial RMS Normalization""" @@ -136,4 +165,5 @@ def forward(self, x): "rmsnorm": RMSNorm, "prmsnorm": pRMSNorm, "krmsnorm": kRMSNorm, + "hyperspherenorm": HyperSphereNorm, }