Skip to content

Commit

Permalink
Enable reusing past labels DATANG-3799 (#16)
Browse files Browse the repository at this point in the history
* Enable reusing past labels

* Briefly document the label re-using feature

* Improve code quality after in reaction to review comments

* Show reused labels when loading data files
  • Loading branch information
samsucik committed May 22, 2024
1 parent e4a543d commit 7fb0269
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 9 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ rewriting the `postprocess` function in `prompterator/postprocess_output.py`. Th
receive one raw model-generated text at a time and should output its postprocessed version. Both
the raw and the postprocessed text are kept and saved.
### Reusing labels for repeatedly encountered examples
While iterating your prompt on a dataset, you may find yourself annotating a model output that you
already annotated in an earlier round. You can choose to automatically reuse such previously
assigned labels by toggling "reuse past labels". To speed up your annotation process even more,
you can toggle "skip past label rows" so that you only go through the rows for which no
previously assigned label was found.
How this feature works:
- Existing labels are searched for in the current list of files in the sidebar, where a match
requires both the `response` and all the input columns' values to match.
- If multiple different labels are found for a given input+output combination (a sign of
inconsistent past annotation work), the most recent label is re-used.
## Paper
You can find more information on Prompterator in the associated paper: https://aclanthology.org/2023.emnlp-demo.43/
Expand Down
2 changes: 2 additions & 0 deletions prompterator/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def call(self, input, **kwargs):
USER_PROMPT_TEMPLATE_COL = "user_prompt_template"
RESPONSE_DATA_COL = "response_data"
LABEL_COL = "human_label"
REUSED_PAST_LABEL_COL = "reused_label"
TIMESTAMP_COL = "timestamps"
MODEL_KEY = "model"
PROMPT_CREATOR_KEY = "creator"
Expand Down Expand Up @@ -96,6 +97,7 @@ def call(self, input, **kwargs):
USER_PROMPT_TEMPLATE_COL,
RESPONSE_DATA_COL,
LABEL_COL,
REUSED_PAST_LABEL_COL,
]
LABEL_GOOD = "good"
LABEL_BAD = "bad"
Expand Down
117 changes: 109 additions & 8 deletions prompterator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ def update_displayed_data_point():


def show_next_row():
if st.session_state.row_number < st.session_state.n_data_points - 1:
st.session_state.row_number = st.session_state.row_number + 1
current_idx = st.session_state.rows_for_labelling.index(st.session_state.row_number)
if current_idx < len(st.session_state.rows_for_labelling) - 1:
st.session_state.row_number = st.session_state.rows_for_labelling[current_idx + 1]
update_displayed_data_point()


def show_prev_row():
if st.session_state.row_number > 0:
st.session_state.row_number = st.session_state.row_number - 1
current_idx = st.session_state.rows_for_labelling.index(st.session_state.row_number)
if current_idx > 0:
st.session_state.row_number = st.session_state.rows_for_labelling[current_idx - 1]
update_displayed_data_point()


Expand Down Expand Up @@ -86,14 +88,21 @@ def initialise_labelling():
text_orig=text_orig,
text_generated=text_generated,
n_data_points=len(st.session_state.df),
rows_for_labelling=list(range(len(st.session_state.df))),
)

if st.session_state.responses_generated_externally:
st.session_state[c.MODEL_KEY] = m.MODELS[c.UNKNOWN_MODEL_NAME].copy()


def set_up_dynamic_session_state_vars():
st.session_state.n_checked = len(st.session_state.df.query(f"{c.LABEL_COL}.notnull()"))
st.session_state.n_labelled = len(
st.session_state.df.query(
f"{c.LABEL_COL}.notnull() | {c.REUSED_PAST_LABEL_COL}.notnull()"
if st.session_state.get("reuse_past_labels", False)
else f"{c.LABEL_COL}.notnull()"
)
)

# we need to initialise this one, too, because it wouldn't persist in session_state in the
# cases where no element with key `text_generated` exists -- when the diff viewer is shown.
Expand Down Expand Up @@ -202,9 +211,13 @@ def show_selected_datafile(file_name):
st.session_state.enable_labelling = True
df = st.session_state.datafiles[file_name][c.DATAFILE_DATA_KEY].copy(deep=True)
metadata = st.session_state.datafiles[file_name][c.DATAFILE_METADATA_KEY]

row = df.iloc[c.DEFAULT_ROW_NO]
text_orig = u.get_text_orig(row)
text_generated = u.get_text_generated(row)

has_reused_labels = any(df[c.REUSED_PAST_LABEL_COL].notnull())

