Skip to content

Commit

Permalink
Filter Repeating Logs (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
NihalHarish authored Feb 18, 2021
1 parent c29e352 commit 64fb95f
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 1 deletion.
1 change: 1 addition & 0 deletions smdebug/core/config_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
LATEST_GLOBAL_STEP_SEEN = "latest-global-step-seen"
LATEST_GLOBAL_STEP_SAVED = "latest-global-step-saved"
LATEST_MODE_STEP = "latest-mode-step"
LOG_DUPLICATION_THRESHOLD = 3
TRAINING_RUN = "training-run"

INCOMPLETE_STEP_WAIT_WINDOW_KEY = "SMDEBUG_INCOMPLETE_STEP_WAIT_WINDOW"
Expand Down
17 changes: 17 additions & 0 deletions smdebug/core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import os
import socket
import sys
from collections import defaultdict

# First Party
from smdebug.core.config_constants import LOG_DUPLICATION_THRESHOLD

_logger_initialized = False

Expand All @@ -19,6 +23,18 @@ def filter(self, record):
return record.levelno < self.level


class DuplicateLogFilter:
"""Filters duplicate messages to prevent spamming users"""

def __init__(self):
self.msgs = defaultdict(int)
self.repeat_threshold = LOG_DUPLICATION_THRESHOLD

def filter(self, record):
self.msgs[record.msg] += 1
return self.msgs[record.msg] <= self.repeat_threshold


def _get_log_level():
default = "info"
log_level = os.environ.get("SMDEBUG_LOG_LEVEL", default=default)
Expand Down Expand Up @@ -75,6 +91,7 @@ def get_logger(name="smdebug"):
logger.addHandler(stderr_handler)

logger.addHandler(stdout_handler)
logger.addFilter(DuplicateLogFilter())

# SMDEBUG_LOG_PATH is the full path to log file
# by default, log is only written to stdout&stderr
Expand Down
4 changes: 3 additions & 1 deletion smdebug/core/state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from smdebug.core.logger import get_logger

logger = get_logger()


# This is 'predicate' for sorting the list of states based on seen steps.
def _rule_for_sorting(state):
return state[LATEST_GLOBAL_STEP_SEEN]
Expand Down Expand Up @@ -99,7 +101,7 @@ def is_checkpoint_updated(self):
if self._checkpoint_dir is not None:
checkpoint_files = self._get_checkpoint_files_in_dir(self._checkpoint_dir)
if not checkpoint_files:
logger.info(
logger.debug(
"Checkpoints not updated. There are no checkpoint files created yet, to be updated"
)
return False
Expand Down
33 changes: 33 additions & 0 deletions tests/core/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Standard Library
import types

# First Party
from smdebug.core.config_constants import LOG_DUPLICATION_THRESHOLD
from smdebug.core.logger import DuplicateLogFilter
from smdebug.core.utils import get_logger


def test_dup_filter():
logger = get_logger()
dup_filter = None

for _filter in logger.filters:
if isinstance(_filter, DuplicateLogFilter):
dup_filter = _filter
dup_filter.test_counter = 0

dup_filter.old_filter = dup_filter.filter

def filter(self, record):
if self.old_filter(record):
self.test_counter += 1
return True
else:
return False

dup_filter.filter = types.MethodType(filter, dup_filter)

for _ in range(10):
logger.warning("I love spam musubi")

assert dup_filter.test_counter == LOG_DUPLICATION_THRESHOLD

0 comments on commit 64fb95f

Please sign in to comment.