-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
improved autofix strategy #148
base: main
Are you sure you want to change the base?
Changes from 17 commits
4615637
2a7cf91
e7a3d07
72fc919
d67bbc3
fc4bf7c
9f00909
d2a3432
6bcec4c
cc52ce2
62efa2d
e5c4872
02294c8
1d644a0
3ff2507
7235b40
1b99d60
330aa44
a19c88c
69ccda6
19143a3
e5b97f5
20a532c
3bbfc1c
b892e87
b54a0a7
f870e04
eb106d1
a7acfa6
1f0344d
692efe4
afbe4a9
7b96faa
b31674c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -1,25 +1,28 @@ | ||||
""" | ||||
Python API for Cleanlab Studio. | ||||
""" | ||||
from typing import Any, List, Literal, Optional, Union | ||||
import warnings | ||||
from typing import Any, List, Literal, Optional, Union, Dict | ||||
|
||||
import numpy as np | ||||
import numpy.typing as npt | ||||
import pandas as pd | ||||
|
||||
from . import inference | ||||
from . import trustworthy_language_model | ||||
from cleanlab_studio.errors import CleansetError | ||||
from cleanlab_studio.internal import clean_helpers, upload_helpers | ||||
from cleanlab_studio.internal.api import api | ||||
from cleanlab_studio.internal.settings import CleanlabSettings | ||||
from cleanlab_studio.internal.types import FieldSchemaDict | ||||
from cleanlab_studio.internal.util import ( | ||||
init_dataset_source, | ||||
apply_autofixed_cleanset_to_new_dataframe, | ||||
_get_autofix_default_thresholds, | ||||
check_none, | ||||
check_not_none, | ||||
get_autofix_defaults, | ||||
init_dataset_source, | ||||
) | ||||
from cleanlab_studio.internal.settings import CleanlabSettings | ||||
from cleanlab_studio.internal.types import FieldSchemaDict | ||||
|
||||
from . import inference, trustworthy_language_model | ||||
|
||||
_pyspark_exists = api.pyspark_exists | ||||
if _pyspark_exists: | ||||
|
@@ -131,7 +134,7 @@ def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool | |||
label_column = api.get_label_column_of_project(self._api_key, project_id) | ||||
id_col = api.get_id_column(self._api_key, cleanset_id) | ||||
if _pyspark_exists and isinstance(dataset, pyspark.sql.DataFrame): | ||||
from pyspark.sql.functions import row_number, monotonically_increasing_id, when, col | ||||
from pyspark.sql.functions import col, monotonically_increasing_id, row_number, when | ||||
from pyspark.sql.window import Window | ||||
|
||||
cl_cols = self.download_cleanlab_columns( | ||||
|
@@ -383,3 +386,36 @@ def poll_cleanset_status(self, cleanset_id: str, timeout: Optional[int] = None) | |||
|
||||
except (TimeoutError, CleansetError): | ||||
return False | ||||
|
||||
def autofix_dataset( | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. allow string options to passed straight through into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be added now, clarified in the docs:
|
||||
self, | ||||
original_df: pd.DataFrame, | ||||
cleanset_id: str, | ||||
params: Optional[Dict[str, Union[int, float]]] = None, | ||||
strategy="optimized_training_data", | ||||
) -> pd.DataFrame: | ||||
""" | ||||
This method returns the auto-fixed dataset. | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring should clarify that Dataset must be a DataFrame (text or tabular dataset only) |
||||
Args: | ||||
cleanset_id (str): ID of cleanset. | ||||
params (dict, optional): Default parameter dictionary containing confidence threshold for auto-relabelling, and | ||||
number of rows to drop for each issue type. If not provided, default values will be used. | ||||
|
||||
Example: | ||||
{ | ||||
'drop_ambiguous': 9, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change to fractions |
||||
'drop_label_issue': 92, | ||||
'drop_near_duplicate': 1, | ||||
'drop_outlier': 3, | ||||
'drop_confidence_threshold': 0.95 | ||||
} | ||||
|
||||
Returns: | ||||
pd.DataFrame: A new dataframe after applying auto-fixes to the cleanset. | ||||
|
||||
""" | ||||
cleanset_df = self.download_cleanlab_columns(cleanset_id) | ||||
if params is None: | ||||
params = get_autofix_defaults(cleanset_df, strategy) | ||||
print("Using autofix values:", params) | ||||
return apply_autofixed_cleanset_to_new_dataframe(original_df, cleanset_df, params) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import pandas as pd | ||
import pytest | ||
from cleanlab_studio.internal.util import ( | ||
get_autofix_defaults, | ||
_update_label_based_on_confidence, | ||
_get_top_fraction_ids, | ||
) | ||
import numpy as np | ||
|
||
|
||
class TestAutofix: | ||
@pytest.mark.parametrize( | ||
"strategy, expected_results", | ||
[ | ||
( | ||
"optimized_training_data", | ||
{ | ||
"drop_ambiguous": 0, | ||
"drop_label_issue": 2, | ||
"drop_near_duplicate": 2, | ||
"drop_outlier": 3, | ||
"relabel_confidence_threshold": 0.95, | ||
}, | ||
), | ||
( | ||
"drop_all_issues", | ||
{ | ||
"drop_ambiguous": 10, | ||
"drop_label_issue": 3, | ||
"drop_near_duplicate": 6, | ||
"drop_outlier": 6, | ||
}, | ||
), | ||
( | ||
"suggested_actions", | ||
{ | ||
"drop_near_duplicate": 6, | ||
"drop_outlier": 6, | ||
"relabel_confidence_threshold": 0.0, | ||
}, | ||
), | ||
], | ||
ids=["optimized_training_data", "drop_all_issues", "suggested_actions"], | ||
) | ||
def test_get_autofix_defaults(self, strategy, expected_results): | ||
cleanlab_columns = pd.DataFrame() | ||
cleanlab_columns["is_label_issue"] = [True] * 3 + [False] * 7 | ||
cleanlab_columns["is_near_duplicate"] = [True] * 6 + [False] * 4 | ||
cleanlab_columns["is_outlier"] = [True] * 6 + [False] * 4 | ||
cleanlab_columns["is_ambiguous"] = [True] * 10 | ||
|
||
params = get_autofix_defaults(cleanlab_columns, strategy) | ||
assert params == expected_results | ||
|
||
@pytest.mark.parametrize( | ||
"row, expected_updated_row", | ||
[ | ||
( | ||
{ | ||
"is_label_issue": True, | ||
"suggested_label_confidence_score": 0.6, | ||
"label": "label_0", | ||
"suggested_label": "label_1", | ||
"is_issue": True, | ||
}, | ||
{ | ||
"is_label_issue": True, | ||
"suggested_label_confidence_score": 0.6, | ||
"label": "label_1", | ||
"suggested_label": "label_1", | ||
"is_issue": False, | ||
}, | ||
), | ||
( | ||
{ | ||
"is_label_issue": True, | ||
"suggested_label_confidence_score": 0.5, | ||
"label": "label_0", | ||
"suggested_label": "label_1", | ||
"is_issue": True, | ||
}, | ||
{ | ||
"is_label_issue": True, | ||
"suggested_label_confidence_score": 0.5, | ||
"label": "label_0", | ||
"suggested_label": "label_1", | ||
"is_issue": True, | ||
}, | ||
), | ||
( | ||
{ | ||
"is_label_issue": True, | ||
"suggested_label_confidence_score": 0.4, | ||
"label": "label_0", | ||
"suggested_label": "label_1", | ||
"is_issue": True, | ||
}, | ||
{ | ||
"is_label_issue": True, | ||
"suggested_label_confidence_score": 0.4, | ||
"label": "label_0", | ||
"suggested_label": "label_1", | ||
"is_issue": True, | ||
}, | ||
), | ||
( | ||
{ | ||
"is_label_issue": False, | ||
"suggested_label_confidence_score": 0.4, | ||
"label": "label_0", | ||
"suggested_label": "label_1", | ||
"is_issue": True, | ||
}, | ||
{ | ||
"is_label_issue": False, | ||
"suggested_label_confidence_score": 0.4, | ||
"label": "label_0", | ||
"suggested_label": "label_1", | ||
"is_issue": True, | ||
}, | ||
), | ||
], | ||
ids=[ | ||
"is a label issue with confidence score greater than threshold", | ||
"is a label issue with confidence score equal to threshold", | ||
"is a label issue with confidence score less than threshold", | ||
"is not a label issue", | ||
], | ||
) | ||
def test_update_label_based_on_confidence(self, row, expected_updated_row): | ||
conf_threshold = 0.5 | ||
updated_row = _update_label_based_on_confidence(row, conf_threshold) | ||
assert updated_row == expected_updated_row | ||
|
||
def test_get_top_fraction_ids(self): | ||
cleanlab_columns = pd.DataFrame() | ||
|
||
cleanlab_columns["cleanlab_row_ID"] = np.arange(10) | ||
cleanlab_columns["is_dummy"] = [False] * 5 + [True] * 5 | ||
cleanlab_columns["dummy_score"] = np.arange(10) * 0.1 | ||
top_ids = _get_top_fraction_ids(cleanlab_columns, "dummy", 3) | ||
assert set(top_ids) == set([5, 6, 7]) | ||
|
||
def test_get_top_fraction_ids_near_duplicate(self): | ||
cleanlab_columns = pd.DataFrame() | ||
|
||
cleanlab_columns["cleanlab_row_ID"] = np.arange(12) | ||
cleanlab_columns["is_near_duplicate"] = [False] * 6 + [True] * 6 | ||
cleanlab_columns["near_duplicate_score"] = np.arange(12) * 0.1 | ||
cleanlab_columns["near_duplicate_cluster_id"] = [None] * 6 + [0, 0, 1, 1, 1, 1] | ||
|
||
top_ids = _get_top_fraction_ids(cleanlab_columns, "near_duplicate", 5) | ||
assert set(top_ids) == set([6, 8, 10]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In autofix, we can simply multiply the fraction of issues that are the cleanset defaults by the number of datapoints to get this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right when we spoke originally, we wanted this call to be similar to the Studio web interface call, hence I rewrote it this way, it was floating percentage before.
the function
_get_autofix_defaults
does the multiplication by number of datapoints