Skip to content

Commit

Permalink
[Benchmarking-Py] Release 1.1.0 - Replacing all print() calls by `l…
Browse files Browse the repository at this point in the history
…ogging.<level>()` calls
  • Loading branch information
DEKHTIARJonathan committed Jul 26, 2022
1 parent 95e097c commit 0f999c4
Show file tree
Hide file tree
Showing 17 changed files with 651 additions and 81 deletions.
4 changes: 4 additions & 0 deletions tftrt/benchmarking-python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ Description of the change

<!-- YOU CAN EDIT FROM HERE -->

## [1.1.0] - 2022.07.25 - @DEKHTIARJonathan

Replacing all `print()` calls by `logging.<level>()` calls

## [1.0.1] - 2022.07.25 - @DEKHTIARJonathan

Removing AutoTuning on `get_dequeue_batch_fn` because DALIDataset was not
Expand Down
16 changes: 15 additions & 1 deletion tftrt/benchmarking-python/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,16 @@
#!/usr/bin/env python
#! /usr/bin/python
# -*- coding: utf-8 -*-

# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
6 changes: 5 additions & 1 deletion tftrt/benchmarking-python/benchmark_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tensorflow.python.saved_model.signature_constants import \
DEFAULT_SERVING_SIGNATURE_DEF_KEY

from benchmark_logger import logging
from benchmark_utils import print_dict


Expand Down Expand Up @@ -392,6 +393,9 @@ def _post_process_args(self, args):
# Let's fix it to 1 to save memory.
args.total_max_samples = 1

if args.debug or args.debug_data_aggregation or args.debug_performance:
logging.set_verbosity(logging.DEBUG)

return args

def parse_args(self):
Expand All @@ -400,7 +404,7 @@ def parse_args(self):
args = self._post_process_args(args)
self._validate_args(args)

print("\nBenchmark arguments:")
logging.info("Benchmark arguments:")
print_dict(vars(args))
print()

Expand Down
20 changes: 10 additions & 10 deletions tftrt/benchmarking-python/benchmark_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import tensorflow as tf

from benchmark_logger import logging
from benchmark_utils import force_gpu_resync


