From e4a543d9c8dce325aadc64961988aed8b62b0407 Mon Sep 17 00:00:00 2001 From: Sam Sucik Date: Fri, 19 Apr 2024 09:07:24 +0200 Subject: [PATCH] Add support for custom output postprocessor DATANG-3786 (#15) * Add optional postprocessing of generated texts using a customisable function * Reformat code * Cover output postprocessing in README * Fix imports' formatting * Make the postprocessor do nothing by default --- README.md | 11 +++++++++++ prompterator/constants.py | 2 ++ prompterator/main.py | 16 +++++++++++++++- prompterator/postprocess_output.py | 5 +++++ prompterator/utils.py | 11 +++++++++++ 5 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 prompterator/postprocess_output.py diff --git a/README.md b/README.md index 3f3d9ad..2290527 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,17 @@ which would lead to this in your prompt: - C ``` +### Postprocessing the model outputs + +When working with LLMs, you would often postprocess the raw generated text. Prompterator +supports this use case so that you can iterate your prompts based on inspecting/annotating +postprocessed model outputs. + +By default, no postprocessing is carried out. You can change this by +rewriting the `postprocess` function in `prompterator/postprocess_output.py`. The function will +receive one raw model-generated text at a time and should output its postprocessed version. Both +the raw and the postprocessed text are kept and saved. + ## Paper You can find more information on Prompterator in the associated paper: https://aclanthology.org/2023.emnlp-demo.43/ diff --git a/prompterator/constants.py b/prompterator/constants.py index 916ec8c..39df748 100644 --- a/prompterator/constants.py +++ b/prompterator/constants.py @@ -61,6 +61,7 @@ def call(self, input, **kwargs): RESPONSE_CREATION_TIMESTAMP_KEY = "created" TEXT_ORIG_COL = "text" +RAW_TEXT_GENERATED_COL = "raw_response" TEXT_GENERATED_COL = "response" COLS_TO_SHOW_KEY = "columns_to_show" SYSTEM_PROMPT_TEMPLATE_COL = "system_prompt_template" @@ -89,6 +90,7 @@ def call(self, input, **kwargs): } # these are the columns that users won't be able to show or inject into their prompts COLS_NOT_FOR_PROMPT_INTERPOLATION = [ + RAW_TEXT_GENERATED_COL, TEXT_GENERATED_COL, SYSTEM_PROMPT_TEMPLATE_COL, USER_PROMPT_TEMPLATE_COL, diff --git a/prompterator/main.py b/prompterator/main.py index ce0dded..9a5dc93 100644 --- a/prompterator/main.py +++ b/prompterator/main.py @@ -12,6 +12,7 @@ import prompterator.constants as c import prompterator.models as m import prompterator.utils as u +from prompterator.postprocess_output import postprocess as postprocess_generated_text # needed to use the simple custom component # from apps.scripts.components_callbacks import register_callback @@ -150,7 +151,19 @@ def run_prompt(progress_ui_area): f'Original text: "{row[c.TEXT_ORIG_COL]}"' ) - row[c.TEXT_GENERATED_COL] = results[i].get("response", "GENERATION ERROR") + row[c.RAW_TEXT_GENERATED_COL] = results[i].get("response", "GENERATION ERROR") + if "response" in results[i]: + try: + postprocessed_text = postprocess_generated_text(results[i]["response"]) + except Exception: + print( + f"Postprocessing the generated text failed. Generated text: " + f"'{results[i]}'\nException: {tb.format_exc()}" + ) + postprocessed_text = "POSTPROCESSING ERROR" + else: + postprocessed_text = "GENERATION ERROR" + row[c.TEXT_GENERATED_COL] = postprocessed_text row[c.RESPONSE_DATA_COL] = results[i].get("data") row[c.LABEL_COL] = None st.session_state.df.loc[len(st.session_state.df)] = row @@ -176,6 +189,7 @@ def load_datafiles_into_session(): for file in datafiles: if file not in st.session_state.datafiles: df, metadata = u.load_datafile(os.path.join(c.DATA_STORE_DIR, file)) + df = u.ensure_legacy_datafile_has_all_columns(df) st.session_state.datafiles[file] = { c.DATAFILE_DATA_KEY: df, c.DATAFILE_METADATA_KEY: metadata, diff --git a/prompterator/postprocess_output.py b/prompterator/postprocess_output.py new file mode 100644 index 0000000..b330b6a --- /dev/null +++ b/prompterator/postprocess_output.py @@ -0,0 +1,5 @@ +def postprocess(generated_text: str) -> str: + """ + Rewrite this function to postprocess the generated texts one by one. + """ + return generated_text diff --git a/prompterator/utils.py b/prompterator/utils.py index 218f53f..c5b15e7 100644 --- a/prompterator/utils.py +++ b/prompterator/utils.py @@ -42,6 +42,17 @@ def load_datafile(file_name): return data, contents[c.DATAFILE_METADATA_KEY] +def ensure_legacy_datafile_has_all_columns(df): + if c.RAW_TEXT_GENERATED_COL not in df.columns: + df.insert( + df.columns.get_loc(c.TEXT_GENERATED_COL), + c.RAW_TEXT_GENERATED_COL, + df[c.TEXT_GENERATED_COL], + ) + + return df + + def load_dataframe(file): df = pd.read_csv(file, index_col=0) df[c.TEXT_GENERATED_COL] = df[c.TEXT_GENERATED_COL].apply(lambda val: eval(val)[0])