Skip to content

Commit

Permalink
Enable some ruff rules on tests/ folder (#4212)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbeauchesne authored Feb 28, 2025
1 parent 3276393 commit 2597fc3
Show file tree
Hide file tree
Showing 22 changed files with 55 additions and 57 deletions.
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,10 @@ ignore = [
"tests/*" = [
"ANN201", # missing-return-type-undocumented-public-function: TODO
"N801", # invalid-class-name: TODO
"ARG002", # unused-method-argument: TODO
"E501", # line-too-long: TODO
"SIM117", # multiple-with-statements: TODO
"N806", # non-lowercase-variable-in-function: TODO
"TRY002", # raise-vanilla-class: TODO
"FBT002", # boolean-default-value-positional-argument: TODO
"B007", # unused-loop-control-variable: TODO
"INP001", # implicit-namespace-package: TODO
"TRY301", # raise-within-try: TODO

Expand Down Expand Up @@ -218,6 +215,7 @@ ignore = [
]
"tests/parametric/*" = [
"ANN401", # any-type: TODO
"ARG002", # unused-method-argument: TODO
]
"tests/parametric/{test_headers_baggage.py,test_headers_datadog.py,test_library_tracestats.py}" = [
"N802", # invalid-function-name: some tests methods contains code with capital letters
Expand Down
2 changes: 1 addition & 1 deletion tests/appsec/iast/source/test_kafka_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ class TestKafkaKey(BaseSourceTest):
source_type = "kafka.message.key"
source_value = "hello key!"

def get_sources(self, request):
def get_sources(self, request): # noqa: ARG002
iast_event = get_all_iast_events()
return get_iast_sources(iast_event)
2 changes: 1 addition & 1 deletion tests/appsec/iast/source/test_kafka_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ class TestKafkaValue(BaseSourceTest):
source_type = "kafka.message.value"
source_value = "hello value!"

def get_sources(self, request):
def get_sources(self, request): # noqa: ARG002
iast_event = get_all_iast_events()
return get_iast_sources(iast_event)
28 changes: 13 additions & 15 deletions tests/appsec/iast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def assert_iast_vulnerability(
def assert_metric(request, metric, *, expected: bool):
spans_checked = 0
metric_available = False
for data, trace, span in interfaces.library.get_spans(request):
for _, __, span in interfaces.library.get_spans(request):
if metric in span["metrics"]:
metric_available = True
spans_checked += 1
Expand Down Expand Up @@ -181,19 +181,17 @@ def test_secure(self):
def assert_no_iast_event(request, tested_vulnerability_type=None) -> None:
assert request.status_code == 200, f"Request failed with status code {request.status_code}"

for data, _, span in interfaces.library.get_spans(request=request):
logger.info(f"Looking for IAST events in {data['log_filename']}")
meta, meta_struct = _get_span_meta(request=request)
iast_json = meta.get("_dd.iast.json") if meta else meta_struct.get("iast")
if iast_json is not None:
if tested_vulnerability_type is None:
logger.error(json.dumps(iast_json, indent=2))
raise ValueError("Unexpected vulnerability reported")
elif iast_json["vulnerabilities"]:
for vuln in iast_json["vulnerabilities"]:
if vuln["type"] == tested_vulnerability_type:
logger.error(json.dumps(iast_json, indent=2))
raise ValueError(f"Unexpected vulnerability reported: {vuln['type']}")
meta, meta_struct = _get_span_meta(request=request)
iast_json = meta.get("_dd.iast.json") if meta else meta_struct.get("iast")
if iast_json is not None:
if tested_vulnerability_type is None:
logger.error(json.dumps(iast_json, indent=2))
raise ValueError("Unexpected vulnerability reported")
elif iast_json["vulnerabilities"]:
for vuln in iast_json["vulnerabilities"]:
if vuln["type"] == tested_vulnerability_type:
logger.error(json.dumps(iast_json, indent=2))
raise ValueError(f"Unexpected vulnerability reported: {vuln['type']}")


def validate_stack_traces(request):
Expand Down Expand Up @@ -285,7 +283,7 @@ def validate_stack_traces(request):
assert locationFrame is not None, "location not found in stack trace"


def validate_extended_location_data(request, vulnerability_type, is_expected_location_required=True):
def validate_extended_location_data(request, vulnerability_type, *, is_expected_location_required=True):
span = interfaces.library.get_root_span(request)
iast = span.get("meta", {}).get("_dd.iast.json")
assert iast, "Expected at least one vulnerability"
Expand Down
4 changes: 2 additions & 2 deletions tests/appsec/test_asm_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def assert_product_is_enabled(request, product) -> None:
product_enabled = False
tags = "_dd.iast.json" if product == "iast" else "_dd.appsec.json"
meta_struct_key = "iast" if product == "iast" else "appsec"
for data, trace, span in interfaces.library.get_spans(request=request):
for _, __, span in interfaces.library.get_spans(request=request):
# Check if the product is enabled in meta
meta = span["meta"]
if tags in meta:
Expand Down Expand Up @@ -679,7 +679,7 @@ class SCAStandalone_Telemetry_Base:

def assert_standalone_is_enabled(self, request):
# test standalone is enabled and dropping traces
for data, _trace, span in interfaces.library.get_spans(request):
for _, __, span in interfaces.library.get_spans(request):
assert span["metrics"]["_sampling_priority_v1"] <= 0
assert span["metrics"]["_dd.apm.enabled"] == 0

Expand Down
2 changes: 1 addition & 1 deletion tests/appsec/waf/test_addresses.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class Test_BodyXml:
ATTACK = '<vmlframe src="xss">'
ENCODED_ATTACK = "&lt;vmlframe src=&quot;xss&quot;&gt;"

def weblog_post(self, path="/", params=None, data=None, headers=None, **kwargs):
def weblog_post(self, path="/", params=None, data=None, headers=None):
headers = headers or {}
headers["Content-Type"] = "application/xml"
data = f"<?xml version='1.0' encoding='utf-8'?>{data}"
Expand Down
2 changes: 1 addition & 1 deletion tests/debugger/test_debugger_exception_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from utils.tools import logger


def get_env_bool(env_var_name, default=False):
def get_env_bool(env_var_name, *, default=False) -> bool:
value = os.getenv(env_var_name, str(default)).lower()
return value in {"true", "True", "1"}

Expand Down
4 changes: 2 additions & 2 deletions tests/debugger/test_debugger_pii.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
@scenarios.debugger_pii_redaction
class Test_Debugger_PII_Redaction(debugger.Base_Debugger_Test):
############ setup ############
def _setup(self, line_probe=False):
def _setup(self, *, line_probe=False):
self.initialize_weblog_remote_config()

if line_probe:
Expand All @@ -130,7 +130,7 @@ def _setup(self, line_probe=False):
self.wait_for_all_probes_emitting()

############ assert ############
def _assert(self, redacted_keys, redacted_types, line_probe=False):
def _assert(self, redacted_keys, redacted_types, *, line_probe=False):
self.collect()
self.assert_setup_ok()
self.assert_rc_state_not_error()
Expand Down
3 changes: 2 additions & 1 deletion tests/fuzzer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def value(self):


class _RequestDumper:
def __init__(self, name=None, enabled=True):
def __init__(self, name=None, *, enabled=True):
self.enabled = enabled
self.logger = None
if name:
Expand Down Expand Up @@ -70,6 +70,7 @@ def __init__(
request_count=None,
max_time=None,
dump_on_status=("500",),
*,
debug=False,
systematic_export=False,
):
Expand Down
4 changes: 2 additions & 2 deletions tests/fuzzer/request_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class RequestMutator:
invalid_methods = tuple()
invalid_header_keys = tuple()

def __init__(self, no_mutation=False):
def __init__(self, *, no_mutation=False):
self.methods = tuple(method for method in self.methods if method not in self.invalid_methods)

self.invalid_header_keys = tuple(key.lower() for key in self.invalid_header_keys)
Expand Down Expand Up @@ -530,7 +530,7 @@ def get_random_payload(self, payload_type):
def get_payload_key(self):
return random.choice(data.blns)

def get_payload_value(self, allow_nested=False):
def get_payload_value(self, *, allow_nested=False):
if not allow_nested:
return random.choice(self.payload_values)

Expand Down
3 changes: 2 additions & 1 deletion tests/fuzzer/tools/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
format_string=None,
display_length=5,
value=0,
*,
has_raw_value=True,
raw_name=None,
):
Expand Down Expand Up @@ -339,7 +340,7 @@ def signal(self, key, value):
def value(self, key, value):
self.logger.info(f"V {key}: {value}")

def pulse(self, metrics_getter, force=False):
def pulse(self, metrics_getter, *, force=False):
if self._is_report_time() or force:
metrics = metrics_getter()

Expand Down
8 changes: 4 additions & 4 deletions tests/integrations/test_db_integrations_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_db_user(self, excluded_operations=()):
"""Username for accessing the database."""
db_container = context.scenario.get_container_by_dd_integration_name(self.db_service)

for db_operation, span in self.get_spans(excluded_operations=excluded_operations):
for _, span in self.get_spans(excluded_operations=excluded_operations):
assert span["meta"]["db.user"].casefold() == db_container.db_user.casefold()

@missing_feature(library="python", reason="not implemented yet")
Expand All @@ -112,7 +112,7 @@ def test_db_instance(self, excluded_operations=()):
"""The name of the database being connected to. Database instance name. Formerly db.name"""
db_container = context.scenario.get_container_by_dd_integration_name(self.db_service)

for db_operation, span in self.get_spans(excluded_operations=excluded_operations):
for _, span in self.get_spans(excluded_operations=excluded_operations):
assert span["meta"]["db.instance"] == db_container.db_instance

@missing_feature(library="python", reason="not implemented yet")
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_db_jdbc_drive_classname(self):
assert span["meta"]["db.jdbc.driver_classname"].strip(), f"Test is failing for {db_operation}"

def test_error_message(self):
for db_operation, span in self.get_spans(operations=["select_error"]):
for _, span in self.get_spans(operations=["select_error"]):
# A string representing the error message.
assert span["meta"]["error.message"].strip()

Expand Down Expand Up @@ -257,7 +257,7 @@ def test_db_name(self):
super().test_db_name()

@bug(context.library < "[email protected]", reason="APMRP-360")
def test_db_user(self, excluded_operations=()):
def test_db_user(self):
super().test_db_user()


Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/test_inferred_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_span(interface, resource):
return None


def assert_api_gateway_span(test_case, span, path, status_code, is_distributed=False, is_error=False):
def assert_api_gateway_span(test_case, span, path, status_code, *, is_distributed=False, is_error=False):
assert span["name"] == "aws.apigateway", "Inferred AWS API Gateway span name should be 'aws.apigateway'"

# Assertions to check if the span data contains the required keys and values.
Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/test_open_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_obfuscate_query(self):

def test_sql_success(self):
"""We check all sql launched for the app work"""
for db_operation, request in self.get_requests(excluded_operations=["select_error"]):
for _, request in self.get_requests(excluded_operations=["select_error"]):
span = self.get_span_from_agent(request)
assert "error" not in span or span["error"] == 0

Expand Down
20 changes: 10 additions & 10 deletions tests/parametric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _write_log(self, log_type, json_trace):
log.write(f"\n{log_type}>>>>\n")
log.write(json.dumps(json_trace))

def traces(self, clear=False, **kwargs) -> list[Trace]:
def traces(self, *, clear=False, **kwargs) -> list[Trace]:
resp = self._session.get(self._url("/test/session/traces"), **kwargs)
if clear:
self.clear()
Expand Down Expand Up @@ -216,7 +216,7 @@ def set_trace_delay(self, delay):
resp = self._session.post(self._url("/test/settings"), json={"trace_request_delay": delay})
assert resp.status_code == 202

def raw_telemetry(self, clear=False):
def raw_telemetry(self, *, clear=False):
raw_reqs = self.requests()
reqs = []
for req in raw_reqs:
Expand All @@ -226,7 +226,7 @@ def raw_telemetry(self, clear=False):
self.clear()
return reqs

def telemetry(self, clear=False, **kwargs):
def telemetry(self, *, clear=False, **kwargs):
resp = self._session.get(self._url("/test/session/apmtelemetry"), **kwargs)
if clear:
self.clear()
Expand All @@ -244,7 +244,7 @@ def requests(self, **kwargs) -> list[AgentRequest]:
self._write_log("requests", resp_json)
return resp_json

def rc_requests(self, post_only=False):
def rc_requests(self, *, post_only=False):
reqs = self.requests()
rc_reqs = [r for r in reqs if r["url"].endswith("/v0.7/config") and (not post_only or r["method"] == "POST")]
for r in rc_reqs:
Expand Down Expand Up @@ -314,7 +314,7 @@ def wait_for_num_traces(
"""
num_received = None
traces = []
for i in range(wait_loops):
for _ in range(wait_loops):
try:
traces = self.traces(clear=False)
except requests.exceptions.RequestException:
Expand Down Expand Up @@ -344,7 +344,7 @@ def wait_for_num_spans(
When sort_by_start=True returned traces are sorted by the span start time to simplify assertions by knowing that returned traces are in the same order as they have been created.
"""
num_received = None
for i in range(wait_loops):
for _ in range(wait_loops):
try:
traces = self.traces(clear=False)
except requests.exceptions.RequestException:
Expand All @@ -369,7 +369,7 @@ def wait_for_num_spans(

def wait_for_telemetry_event(self, event_name: str, *, clear: bool = False, wait_loops: int = 200):
"""Wait for and return the given telemetry event from the test agent."""
for i in range(wait_loops):
for _ in range(wait_loops):
try:
events = self.telemetry(clear=False)
except requests.exceptions.RequestException:
Expand Down Expand Up @@ -406,7 +406,7 @@ def wait_for_rc_apply_state(
last_known_state = None
for _ in range(wait_loops):
try:
rc_reqs = self.rc_requests(post_only)
rc_reqs = self.rc_requests(post_only=post_only)
except requests.exceptions.RequestException:
logger.exception("Error getting RC requests")
else:
Expand Down Expand Up @@ -436,7 +436,7 @@ def wait_for_rc_apply_state(

def wait_for_rc_capabilities(self, wait_loops: int = 100) -> set[Capabilities]:
"""Wait for the given RemoteConfig apply state to be received by the test agent."""
for i in range(wait_loops):
for _ in range(wait_loops):
try:
rc_reqs = self.rc_requests()
except requests.exceptions.RequestException:
Expand Down Expand Up @@ -485,7 +485,7 @@ def assert_rc_capabilities(self, expected_capabilities: set[Capabilities], wait_

def wait_for_tracer_flare(self, case_id: str | None = None, *, clear: bool = False, wait_loops: int = 100):
"""Wait for the tracer-flare to be received by the test agent."""
for i in range(wait_loops):
for _ in range(wait_loops):
try:
tracer_flares = self.get_tracer_flares()
except requests.exceptions.RequestException:
Expand Down
2 changes: 1 addition & 1 deletion tests/parametric/test_library_tracestats.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def test_relative_error_TS008(self, library_env, test_agent, test_library):

with test_library:
# Create 10 traces to get more data
for i in range(10):
for _ in range(10):
with test_library.dd_start_span(name="web.request", resource="/users", service="webserver"):
pass

Expand Down
8 changes: 4 additions & 4 deletions tests/parametric/test_span_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_single_rule_rate_limiter_span_sampling_sss008(self, test_agent, test_li
# generate three traces before requesting them to avoid timing issues
trace_span_ids = []
with test_library:
for i in range(6):
for _ in range(6):
with test_library.dd_start_span(name="web.request", service="webserver") as s1:
pass
trace_span_ids.append((s1.trace_id, s1.span_id))
Expand Down Expand Up @@ -277,7 +277,7 @@ def test_sampling_rate_not_absolute_value_sss009(self, test_agent, test_library)
half do not.
"""
# make 100 new traces, each with one span
for i in range(100):
for _ in range(100):
with test_library:
with test_library.dd_start_span(name="web.request", service="webserver"):
pass
Expand Down Expand Up @@ -454,11 +454,11 @@ def test_multi_rule_independent_rate_limiters_sss013(self, test_agent, test_libr
# generate spans before requesting them to avoid timing issues
trace_and_span_ids = []
with test_library:
for i in range(4):
for _ in range(4):
with test_library.dd_start_span(name="web.request", service="webserver") as s1:
pass
trace_and_span_ids.append((s1.trace_id, s1.span_id))
for i in range(6):
for _ in range(6):
with test_library.dd_start_span(name="web.request2", service="webserver2") as s2:
pass
trace_and_span_ids.append((s2.trace_id, s2.span_id))
Expand Down
2 changes: 1 addition & 1 deletion tests/perfs/test_performances.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def nested(size, deep):

for path in TESTED_PATHS:
for header in headers:
for data in datas:
for __ in datas:
for _ in range(5):
self.add_request(
{
Expand Down
2 changes: 1 addition & 1 deletion tests/remote_config/test_remote_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def dict_is_included(sub_dict: dict, main_dict: dict):
return True


def dict_is_in_array(needle: dict, haystack: list, allow_additional_fields=True):
def dict_is_in_array(needle: dict, haystack: list, *, allow_additional_fields=True):
"""Returns true is needle is contained in haystack.
If allow_additional_field is true, needle can contains less field than the one in haystack
"""
Expand Down
Loading

0 comments on commit 2597fc3

Please sign in to comment.