set_session_state(
df=df,
row_number=c.DEFAULT_ROW_NO,
Expand All @@ -213,9 +226,12 @@ def show_selected_datafile(file_name):
text_orig=text_orig,
text_generated=text_generated,
n_data_points=len(df),
rows_for_labelling=list(range(len(df))),
user_prompt=metadata[c.USER_PROMPT_TEMPLATE_COL],
system_prompt=metadata[c.SYSTEM_PROMPT_TEMPLATE_COL],
columns_to_show=metadata.get(c.COLS_TO_SHOW_KEY, [c.TEXT_ORIG_COL]),
reuse_past_labels=has_reused_labels,
skip_past_label_rows=False,
**{
c.PROMPT_NAME_KEY: metadata[c.PROMPT_NAME_KEY],
c.PROMPT_COMMENT_KEY: metadata[c.PROMPT_COMMENT_KEY],
Expand Down Expand Up @@ -504,6 +520,77 @@ def display_image(st_container, base64_str):
)


def _get_input_columns_from_df(df):
return [col for col in df.columns.tolist() if col not in c.COLS_NOT_FOR_PROMPT_INTERPOLATION]


def _get_past_labels():
relevant_columns = _get_input_columns_from_df(st.session_state.df) + [c.TEXT_GENERATED_COL]
rows_needing_labels = [
tuple(row) for _, row in st.session_state.df[relevant_columns].iterrows()
]
past_labels_for_rows = {row: None for row in rows_needing_labels}
relevant_columns_set = set(relevant_columns)

datafiles_newest_to_oldest = sorted(
st.session_state.datafiles.values(),
key=lambda datafile: datafile[c.DATAFILE_METADATA_KEY][c.TIMESTAMP_COL],
reverse=True,
)
for datafile in datafiles_newest_to_oldest:
df = datafile[c.DATAFILE_DATA_KEY]

if not relevant_columns_set.issubset(set(df.columns)):
continue

for _, entire_historical_row in df.iterrows():
historical_row = tuple(entire_historical_row[relevant_columns])

if (
historical_row in rows_needing_labels
# take the first (most recent) label that we encounter for current row and don't
# overwrite it with subsequent labels encountered for the same row
and past_labels_for_rows[historical_row] is None
):
past_label = (
entire_historical_row[c.LABEL_COL]
or entire_historical_row[c.REUSED_PAST_LABEL_COL]
)
if past_label is not None:
past_labels_for_rows[historical_row] = past_label

return [past_labels_for_rows[row] for row in rows_needing_labels]


def _handle_reuse_past_labels_toggle():
if st.session_state.reuse_past_labels:
past_labels = _get_past_labels()
st.session_state.df[c.REUSED_PAST_LABEL_COL] = past_labels


def _handle_skip_past_label_rows_toggle():
if st.session_state.skip_past_label_rows:
rows_for_labelling = [
i
for i in range(len(st.session_state.df))
if st.session_state.df.iloc[i][c.REUSED_PAST_LABEL_COL] is None
]

# ensure we've got at least one available row so we can display something in the UI
if not rows_for_labelling:
rows_for_labelling = [st.session_state.row_number]

st.session_state.rows_for_labelling = rows_for_labelling

st.session_state.row_number = [
row_idx
for row_idx in st.session_state.rows_for_labelling
if row_idx >= st.session_state.row_number
][0]
else:
st.session_state.rows_for_labelling = list(range(len(st.session_state.df)))


def set_up_ui_labelling():
col1_orig, col2_orig = st.columns([1, 1])
text_orig_length = len(st.session_state.get("text_orig", ""))
Expand Down Expand Up @@ -579,15 +666,27 @@ def set_up_ui_labelling():
# on_click=assign_label, kwargs={"label_value": label_good})
col2.button("👎", key="mark_bad", on_click=assign_label, kwargs={"label_value": c.LABEL_BAD})
col3.progress(
st.session_state.n_checked / len(st.session_state.df)
st.session_state.n_labelled / len(st.session_state.df)
if len(st.session_state.df) > 0
else 0,
text=f"{st.session_state.n_checked}/{len(st.session_state.df)} checked",
text=f"{st.session_state.n_labelled}/{len(st.session_state.df)} labelled",
)
col4, col5, col6, col_empty = labelling_container.columns([1, 1, 2, 8])
col4, col5, col6, col7, col8 = labelling_container.columns([1, 1, 2, 4, 4])
col4.button("⬅️", key="prev_data_point", on_click=show_prev_row)
col5.button("➡️", key="next_data_point", on_click=show_next_row)
col6.write(f"#{st.session_state.row_number + 1}: {st.session_state.current_row_label}")
col7.toggle(
label="reuse past labels",
value=False,
key="reuse_past_labels",
on_change=_handle_reuse_past_labels_toggle,
)
col8.toggle(
label="skip past label rows",
value=False,
key="skip_past_label_rows",
on_change=_handle_skip_past_label_rows_toggle,
)
labelling_container.button(
"Save ⤵️", key="save_labelled_data", on_click=u.save_labelled_data, type="primary"
)
Expand All @@ -609,6 +708,8 @@ def show_dataframe():
if st.session_state.get("df") is not None:
columns_to_show = st.session_state.columns_to_show.copy()
columns_to_show.extend([c.TEXT_GENERATED_COL, c.LABEL_COL])
if st.session_state.get("reuse_past_labels", False):
columns_to_show.extend([c.REUSED_PAST_LABEL_COL])
df_to_show = st.session_state.df[columns_to_show]
else:
df_to_show = u.get_dummy_dataframe()
Expand Down
15 changes: 14 additions & 1 deletion prompterator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def ensure_legacy_datafile_has_all_columns(df):
df[c.TEXT_GENERATED_COL],
)

if c.REUSED_PAST_LABEL_COL not in df.columns:
df.insert(
df.columns.get_loc(c.LABEL_COL),
c.REUSED_PAST_LABEL_COL,
None,
)

return df


Expand Down Expand Up @@ -182,7 +189,13 @@ def generate_responses_using_parallelism(

def get_correctness_summary(df):
return "{good}/{all}".format(
good=len(df.query(f"{c.LABEL_COL} == '{c.LABEL_GOOD}'")), all=len(df)
good=len(
df.query(
f"({c.LABEL_COL} == '{c.LABEL_GOOD}') | "
f"(({c.LABEL_COL}.isnull()) & ({c.REUSED_PAST_LABEL_COL} == '{c.LABEL_GOOD}'))"
)
),
all=len(df),
)


Expand Down

0 comments on commit 7fb0269

Please sign in to comment.