Skip to content

Commit

Permalink
update triton_templates
Browse files Browse the repository at this point in the history
  • Loading branch information
yorickvP committed Apr 15, 2024
1 parent 5f80297 commit 57b92cf
Show file tree
Hide file tree
Showing 10 changed files with 1,049 additions and 499 deletions.
6 changes: 3 additions & 3 deletions triton_templates/ensemble/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ input [
]
output [
{
name: "output_ids"
data_type: TYPE_INT32
name: "text_output"
data_type: TYPE_STRING
dims: [ -1 ]
},
{
Expand Down Expand Up @@ -421,7 +421,7 @@ ensemble_scheduling {
}
output_map {
key: "OUTPUT"
value: "output_ids"
value: "text_output"
}
output_map {
key: "OUT_OUTPUT_LOG_PROBS"
Expand Down
110 changes: 57 additions & 53 deletions triton_templates/postprocessing/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@

import numpy as np
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer, LlamaTokenizerFast, T5Tokenizer

import time
from transformers import AutoTokenizer


class TritonPythonModel:
Expand All @@ -53,32 +51,20 @@ def initialize(self, args):
* model_version: Model version
* model_name: Model name
"""


# Parse model configs
model_config = json.loads(args['model_config'])
tokenizer_dir = model_config['parameters']['tokenizer_dir'][
'string_value']
tokenizer_type = model_config['parameters']['tokenizer_type'][
'string_value']
self.skip_special_tokens = model_config['parameters'].get(
'skip_special_tokens',
{'string_value': "true"})['string_value'].lower() in [
'true', '1', 't', 'y', 'yes'
]

if tokenizer_type == 't5':
self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir,
padding_side='left')
elif tokenizer_type == 'auto':
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, padding_side='left', trust_remote_code=True)
elif tokenizer_type == 'llama':
self.tokenizer = LlamaTokenizerFast.from_pretrained(
tokenizer_dir, legacy=False, padding_side='left')
else:
raise AttributeError(
f'Unexpected tokenizer type: {tokenizer_type}')
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir,
legacy=False,
padding_side='left',
trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token

# Parse model output configs
Expand All @@ -88,7 +74,6 @@ def initialize(self, args):
# Convert Triton types to numpy types
self.output_dtype = pb_utils.triton_string_to_numpy(
output_config['data_type'])


def execute(self, requests):
"""`execute` must be implemented in every Python model. `execute`
Expand All @@ -109,6 +94,7 @@ def execute(self, requests):
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""

responses = []

# Every Python backend must iterate over everyone of the requests
Expand All @@ -124,19 +110,19 @@ def execute(self, requests):

# Get cum log probs
cum_log_probs = pb_utils.get_input_tensor_by_name(
request, 'CUM_LOG_PROBS').as_numpy()
request, 'CUM_LOG_PROBS')

# Get sequence length
output_log_probs = pb_utils.get_input_tensor_by_name(
request, 'OUTPUT_LOG_PROBS').as_numpy()
request, 'OUTPUT_LOG_PROBS')

# Get context logits
context_logits = pb_utils.get_input_tensor_by_name(
request, 'CONTEXT_LOGITS').as_numpy()
request, 'CONTEXT_LOGITS')

# Get generation logits
generation_logits = pb_utils.get_input_tensor_by_name(
request, 'GENERATION_LOGITS').as_numpy()
request, 'GENERATION_LOGITS')

# Reshape Input
# tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]])
Expand All @@ -147,25 +133,51 @@ def execute(self, requests):

# Create output tensors. You need pb_utils.Tensor
# objects to create pb_utils.InferenceResponse.
# output_tensor = pb_utils.Tensor(
# 'OUTPUT',
# np.array(outputs).astype(self.output_dtype))

output_tensor = pb_utils.Tensor(
'OUTPUT',
tokens_batch)

out_cum_log_probs = pb_utils.Tensor('OUT_CUM_LOG_PROBS',
cum_log_probs)

out_output_log_probs = pb_utils.Tensor('OUT_OUTPUT_LOG_PROBS',
output_log_probs)

out_context_logits = pb_utils.Tensor('OUT_CONTEXT_LOGITS',
context_logits)

