Skip to content

Commit

Permalink
address typing
Browse files Browse the repository at this point in the history
  • Loading branch information
ulya-tkch committed Apr 30, 2024
1 parent 903ead2 commit 3056611
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions cleanlab_studio/utils/data_enrichment/enrich.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def enrich_data(
prompt: str,
data: pd.DataFrame,
*,
regex: Optional[Union[str, re.Pattern, List[re.Pattern]]] = None,
regex: Optional[Union[str, re.Pattern[str], List[re.Pattern[str]]]] = None,
return_values: Optional[List[str]] = None,
optimize_prompt: bool = True,
subset_indices: Optional[Union[Tuple[int, int], List[int]]] = (0, 3),
Expand Down Expand Up @@ -113,7 +113,7 @@ def enrich_data(

def get_regex_matches(
column_data: Union[pd.Series, List[str]],
regex: Union[str, re.Pattern, List[re.Pattern]],
regex: Union[str, re.Pattern[str], List[re.Pattern[str]]],
) -> Union[pd.Series, List[str]]:
"""
Extracts the first match from the response using the provided regex patterns. Return first match if multiple exist.
Expand Down
9 changes: 6 additions & 3 deletions cleanlab_studio/utils/data_enrichment/enrichment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import pandas as pd

from cleanlab_studio.errors import ValidationError
from cleanlab_studio.studio.studio import Studio


def get_prompt_outputs(studio, prompt, data, **kwargs):
def get_prompt_outputs(studio: Studio, prompt: str, data: pd.DataFrame, **kwargs) -> List[dict]:
"""Returns the outputs of the prompt for each row in the dataframe."""
tlm = studio.TLM(**kwargs)
formatted_prompts = data.apply(lambda x: prompt.format(**x), axis=1).to_list()
Expand All @@ -33,7 +34,9 @@ def extract_df_subset(
return subset_df


def get_compiled_regex_list(regex: Union[str, re.Pattern, List[re.Pattern]]) -> List[re.Pattern]:
def get_compiled_regex_list(
regex: Union[str, re.Pattern[str], List[re.Pattern[str]]]
) -> List[re.Pattern[str]]:
"""Compile the regex pattern(s) provided and return a list of compiled regex patterns."""
if isinstance(regex, str):
return [re.compile(rf"{regex}")]
Expand All @@ -47,7 +50,7 @@ def get_compiled_regex_list(regex: Union[str, re.Pattern, List[re.Pattern]]) ->
)


def get_regex_match(response: str, regex_list: List[re.Pattern]) -> Union[str, None]:
def get_regex_match(response: str, regex_list: List[re.Pattern[str]]) -> Union[str, None]:
"""Extract the first match from the response using the provided regex patterns. Return first match if multiple exist.
Note: This function assumes the regex patterns each specify exactly 1 group that is the match group using '(<group>)'."""
for regex_pattern in regex_list:
Expand Down

0 comments on commit 3056611

Please sign in to comment.