Skip to content

Commit

Permalink
Fix and improve experience when using Jinja to iterate over objects D…
Browse files Browse the repository at this point in the history
…ATANG-3679 (#14)

* Add another Jinja function for parsing AST strings that aren't valid JSON strings

* Show meaningful error message instead of breaking Prompterator when there are issues with usage of fromjson or fromAstString in Jinja templates

* Format code

* Add brief documentation in README for using Jinja templates in prompts

* Improve the Jinja example in README

* Further improve the Jinja templating documentation

* Fix reading of backslash-escaped stuff from CSV files

* Improve Jinja docs in README

* Improve naming and remove useless arg

* Simplify Jinja-related docs and add some general usage tips
  • Loading branch information
samsucik authored Mar 27, 2024
1 parent a0ad185 commit 1ec8472
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 9 deletions.
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,46 @@ The image will be rendered inside the displayed dataframe and next to the "gener

(*Note: you also need an `OPENAI_API_KEY` environment variable to use `gpt-4-vision-preview`*)

## Usage guide

### Input format

Prompterator accepts CSV files as input. Additionally, the CSV data should follow these rules:
- be parseable using a
[`pd.read_csv`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html)
call with the default argument values. This means e.g. having **column names** in the first row,
using **comma** as the separator, and enclosing values (where needed) in **double quotes** (`"`)
- have a column named `text`
### Using input data in prompts
The user/system prompt textboxes support [Jinja](https://jinja.palletsprojects.com/) templates.
Given a column named `text` in your uploaded CSV data, you can include values from this column by
writing the simple `{{text}}` template in your prompt.
If the values in your column represent more complex objects, you can still work with them but make
sure they are either valid JSON strings or valid Python expressions accepted by
[`ast.literal_eval`](https://docs.python.org/3/library/ast.html#ast.literal_eval).
To parse string representations of objects, use:
- `fromjson`: for valid JSON strings, e.g. `'["A", "B"]'`
- `fromAstString`: for Python expressions such as dicts/lists/tuples/... (see the accepted types of
[`ast.literal_eval`](https://docs.python.org/3/library/ast.html#ast.literal_eval)), e.g. `"{'key': 'value'}"`
For example, given a CSV column `texts` with a value `"[""A"", ""B"", ""C""]"`, you can utilise this template to enumerate the individual list items
in your prompt:
```jinja
{% for item in fromjson(texts) -%}
- {{ item }}
{% endfor %}
```
which would lead to this in your prompt:
```
- A
- B
- C
```
## Paper
You can find more information on Prompterator in the associated paper: https://aclanthology.org/2023.emnlp-demo.43/
Expand Down
22 changes: 16 additions & 6 deletions prompterator/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import traceback as tb
from collections import OrderedDict
from datetime import datetime

Expand Down Expand Up @@ -113,12 +114,21 @@ def run_prompt(progress_ui_area):
model_instance = m.MODEL_INSTANCES[model.name]
model_params = {param: st.session_state[param] for param in model.configurable_params}
df_old = st.session_state.df.copy()
model_inputs = {
i: u.create_model_input(
model, model_instance, user_prompt_template, system_prompt_template, row

try:
model_inputs = {
i: u.create_model_input(
model, model_instance, user_prompt_template, system_prompt_template, row
)
for i, row in df_old.iterrows()
}
except Exception as e:
traceback = u.format_traceback_for_markdown(tb.format_exc())
st.error(
f"Couldn't prepare model inputs due to this error: {e}\n\nFull error "
f"message:\n\n{traceback}"
)
for i, row in df_old.iterrows()
}
return

if len(model_inputs) == 0:
st.error("No input data to generate texts from!")
Expand Down Expand Up @@ -604,7 +614,7 @@ def show_dataframe():

def process_uploaded_file():
if st.session_state.uploaded_file is not None:
df = pd.read_csv(st.session_state.uploaded_file, header=0)
df = pd.read_csv(st.session_state.uploaded_file)
assert c.TEXT_ORIG_COL in df.columns
st.session_state.responses_generated_externally = c.TEXT_GENERATED_COL in df.columns
initialise_session_from_uploaded_file(df)
Expand Down
33 changes: 30 additions & 3 deletions prompterator/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import ast
import concurrent.futures
import itertools
import json
import logging
import os
import re
import socket
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from datetime import datetime
from functools import partial
from typing import Any

import jinja2
import openai
Expand Down Expand Up @@ -233,11 +236,30 @@ def create_model_input(

@st.cache_resource
def jinja_env() -> jinja2.Environment:
def from_json(text: str):
return json.loads(text)
def fromjson(text: str) -> Any:
try:
return json.loads(text)
except json.decoder.JSONDecodeError as e:
raise ValueError(
f"The string you passed into `fromjson` is not a valid JSON string: " f"`{text}`"
) from e

def fromAstString(text: str) -> Any:
try:
return ast.literal_eval(text)
except Exception as e:
raise ValueError(
f"The string you passed into `fromAstString` is not a valid "
f"input: `{text}`. Generally, try passing a valid string "
f"representation of a "
f"Python dictionary/list/set/tuple or other simple types. For more "
f"details, refer to "
f"[`ast.literal_eval`](https://docs.python.org/3/library/ast.html#ast.literal_eval)."
) from e

env = jinja2.Environment()
env.globals["fromjson"] = from_json
env.globals["fromjson"] = fromjson
env.globals["fromAstString"] = fromAstString
return env


Expand Down Expand Up @@ -311,3 +333,8 @@ def insert_hidden_html_marker(helper_element_id, target_streamlit_element=None):
""",
unsafe_allow_html=True,
)


def format_traceback_for_markdown(text):
text = re.sub(r" ", " ", text)
return re.sub(r"\n", "\n\n", text)

0 comments on commit 1ec8472

Please sign in to comment.