Skip to content

Commit

Permalink
Single request only
Browse files Browse the repository at this point in the history
  • Loading branch information
kamahori committed Apr 25, 2024
1 parent 26fe472 commit ecf8abd
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/fiddler/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
args = parser.parse_args()
model = FiddlerMixtral(args)
prefill_time, decode_time, hit_rate = model.generate(
[args.input], output_token=args.n_token
args.input, output_token=args.n_token
)
print(
f"prefill_time: {prefill_time}, decode_time: {decode_time}, hit_rate: {hit_rate}"
Expand Down
32 changes: 16 additions & 16 deletions src/fiddler/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def calc_n_expert_on_gpu(self):
return int((free_mem) // (n_param * 2))

def initial_beam_tensor(self, input_tensor):
# transfer tensor of shape (batch_size*beam_width, seq_len, beam_width) to (batch_size*beam_width, 1) properly
# transpose tensor of shape (beam_width, seq_len, beam_width) to (beam_width, 1) properly
assert input_tensor.shape[-1] == self.beam_width
input_tensor = input_tensor[:, -1]
row_idx = torch.tensor(
Expand All @@ -370,18 +370,15 @@ def initial_beam_tensor(self, input_tensor):
output_tensor = input_tensor[row_idx].view(-1, 1)
return output_tensor

def generate(self, texts=None, output_token=20, input_token=None):
def generate(self, text=None, output_token=20, input_token=None):
torch.set_num_threads(16) # TODO: set appropriately
self.past_key_value = transformers.cache_utils.DynamicCache.from_legacy_cache()
self.past_key_values_length = 0

self.cnt_expert_hit = 0
self.cnt_expert_all = 0

input_ids, position_ids=self.tokenize(texts)

# input_ids.shape: (batch_size, seq_len)
# position_ids.shape: (1,seq_len)
input_ids, position_ids = self.tokenize(text)

if input_token is not None:
input_ids = input_ids[:, :input_token]
Expand Down Expand Up @@ -447,30 +444,33 @@ def generate(self, texts=None, output_token=20, input_token=None):
decode_time = time.time() - tick
probs = probs.view(-1, self.beam_width)
max_ids = torch.argmax(probs, dim=-1)
for i in range(max_ids.shape[0]):
print("--------------------")
print(f"Input: {texts[i]}")
print(f"Output: {decode_strings[i * self.beam_width + max_ids[i]]}")

print("--------------------")
print(f"Input: {text}")
print(f"Output: {decode_strings[max_ids[0]]}")

return (
prefill_time,
decode_time,
self.cnt_expert_hit / self.cnt_expert_all,
)

def tokenize(self, texts):
def tokenize(self, text):
input_ids = []
for text in texts:
encodings = self.tokenizer(text, return_tensors="pt")
input_id = encodings.input_ids.to(self.dev)
for i in range(self.beam_width):
input_ids.append(input_id[0])
encodings = self.tokenizer(text, return_tensors="pt")
input_id = encodings.input_ids.to(self.dev)
for i in range(self.beam_width):
input_ids.append(input_id[0])

input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
).to(self.dev)

position_ids = torch.arange(
0, input_ids.shape[-1], dtype=torch.long, device=self.dev
)
position_ids = position_ids.unsqueeze(0).view(-1, input_ids.shape[-1])

return input_ids, position_ids

@torch.no_grad()
Expand Down

0 comments on commit ecf8abd

Please sign in to comment.