Skip to content

Commit

Permalink
Enable reusing past labels
Browse files Browse the repository at this point in the history
  • Loading branch information
samsucik committed Apr 19, 2024
1 parent e4a543d commit 4428106
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 9 deletions.
2 changes: 2 additions & 0 deletions prompterator/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def call(self, input, **kwargs):
USER_PROMPT_TEMPLATE_COL = "user_prompt_template"
RESPONSE_DATA_COL = "response_data"
LABEL_COL = "human_label"
REUSED_PAST_LABEL_COL = "reused_label"
TIMESTAMP_COL = "timestamps"
MODEL_KEY = "model"
PROMPT_CREATOR_KEY = "creator"
Expand Down Expand Up @@ -96,6 +97,7 @@ def call(self, input, **kwargs):
USER_PROMPT_TEMPLATE_COL,
RESPONSE_DATA_COL,
LABEL_COL,
REUSED_PAST_LABEL_COL,
]
LABEL_GOOD = "good"
LABEL_BAD = "bad"
Expand Down
103 changes: 95 additions & 8 deletions prompterator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,21 @@ def update_displayed_data_point():


def show_next_row():
if st.session_state.row_number < st.session_state.n_data_points - 1:
st.session_state.row_number = st.session_state.row_number + 1
if (
st.session_state.available_rows.index(st.session_state.row_number)
< len(st.session_state.available_rows) - 1
):
st.session_state.row_number = st.session_state.available_rows[
st.session_state.available_rows.index(st.session_state.row_number) + 1
]
update_displayed_data_point()


def show_prev_row():
if st.session_state.row_number > 0:
st.session_state.row_number = st.session_state.row_number - 1
if st.session_state.available_rows.index(st.session_state.row_number) > 0:
st.session_state.row_number = st.session_state.available_rows[
st.session_state.available_rows.index(st.session_state.row_number) - 1
]
update_displayed_data_point()


Expand Down Expand Up @@ -86,14 +93,21 @@ def initialise_labelling():
text_orig=text_orig,
text_generated=text_generated,
n_data_points=len(st.session_state.df),
available_rows=list(range(len(st.session_state.df))),
)

if st.session_state.responses_generated_externally:
st.session_state[c.MODEL_KEY] = m.MODELS[c.UNKNOWN_MODEL_NAME].copy()


def set_up_dynamic_session_state_vars():
st.session_state.n_checked = len(st.session_state.df.query(f"{c.LABEL_COL}.notnull()"))
st.session_state.n_labelled = len(
st.session_state.df.query(
f"{c.LABEL_COL}.notnull() | {c.REUSED_PAST_LABEL_COL}.notnull()"
if st.session_state.get("reuse_past_labels", False)
else f"{c.LABEL_COL}.notnull()"
)
)

# we need to initialise this one, too, because it wouldn't persist in session_state in the
# cases where no element with key `text_generated` exists -- when the diff viewer is shown.
Expand Down Expand Up @@ -213,9 +227,12 @@ def show_selected_datafile(file_name):
text_orig=text_orig,
text_generated=text_generated,
n_data_points=len(df),
available_rows=list(range(len(df))),
user_prompt=metadata[c.USER_PROMPT_TEMPLATE_COL],
system_prompt=metadata[c.SYSTEM_PROMPT_TEMPLATE_COL],
columns_to_show=metadata.get(c.COLS_TO_SHOW_KEY, [c.TEXT_ORIG_COL]),
reuse_past_labels=False,
skip_past_label_rows=False,
**{
c.PROMPT_NAME_KEY: metadata[c.PROMPT_NAME_KEY],
c.PROMPT_COMMENT_KEY: metadata[c.PROMPT_COMMENT_KEY],
Expand Down Expand Up @@ -504,6 +521,62 @@ def display_image(st_container, base64_str):
)


def _get_input_columns_from_df(df):
return [col for col in df.columns.tolist() if col not in c.COLS_NOT_FOR_PROMPT_INTERPOLATION]


def _get_past_labels():
relevant_columns = _get_input_columns_from_df(st.session_state.df) + [c.TEXT_GENERATED_COL]
current_rows = [tuple(row) for _, row in st.session_state.df[relevant_columns].iterrows()]
past_labels_for_current_rows = {row: None for row in current_rows}
relevant_columns_set = set(relevant_columns)
for name, datafile in st.session_state.datafiles.items():
df = datafile[c.DATAFILE_DATA_KEY]

if not relevant_columns_set.issubset(set(df.columns)):
continue