out_generation_logits = pb_utils.Tensor('OUT_GENERATION_LOGITS',
generation_logits)
np.array(outputs).astype(self.output_dtype))

outputs = []
outputs.append(output_tensor)

if cum_log_probs:
out_cum_log_probs = pb_utils.Tensor('OUT_CUM_LOG_PROBS',
cum_log_probs.as_numpy())
outputs.append(out_cum_log_probs)
else:
out_cum_log_probs = pb_utils.Tensor(
'OUT_CUM_LOG_PROBS', np.array([[0.0]], dtype=np.float32))
outputs.append(out_cum_log_probs)

if output_log_probs:
out_output_log_probs = pb_utils.Tensor(
'OUT_OUTPUT_LOG_PROBS', output_log_probs.as_numpy())
outputs.append(out_output_log_probs)
else:
out_output_log_probs = pb_utils.Tensor(
'OUT_OUTPUT_LOG_PROBS',
np.array([[[0.0]]], dtype=np.float32))
outputs.append(out_output_log_probs)

if context_logits:
out_context_logits = pb_utils.Tensor('OUT_CONTEXT_LOGITS',
context_logits.as_numpy())
outputs.append(out_context_logits)
else:
out_context_logits = pb_utils.Tensor(
'OUT_CONTEXT_LOGITS', np.array([[[0.0]]],
dtype=np.float32))
outputs.append(out_context_logits)

if generation_logits:
out_generation_logits = pb_utils.Tensor(
'OUT_GENERATION_LOGITS', generation_logits.as_numpy())
outputs.append(out_generation_logits)
else:
out_generation_logits = pb_utils.Tensor(
'OUT_GENERATION_LOGITS',
np.array([[[[0.0]]]], dtype=np.float32))
outputs.append(out_generation_logits)

# Create InferenceResponse. You can set an error here in case
# there was a problem with handling this inference request.
Expand All @@ -174,15 +186,12 @@ def execute(self, requests):
#
# pb_utils.InferenceResponse(
# output_tensors=..., TritonError("An error occurred"))
inference_response = pb_utils.InferenceResponse(output_tensors=[
output_tensor, out_cum_log_probs, out_output_log_probs,
out_context_logits, out_generation_logits
])
inference_response = pb_utils.InferenceResponse(
output_tensors=outputs)
responses.append(inference_response)

# You should return a list of pb_utils.InferenceResponse. Length
# of this list must match the length of `requests` list.

return responses

def finalize(self):
Expand All @@ -193,17 +202,12 @@ def finalize(self):
print('Cleaning up...')

def _postprocessing(self, tokens_batch, sequence_lengths):
start = time.time()
outputs = []
for batch_idx, beam_tokens in enumerate(tokens_batch):
for beam_idx, tokens in enumerate(beam_tokens):
inner_loop_time = time.time()
seq_len = sequence_lengths[batch_idx][beam_idx]
tokens_to_decode = tokens[:seq_len]
tokenizer_start_time = time.time()
output = self.tokenizer.decode(
tokens_to_decode,
tokens[:seq_len],
skip_special_tokens=self.skip_special_tokens)
tokenizer_output_time = time.time()
outputs.append(output.encode('utf8'))
end_inner_loop = time.time()
return outputs
11 changes: 3 additions & 8 deletions triton_templates/postprocessing/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ input [
name: "CUM_LOG_PROBS"
data_type: TYPE_FP32
dims: [ -1 ]
optional: true
},
{
name: "OUTPUT_LOG_PROBS"
data_type: TYPE_FP32
dims: [ -1, -1 ]
optional: true
},
{
name: "CONTEXT_LOGITS"
Expand All @@ -64,7 +66,7 @@ input [
output [
{
name: "OUTPUT"
data_type: TYPE_INT32
data_type: TYPE_STRING
dims: [ -1 ]
},
{
Expand Down Expand Up @@ -96,13 +98,6 @@ parameters {
}
}

parameters {
key: "tokenizer_type"
value: {
string_value: "${tokenizer_type}"
}
}

parameters {
key: "skip_special_tokens"
value: {
Expand Down
Loading

0 comments on commit 57b92cf

Please sign in to comment.