Skip to content

Commit

Permalink
implement depth & total span limits with tests (#1835)
Browse files Browse the repository at this point in the history
* implement depth & total span limits with tests

* mock time monotonic for depth limit test to check for proper trimming on windows

* name/document fields, mock time monotonic with counter fixture

* make total limit independent from depth limit

* add comments to test asserts

* split magic number into two magic numbers
  • Loading branch information
sfc-gh-mchok authored Nov 7, 2024
1 parent de448a9 commit e559765
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 68 deletions.
102 changes: 88 additions & 14 deletions src/snowflake/cli/api/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import uuid
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from typing import ClassVar, Dict, Iterator, List, Optional
from heapq import nsmallest
from typing import ClassVar, Dict, Iterator, List, Optional, Set


class CLIMetricsInvalidUsageError(RuntimeError):
Expand Down Expand Up @@ -84,24 +85,55 @@ class for holding metrics span data and encapsulating related operations
START_TIME_KEY: ClassVar[str] = "start_time"
EXECUTION_TIME_KEY: ClassVar[str] = "execution_time"
ERROR_KEY: ClassVar[str] = "error"

# total number of spans started under this span, inclusive of itself and its children's children (recursively)
SPAN_COUNT_IN_SUBTREE_KEY: ClassVar[str] = "span_count_in_subtree"
# the number of spans in the path between the current span and the topmost parent span, inclusive of both
SPAN_DEPTH_KEY: ClassVar[str] = "span_depth"
# denotes whether direct children were trimmed from telemetry payload
TRIMMED_KEY: ClassVar[str] = "trimmed"

# constructor vars
name: str
start_time: float # relative to when the command first started executing
parent: Optional[CLIMetricsSpan] = None

# ensure we get unique ids for each step for the parent-child link in case of steps with the same name
step_id: str = field(init=False, default_factory=lambda: uuid.uuid4().hex)
# vars for reporting
span_id: str = field(init=False, default_factory=lambda: uuid.uuid4().hex)
execution_time: Optional[float] = field(init=False, default=None)
error: Optional[BaseException] = field(init=False, default=None)
span_depth: int = field(init=False, default=1)
span_count_in_subtree: int = field(init=False, default=1)

# vars for postprocessing
# spans started directly under this one
children: Set[CLIMetricsSpan] = field(init=False, default_factory=set)

# private state
# start time of the step from the monotonic clock in order to calculate execution time
_monotonic_start: float = field(
init=False, default_factory=lambda: time.monotonic()
)

def __hash__(self) -> int:
return hash(self.span_id)

def __post_init__(self):
if not self.name:
raise CLIMetricsInvalidUsageError("step name must not be empty")
raise CLIMetricsInvalidUsageError("span name must not be empty")

if self.parent:
self.parent.add_child(self)
self.span_depth = self.parent.span_depth + 1

def increment_subtree_node_count(self) -> None:
self.span_count_in_subtree += 1

if self.parent:
self.parent.increment_subtree_node_count()

def add_child(self, child: CLIMetricsSpan) -> None:
self.children.add(child)
self.increment_subtree_node_count()

def finish(self, error: Optional[BaseException] = None) -> None:
"""
Expand All @@ -119,19 +151,21 @@ def finish(self, error: Optional[BaseException] = None) -> None:

def to_dict(self) -> Dict:
"""
Custom dict conversion function to be used for reporting telemetry, with only the required fields
Custom dict conversion function to be used for reporting telemetry
"""

return {
self.ID_KEY: self.step_id,
self.ID_KEY: self.span_id,
self.NAME_KEY: self.name,
self.START_TIME_KEY: self.start_time,
self.PARENT_KEY: self.parent.name if self.parent is not None else None,
self.PARENT_ID_KEY: self.parent.step_id
self.PARENT_ID_KEY: self.parent.span_id
if self.parent is not None
else None,
self.EXECUTION_TIME_KEY: self.execution_time,
self.ERROR_KEY: type(self.error).__name__ if self.error else None,
self.SPAN_COUNT_IN_SUBTREE_KEY: self.span_count_in_subtree,
self.SPAN_DEPTH_KEY: self.span_depth,
}


Expand All @@ -141,6 +175,10 @@ class CLIMetrics:
Class to track various metrics across the execution of a command
"""

# limits for reporting purposes
SPAN_DEPTH_LIMIT: ClassVar[int] = 5
SPAN_TOTAL_LIMIT: ClassVar[int] = 100

_counters: Dict[str, int] = field(init=False, default_factory=dict)
# stack of in progress spans as command is executing
_in_progress_spans: List[CLIMetricsSpan] = field(init=False, default_factory=list)
Expand Down Expand Up @@ -213,15 +251,51 @@ def counters(self) -> Dict[str, int]:
# return a copy of the original dict to avoid mutating the original
return self._counters.copy()

@property
def num_spans_past_depth_limit(self) -> int:
return len(
[
span
for span in self._completed_spans
if span.span_depth > self.SPAN_DEPTH_LIMIT
]
)

@property
def num_spans_past_total_limit(self) -> int:
return max(0, len(self._completed_spans) - self.SPAN_TOTAL_LIMIT)

@property
def completed_spans(self) -> List[Dict]:
"""
returns the completed spans tracked throughout a command, sorted by start time, for reporting telemetry
Returns the completed spans tracked throughout a command, sorted by start time, for reporting telemetry
Ensures that the spans we send are within the configured limits and marks
certain spans as trimmed if their children would bypass the limits we set
"""
return [
step.to_dict()
for step in sorted(
self._completed_spans,
key=lambda step: step.start_time,
# take spans breadth-first within the depth and total limits
# since we care more about the big picture than granularity
spans_to_report = set(
nsmallest(
n=self.SPAN_TOTAL_LIMIT,
iterable=(
span
for span in self._completed_spans
if span.span_depth <= self.SPAN_DEPTH_LIMIT
),
key=lambda span: (span.span_depth, span.start_time),
)
)

# sort by start time to make reading the payload easier
sorted_spans_to_report = sorted(
spans_to_report, key=lambda span: span.start_time
)

return [
{
**span.to_dict(),
CLIMetricsSpan.TRIMMED_KEY: not span.children <= spans_to_report,
}
for span in sorted_spans_to_report
]
Loading

0 comments on commit e559765

Please sign in to comment.