Skip to content

Commit

Permalink
Add support for custom output postprocessor DATANG-3786 (#15)
Browse files Browse the repository at this point in the history
* Add optional postprocessing of generated texts using a customisable function

* Reformat code

* Cover output postprocessing in README

* Fix imports' formatting

* Make the postprocessor do nothing by default
  • Loading branch information
samsucik committed Apr 19, 2024
1 parent 1ec8472 commit e4a543d
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 1 deletion.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ which would lead to this in your prompt:
- C
```
### Postprocessing the model outputs
When working with LLMs, you would often postprocess the raw generated text. Prompterator
supports this use case so that you can iterate your prompts based on inspecting/annotating
postprocessed model outputs.
By default, no postprocessing is carried out. You can change this by
rewriting the `postprocess` function in `prompterator/postprocess_output.py`. The function will
receive one raw model-generated text at a time and should output its postprocessed version. Both
the raw and the postprocessed text are kept and saved.
## Paper
You can find more information on Prompterator in the associated paper: https://aclanthology.org/2023.emnlp-demo.43/
Expand Down
2 changes: 2 additions & 0 deletions prompterator/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def call(self, input, **kwargs):
RESPONSE_CREATION_TIMESTAMP_KEY = "created"

TEXT_ORIG_COL = "text"
RAW_TEXT_GENERATED_COL = "raw_response"
TEXT_GENERATED_COL = "response"
COLS_TO_SHOW_KEY = "columns_to_show"
SYSTEM_PROMPT_TEMPLATE_COL = "system_prompt_template"
Expand Down Expand Up @@ -89,6 +90,7 @@ def call(self, input, **kwargs):
}
# these are the columns that users won't be able to show or inject into their prompts
COLS_NOT_FOR_PROMPT_INTERPOLATION = [
RAW_TEXT_GENERATED_COL,
TEXT_GENERATED_COL,
SYSTEM_PROMPT_TEMPLATE_COL,
USER_PROMPT_TEMPLATE_COL,
Expand Down
16 changes: 15 additions & 1 deletion prompterator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import prompterator.constants as c
import prompterator.models as m
import prompterator.utils as u
from prompterator.postprocess_output import postprocess as postprocess_generated_text

# needed to use the simple custom component
# from apps.scripts.components_callbacks import register_callback
Expand Down Expand Up @@ -150,7 +151,19 @@ def run_prompt(progress_ui_area):
f'Original text: "{row[c.TEXT_ORIG_COL]}"'
)

row[c.TEXT_GENERATED_COL] = results[i].get("response", "GENERATION ERROR")
row[c.RAW_TEXT_GENERATED_COL] = results[i].get("response", "GENERATION ERROR")
if "response" in results[i]:
try:
postprocessed_text = postprocess_generated_text(results[i]["response"])
except Exception:
print(
f"Postprocessing the generated text failed. Generated text: "
f"'{results[i]}'\nException: {tb.format_exc()}"
)
postprocessed_text = "POSTPROCESSING ERROR"
else:
postprocessed_text = "GENERATION ERROR"
row[c.TEXT_GENERATED_COL] = postprocessed_text
row[c.RESPONSE_DATA_COL] = results[i].get("data")
row[c.LABEL_COL] = None
st.session_state.df.loc[len(st.session_state.df)] = row
Expand All @@ -176,6 +189,7 @@ def load_datafiles_into_session():
for file in datafiles:
if file not in st.session_state.datafiles:
df, metadata = u.load_datafile(os.path.join(c.DATA_STORE_DIR, file))
df = u.ensure_legacy_datafile_has_all_columns(df)
st.session_state.datafiles[file] = {
c.DATAFILE_DATA_KEY: df,
c.DATAFILE_METADATA_KEY: metadata,
Expand Down
5 changes: 5 additions & 0 deletions prompterator/postprocess_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def postprocess(generated_text: str) -> str:
"""
Rewrite this function to postprocess the generated texts one by one.
"""
return generated_text
11 changes: 11 additions & 0 deletions prompterator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def load_datafile(file_name):
return data, contents[c.DATAFILE_METADATA_KEY]


def ensure_legacy_datafile_has_all_columns(df):
if c.RAW_TEXT_GENERATED_COL not in df.columns:
df.insert(
df.columns.get_loc(c.TEXT_GENERATED_COL),
c.RAW_TEXT_GENERATED_COL,
df[c.TEXT_GENERATED_COL],
)

return df


def load_dataframe(file):
df = pd.read_csv(file, index_col=0)
df[c.TEXT_GENERATED_COL] = df[c.TEXT_GENERATED_COL].apply(lambda val: eval(val)[0])
Expand Down

0 comments on commit e4a543d

Please sign in to comment.