Skip to content

Commit

Permalink
Condense the WorkflowRunner sub-classes (#319)
Browse files Browse the repository at this point in the history
* Condense the `WorkflowRunner` sub-classes

The workflow output format should now be the same for both TF and Torch, so we can condense their respective sub-classes into the base `WorkflowRunner`, which only leaves the HugeCTR sub-class to deal with later (when we have HugeCTR support.)
Move `_convert_to_np` to the HugeCTR `WorkflowRunner`

* Remove `HugeCTRWorkflowRunner`

* Inline `_transform_outputs` hook method

* Remove stray code that got pasted into docstring
  • Loading branch information
karlhigley authored Apr 11, 2023
1 parent 1d6c5c8 commit f8d8808
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 260 deletions.
5 changes: 0 additions & 5 deletions merlin/systems/dag/ops/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(
sparse_max: dict = None,
max_batch_size: int = None,
label_columns: List[str] = None,
model_framework: str = None,
cats: List[str] = None,
conts: List[str] = None,
):
Expand All @@ -51,9 +50,6 @@ def __init__(
Maximum batch size, by default None
label_columns : List[str], optional
List of strings identifying the label columns, by default None
model_framework : str, optional
String representing the target framework
(supported: hugectr, tensorflow, pytorch, python), by default None
cats : List[str], optional
List of strings identifying categorical columns, by default None
conts : List[str], optional
Expand All @@ -68,7 +64,6 @@ def __init__(
self.sparse_max = sparse_max or {}
self.max_batch_size = max_batch_size
self.label_columns = label_columns or []
self.model_framework = model_framework or ""
self.cats = cats or []
self.conts = conts or []

Expand Down
59 changes: 10 additions & 49 deletions merlin/systems/dag/runtimes/triton/ops/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
tensor_table_to_triton_request,
triton_response_to_tensor_table,
)
from merlin.systems.triton.export import _add_model_param, _convert_dtype
from merlin.systems.triton.export import _add_model_param


class TransformWorkflowTriton(TritonOperator):
Expand All @@ -54,9 +54,6 @@ def __init__(self, op):
Maximum batch size, by default None
label_columns : List[str], optional
List of strings identifying the label columns, by default None
model_framework : str, optional
String representing the target framework
(supported: hugectr, tensorflow, pytorch, python), by default None
cats : List[str], optional
List of strings identifying categorical columns, by default None
conts : List[str], optional
Expand Down Expand Up @@ -196,7 +193,6 @@ def _generate_nvtabular_model(
name,
output_path,
version=1,
output_model=None,
max_batch_size=None,
sparse_max=None,
backend="python",
Expand All @@ -218,7 +214,6 @@ def _generate_nvtabular_model(
workflow,
name,
output_path,
output_model,
max_batch_size,
sparse_max=sparse_max,
backend=backend,
Expand All @@ -243,7 +238,6 @@ def _generate_nvtabular_config(
workflow,
name,
output_path,
output_model=None,
max_batch_size=None,
sparse_max=None,
backend="python",
Expand All @@ -262,7 +256,6 @@ def _generate_nvtabular_config(
)

config.parameters["python_module"].string_value = "merlin.systems.triton.models.workflow_model"
config.parameters["output_model"].string_value = output_model if output_model else ""

config.parameters["cats"].string_value = json.dumps(cats) if cats else ""
config.parameters["conts"].string_value = json.dumps(conts) if conts else ""
Expand All @@ -271,48 +264,16 @@ def _generate_nvtabular_config(
# this assumes seq_length is same for each list column
config.parameters["sparse_max"].string_value = json.dumps(sparse_max)

if output_model == "hugectr":
config.instance_group.append(model_config.ModelInstanceGroup(kind=2))
for col_name, col_schema in workflow.input_schema.column_schemas.items():
_add_model_param(col_schema, model_config.ModelInput, config.input)

for column in workflow.output_node.input_columns.names:
dtype = workflow.input_dtypes[column]
config.input.append(
model_config.ModelInput(name=column, data_type=_convert_dtype(dtype), dims=[-1])
)

config.output.append(
model_config.ModelOutput(name="DES", data_type=model_config.TYPE_FP32, dims=[-1])
)

config.output.append(
model_config.ModelOutput(name="CATCOLUMN", data_type=model_config.TYPE_INT64, dims=[-1])
)

config.output.append(
model_config.ModelOutput(name="ROWINDEX", data_type=model_config.TYPE_INT32, dims=[-1])
)
elif output_model == "pytorch":
for col_name, col_schema in workflow.input_schema.column_schemas.items():
_add_model_param(col_schema, model_config.ModelInput, config.input)

for col_name, col_schema in workflow.output_schema.column_schemas.items():
_add_model_param(
col_schema,
model_config.ModelOutput,
config.output,
[-1, 1],
)
else:
for col_name, col_schema in workflow.input_schema.column_schemas.items():
_add_model_param(col_schema, model_config.ModelInput, config.input)

for col_name, col_schema in workflow.output_schema.column_schemas.items():
if sparse_max and col_name in sparse_max.keys():
# this assumes max_sequence_length is equal for all output columns
dim = sparse_max[col_name]
_add_model_param(col_schema, model_config.ModelOutput, config.output, [-1, dim])
else:
_add_model_param(col_schema, model_config.ModelOutput, config.output)
for col_name, col_schema in workflow.output_schema.column_schemas.items():
if sparse_max and col_name in sparse_max.keys():
# this assumes max_sequence_length is equal for all output columns
dim = sparse_max[col_name]
_add_model_param(col_schema, model_config.ModelOutput, config.output, [-1, dim])
else:
_add_model_param(col_schema, model_config.ModelOutput, config.output)

with open(os.path.join(output_path, "config.pbtxt"), "w", encoding="utf-8") as o:
text_format.PrintMessage(config, o)
Expand Down
22 changes: 4 additions & 18 deletions merlin/systems/triton/models/workflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@
from merlin.core.dispatch import is_list_dtype
from merlin.systems.triton import _convert_tensor
from merlin.systems.triton.utils import triton_error_handling, triton_multi_request
from merlin.systems.workflow.hugectr import HugeCTRWorkflowRunner
from merlin.systems.workflow.pytorch import PyTorchWorkflowRunner
from merlin.systems.workflow.tensorflow import TensorflowWorkflowRunner
from merlin.systems.workflow.base import WorkflowRunner


class TritonPythonModel:
Expand Down Expand Up @@ -73,7 +71,6 @@ def initialize(self, args):

# Config loading and parsing
self.model_config = json.loads(args["model_config"])
model_framework = self.model_config["parameters"]["output_model"]["string_value"]

# Dtype parsing
input_dtypes = self.workflow.input_dtypes.items()
Expand All @@ -87,14 +84,7 @@ def initialize(self, args):
else:
self._set_output_dtype(col_name)

if model_framework == "hugectr":
runner_class = HugeCTRWorkflowRunner
elif model_framework == "pytorch":
runner_class = PyTorchWorkflowRunner
else:
runner_class = TensorflowWorkflowRunner

self.runner = runner_class(
self.runner = WorkflowRunner(
self.workflow, self.output_dtypes, self.model_config, model_device
)

Expand Down Expand Up @@ -122,12 +112,8 @@ def execute(self, request):
)
input_tensors[name] = (values, offsets)

# raise pb_utils.TritonModelException("Custom Error To check raised!")
raw_tensor_tuples = self.runner.run_workflow(input_tensors)
if isinstance(raw_tensor_tuples, dict):
raw_tensor_tuples = list(raw_tensor_tuples.items())

result = [pb_utils.Tensor(name, data) for name, data in raw_tensor_tuples]
transformed = self.runner.run_workflow(input_tensors)
result = [pb_utils.Tensor(name, data) for name, data in transformed.items()]

return pb_utils.InferenceResponse(result)

Expand Down
21 changes: 4 additions & 17 deletions merlin/systems/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,17 @@
import functools
import json
import logging
from abc import ABC, abstractmethod

import numpy as np

from merlin.core.dispatch import concat_columns
from merlin.dag import ColumnSelector, Supports
from merlin.schema import Tags
from merlin.systems.triton.conversions import convert_format
from merlin.table import TensorTable

LOG = logging.getLogger("merlin-systems")


class WorkflowRunner(ABC):
class WorkflowRunner:
def __init__(self, workflow, output_dtypes, model_config, model_device):
self.workflow = workflow
self.output_dtypes = output_dtypes
Expand All @@ -56,6 +54,7 @@ def __init__(self, workflow, output_dtypes, model_config, model_device):

self.cats = mc_cats or schema_cats
self.conts = mc_conts or schema_conts
self.offsets = None

workflow_outputs = set(workflow.output_schema.column_names)
requested_cols = set(self.cats + self.conts)
Expand Down Expand Up @@ -106,19 +105,7 @@ def run_workflow(self, input_tensors):
if kind != Supports.CPU_DICT_ARRAY:
transformed, kind = convert_format(transformed, kind, Supports.CPU_DICT_ARRAY)

# convert to the format expected by the DL models
return self._transform_outputs(transformed)

@abstractmethod
def _transform_outputs(self, tensors):
pass

def _convert_to_np(self, columns, tensors, dtype, rows):
"""converts outputs to a numpy input compatible with pytorch"""
d = np.empty((rows, len(columns)), dtype=dtype)
for i, name in enumerate(columns):
d[:, i] = tensors[name].astype(dtype)
return d
return TensorTable(transformed).to_dict()

def _transform_tensors(self, input_tensors, workflow_node):
upstream_inputs = []
Expand Down
87 changes: 0 additions & 87 deletions merlin/systems/workflow/hugectr.py

This file was deleted.

46 changes: 0 additions & 46 deletions merlin/systems/workflow/pytorch.py

This file was deleted.

Loading

0 comments on commit f8d8808

Please sign in to comment.