From 76d376f1f05a5ec4aa8e3718922d3df2d9343911 Mon Sep 17 00:00:00 2001 From: bashFish Date: Wed, 2 Nov 2022 19:39:34 +0000 Subject: [PATCH] fixed training on , tested on new deepspeed, single node --- magma/magma.py | 14 +++++++------- train.py | 3 ++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/magma/magma.py b/magma/magma.py index 9862b92..9b57446 100644 --- a/magma/magma.py +++ b/magma/magma.py @@ -40,7 +40,7 @@ def __init__(self, config, device=None): "cuda" if torch.cuda.is_available() else "cpu" ) self.config = config - self.lm = get_gptj().to(self.device) + self.lm = get_gptj() #.to(self.device) self.seq_len = self.lm.config.max_position_embeddings self.tokenizer = get_tokenizer("gpt2", sequence_length=self.seq_len) @@ -49,16 +49,16 @@ def __init__(self, config, device=None): self.eos_token = self.tokenizer.eos_token_id self.lm.resize_token_embeddings(len(self.tokenizer)) self.lm.config.pad_token_id = self.tokenizer.eos_token_id - self.word_embedding = self.lm.transformer.wte.to(device) + self.word_embedding = self.lm.transformer.wte #.to(device) self.transformer = self.lm.transformer.h # adapter settings self.mlp_adapter_added, self.attn_adapter_added = False, False - + self.image_prefix = ImagePrefix( config=config, out_dim=self.lm.config.hidden_size, - ).to(self.device) + ) #.to(self.device) # might change based on the type of image encoder, so get from prefix instead of config self.image_prefix_seq_len = self.image_prefix.out_seq_len @@ -189,7 +189,7 @@ def preprocess_inputs(self, input_list: list, embed = True) -> List[torch.Tensor if embed == True: return self.embed(input_list) - else: + else: return input_list def embed(self, inputs: List[torch.Tensor]) -> TensorType["b", "s", "d"]: @@ -282,7 +282,7 @@ def from_checkpoint(cls, config_path, checkpoint_path, device = 'cpu'): """ checkpoint_url = 'https://bit.ly/aleph-alpha-magma-download' - + if exists(checkpoint_path) == False: print_main(f'checkpoint: {checkpoint_path} does not exist, downloading model') download_checkpoint(checkpoint_url = checkpoint_url, save_as = checkpoint_path) @@ -298,4 +298,4 @@ def from_checkpoint(cls, config_path, checkpoint_path, device = 'cpu'): print_main("magma successfully loaded") model.half().to(device).eval() - return model \ No newline at end of file + return model diff --git a/train.py b/train.py index a004bf1..abc8bdc 100644 --- a/train.py +++ b/train.py @@ -77,7 +77,8 @@ def get_pretraining_datasets(config, tokenizer, transforms): # load model + tokenizer: model = Magma( - args.config + args.config, + device=torch.device("cuda", args.local_rank) ) # for finetuning one might want to load the model via Magma.from_checkpoint(...) here tokenizer, config, transforms = model.tokenizer, model.config, model.transforms