From 673d601fd3f5f223aa7eca7c489e96fff6f6028f Mon Sep 17 00:00:00 2001 From: skytnt Date: Sat, 5 Oct 2024 21:44:07 +0800 Subject: [PATCH] use_cache=False to avoid warning --- midi_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/midi_model.py b/midi_model.py index 78148aa..8aa8517 100644 --- a/midi_model.py +++ b/midi_model.py @@ -29,11 +29,13 @@ def get_config(tokenizer_ver="v2", optimise_midi=True, n_layer=12, n_head=16, n_ net_config = LlamaConfig(vocab_size=tokenizer.vocab_size, hidden_size=n_embd, num_attention_heads=n_head, num_hidden_layers=n_layer, intermediate_size=n_inner, - pad_token_id=tokenizer.pad_id, max_position_embeddings=4096) + pad_token_id=tokenizer.pad_id, max_position_embeddings=4096, + use_cache=False) net_token_config = LlamaConfig(vocab_size=tokenizer.vocab_size, hidden_size=n_embd, num_attention_heads=n_head // 4, num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4, - pad_token_id=tokenizer.pad_id, max_position_embeddings=4096) + pad_token_id=tokenizer.pad_id, max_position_embeddings=4096, + use_cache=False) return MIDIModelConfig(tokenizer, net_config, net_token_config) @staticmethod