Skip to content

Commit

Permalink
delegate add_graph to pytorch (#595)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa authored Jul 1, 2020
1 parent 56a9708 commit d7238f5
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,10 +780,23 @@ def add_openvino_graph(self, xmlname):
"""
self._get_file_writer().add_openvino_graph(load_openvino_graph(xmlname))

def add_graph(self, model, input_to_model=None, verbose=False, profile_with_cuda=False, **kwargs):
def add_graph(self, model, input_to_model=None, verbose=False):
"""Add graph data to summary. The graph is actually processed by `torch.utils.tensorboard.add_graph()`
Args:
model (torch.nn.Module): Model to draw.
input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of
variables to be fed.
verbose (bool): Whether to print graph structure in console.
"""
from torch.utils.tensorboard._pytorch_graph import graph
self._get_file_writer().add_graph(graph(model, input_to_model, verbose))

def add_graph_deprecated(self, model, input_to_model=None, verbose=False, profile_with_cuda=False, **kwargs):
# prohibit second call?
# no, let tensorboard handle it and show its warning message.
"""Add graph data to summary.
"""[deprecated] Add graph data to summary. This was used in tensorboardX <= 2.0
Args:
model (torch.nn.Module): Model to draw.
Expand Down

0 comments on commit d7238f5

Please sign in to comment.