Skip to content

Commit

Permalink
add: Display rendered System and User prompts
Browse files Browse the repository at this point in the history
* Ensure that when jinja2 is used, the rendered versions of the System
  and User prompts are also shown in the UI.

Signed-off-by: mrshu <[email protected]>
  • Loading branch information
mrshu authored and samsucik committed Aug 27, 2024
1 parent 7fb0269 commit 607ba92
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 16 deletions.
1 change: 1 addition & 0 deletions prompterator/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,6 @@ def call(self, input, **kwargs):
DEFAULT_ROW_NO = 0
DATA_POINT_TEXT_AREA_HEIGHT = 180
PROMPT_TEXT_AREA_HEIGHT = 300
PROMPT_PREVIEW_TEXT_AREA_HEIGHT = 200

DATAFILE_FILTER_ALL = "all"
88 changes: 72 additions & 16 deletions prompterator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def initialise_session_from_uploaded_file(df):

st.session_state["df"] = df
st.session_state[c.COLS_TO_SHOW_KEY] = [c.TEXT_ORIG_COL]
st.session_state.row = st.session_state.df.iloc[c.DEFAULT_ROW_NO]

if st.session_state.responses_generated_externally:
st.session_state.enable_labelling = True
Expand Down Expand Up @@ -432,17 +433,36 @@ def set_up_ui_generation():
height=c.PROMPT_TEXT_AREA_HEIGHT,
disabled=not model_supports_user_prompt,
)
col1, col2 = st.columns([1, 2])

if "df" in st.session_state:
cols_for_interpolation = set(st.session_state.df.columns).difference(
c.COLS_NOT_FOR_PROMPT_INTERPOLATION
prompt_parsing_error_message_area = st.empty()
col1, col2 = st.columns([3, 1])
cols_for_interpolation = list(
set(st.session_state.df.columns).difference(c.COLS_NOT_FOR_PROMPT_INTERPOLATION)
)
col1.write(
f"These are the columns available in the data, feel free to include them in "
f"your prompt: {cols_for_interpolation}"
f"your prompt(s): `{'`, `'.join(cols_for_interpolation)}`."
)
with col2:
st.toggle(label="show prompt preview", value=False, key="show_prompt_preview")

if st.session_state.show_prompt_preview:
col1, col2 = st.columns([3, 2])
set_up_prompt_preview(
col1,
st.session_state.system_prompt,
prompt_parsing_error_message_area,
prompt_kind="system",
)
set_up_prompt_preview(
col2,
st.session_state.user_prompt,
prompt_parsing_error_message_area,
prompt_kind="user",
)

col2.button(
st.button(
label="Run prompt",
on_click=run_prompt,
kwargs={"progress_ui_area": progress_ui_area},
Expand Down Expand Up @@ -488,29 +508,64 @@ def _get_coloured_patch(patch):
)


def set_up_prompt_attrs_area(st_container):
env = u.jinja_env()
parsed_content = env.parse(st.session_state.system_prompt)
vars = meta.find_undeclared_variables(parsed_content)
def set_up_prompt_vars_area(st_container, error_container):
try:
parsed_content = u.jinja_env().parse(
st.session_state.system_prompt + st.session_state.user_prompt
)
used_vars = meta.find_undeclared_variables(parsed_content)
except Exception as e:
traceback = u.format_traceback_for_markdown(tb.format_exc())
error_container.error(
f"Couldn't parse the Jinja templates in the prompt(s). Ensure the "
f"templates are valid. Short error message: {e}\n\n"
f"Full error message:\n\n{traceback}"
)
used_vars = set()

if c.TEXT_ORIG_COL in vars:
vars.remove(c.TEXT_ORIG_COL)
if c.TEXT_ORIG_COL in used_vars:
used_vars.remove(c.TEXT_ORIG_COL)

if len(vars) > 0:
if len(used_vars) > 0:
# create text of used prompt's variables and their values
vars_values = ""
for var in vars:
for var in used_vars:
vars_values += var + ":\n " + st.session_state.row.get(var, "none") + "\n"

st_container.text_area(
label=f"Attributes used in a prompt",
label=f"Attributes other than `{c.TEXT_ORIG_COL}` used in the prompt(s)",
key="attributes",
value=vars_values,
disabled=True,
height=c.DATA_POINT_TEXT_AREA_HEIGHT,
)


def set_up_prompt_preview(st_container, prompt, error_container, prompt_kind="system"):
if "row" in st.session_state:
try:
prompt = u.jinja_env().from_string(prompt).render(**st.session_state.row.to_dict())
except Exception as e:
traceback = u.format_traceback_for_markdown(tb.format_exc())
error_container.error(
f"Couldn't show {prompt_kind} prompt preview due to an error "
f"when parsing and rendering the Jinja templates. Ensure that "
f"your templates are valid. Short error message: {e}\n\n"
f"Full error message:\n\n{traceback}"
)
prompt = "ERROR"
else:
prompt = ""

st_container.text_area(
label=f"{prompt_kind.title()} prompt preview",
key=f"{prompt_kind}_prompt_jinja2",
value=prompt,
disabled=True,
height=c.PROMPT_PREVIEW_TEXT_AREA_HEIGHT,
)


def display_image(st_container, base64_str):
st_container.markdown(
f"""
Expand Down Expand Up @@ -592,15 +647,16 @@ def _handle_skip_past_label_rows_toggle():


def set_up_ui_labelling():
prompt_parsing_error_message_area = st.empty()
col1_orig, col2_orig = st.columns([1, 1])
text_orig_length = len(st.session_state.get("text_orig", ""))
col1_orig.text_area(
label=f"Original text ({text_orig_length} chars)",
label=f"Original text (`{c.TEXT_ORIG_COL}` column) ({text_orig_length} chars)",
key="text_orig",
disabled=True,
height=c.DATA_POINT_TEXT_AREA_HEIGHT,
)
set_up_prompt_attrs_area(col2_orig)
set_up_prompt_vars_area(col2_orig, prompt_parsing_error_message_area)

if "image" in st.session_state.row:
display_image(col2_orig, st.session_state.row["image"])
Expand Down

0 comments on commit 607ba92

Please sign in to comment.