Skip to content

Commit cfb6053

Browse files
authored
Do not report feature data if deployment has it disabled. (#1662)
1 parent 3d51161 commit cfb6053

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

custom_model_runner/datarobot_drum/drum/language_predictors/base_language_predictor.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,16 @@ def __init__(
104104
self._mlops = None
105105
self._schema_validator = None
106106
self._prompt_column_name = DEFAULT_PROMPT_COLUMN_NAME
107+
self._deployment = None
108+
109+
self._tracking_settings = {
110+
"target_drift": {"enabled": True},
111+
"feature_drift": {"enabled": True},
112+
}
113+
self._data_collection = {"enabled": True}
114+
115+
self._settings_refresh_time = time.monotonic()
116+
self._settings_refresh_interval = 60 # sec
107117

108118
def configure(self, params):
109119
"""
@@ -154,6 +164,7 @@ def _init_mlops(self):
154164
if to_bool(self._params.get("allow_dr_api_access")):
155165
try:
156166
self._deployment = dr.Deployment.get(deployment_id)
167+
self._refresh_tracking_settings()
157168
except Exception as e:
158169
logger.warning(f"Failed to get deployment info: {e}", exc_info=True)
159170

@@ -172,6 +183,14 @@ def _init_mlops(self):
172183

173184
self._mlops.init()
174185

186+
def _refresh_tracking_settings(self):
187+
deployment_id = self._params.get("deployment_id", None)
188+
if to_bool(self._params.get("allow_dr_api_access")) and deployment_id is not None:
189+
self._deployment = dr.Deployment.get(deployment_id)
190+
self._tracking_settings = self._deployment.get_drift_tracking_settings()
191+
self._data_collection = self._deployment.get_predictions_data_collection_settings()
192+
self._settings_refresh_time = time.monotonic()
193+
175194
def _configure_mlops(self):
176195
# If monitor_settings were provided (e.g. for testing) use them, otherwise we will
177196
# use the API spooler as the default config.
@@ -318,6 +337,16 @@ def _mlops_report_chat_prediction(
318337
except DRCommonException:
319338
logger.exception("Failed to report deployment stats")
320339

340+
if self._deployment is not None:
341+
if time.monotonic() - self._settings_refresh_time > self._settings_refresh_interval:
342+
self._refresh_tracking_settings()
343+
344+
is_drift = self._tracking_settings["feature_drift"]["enabled"]
345+
is_collection = self._data_collection["enabled"]
346+
347+
if not (is_drift or is_collection):
348+
return
349+
321350
prompt_content = completion_create_params["messages"][-1]["content"]
322351
if isinstance(prompt_content, str):
323352
latest_message = completion_create_params["messages"][-1]["content"]

tests/unit/datarobot_drum/drum/language_predictors/test_base_language_predictor.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _language_predictor_with_mlops_params_dr_api_access(self):
123123

124124
@pytest.fixture
125125
def mock_dr_client(self):
126-
with patch.object(dr, "Client") as _:
126+
with patch.object(dr, "Client") as m:
127127
yield
128128

129129
def test_mlops_init(self, language_predictor_with_mlops, mock_mlops):
@@ -271,7 +271,10 @@ def test_association_id(self, language_predictor_with_mlops, mock_mlops):
271271
mock_chat.assert_called_once_with(ANY, association_id)
272272
hasattr(completion, "datarobot_association_id")
273273

274-
def test_prompt_column_name(self, chat_python_model_adapter, mock_mlops, mock_dr_client):
274+
@pytest.mark.parametrize("row_storage_enabled", [False, True])
275+
def test_prompt_column_name(
276+
self, chat_python_model_adapter, mock_mlops, mock_dr_client, row_storage_enabled
277+
):
275278
language_predictor = TestLanguagePredictor()
276279
language_predictor_with_mlops_params = (
277280
self._language_predictor_with_mlops_params_dr_api_access()
@@ -282,6 +285,13 @@ def test_prompt_column_name(self, chat_python_model_adapter, mock_mlops, mock_dr
282285
deployment_instance.return_value.get_champion_model_package.return_value = Mock()
283286
mock_deployment.get.return_value = deployment_instance
284287

288+
deployment_instance.get_drift_tracking_settings.return_value = {
289+
"target_drift": {"enabled": False},
290+
"feature_drift": {"enabled": False},
291+
}
292+
deployment_instance.get_predictions_data_collection_settings.return_value = {
293+
"enabled": row_storage_enabled
294+
}
285295
language_predictor.configure(language_predictor_with_mlops_params)
286296

287297
def chat_hook(completion_request):
@@ -298,16 +308,19 @@ def chat_hook(completion_request):
298308
}
299309
)
300310

301-
mock_mlops.report_predictions_data.assert_called_once_with(
302-
ANY,
303-
["How are you"],
304-
association_ids=ANY,
305-
)
306-
# Compare features dataframe separately as this doesn't play nice with assert_called
307-
assert (
308-
mock_mlops.report_predictions_data.call_args.args[0]["newPromptName"].values[0]
309-
== "Hello!"
310-
)
311+
if row_storage_enabled:
312+
mock_mlops.report_predictions_data.assert_called_once_with(
313+
ANY,
314+
["How are you"],
315+
association_ids=ANY,
316+
)
317+
# Compare features dataframe separately as this doesn't play nice with assert_called
318+
assert (
319+
mock_mlops.report_predictions_data.call_args.args[0]["newPromptName"].values[0]
320+
== "Hello!"
321+
)
322+
else:
323+
mock_mlops.report_predictions_data.assert_not_called()
311324

312325
@pytest.mark.parametrize("stream", [False, True])
313326
def test_failing_hook_with_mlops(self, language_predictor_with_mlops, mock_mlops, stream):

0 commit comments

Comments
 (0)