From ecf8abd131404c216b47983cbf20d7f1987c0cb8 Mon Sep 17 00:00:00 2001 From: Keisuke Kamahori Date: Thu, 25 Apr 2024 18:46:59 +0000 Subject: [PATCH] Single request only --- src/fiddler/infer.py | 2 +- src/fiddler/mixtral.py | 32 ++++++++++++++++---------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/fiddler/infer.py b/src/fiddler/infer.py index 41310c7..96ccecb 100644 --- a/src/fiddler/infer.py +++ b/src/fiddler/infer.py @@ -39,7 +39,7 @@ args = parser.parse_args() model = FiddlerMixtral(args) prefill_time, decode_time, hit_rate = model.generate( - [args.input], output_token=args.n_token + args.input, output_token=args.n_token ) print( f"prefill_time: {prefill_time}, decode_time: {decode_time}, hit_rate: {hit_rate}" diff --git a/src/fiddler/mixtral.py b/src/fiddler/mixtral.py index e2eabf9..df40b6a 100644 --- a/src/fiddler/mixtral.py +++ b/src/fiddler/mixtral.py @@ -361,7 +361,7 @@ def calc_n_expert_on_gpu(self): return int((free_mem) // (n_param * 2)) def initial_beam_tensor(self, input_tensor): - # transfer tensor of shape (batch_size*beam_width, seq_len, beam_width) to (batch_size*beam_width, 1) properly + # transpose tensor of shape (beam_width, seq_len, beam_width) to (beam_width, 1) properly assert input_tensor.shape[-1] == self.beam_width input_tensor = input_tensor[:, -1] row_idx = torch.tensor( @@ -370,7 +370,7 @@ def initial_beam_tensor(self, input_tensor): output_tensor = input_tensor[row_idx].view(-1, 1) return output_tensor - def generate(self, texts=None, output_token=20, input_token=None): + def generate(self, text=None, output_token=20, input_token=None): torch.set_num_threads(16) # TODO: set appropriately self.past_key_value = transformers.cache_utils.DynamicCache.from_legacy_cache() self.past_key_values_length = 0 @@ -378,10 +378,7 @@ def generate(self, texts=None, output_token=20, input_token=None): self.cnt_expert_hit = 0 self.cnt_expert_all = 0 - input_ids, position_ids=self.tokenize(texts) - - # input_ids.shape: (batch_size, seq_len) - # position_ids.shape: (1,seq_len) + input_ids, position_ids = self.tokenize(text) if input_token is not None: input_ids = input_ids[:, :input_token] @@ -447,30 +444,33 @@ def generate(self, texts=None, output_token=20, input_token=None): decode_time = time.time() - tick probs = probs.view(-1, self.beam_width) max_ids = torch.argmax(probs, dim=-1) - for i in range(max_ids.shape[0]): - print("--------------------") - print(f"Input: {texts[i]}") - print(f"Output: {decode_strings[i * self.beam_width + max_ids[i]]}") + + print("--------------------") + print(f"Input: {text}") + print(f"Output: {decode_strings[max_ids[0]]}") + return ( prefill_time, decode_time, self.cnt_expert_hit / self.cnt_expert_all, ) - def tokenize(self, texts): + def tokenize(self, text): input_ids = [] - for text in texts: - encodings = self.tokenizer(text, return_tensors="pt") - input_id = encodings.input_ids.to(self.dev) - for i in range(self.beam_width): - input_ids.append(input_id[0]) + encodings = self.tokenizer(text, return_tensors="pt") + input_id = encodings.input_ids.to(self.dev) + for i in range(self.beam_width): + input_ids.append(input_id[0]) + input_ids = pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id ).to(self.dev) + position_ids = torch.arange( 0, input_ids.shape[-1], dtype=torch.long, device=self.dev ) position_ids = position_ids.unsqueeze(0).view(-1, input_ids.shape[-1]) + return input_ids, position_ids @torch.no_grad()