Expand Down Expand Up @@ -35,21 +36,20 @@ def _autotune(self, *arg, **kwargs):
output = self._fns[fn_id](*arg, **kwargs)
self._timings[fn_id].append(time.time() - start_t)
except IndexError:
print(
"\n[DEBUG] AutoTuning is over... Collecting timing statistics:"
)
print() # visual spacing
logging.debug("AutoTuning is over... Collecting timing statistics:")
perf_data = []
for idx, fn_stat in enumerate(self._timings):
perf_data.append(np.mean(fn_stat[self._skip_n_first:]))
print(
f"\t- [DEBUG] Function ID: {idx} - "
logging.debug(
f"\t- Function ID: {idx} - "
f"Name: {self._fns[idx].__name__:40s} - "
f"Average Exec Time: {perf_data[-1]}"
)

best_fn_id = np.argmin(perf_data)
print(
f"[DEBUG] Selecting function ID: {best_fn_id}. "
logging.debug(
f"Selecting function ID: {best_fn_id}. "
f"Setting exec path to: `{self._fns[best_fn_id].__name__}`\n"
)

Expand All @@ -71,7 +71,7 @@ def _wrapper(*args, **kwargs):
try:
return context[0](*args, **kwargs)
except IndexError:
print(f"[INFO] Building the concrete function")
logging.info(f"Building the concrete function")
context.append(func.get_concrete_function(*args, **kwargs))
return context[0](*args, **kwargs)

Expand Down Expand Up @@ -106,8 +106,8 @@ def resync_gpu_wrap_fn(_func, str_appended):
funcs2autotune = [eager_function, tf_function]

if use_synthetic_data:
print(
"[INFO] Allowing direct concrete_function call with "
logging.debug(
"Allowing direct concrete_function call with "
"synthetic data loader."
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import subprocess
import shlex

# The `__version__` number shall be updated everytime core benchmarking files
# are updated.
# Please update CHANGELOG.md with a description of what this version changed.
__version__ = "1.1.0"


def get_commit_id():

Expand Down
241 changes: 241 additions & 0 deletions tftrt/benchmarking-python/benchmark_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
#! /usr/bin/python
# -*- coding: utf-8 -*-

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

import inspect
import warnings

from contextlib import contextmanager

from six import add_metaclass

import logging as _logging

from benchmark_info import __version__

from logging_utils.formatters import BaseFormatter
from logging_utils.metaclasses import SingletonMetaClass

__all__ = [
'Logger',
]


class StdOutFormatter(BaseFormatter):
DEFAULT_FORMAT = f"%(color)s[BENCH - v{__version__}] "
DEFAULT_FORMAT += "%(levelname)-8s: %(end_color)s%(message)s"


@add_metaclass(SingletonMetaClass)
class Logger(object):

# Level 0
NOTSET = _logging.NOTSET

# Level 10
DEBUG = _logging.DEBUG

# Level 20
INFO = _logging.INFO

# Level 30
WARNING = _logging.WARNING

# Level 40
ERROR = _logging.ERROR

# Level 50
CRITICAL = _logging.CRITICAL

_level_names = {
0: 'NOTSET',
10: 'DEBUG',
20: 'INFO',
30: 'WARNING',
40: 'ERROR',
50: 'CRITICAL',
}

def __init__(self, capture_io=True):

self._logger = None

self._handlers = dict()

self._define_logger()

def _define_logger(self):

# Use double-checked locking to avoid taking lock unnecessarily.
if self._logger is not None:
return self._logger

try:
# Scope the TensorFlow logger to not conflict with users' loggers.
self._logger = _logging.getLogger('benchmarking_suite')
self.reset_stream_handler()

finally:
self.set_verbosity(verbosity_level=Logger.INFO)

self._logger.propagate = False

def reset_stream_handler(self):

if self._logger is None:
raise RuntimeError(
"Impossible to set handlers if the Logger is not predefined"
)

# ======== Remove Handler if already existing ========

try:
self._logger.removeHandler(self._handlers["stream_stdout"])
except KeyError:
pass

try:
self._logger.removeHandler(self._handlers["stream_stderr"])
except KeyError:
pass

# ================= Streaming Handler =================

# Add the output handler.
self._handlers["stream_stdout"] = _logging.StreamHandler(sys.stdout)
self._handlers["stream_stdout"].addFilter(
lambda record: record.levelno <= _logging.INFO
)

self._handlers["stream_stderr"] = _logging.StreamHandler(sys.stderr)
self._handlers["stream_stderr"].addFilter(
lambda record: record.levelno > _logging.INFO
)

Formatter = StdOutFormatter

self._handlers["stream_stdout"].setFormatter(Formatter())
self._logger.addHandler(self._handlers["stream_stdout"])

try:
self._handlers["stream_stderr"].setFormatter(Formatter())
self._logger.addHandler(self._handlers["stream_stderr"])
except KeyError:
pass

def get_verbosity(self):
"""Return how much logging output will be produced."""
if self._logger is not None:
return self._logger.getEffectiveLevel()

def set_verbosity(self, verbosity_level):
"""Sets the threshold for what messages will be logged."""
if self._logger is not None:
self._logger.setLevel(verbosity_level)

for handler in self._logger.handlers:
handler.setLevel(verbosity_level)

@contextmanager
def temp_verbosity(self, verbosity_level):
"""Sets the a temporary threshold for what messages will be logged."""

if self._logger is not None:

old_verbosity = self.get_verbosity()

try:
self.set_verbosity(verbosity_level)
yield

finally:
self.set_verbosity(old_verbosity)

else:
try:
yield

finally:
pass

def debug(self, msg, *args, **kwargs):
"""
Log 'msg % args' with severity 'DEBUG'.
To pass exception information, use the keyword argument exc_info with
a true value, e.g.
logger.debug("Houston, we have a %s", "thorny problem", exc_info=1)
"""
if self._logger is not None:
self._logger._log(Logger.DEBUG, msg, args, **kwargs)

def info(self, msg, *args, **kwargs):
"""
Log 'msg % args' with severity 'INFO'.
To pass exception information, use the keyword argument exc_info with
a true value, e.g.
logger.info("Houston, we have a %s", "interesting problem", exc_info=1)
"""
if self._logger is not None:
self._logger._log(Logger.INFO, msg, args, **kwargs)

def warning(self, msg, *args, **kwargs):
"""
Log 'msg % args' with severity 'WARNING'.
To pass exception information, use the keyword argument exc_info with
a true value, e.g.
logger.warning("Houston, we have a %s", "bit of a problem", exc_info=1)
"""
if self._logger is not None:
self._logger._log(Logger.WARNING, msg, args, **kwargs)

def error(self, msg, *args, **kwargs):
"""
Log 'msg % args' with severity 'ERROR'.
To pass exception information, use the keyword argument exc_info with
a true value, e.g.
logger.error("Houston, we have a %s", "major problem", exc_info=1)
"""
if self._logger is not None:
self._logger._log(Logger.ERROR, msg, args, **kwargs)

def critical(self, msg, *args, **kwargs):
"""
Log 'msg % args' with severity 'CRITICAL'.
To pass exception information, use the keyword argument exc_info with
a true value, e.g.
logger.critical("Houston, we have a %s", "major disaster", exc_info=1)
"""
if self._logger is not None:
self._logger._log(Logger.CRITICAL, msg, args, **kwargs)


# Necessary to catch the correct caller
_logging._srcfile = os.path.normcase(inspect.getfile(Logger.__class__))

logging = Logger()
Loading

0 comments on commit 0f999c4

Please sign in to comment.