diff --git a/README.md b/README.md index 2290527..d304fd5 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,20 @@ rewriting the `postprocess` function in `prompterator/postprocess_output.py`. Th receive one raw model-generated text at a time and should output its postprocessed version. Both the raw and the postprocessed text are kept and saved. +### Reusing labels for repeatedly encountered examples + +While iterating your prompt on a dataset, you may find yourself annotating a model output that you +already annotated in an earlier round. You can choose to automatically reuse such previously +assigned labels by toggling "reuse past labels". To speed up your annotation process even more, +you can toggle "skip past label rows" so that you only go through the rows for which no +previously assigned label was found. + +How this feature works: +- Existing labels are searched for in the current list of files in the sidebar, where a match + requires both the `response` and all the input columns' values to match. +- If multiple different labels are found for a given input+output combination (a sign of + inconsistent past annotation work), the most recent label is re-used. + ## Paper You can find more information on Prompterator in the associated paper: https://aclanthology.org/2023.emnlp-demo.43/ diff --git a/prompterator/constants.py b/prompterator/constants.py index 39df748..984e4bb 100644 --- a/prompterator/constants.py +++ b/prompterator/constants.py @@ -68,6 +68,7 @@ def call(self, input, **kwargs): USER_PROMPT_TEMPLATE_COL = "user_prompt_template" RESPONSE_DATA_COL = "response_data" LABEL_COL = "human_label" +REUSED_PAST_LABEL_COL = "reused_label" TIMESTAMP_COL = "timestamps" MODEL_KEY = "model" PROMPT_CREATOR_KEY = "creator" @@ -96,6 +97,7 @@ def call(self, input, **kwargs): USER_PROMPT_TEMPLATE_COL, RESPONSE_DATA_COL, LABEL_COL, + REUSED_PAST_LABEL_COL, ] LABEL_GOOD = "good" LABEL_BAD = "bad" diff --git a/prompterator/main.py b/prompterator/main.py index 9a5dc93..af4e9fb 100644 --- a/prompterator/main.py +++ b/prompterator/main.py @@ -43,14 +43,16 @@ def update_displayed_data_point(): def show_next_row(): - if st.session_state.row_number < st.session_state.n_data_points - 1: - st.session_state.row_number = st.session_state.row_number + 1 + current_idx = st.session_state.rows_for_labelling.index(st.session_state.row_number) + if current_idx < len(st.session_state.rows_for_labelling) - 1: + st.session_state.row_number = st.session_state.rows_for_labelling[current_idx + 1] update_displayed_data_point() def show_prev_row(): - if st.session_state.row_number > 0: - st.session_state.row_number = st.session_state.row_number - 1 + current_idx = st.session_state.rows_for_labelling.index(st.session_state.row_number) + if current_idx > 0: + st.session_state.row_number = st.session_state.rows_for_labelling[current_idx - 1] update_displayed_data_point() @@ -86,6 +88,7 @@ def initialise_labelling(): text_orig=text_orig, text_generated=text_generated, n_data_points=len(st.session_state.df), + rows_for_labelling=list(range(len(st.session_state.df))), ) if st.session_state.responses_generated_externally: @@ -93,7 +96,13 @@ def initialise_labelling(): def set_up_dynamic_session_state_vars(): - st.session_state.n_checked = len(st.session_state.df.query(f"{c.LABEL_COL}.notnull()")) + st.session_state.n_labelled = len( + st.session_state.df.query( + f"{c.LABEL_COL}.notnull() | {c.REUSED_PAST_LABEL_COL}.notnull()" + if st.session_state.get("reuse_past_labels", False) + else f"{c.LABEL_COL}.notnull()" + ) + ) # we need to initialise this one, too, because it wouldn't persist in session_state in the # cases where no element with key `text_generated` exists -- when the diff viewer is shown. @@ -202,9 +211,13 @@ def show_selected_datafile(file_name): st.session_state.enable_labelling = True df = st.session_state.datafiles[file_name][c.DATAFILE_DATA_KEY].copy(deep=True) metadata = st.session_state.datafiles[file_name][c.DATAFILE_METADATA_KEY] + row = df.iloc[c.DEFAULT_ROW_NO] text_orig = u.get_text_orig(row) text_generated = u.get_text_generated(row) + + has_reused_labels = any(df[c.REUSED_PAST_LABEL_COL].notnull()) + set_session_state( df=df, row_number=c.DEFAULT_ROW_NO, @@ -213,9 +226,12 @@ def show_selected_datafile(file_name): text_orig=text_orig, text_generated=text_generated, n_data_points=len(df), + rows_for_labelling=list(range(len(df))), user_prompt=metadata[c.USER_PROMPT_TEMPLATE_COL], system_prompt=metadata[c.SYSTEM_PROMPT_TEMPLATE_COL], columns_to_show=metadata.get(c.COLS_TO_SHOW_KEY, [c.TEXT_ORIG_COL]), + reuse_past_labels=has_reused_labels, + skip_past_label_rows=False, **{ c.PROMPT_NAME_KEY: metadata[c.PROMPT_NAME_KEY], c.PROMPT_COMMENT_KEY: metadata[c.PROMPT_COMMENT_KEY], @@ -504,6 +520,77 @@ def display_image(st_container, base64_str): ) +def _get_input_columns_from_df(df): + return [col for col in df.columns.tolist() if col not in c.COLS_NOT_FOR_PROMPT_INTERPOLATION] + + +def _get_past_labels(): + relevant_columns = _get_input_columns_from_df(st.session_state.df) + [c.TEXT_GENERATED_COL] + rows_needing_labels = [ + tuple(row) for _, row in st.session_state.df[relevant_columns].iterrows() + ] + past_labels_for_rows = {row: None for row in rows_needing_labels} + relevant_columns_set = set(relevant_columns) + + datafiles_newest_to_oldest = sorted( + st.session_state.datafiles.values(), + key=lambda datafile: datafile[c.DATAFILE_METADATA_KEY][c.TIMESTAMP_COL], + reverse=True, + ) + for datafile in datafiles_newest_to_oldest: + df = datafile[c.DATAFILE_DATA_KEY] + + if not relevant_columns_set.issubset(set(df.columns)): + continue + + for _, entire_historical_row in df.iterrows(): + historical_row = tuple(entire_historical_row[relevant_columns]) + + if ( + historical_row in rows_needing_labels + # take the first (most recent) label that we encounter for current row and don't + # overwrite it with subsequent labels encountered for the same row + and past_labels_for_rows[historical_row] is None + ): + past_label = ( + entire_historical_row[c.LABEL_COL] + or entire_historical_row[c.REUSED_PAST_LABEL_COL] + ) + if past_label is not None: + past_labels_for_rows[historical_row] = past_label + + return [past_labels_for_rows[row] for row in rows_needing_labels] + + +def _handle_reuse_past_labels_toggle(): + if st.session_state.reuse_past_labels: + past_labels = _get_past_labels() + st.session_state.df[c.REUSED_PAST_LABEL_COL] = past_labels + + +def _handle_skip_past_label_rows_toggle(): + if st.session_state.skip_past_label_rows: + rows_for_labelling = [ + i + for i in range(len(st.session_state.df)) + if st.session_state.df.iloc[i][c.REUSED_PAST_LABEL_COL] is None + ] + + # ensure we've got at least one available row so we can display something in the UI + if not rows_for_labelling: + rows_for_labelling = [st.session_state.row_number] + + st.session_state.rows_for_labelling = rows_for_labelling + + st.session_state.row_number = [ + row_idx + for row_idx in st.session_state.rows_for_labelling + if row_idx >= st.session_state.row_number + ][0] + else: + st.session_state.rows_for_labelling = list(range(len(st.session_state.df))) + + def set_up_ui_labelling(): col1_orig, col2_orig = st.columns([1, 1]) text_orig_length = len(st.session_state.get("text_orig", "")) @@ -579,15 +666,27 @@ def set_up_ui_labelling(): # on_click=assign_label, kwargs={"label_value": label_good}) col2.button("👎", key="mark_bad", on_click=assign_label, kwargs={"label_value": c.LABEL_BAD}) col3.progress( - st.session_state.n_checked / len(st.session_state.df) + st.session_state.n_labelled / len(st.session_state.df) if len(st.session_state.df) > 0 else 0, - text=f"{st.session_state.n_checked}/{len(st.session_state.df)} checked", + text=f"{st.session_state.n_labelled}/{len(st.session_state.df)} labelled", ) - col4, col5, col6, col_empty = labelling_container.columns([1, 1, 2, 8]) + col4, col5, col6, col7, col8 = labelling_container.columns([1, 1, 2, 4, 4]) col4.button("⬅️", key="prev_data_point", on_click=show_prev_row) col5.button("➡️", key="next_data_point", on_click=show_next_row) col6.write(f"#{st.session_state.row_number + 1}: {st.session_state.current_row_label}") + col7.toggle( + label="reuse past labels", + value=False, + key="reuse_past_labels", + on_change=_handle_reuse_past_labels_toggle, + ) + col8.toggle( + label="skip past label rows", + value=False, + key="skip_past_label_rows", + on_change=_handle_skip_past_label_rows_toggle, + ) labelling_container.button( "Save ⤵️", key="save_labelled_data", on_click=u.save_labelled_data, type="primary" ) @@ -609,6 +708,8 @@ def show_dataframe(): if st.session_state.get("df") is not None: columns_to_show = st.session_state.columns_to_show.copy() columns_to_show.extend([c.TEXT_GENERATED_COL, c.LABEL_COL]) + if st.session_state.get("reuse_past_labels", False): + columns_to_show.extend([c.REUSED_PAST_LABEL_COL]) df_to_show = st.session_state.df[columns_to_show] else: df_to_show = u.get_dummy_dataframe() diff --git a/prompterator/utils.py b/prompterator/utils.py index c5b15e7..f22fe04 100644 --- a/prompterator/utils.py +++ b/prompterator/utils.py @@ -50,6 +50,13 @@ def ensure_legacy_datafile_has_all_columns(df): df[c.TEXT_GENERATED_COL], ) + if c.REUSED_PAST_LABEL_COL not in df.columns: + df.insert( + df.columns.get_loc(c.LABEL_COL), + c.REUSED_PAST_LABEL_COL, + None, + ) + return df @@ -182,7 +189,13 @@ def generate_responses_using_parallelism( def get_correctness_summary(df): return "{good}/{all}".format( - good=len(df.query(f"{c.LABEL_COL} == '{c.LABEL_GOOD}'")), all=len(df) + good=len( + df.query( + f"({c.LABEL_COL} == '{c.LABEL_GOOD}') | " + f"(({c.LABEL_COL}.isnull()) & ({c.REUSED_PAST_LABEL_COL} == '{c.LABEL_GOOD}'))" + ) + ), + all=len(df), )