From 25349edb65d034f35e29746dd16591bf05883b04 Mon Sep 17 00:00:00 2001 From: Zdenek Kasner Date: Tue, 17 Dec 2024 13:40:39 +0100 Subject: [PATCH] Make computation of campaign statistics more idiomatic and efficient --- factgenie/analysis.py | 206 ++++++++++++++++-------------------------- 1 file changed, 78 insertions(+), 128 deletions(-) diff --git a/factgenie/analysis.py b/factgenie/analysis.py index 52ae9eb..743f346 100644 --- a/factgenie/analysis.py +++ b/factgenie/analysis.py @@ -1,83 +1,55 @@ #!/usr/bin/env python3 -import re -import glob -import json -import random import os -import argparse import pandas as pd from collections import defaultdict from scipy.stats import pearsonr import sys -from pathlib import Path -from slugify import slugify import logging -import coloredlogs import traceback import factgenie.workflows as workflows -from factgenie import CAMPAIGN_DIR - sys.path.append(os.path.join(os.path.dirname(__file__), "..")) logger = logging.getLogger(__name__) -# coloredlogs.install(level="INFO", logger=logger, fmt="%(asctime)s %(levelname)s %(message)s") - - -def create_example_record(line, metadata, annotation_span_categories, annotation_records, jsonl_file): - # a record is created even if there are no annotations - j = json.loads(line) - example_record = workflows.create_annotation_example_record(j, jsonl_file) - for i, category in enumerate(annotation_span_categories): - example_record["cat_" + str(i)] = 0 +def generate_example_index(app, campaign): + logger.info(f"Preparing example index for campaign {campaign.campaign_id}") - for annotation in j["annotations"]: - if int(annotation["type"]) == i: - example_record["cat_" + str(i)] += 1 + annotation_span_categories = campaign.metadata["config"]["annotation_span_categories"] + example_index = workflows.get_annotation_index(app, force_reload=True).copy() - example_record["annotations"] = [ - { - "annotation_type": r["annotation_type"], - "annotation_start": r["annotation_start"], - "annotation_text": r["annotation_text"], - } - for r in annotation_records - ] + # Add category count columns to example index + for i in range(len(annotation_span_categories)): + col_name = f"cat_{i}" + example_index[col_name] = example_index["annotations"].apply( + lambda anns: sum(1 for a in anns if a["type"] == i) + ) - return example_record + return example_index -def load_annotations_for_campaign(campaign): - annotation_index = [] - example_index = [] +def generate_span_index(app, campaign): + logger.info(f"Preparing span index for campaign {campaign.campaign_id}") - annotation_span_categories = campaign.metadata["config"]["annotation_span_categories"] + span_index = workflows.get_annotation_index(app).copy() - jsonl_files = glob.glob(os.path.join(CAMPAIGN_DIR, campaign.metadata["id"], "files", "*.jsonl")) + # Remove examples with no annotations + span_index = span_index[span_index["annotations"].apply(lambda x: len(x) > 0)] - for jsonl_file in jsonl_files: - with open(jsonl_file) as f: - lines = f.readlines() - for line in lines: - try: - annotation_records = workflows.load_annotations_from_record(line, split_spans=True) - annotation_index += annotation_records + # Create a separate row for each annotation + span_index = span_index.explode("annotations").reset_index(drop=True) - example_record = create_example_record( - line, campaign.metadata, annotation_span_categories, annotation_records, jsonl_file - ) - example_index.append(example_record) - except Exception as e: - logger.error(f"Error while processing line: {line}") - logger.error(e) + # Extract annotation fields into separate columns + span_index["annotation_type"] = span_index["annotations"].apply(lambda x: x["type"]) + span_index["annotation_start"] = span_index["annotations"].apply(lambda x: x["start"]) + span_index["annotation_text"] = span_index["annotations"].apply(lambda x: x["text"]) - annotation_index = pd.DataFrame(annotation_index) - example_index = pd.DataFrame(example_index) + # Drop the original annotations column + span_index = span_index.drop("annotations", axis=1) - return annotation_index, example_index + return span_index def preprocess_annotations(df, campaign): @@ -100,89 +72,63 @@ def compute_ann_counts(df): """ Compute annotation counts for each annotation type (separately for each dataset, split, setup_id). """ - results = [] - - all_annotation_types = df["annotation_type"].unique() - all_annotation_types.sort() - - for dataset in df["dataset"].unique(): - for split in df["split"].unique(): - for setup_id in df["setup_id"].unique(): - # filter the dataframe - df_filtered = df[(df["dataset"] == dataset) & (df["split"] == split) & (df["setup_id"] == setup_id)] - - # make sure that all annotation types are present in the dataframe, even with zero counts - ann_counts = ( - df_filtered.groupby("annotation_type") - .size() - .reindex(all_annotation_types, fill_value=0) - .reset_index(name="ann_count") - ) + logger.info("Computing annotation counts") - ann_counts["dataset"] = dataset - ann_counts["split"] = split - ann_counts["setup_id"] = setup_id + # Create multi-index groupby once + grouped = df.groupby(["dataset", "split", "setup_id", "annotation_type"]).size().reset_index(name="ann_count") - results.append(ann_counts) + # Create complete multi-index for all combinations + idx = pd.MultiIndex.from_product( + [df["dataset"].unique(), df["split"].unique(), df["setup_id"].unique(), sorted(df["annotation_type"].unique())], + names=["dataset", "split", "setup_id", "annotation_type"], + ) - # concatenate all results into a single dataframe - results = pd.concat(results, ignore_index=True) + # Reindex to include all combinations with zeros + results = ( + grouped.set_index(["dataset", "split", "setup_id", "annotation_type"]).reindex(idx, fill_value=0).reset_index() + ) return results def compute_avg_ann_counts(ann_counts, example_index): - # for each line in ann_counts, find the corresponding dataset in datasets and add the number of examples - # then compute the average annotation count - - # add a column with the number of examples for each dataset, split - ann_counts["example_count"] = 0 - - for i, row in ann_counts.iterrows(): - dataset = row["dataset"] - split = row["split"] - setup_id = row["setup_id"] - - ann_counts.loc[i, "example_count"] = ( - example_index[ - (example_index["dataset"] == dataset) - & (example_index["split"] == split) - & (example_index["setup_id"] == setup_id) - ] - .example_idx.unique() - .shape[0] - ) + logger.info("Computing average annotation counts") + + # Get example counts through groupby operation + example_counts = ( + example_index.groupby(["dataset", "split", "setup_id"]) + .agg(example_count=("example_idx", "nunique")) + .reset_index() + .astype({"example_count": int}) + ) - ann_counts["avg_count"] = ann_counts["ann_count"] / ann_counts["example_count"] + # Merge counts with original dataframe + ann_counts = ann_counts.merge(example_counts, on=["dataset", "split", "setup_id"], how="left") - # round to three decimal places - ann_counts["avg_count"] = ann_counts["avg_count"].round(3) + # Compute average counts vectorized + ann_counts["avg_count"] = (ann_counts["ann_count"] / ann_counts["example_count"]).round(3) return ann_counts def compute_prevalence(ann_counts, example_index): - # for each combination of dataset, split, setup_id, annotation_type, compute the percentage of examples that are affected by the annotation type and add it to the `ann_counts` dataframe - for i, row in ann_counts.iterrows(): - dataset = row["dataset"] - split = row["split"] - setup_id = row["setup_id"] - annotation_type = row["annotation_type"] - - examples = example_index[ - (example_index["dataset"] == dataset) - & (example_index["split"] == split) - & (example_index["setup_id"] == setup_id) - & (example_index["cat_" + str(annotation_type)] > 0) - ] - - if row["example_count"] == 0: - ann_counts.loc[i, "prevalence"] = 0 - else: - ann_counts.loc[i, "prevalence"] = examples.shape[0] / row["example_count"] - - # round to three decimal places - ann_counts["prevalence"] = ann_counts["prevalence"].round(3) + logger.info("Computing annotation prevalence") + + # Compute affected counts for all rows at once + ann_counts["prevalence"] = ann_counts.apply( + lambda row: ( + ( + (example_index["dataset"] == row["dataset"]) + & (example_index["split"] == row["split"]) + & (example_index["setup_id"] == row["setup_id"]) + & (example_index[f"cat_{row['annotation_type']}"] > 0) + ).sum() + / row["example_count"] + if row["example_count"] > 0 + else 0 + ), + axis=1, + ).round(3) return ann_counts @@ -257,15 +203,19 @@ def compute_extra_fields_stats(example_index): def compute_statistics(app, campaign): statistics = {} - annotation_index, example_index = load_annotations_for_campaign(campaign) + span_index = generate_span_index(app, campaign) + example_index = generate_example_index(app, campaign) - if not annotation_index.empty: - annotation_index = preprocess_annotations(annotation_index, campaign) + if not span_index.empty: + span_index = preprocess_annotations(span_index, campaign) - annotation_counts = compute_ann_counts(annotation_index) + annotation_counts = compute_ann_counts(span_index) annotation_counts = compute_avg_ann_counts(annotation_counts, example_index) annotation_counts = compute_prevalence(annotation_counts, example_index) + # replace NaNs with 0 + annotation_counts = annotation_counts.fillna(0.0) + statistics["ann_counts"] = { "full": annotation_counts.to_dict(orient="records"), "span": aggregate_ann_counts(annotation_counts, "span"), @@ -383,14 +333,14 @@ def compute_span_counts(example_index, annotator_count, combinations, cat_column return dataset_level_counts, example_level_counts -def prepare_example_index(combinations, selected_campaigns, campaigns): +def prepare_example_index(app, combinations, selected_campaigns, campaigns): # gather a list of all examples with some annotations example_index = pd.DataFrame() for campaign_id in selected_campaigns: campaign = campaigns[campaign_id] - _, ei = load_annotations_for_campaign(campaign) + ei = generate_example_index(app, campaign) example_index = pd.concat([example_index, ei], ignore_index=True) # a combination is a tuple (dataset, split, setup_id) @@ -429,7 +379,7 @@ def compute_inter_annotator_agreement(app, selected_campaigns, combinations, cam combinations = [(c["dataset"], c["split"], c["setup_id"]) for c in combinations] example_index, annotator_count, annotator_group_ids, cat_columns = prepare_example_index( - combinations=combinations, selected_campaigns=selected_campaigns, campaigns=campaigns + app, combinations=combinations, selected_campaigns=selected_campaigns, campaigns=campaigns ) dataset_level_counts, example_level_counts = compute_span_counts(