Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fea: refactor tracer.py #61

Merged
merged 2 commits into from
Jan 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 108 additions & 100 deletions objwatch/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ def __init__(

# Handle multi-GPU support if PyTorch is available
self.torch_available: bool = torch_available
self.rank_info: str = ""
if self.torch_available:
self.current_rank: Optional[int] = None
self.current_rank = None
self.ranks: Set[int] = set(ranks if ranks is not None else [0])
else:
self.ranks: Set[int] = set()

# Load the function wrapper if provided
self.function_wrapper: Optional[FunctionWrapper] = self.load_wrapper(wrapper)
self.function_wrapper: FunctionWrapper = self.load_wrapper(wrapper)
self.call_depth: int = 0

def _process_targets(self, targets: Optional[List[Union[str, ModuleType]]]) -> Set[str]:
Expand Down Expand Up @@ -197,6 +198,72 @@ def filename_not_endswith(self, filename: str) -> bool:
"""
return not filename.endswith(tuple(self.targets))

def _handle_change_type(
self,
lineno: int,
class_name: str,
key: str,
old_value: Optional[Any],
current_value: Any,
old_value_len: Optional[int],
current_value_len: Optional[int],
) -> None:
"""
Helper function to handle the change type for both object attributes and local variables.

Args:
lineno (int): Line number where the change occurred.
class_name (str): Class name if the change relates to an object attribute.
key (str): The key (variable or attribute) being tracked.
old_value (Optional[Any]): The old value of the variable or attribute.
current_value (Any): The current value of the variable or attribute.
old_value_len (Optional[int]): The length of the old value (if applicable).
current_value_len (Optional[int]): The length of the current value (if applicable).
"""
if old_value_len is not None:
change_type: EventType = (
self.event_handlers.determine_change_type(old_value_len, current_value_len)
if old_value_len is not None
else EventType.UPD
)
else:
change_type = EventType.UPD

if id(old_value) == id(current_value):
if change_type == EventType.APD:
self.event_handlers.handle_apd(
lineno,
class_name,
key,
type(current_value),
old_value_len,
current_value_len,
self.call_depth,
self.rank_info,
)
elif change_type == EventType.POP:
self.event_handlers.handle_pop(
lineno,
class_name,
key,
type(current_value),
old_value_len,
current_value_len,
self.call_depth,
self.rank_info,
)
elif change_type == EventType.UPD:
self.event_handlers.handle_upd(
lineno,
class_name,
key,
old_value,
current_value,
self.call_depth,
self.rank_info,
self.function_wrapper,
)

def trace_factory(self) -> FunctionType: # noqa: C901
"""
Create the tracing function to be used with sys.settrace.
Expand All @@ -210,19 +277,20 @@ def trace_func(frame: FrameType, event: str, arg: Any) -> Optional[FunctionType]
return trace_func

# Handle multi-GPU ranks if PyTorch is available
rank_info = ""
if self.torch_available:
if self.current_rank is None and torch.distributed and torch.distributed.is_initialized():
self.current_rank = torch.distributed.get_rank()
if self.current_rank in self.ranks:
rank_info: str = f"[Rank {self.current_rank}] "
elif self.current_rank is not None and self.current_rank not in self.ranks:
if self.current_rank is None:
if torch.distributed and torch.distributed.is_initialized():
self.current_rank = torch.distributed.get_rank()
self.rank_info = f"[Rank {self.current_rank}] "
elif self.current_rank not in self.ranks:
return trace_func

lineno = frame.f_lineno
if event == "call":
func_info = self._get_function_info(frame)
self.event_handlers.handle_run(lineno, func_info, self.function_wrapper, self.call_depth, rank_info)
self.event_handlers.handle_run(
lineno, func_info, self.function_wrapper, self.call_depth, self.rank_info
)
self.call_depth += 1

if self.with_locals:
Expand All @@ -241,7 +309,7 @@ def trace_func(frame: FrameType, event: str, arg: Any) -> Optional[FunctionType]
self.call_depth -= 1
func_info = self._get_function_info(frame)
self.event_handlers.handle_end(
lineno, func_info, self.function_wrapper, self.call_depth, rank_info, arg
lineno, func_info, self.function_wrapper, self.call_depth, self.rank_info, arg
)

if self.with_locals and frame in self.tracked_locals:
Expand All @@ -263,48 +331,21 @@ def trace_func(frame: FrameType, event: str, arg: Any) -> Optional[FunctionType]
for key, current_value in current_attrs.items():
old_value = old_attrs.get(key, None)
old_value_len = old_attrs_lens.get(key, None)
if old_value_len is not None:
current_value_len = len(current_value)
change_type: EventType = self.event_handlers.determine_change_type(
old_value_len, current_value_len
)
else:
change_type = EventType.UPD

if id(old_value) == id(current_value):
if change_type == EventType.APD:
self.event_handlers.handle_apd(
lineno,
class_name,
key,
type(current_value),
old_value_len,
current_value_len,
self.call_depth,
rank_info,
)
elif change_type == EventType.POP:
self.event_handlers.handle_pop(
lineno,
class_name,
key,
type(current_value),
old_value_len,
current_value_len,
self.call_depth,
rank_info,
)
elif change_type == EventType.UPD:
self.event_handlers.handle_upd(
lineno,
class_name,
key,
old_value,
current_value,
self.call_depth,
rank_info,
self.function_wrapper,
)
is_current_seq = isinstance(current_value, log_sequence_types)
current_value_len = (
len(current_value) if old_value_len is not None and is_current_seq else None
)

self._handle_change_type(
lineno,
class_name,
key,
old_value,
current_value,
old_value_len,
current_value_len,
)

old_attrs[key] = current_value
if isinstance(current_value, log_sequence_types):
self.tracked_objects_lens[obj][key] = len(current_value)
Expand All @@ -318,68 +359,35 @@ def trace_func(frame: FrameType, event: str, arg: Any) -> Optional[FunctionType]

added_vars: Set[str] = set(current_locals.keys()) - set(old_locals.keys())
for var in added_vars:
current_local: Any = current_locals[var]
current_local = current_locals[var]
self.event_handlers.handle_upd(
lineno,
class_name="_",
key=var,
old_value=None,
current_value=current_local,
call_depth=self.call_depth,
rank_info=rank_info,
rank_info=self.rank_info,
function_wrapper=self.function_wrapper,
)
if isinstance(current_local, log_sequence_types):
self.tracked_locals_lens[frame][var] = len(current_local)

common_vars: Set[str] = set(old_locals.keys()) & set(current_locals.keys())
for var in common_vars:
old_local: Any = old_locals[var]
old_local_len: Optional[int] = old_locals_lens.get(var, None)
current_local: Any = current_locals[var]
if old_local_len is not None and isinstance(current_local, log_sequence_types):
current_local_len: int = len(current_local)
change_type: EventType = self.event_handlers.determine_change_type(
old_local_len, current_local_len
)
else:
change_type = EventType.UPD

if id(old_local) == id(current_local):
if change_type == EventType.APD:
self.event_handlers.handle_apd(
lineno,
"_",
var,
type(current_local),
old_local_len,
current_local_len,
self.call_depth,
rank_info,
)
elif change_type == EventType.POP:
self.event_handlers.handle_pop(
lineno,
"_",
var,
type(current_local),
old_local_len,
current_local_len,
self.call_depth,
rank_info,
)
elif change_type == EventType.UPD:
self.event_handlers.handle_upd(
lineno,
"_",
var,
old_local,
current_local,
self.call_depth,
rank_info,
self.function_wrapper,
)
if isinstance(current_local, log_sequence_types):
old_local = old_locals[var]
old_local_len: int = old_locals_lens.get(var, None)
current_local = current_locals[var]
is_current_seq = isinstance(current_local, log_sequence_types)
current_local_len: int = (
len(current_local) if old_local_len is not None and is_current_seq else None
)

self._handle_change_type(
lineno, "_", var, old_local, current_local, old_local_len, current_local_len
)

if is_current_seq:
self.tracked_locals_lens[frame][var] = len(current_local)

self.tracked_locals[frame] = current_locals
Expand Down
Loading