Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and improve experience when using Jinja to iterate over objects DATANG-3679 #14

Merged
merged 10 commits into from
Mar 27, 2024
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,46 @@ The image will be rendered inside the displayed dataframe and next to the "gener

(*Note: you also need an `OPENAI_API_KEY` environment variable to use `gpt-4-vision-preview`*)

## Usage guide

### Input format

Prompterator accepts CSV files as input. Additionally, the CSV data should follow these rules:
- be parseable using a
[`pd.read_csv`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html)
call with the default argument values. This means e.g. having **column names** in the first row,
using **comma** as the separator, and enclosing values (where needed) in **double quotes** (`"`)
- have a column named `text`

### Using input data in prompts

The user/system prompt textboxes support [Jinja](https://jinja.palletsprojects.com/) templates.
Given a column named `text` in your uploaded CSV data, you can include values from this column by
writing the simple `{{text}}` template in your prompt.

If the values in your column represent more complex objects, you can still work with them but make
sure they are either valid JSON strings or valid Python expressions accepted by
[`ast.literal_eval`](https://docs.python.org/3/library/ast.html#ast.literal_eval).

To parse string representations of objects, use:
- `fromjson`: for valid JSON strings, e.g. `'["A", "B"]'`
- `fromAstString`: for Python expressions such as dicts/lists/tuples/... (see the accepted types of
[`ast.literal_eval`](https://docs.python.org/3/library/ast.html#ast.literal_eval)), e.g. `"{'key': 'value'}"`

For example, given a CSV column `texts` with a value `"[""A"", ""B"", ""C""]"`, you can utilise this template to enumerate the individual list items
in your prompt:
```jinja
{% for item in fromjson(texts) -%}
- {{ item }}
{% endfor %}
```
which would lead to this in your prompt:
```
- A
- B
- C
```

## Paper

You can find more information on Prompterator in the associated paper: https://aclanthology.org/2023.emnlp-demo.43/
Expand Down
22 changes: 16 additions & 6 deletions prompterator/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import traceback as tb
from collections import OrderedDict
from datetime import datetime

Expand Down Expand Up @@ -113,12 +114,21 @@ def run_prompt(progress_ui_area):
model_instance = m.MODEL_INSTANCES[model.name]
model_params = {param: st.session_state[param] for param in model.configurable_params}
df_old = st.session_state.df.copy()
model_inputs = {
i: u.create_model_input(
model, model_instance, user_prompt_template, system_prompt_template, row

try:
model_inputs = {
i: u.create_model_input(
model, model_instance, user_prompt_template, system_prompt_template, row
)
for i, row in df_old.iterrows()
}
except Exception as e:
traceback = u.format_traceback_for_markdown(tb.format_exc())
st.error(
f"Couldn't prepare model inputs due to this error: {e}\n\nFull error "
f"message:\n\n{traceback}"
)
for i, row in df_old.iterrows()
}
return

if len(model_inputs) == 0:
st.error("No input data to generate texts from!")
Expand Down Expand Up @@ -604,7 +614,7 @@ def show_dataframe():

def process_uploaded_file():
if st.session_state.uploaded_file is not None:
df = pd.read_csv(st.session_state.uploaded_file, header=0)
df = pd.read_csv(st.session_state.uploaded_file)
assert c.TEXT_ORIG_COL in df.columns
st.session_state.responses_generated_externally = c.TEXT_GENERATED_COL in df.columns
initialise_session_from_uploaded_file(df)
Expand Down
33 changes: 30 additions & 3 deletions prompterator/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import ast
import concurrent.futures
import itertools
import json
import logging
import os
import re
import socket
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from datetime import datetime
from functools import partial
from typing import Any

import jinja2
import openai
Expand Down Expand Up @@ -233,11 +236,30 @@ def create_model_input(

@st.cache_resource
def jinja_env() -> jinja2.Environment:
def from_json(text: str):
return json.loads(text)
def fromjson(text: str) -> Any:
try:
return json.loads(text)
except json.decoder.JSONDecodeError as e:
raise ValueError(
f"The string you passed into `fromjson` is not a valid JSON string: " f"`{text}`"
) from e

def fromAstString(text: str) -> Any:
try:
return ast.literal_eval(text)
except Exception as e:
raise ValueError(
f"The string you passed into `fromAstString` is not a valid "
f"input: `{text}`. Generally, try passing a valid string "
f"representation of a "
f"Python dictionary/list/set/tuple or other simple types. For more "
f"details, refer to "
f"[`ast.literal_eval`](https://docs.python.org/3/library/ast.html#ast.literal_eval)."
) from e

env = jinja2.Environment()
env.globals["fromjson"] = from_json
env.globals["fromjson"] = fromjson
env.globals["fromAstString"] = fromAstString
return env


Expand Down Expand Up @@ -311,3 +333,8 @@ def insert_hidden_html_marker(helper_element_id, target_streamlit_element=None):
""",
unsafe_allow_html=True,
)


def format_traceback_for_markdown(text):
text = re.sub(r" ", " ", text)
samsucik marked this conversation as resolved.
Show resolved Hide resolved
return re.sub(r"\n", "\n\n", text)
Loading