Skip to content

Commit

Permalink
Merge pull request #36 from bashFish/fixed_training
Browse files Browse the repository at this point in the history
fixed training , tested on new deepspeed, single node
  • Loading branch information
Mayukhdeb authored Nov 3, 2022
2 parents a69a575 + 76d376f commit 4d01e51
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
14 changes: 7 additions & 7 deletions magma/magma.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, config, device=None):
"cuda" if torch.cuda.is_available() else "cpu"
)
self.config = config
self.lm = get_gptj().to(self.device)
self.lm = get_gptj() #.to(self.device)
self.seq_len = self.lm.config.max_position_embeddings

self.tokenizer = get_tokenizer("gpt2", sequence_length=self.seq_len)
Expand All @@ -49,16 +49,16 @@ def __init__(self, config, device=None):
self.eos_token = self.tokenizer.eos_token_id
self.lm.resize_token_embeddings(len(self.tokenizer))
self.lm.config.pad_token_id = self.tokenizer.eos_token_id
self.word_embedding = self.lm.transformer.wte.to(device)
self.word_embedding = self.lm.transformer.wte #.to(device)
self.transformer = self.lm.transformer.h

# adapter settings
self.mlp_adapter_added, self.attn_adapter_added = False, False

self.image_prefix = ImagePrefix(
config=config,
out_dim=self.lm.config.hidden_size,
).to(self.device)
) #.to(self.device)

# might change based on the type of image encoder, so get from prefix instead of config
self.image_prefix_seq_len = self.image_prefix.out_seq_len
Expand Down Expand Up @@ -189,7 +189,7 @@ def preprocess_inputs(self, input_list: list, embed = True) -> List[torch.Tensor

if embed == True:
return self.embed(input_list)
else:
else:
return input_list

def embed(self, inputs: List[torch.Tensor]) -> TensorType["b", "s", "d"]:
Expand Down Expand Up @@ -282,7 +282,7 @@ def from_checkpoint(cls, config_path, checkpoint_path, device = 'cpu'):
"""

checkpoint_url = 'https://bit.ly/aleph-alpha-magma-download'

if exists(checkpoint_path) == False:
print_main(f'checkpoint: {checkpoint_path} does not exist, downloading model')
download_checkpoint(checkpoint_url = checkpoint_url, save_as = checkpoint_path)
Expand All @@ -298,4 +298,4 @@ def from_checkpoint(cls, config_path, checkpoint_path, device = 'cpu'):
print_main("magma successfully loaded")

model.half().to(device).eval()
return model
return model
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def get_pretraining_datasets(config, tokenizer, transforms):

# load model + tokenizer:
model = Magma(
args.config
args.config,
device=torch.device("cuda", args.local_rank)
) # for finetuning one might want to load the model via Magma.from_checkpoint(...) here
tokenizer, config, transforms = model.tokenizer, model.config, model.transforms

Expand Down

0 comments on commit 4d01e51

Please sign in to comment.