for _, row in df.iterrows():
if (
tuple(row[relevant_columns]) in current_rows
and past_labels_for_current_rows[tuple(row[relevant_columns])] is None
):
past_label = row[c.LABEL_COL] or row[c.REUSED_PAST_LABEL_COL]
if past_label is not None:
past_labels_for_current_rows[tuple(row[relevant_columns])] = past_label

return [past_labels_for_current_rows[row] for row in current_rows]


def _handle_reuse_past_labels_toggle():
if st.session_state.reuse_past_labels:
past_labels = _get_past_labels()
st.session_state.df[c.REUSED_PAST_LABEL_COL] = past_labels


def _handle_skip_past_label_rows_toggle():
if st.session_state.skip_past_label_rows:
available_rows = [
i
for i in range(len(st.session_state.df))
if st.session_state.df.iloc[i][c.REUSED_PAST_LABEL_COL] is None
]

# ensure we've got at least one available row so we can display something in the UI
if not available_rows:
available_rows = [st.session_state.row_number]

st.session_state.available_rows = available_rows

st.session_state.row_number = [
row_idx
for row_idx in st.session_state.available_rows
if row_idx >= st.session_state.row_number
][0]
else:
st.session_state.available_rows = list(range(len(st.session_state.df)))


def set_up_ui_labelling():
col1_orig, col2_orig = st.columns([1, 1])
text_orig_length = len(st.session_state.get("text_orig", ""))
Expand Down Expand Up @@ -579,15 +652,27 @@ def set_up_ui_labelling():
# on_click=assign_label, kwargs={"label_value": label_good})
col2.button("👎", key="mark_bad", on_click=assign_label, kwargs={"label_value": c.LABEL_BAD})
col3.progress(
st.session_state.n_checked / len(st.session_state.df)
st.session_state.n_labelled / len(st.session_state.df)
if len(st.session_state.df) > 0
else 0,
text=f"{st.session_state.n_checked}/{len(st.session_state.df)} checked",
text=f"{st.session_state.n_labelled}/{len(st.session_state.df)} labelled",
)
col4, col5, col6, col_empty = labelling_container.columns([1, 1, 2, 8])
col4, col5, col6, col7, col8 = labelling_container.columns([1, 1, 2, 4, 4])
col4.button("⬅️", key="prev_data_point", on_click=show_prev_row)
col5.button("➡️", key="next_data_point", on_click=show_next_row)
col6.write(f"#{st.session_state.row_number + 1}: {st.session_state.current_row_label}")
col7.toggle(
label="reuse past labels",
value=False,
key="reuse_past_labels",
on_change=_handle_reuse_past_labels_toggle,
)
col8.toggle(
label="skip past label rows",
value=False,
key="skip_past_label_rows",
on_change=_handle_skip_past_label_rows_toggle,
)
labelling_container.button(
"Save ⤵️", key="save_labelled_data", on_click=u.save_labelled_data, type="primary"
)
Expand All @@ -609,6 +694,8 @@ def show_dataframe():
if st.session_state.get("df") is not None:
columns_to_show = st.session_state.columns_to_show.copy()
columns_to_show.extend([c.TEXT_GENERATED_COL, c.LABEL_COL])
if st.session_state.get("reuse_past_labels", False):
columns_to_show.extend([c.REUSED_PAST_LABEL_COL])
df_to_show = st.session_state.df[columns_to_show]
else:
df_to_show = u.get_dummy_dataframe()
Expand Down
15 changes: 14 additions & 1 deletion prompterator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def ensure_legacy_datafile_has_all_columns(df):
df[c.TEXT_GENERATED_COL],
)

if c.REUSED_PAST_LABEL_COL not in df.columns:
df.insert(
df.columns.get_loc(c.LABEL_COL),
c.REUSED_PAST_LABEL_COL,
None,
)

return df


Expand Down Expand Up @@ -182,7 +189,13 @@ def generate_responses_using_parallelism(

def get_correctness_summary(df):
return "{good}/{all}".format(
good=len(df.query(f"{c.LABEL_COL} == '{c.LABEL_GOOD}'")), all=len(df)
good=len(
df.query(
f"(({c.LABEL_COL}.notnull()) & ({c.LABEL_COL} == '{c.LABEL_GOOD}')) | "
f"({c.REUSED_PAST_LABEL_COL} == '{c.LABEL_GOOD}')"
)
),
all=len(df),
)


Expand Down

0 comments on commit 4428106

Please sign in to comment.