Skip to content

Commit

Permalink
Merge pull request #23 from ucb-bar/fix-converter
Browse files Browse the repository at this point in the history
Fix: case handling where Linear module can have no bias
  • Loading branch information
T-K-233 authored Jan 1, 2025
2 parents 70503b0 + 6faf2e9 commit 1f129c5
Showing 1 changed file with 51 additions and 19 deletions.
70 changes: 51 additions & 19 deletions converter/src/torchconverter/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(self, model: torch.nn.Module):
self.gm: torch.fx.GraphModule = gm

# extract node information
# this dictionary maps each node name to a tuple of the node's args and kwargs
# this is used for getting the input tensor and parameters for each forward function call
self.node_info: Dict[str, Tuple[Any, Any]] = {n.name: (n.args, n.kwargs) for n in self.graph.nodes}

# initialize jinja2 code generation environment
Expand Down Expand Up @@ -172,7 +174,11 @@ def get_module(self, module_name: str) -> torch.nn.Module:

def add_uninitialized_tensor(self, name: str, tensor: torch.Tensor):
"""
Add an uninitialized tensor to the C code.
Adds an uninitialized tensor to the C code.
Args:
name (str): The name of the tensor.
tensor (torch.Tensor): The example tensor. This is used to determine the shape and data type of the tensor.
"""
self.tensors[name] = {
"tensor": tensor,
Expand All @@ -181,7 +187,11 @@ def add_uninitialized_tensor(self, name: str, tensor: torch.Tensor):

def add_initialized_tensor(self, name: str, tensor: torch.Tensor):
"""
Add an initialized tensor to the C code.
Adds an initialized tensor to the C code.
Args:
name (str): The name of the tensor.
tensor (torch.Tensor): The tensor containing the data. This is used to determine the shape, data type, and data of the tensor.
"""
self.tensors[name] = {
"tensor": tensor,
Expand All @@ -190,14 +200,14 @@ def add_initialized_tensor(self, name: str, tensor: torch.Tensor):

def add_forward_call(self, function_name: str, out: torch.Tensor, layer_name: str, input_names: List[str], parameters: List[str] = None):
"""
This method creates the C code for the forward call.
Adds a forward function call to the C code.
Args:
function (Callable): The function to call.
dim (int): The dimension of the output tensor.
dtype (torch.dtype): The data type of the output tensor.
function_name (str): The name template of the function to call.
out (torch.Tensor): The output tensor.
layer_name (str): The name of the layer.
input_names (List[str]): The names of the input tensors.
parameters (List[str]): The additional parameters to pass to the function.
"""

dtype_str = TracedModule.get_dtype_str(out.dtype)
Expand Down Expand Up @@ -229,23 +239,37 @@ def handle_get_attr(self, n: torch.fx.node.Node, out: torch.Tensor):
def handle_call_function(self, n: torch.fx.node.Node, out: torch.Tensor):
"""
Handle the case where the node is a call to a torch function (e.g. relu, elu, etc.)
n has the following attributes:
- op: the operation that is being performed (here it is "call_function")
- name: the name of the layer (e.g. "relu", "linear")
- target: the Python function that is being called (e.g. torch.nn.functional.relu, torch.nn.functional.linear)
- args: a list of torch.fx.node.Node objects that are the arguments to the function
- prev: the previous nodes in the graph
- next: the next nodes in the graph
"""
print("call function:", n.name, n.target, n.args)

# get all the related information
function = n.target
layer_name = n.name
input_names = [n.name for n in self.node_info[n.name][0]]
input_names = []
input_args = n.args

for n in self.node_info[n.name][0]:
if n is None:
# if e.g. linear has no bias, then the bias argument is None
continue
input_names.append(n.name)

# Math operations - Pointwise Ops
if function == operator.__add__:
self.add_uninitialized_tensor(layer_name, out)
self.add_forward_call("nn_add_{dtype}", out, layer_name, input_names)
self.add_forward_call("nn_add{dim}d_{dtype}", out, layer_name, input_names)

elif function == operator.__mul__:
self.add_uninitialized_tensor(layer_name, out)
self.add_forward_call("nn_mul_{dtype}", out, layer_name, input_names)
self.add_forward_call("nn_mul{dim}d_{dtype}", out, layer_name, input_names)

# Convolution Layers

Expand All @@ -269,11 +293,15 @@ def handle_call_function(self, n: torch.fx.node.Node, out: torch.Tensor):
# Linear Layers
elif function == torch.nn.functional.linear:
weight = self.model.state_dict()[input_args[1].target]
bias = self.model.state_dict()[input_args[2].target]
if input_args[2] is not None:
bias = self.model.state_dict()[input_args[2].target]
else:
bias = None
self.add_uninitialized_tensor(layer_name, out)
self.add_initialized_tensor(f"{input_names[1]}", weight)
self.add_initialized_tensor(f"{input_names[2]}", bias)
self.add_forward_call("nn_addmm_{dtype}", out, layer_name, input_names)
if bias is not None:
self.add_initialized_tensor(f"{input_names[2]}", bias)
self.add_forward_call("nn_linear_{dtype}", out, layer_name, input_names)

# Vision Functions

Expand Down Expand Up @@ -334,14 +362,18 @@ def handle_call_module(self, n: torch.fx.node.Node, out: torch.Tensor):
# Linear Layers
elif type(module) == torch.nn.Linear:
weight = module.weight
bias = module.bias
# optionally use the bias if it exists
bias = module.bias if module.bias is not None else None

input_names.append(f"{layer_name}_weight")
input_names.append(f"{layer_name}_bias")
if bias is not None:
input_names.append(f"{layer_name}_bias")

self.add_uninitialized_tensor(layer_name, out)
self.add_initialized_tensor(f"{layer_name}_weight", weight)
self.add_initialized_tensor(f"{layer_name}_bias", bias)
self.add_forward_call("nn_addmm_{dtype}", out, layer_name, input_names)
if bias is not None:
self.add_initialized_tensor(f"{layer_name}_bias", bias)
self.add_forward_call("nn_linear_{dtype}", out, layer_name, input_names)

def handle_output(self, n: torch.fx.node.Node, out: torch.Tensor):
print("output:", n.name, out.shape, out.dtype)
Expand Down Expand Up @@ -455,11 +487,11 @@ class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.seq = nn.Sequential(
nn.Linear(48, 128, bias=True),
nn.Linear(48, 128, bias=False),
nn.ELU(),
nn.Linear(128, 5, bias=True),
nn.Linear(128, 5, bias=False),
)
self.lin2 = nn.Linear(5, 12, bias=True)
self.lin2 = nn.Linear(5, 12, bias=False)

def forward(self, input):
x = self.seq.forward(input)
Expand Down

0 comments on commit 1f129c5

Please sign in to comment.