Skip to content

Commit

Permalink
cleanup and fixes for dq.metrics (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Epstein authored Jul 1, 2022
1 parent 4020d6f commit bb789fb
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 21 deletions.
30 changes: 30 additions & 0 deletions dataquality/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,36 @@ def get_run_metrics(
body = {"task": task, "filter_params": filter_params or {}}
return self.make_request(RequestType.POST, url, body=body, params=params)

def get_column_distribution(
self,
project_name: str,
run_name: str,
split: str,
task: str = None,
inference_name: str = None,
column: str = "data_error_potential",
filter_params: Dict = None,
) -> Dict[str, List]:
project, run = self._get_project_run_id(project_name, run_name)
split = conform_split(split)

all_meta = self.get_metadata_columns(project_name, run_name, split)
continuous_meta = [i["name"] for i in all_meta["meta"] if i["is_continuous"]]
avl_cols = continuous_meta + ["data_error_potential"]
if column not in avl_cols:
raise GalileoException(
f"Column must be one of continuous columns {avl_cols} for this run "
f"but got {column}"
)

path = Route.content_path(project, run, split)
url = f"{config.api_url}/{path}/{Route.distribution}"
params = {"col": column}
if inference_name:
params["inference_name"] = inference_name
body = {"task": task, "filter_params": filter_params or {}}
return self.make_request(RequestType.POST, url, body=body, params=params)

def get_xray_cards(
self, project_name: str, run_name: str, split: str, inference_name: str = None
) -> List[Dict[str, str]]:
Expand Down
82 changes: 61 additions & 21 deletions dataquality/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,58 +62,74 @@ def get_metrics(
Can be "gold" for ground truth, "pred" for predicted values, or any metadata
column logged (or smart feature).
"""
return api_client.get_run_metrics(
metrics = api_client.get_run_metrics(
project_name,
run_name,
split,
task=task,
inference_name=inference_name,
category=category,
)
# Filter out metrics not available for this request
metrics = {k: v for k, v in metrics.items() if v}
return metrics


def display_dep_distribution(
def display_distribution(
project_name: str,
run_name: str,
split: Split,
task: Optional[str] = None,
inference_name: Optional[str] = None,
column: str = "data_error_potential",
) -> None:
"""Displays the DEP distribution for a run. Plotly must be installed
Calculates metrics (f1, recall, precision) overall (weighted) and per label.
Also returns the top 50 rows of the dataframe (sorted by data_error_potential)
"""Displays the column distribution for a run. Plotly must be installed
:param project_name: The project name
:param run_name: The run name
:param split: The split (training/test/validation/inference)
:param task: (If multi-label only) the task name in question
:param inference_name: (If inference split only) The inference split name
:param column: The column to get the distribution for. Default data error potential
"""
try:
import plotly.express as px
except ImportError:
raise GalileoException(
"You must install plotly to use this function. Run `pip install plotly`"
)
summary = api_client.get_run_summary(
project_name, run_name, split, task, inference_name
)["split_run_results"]
easy, hard = summary["easy_samples_threshold"], summary["hard_samples_threshold"]
dep = summary["model_metrics"]["dep_distribution"]
dep_bins, dep_counts = dep["bins"], dep["counts"]

fig = px.bar(
x=dep_bins[1:],
y=dep_counts,
labels={"x": "DEP", "y": "Count"},
color=dep_bins[1:],
color_continuous_scale=[
distribution = api_client.get_column_distribution(
project_name,
run_name,
split,
column=column,
task=task,
inference_name=inference_name,
)
bins, counts = distribution["bins"], distribution["counts"]
labels = {"x": column, "y": "Count"}

color_scale, color = None, None
if column == "data_error_potential":
summary = api_client.get_run_summary(
project_name, run_name, split, task, inference_name
)["split_run_results"]
easy = summary["easy_samples_threshold"]
hard = summary["hard_samples_threshold"]
color_scale = [
(0, "green"),
(easy, "yellow"),
(hard, "red"),
(1, "darkred"),
],
]
color = bins[1:]

fig = px.bar(
x=bins[1:],
y=counts,
labels=labels,
color=color,
color_continuous_scale=color_scale,
)
fig.show()

Expand All @@ -125,25 +141,39 @@ def get_dataframe(
file_type: FileType = FileType.arrow,
include_embs: bool = False,
include_probs: bool = False,
include_token_indices: bool = False,
) -> DataFrame:
"""Gets the dataframe for a run/split
Downloads an arrow (or specified type) file to your machine and returns a loaded
Vaex dataframe
Vaex dataframe.
Special note for NER. By default, the data will be downloaded at a sample level
(1 row per sample text), with spans for each sample in a `spans` column in a
spacy-compatible JSON format. If include_emb is True, the data will be expanded
into span level (1 row per span, with sample text repeated for each span row), in
order to join the span-level embeddings
:param project_name: The project name
:param run_name: The run name
:param split: The split (training/test/validation/inference)
:param file_type: The file type to download the data as. Default arrow
:param include_embs: Whether to include the embeddings in the data. Default False
:param include_probs: Whether to include the probs in the data. Default False
:param include_token_indices: (NER only) Whether to include logged
text_token_indices in the data. Useful for reconstructing tokens for retraining
"""
project_id, run_id = api_client._get_project_run_id(project_name, run_name)
task_type = api_client.get_task_type(project_id, run_id)

file_name = f"/tmp/{uuid4()}-data.{file_type}"
api_client.export_run(project_name, run_name, split, file_name=file_name)
data_df = vaex.open(file_name)
# See docstring. In this case, we need span-level data
if include_embs and task_type == TaskType.text_ner:
# In NER, the `probabilities` contains the span level data
span_df = get_probabilities(project_name, run_name, split)
data_df = span_df.join(data_df[["text", "sample_id"]], on="sample_id")

tasks = []
if task_type == TaskType.text_multi_label:
Expand All @@ -170,6 +200,16 @@ def get_dataframe(
prob_cols = prob_df.get_column_names(regex="prob*") + ["id"]
data_df = data_df.join(prob_df[prob_cols], on="id")
data_df = _rename_prob_cols(data_df, tasks)
if include_token_indices:
if task_type != task_type.text_ner:
warnings.warn(
"Token indices are only available for NER, ignoring", GalileoWarning
)
else:
raw_tokens = get_raw_data(project_name, run_name, split)
raw_tokens = raw_tokens[["id", "text_token_indices"]]
raw_tokens.rename("id", "sample_id")
data_df = data_df.join(raw_tokens, on="sample_id")
return data_df


Expand Down
1 change: 1 addition & 0 deletions dataquality/schemas/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Route(str, Enum):
epochs = "epochs"
summary = "insights/summary"
groupby = "insights/groupby"
distribution = "insights/distribution"
xray = "insights/xray"

@staticmethod
Expand Down

0 comments on commit bb789fb

Please sign in to comment.