Skip to content

Commit

Permalink
fea: support for py3.8–3.12 && fix flake8 issues && optimize test run…
Browse files Browse the repository at this point in the history
…time (#48)
  • Loading branch information
aeeeeeep authored Jan 7, 2025
1 parent a69b0f7 commit 3708538
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 24 deletions.
22 changes: 13 additions & 9 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,26 @@ jobs:
build:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4

- name: Set up Python 3.10
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: "3.10"
python-version: ${{ matrix.python-version }}

- name: Cache Python dependencies
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements/*.txt') }}
key: ${{ runner.os }}-pip-${{ hashFiles('requirements/*.txt') }}-python-${{ matrix.python-version }}
restore-keys: |
${{ runner.os }}-pip-
${{ runner.os }}-pip-python-${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -38,10 +42,10 @@ jobs:
- name: Lint with flake8
run: |
# Stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 ./objwatch ./examples --count --select=E9,F63,F7,F82 --show-source --statistics
# Exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 ./objwatch ./examples --count --exit-zero --max-complexity=16 --max-line-length=127 --statistics
- name: Test with unittest
run: |
python3 -m unittest discover -s tests
python -m unittest discover -s tests
15 changes: 12 additions & 3 deletions objwatch/event_handls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

import atexit
import xml.etree.ElementTree as ET
from types import NoneType, FunctionType
from types import FunctionType

try:
from types import NoneType
except ImportError:
NoneType = type(None)

from typing import Any, Dict, Optional
from .utils.logger import log_debug
from .utils.logger import log_debug, log_warn
from .events import EventType


Expand Down Expand Up @@ -320,6 +326,9 @@ def save_xml(self) -> None:
"""
if self.output_xml and not self.is_xml_saved:
tree = ET.ElementTree(self.stack_root)
ET.indent(tree)
if hasattr(ET, 'indent'):
ET.indent(tree)
else:
log_warn("Current Python version does not support `xml.etree.ElementTree.indent`. XML formatting is skipped.")
tree.write(self.output_xml, encoding='utf-8', xml_declaration=True)
self.is_xml_saved = True
13 changes: 4 additions & 9 deletions objwatch/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _get_function_info(self, frame: FrameType, event: str) -> Dict[str, Any]:

return func_info

def trace_func_factory(self) -> FunctionType:
def trace_factory(self) -> FunctionType: # noqa: C901
"""
Create the tracing function to be used with sys.settrace.
Expand All @@ -190,19 +190,14 @@ def trace_func(frame: FrameType, event: str, arg: Any) -> Optional[FunctionType]
return trace_func

# Handle multi-GPU ranks if PyTorch is available
rank_info = ""
if self.torch_available:
if (
self.current_rank is None
and torch.distributed
and torch.distributed.is_initialized()
):
if self.current_rank is None and torch.distributed and torch.distributed.is_initialized():
self.current_rank = torch.distributed.get_rank()
if self.current_rank in self.ranks:
rank_info: str = f"[Rank {self.current_rank}] "
elif self.current_rank is not None and self.current_rank not in self.ranks:
return trace_func
else:
rank_info = ""

if event == "call":
func_info = self._get_function_info(frame, event)
Expand Down Expand Up @@ -364,7 +359,7 @@ def start(self) -> None:
Start the tracing process by setting the trace function.
"""
log_info("Starting tracing.")
sys.settrace(self.trace_func_factory())
sys.settrace(self.trace_factory())
if self.torch_available and torch.distributed and torch.distributed.is_initialized():
torch.distributed.barrier()

Expand Down
5 changes: 4 additions & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
torch==2.3.1
-f https://download.pytorch.org/whl/torch_stable.html

torch==2.3.1+cpu
numpy
2 changes: 1 addition & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def test_custom_wrapper_call_and_return(self):
mock_frame.f_code.co_name = 'custom_func'
mock_frame.f_locals = {'arg1': 'value1'}

trace_func = self.obj_watch.tracer.trace_func_factory()
trace_func = self.obj_watch.tracer.trace_factory()

trace_func(mock_frame, 'call', None)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_output_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def inner_function(self, lst):

return self.lst

with patch.object(self.tracer, 'trace_func_factory', return_value=self.tracer.trace_func_factory()):
with patch.object(self.tracer, 'trace_factory', return_value=self.tracer.trace_factory()):
self.tracer.start()
try:
t = TestClass()
Expand Down

0 comments on commit 3708538

Please sign in to comment.