|
| 1 | +from unittest.mock import patch |
| 2 | + |
1 | 3 | import pytest |
2 | 4 | from django.db import connections |
| 5 | +from django.db.utils import OperationalError |
3 | 6 |
|
4 | 7 | from sentry.db.models.query import in_iexact |
5 | 8 | from sentry.models.commit import Commit |
@@ -78,6 +81,96 @@ def test_wrapper_over_values_list(self) -> None: |
78 | 81 | qs = User.objects.all().values_list("id") |
79 | 82 | assert list(qs) == list(self.range_wrapper(qs, result_value_getter=lambda r: r[0])) |
80 | 83 |
|
| 84 | + def test_retry_on_operational_error_success_after_failures(self) -> None: |
| 85 | + """Test that with query_timeout_retries=3, after 2 errors and 1 success it works.""" |
| 86 | + total = 5 |
| 87 | + for _ in range(total): |
| 88 | + self.create_user() |
| 89 | + |
| 90 | + qs = User.objects.all() |
| 91 | + batch_attempts: list[int] = [] |
| 92 | + current_batch_count = 0 |
| 93 | + original_getitem = type(qs).__getitem__ |
| 94 | + |
| 95 | + def mock_getitem(self, slice_obj): |
| 96 | + nonlocal current_batch_count |
| 97 | + current_batch_count += 1 |
| 98 | + if len(batch_attempts) == 0 and current_batch_count <= 2: |
| 99 | + raise OperationalError("canceling statement due to user request") |
| 100 | + if len(batch_attempts) == 0 and current_batch_count == 3: |
| 101 | + batch_attempts.append(current_batch_count) |
| 102 | + return original_getitem(self, slice_obj) |
| 103 | + |
| 104 | + with patch.object(type(qs), "__getitem__", mock_getitem): |
| 105 | + results = list( |
| 106 | + self.range_wrapper(qs, step=10, query_timeout_retries=3, retry_delay_seconds=0.01) |
| 107 | + ) |
| 108 | + |
| 109 | + assert len(results) == total |
| 110 | + assert batch_attempts[0] == 3 |
| 111 | + |
| 112 | + def test_retry_exhausted_raises_exception(self) -> None: |
| 113 | + """Test that after exhausting retries, the OperationalError is raised.""" |
| 114 | + total = 5 |
| 115 | + for _ in range(total): |
| 116 | + self.create_user() |
| 117 | + |
| 118 | + qs = User.objects.all() |
| 119 | + |
| 120 | + def always_fail(self, slice_obj): |
| 121 | + raise OperationalError("canceling statement due to user request") |
| 122 | + |
| 123 | + with patch.object(type(qs), "__getitem__", always_fail): |
| 124 | + with pytest.raises(OperationalError, match="canceling statement due to user request"): |
| 125 | + list( |
| 126 | + self.range_wrapper( |
| 127 | + qs, step=10, query_timeout_retries=3, retry_delay_seconds=0.01 |
| 128 | + ) |
| 129 | + ) |
| 130 | + |
| 131 | + def test_retry_does_not_catch_other_exceptions(self) -> None: |
| 132 | + """Test that non-OperationalError exceptions are not retried.""" |
| 133 | + total = 5 |
| 134 | + for _ in range(total): |
| 135 | + self.create_user() |
| 136 | + |
| 137 | + qs = User.objects.all() |
| 138 | + |
| 139 | + attempt_count = {"count": 0} |
| 140 | + |
| 141 | + def raise_value_error(self, slice_obj): |
| 142 | + attempt_count["count"] += 1 |
| 143 | + raise ValueError("Some other error") |
| 144 | + |
| 145 | + with patch.object(type(qs), "__getitem__", raise_value_error): |
| 146 | + with pytest.raises(ValueError, match="Some other error"): |
| 147 | + list( |
| 148 | + self.range_wrapper( |
| 149 | + qs, step=10, query_timeout_retries=3, retry_delay_seconds=0.01 |
| 150 | + ) |
| 151 | + ) |
| 152 | + assert attempt_count["count"] == 1 |
| 153 | + |
| 154 | + def test_no_retry_when_query_timeout_retries_is_none(self) -> None: |
| 155 | + """Test that when query_timeout_retries is None, no retry logic is applied.""" |
| 156 | + total = 5 |
| 157 | + for _ in range(total): |
| 158 | + self.create_user() |
| 159 | + |
| 160 | + qs = User.objects.all() |
| 161 | + |
| 162 | + attempt_count = {"count": 0} |
| 163 | + |
| 164 | + def fail_once(self, slice_obj): |
| 165 | + attempt_count["count"] += 1 |
| 166 | + raise OperationalError("canceling statement due to user request") |
| 167 | + |
| 168 | + with patch.object(type(qs), "__getitem__", fail_once): |
| 169 | + with pytest.raises(OperationalError, match="canceling statement due to user request"): |
| 170 | + list(self.range_wrapper(qs, step=10, query_timeout_retries=None)) |
| 171 | + |
| 172 | + assert attempt_count["count"] == 1 |
| 173 | + |
81 | 174 |
|
82 | 175 | @no_silo_test |
83 | 176 | class RangeQuerySetWrapperWithProgressBarTest(RangeQuerySetWrapperTest): |
|
0 commit comments