Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Display values of variables used in a prompt #4

Merged
merged 18 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 29 additions & 55 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion prompterator/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def call(self, input, **kwargs):
TEXT_DIFF_COLOURS = {"add": "#56E7AB", "delete": "#FE8080"}

DEFAULT_ROW_NO = 0
DATA_POINT_TEXT_AREA_HEIGHT = 130
DATA_POINT_TEXT_AREA_HEIGHT = 180
PROMPT_TEXT_AREA_HEIGHT = 300

DATAFILE_FILTER_ALL = "all"
87 changes: 66 additions & 21 deletions prompterator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import pandas as pd
import streamlit as st
import streamlit_toggle as tog
from diff_match_patch import diff_match_patch
from jinja2 import meta

import prompterator.constants as c
import prompterator.models as m
Expand Down Expand Up @@ -447,16 +447,61 @@ def _get_coloured_patch(patch):
)


def set_up_prompt_attrs_area(st_container):
env = u.jinja_env()
parsed_content = env.parse(st.session_state.system_prompt)
vars = meta.find_undeclared_variables(parsed_content)
samsucik marked this conversation as resolved.
Show resolved Hide resolved

if c.TEXT_ORIG_COL in vars:
vars.remove(c.TEXT_ORIG_COL)

if len(vars) > 0:
# create text of used prompt's variables and their values
vars_values = ""
for var in vars:
vars_values += var + ":\n " + st.session_state.row.get(var, "none") + "\n"

st_container.text_area(
label=f"Attributes used in a prompt",
key="attributes",
value=vars_values,
disabled=True,
height=c.DATA_POINT_TEXT_AREA_HEIGHT,
)


def set_up_ui_labelling():
col1, col2 = st.columns([1, 1])
col1_orig, col2_orig = st.columns([1, 1])
text_orig_length = len(st.session_state.get("text_orig", ""))
col1.text_area(
col1_orig.text_area(
label=f"Original text ({text_orig_length} chars)",
key="text_orig",
disabled=True,
height=c.DATA_POINT_TEXT_AREA_HEIGHT,
)
labelling_container = col2.container()
set_up_prompt_attrs_area(col2_orig)

labeling_area = st.container()
u.insert_hidden_html_marker(
helper_element_id="labeling-area-marker", target_streamlit_element=labeling_area
)

st.markdown(
"""
<style>
/* use the helper elements of the main UI area and of the labeling area */
/* to create a relatively nice selector */
[data-testid="stVerticalBlock"]:has(div#main-ui-area-marker) [data-testid="stVerticalBlock"]:has(div#labeling-area-marker) {
padding: 10px;
border-radius: 10px;
border: 4px solid rgba(10, 199, 120, 0.68);
}
</style>
""",
unsafe_allow_html=True,
)
col1_label, col2_label = labeling_area.columns([1, 1])
generated_text_area = col1_label.container()
text_generated_length = len(st.session_state.get("text_generated", ""))
length_change_percentage = (text_generated_length - text_orig_length) / text_orig_length * 100
length_change_percentage_str = (
Expand All @@ -467,19 +512,24 @@ def set_up_ui_labelling():
)

if not st.session_state.get("show_diff", False):
labelling_container.text_area(
generated_text_area.text_area(
label=generated_text_label,
key="text_generated",
value=st.session_state.get("text_generated", ""),
disabled=True,
height=c.DATA_POINT_TEXT_AREA_HEIGHT,
)
else:
labelling_container.markdown(
generated_text_area.markdown(
create_diff_viewer(generated_text_label), unsafe_allow_html=True
)

col1, col2, col3, col4, col5, col6, col7 = labelling_container.columns([1, 1, 5, 1, 1, 1, 2])
with generated_text_area:
st.toggle(label="show diff", value=False, key="show_diff")

labelling_container = col2_label.container()
labelling_container.markdown("##")
samsucik marked this conversation as resolved.
Show resolved Hide resolved
col1, col2, col3 = labelling_container.columns([1, 1, 10])
col1.button(
"👍",
key="mark_good",
Expand All @@ -497,21 +547,13 @@ def set_up_ui_labelling():
else 0,
text=f"{st.session_state.n_checked}/{len(st.session_state.df)} checked",
)
col4, col5, col6, col_empty = labelling_container.columns([1, 1, 2, 8])
col4.button("⬅️", key="prev_data_point", on_click=show_prev_row)
col5.write(f"#{st.session_state.row_number + 1}: {st.session_state.current_row_label}")
col6.button("➡️", key="next_data_point", on_click=show_next_row)
col7.button("Save ⤵️", key="save_labelled_data", on_click=u.save_labelled_data, type="primary")

with labelling_container:
tog.st_toggle_switch(
label="show diff",
key="show_diff",
default_value=False,
label_after=False,
inactive_color="#D3D3D3",
active_color="#11567f",
track_color="#29B5E8",
)
col5.button("➡️", key="next_data_point", on_click=show_next_row)
col6.write(f"#{st.session_state.row_number + 1}: {st.session_state.current_row_label}")
labelling_container.button(
"Save ⤵️", key="save_labelled_data", on_click=u.save_labelled_data, type="primary"
)


def show_col_selection():
Expand Down Expand Up @@ -574,6 +616,9 @@ def process_uploaded_file():
on_change=process_uploaded_file,
)

# create a helper element at the top of the main UI section to later help us target the area in
# selectors
u.insert_hidden_html_marker(helper_element_id="main-ui-area-marker")

u.ensure_datafiles_directory_exists()
load_datafiles_into_session()
Expand Down
25 changes: 25 additions & 0 deletions prompterator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,28 @@ def get_dummy_dataframe():
assert set(df.columns) == set(c.DUMMY_DATA_COLS)

return df


def insert_hidden_html_marker(helper_element_id, target_streamlit_element=None):
"""
Because targeting streamlit elements (e.g. to style them) is hard, we use a trick.

We create a dummy child elements with a known ID that we can easily target.
"""
if target_streamlit_element:
with target_streamlit_element:
st.markdown(f"""<div id='{helper_element_id}'/>""", unsafe_allow_html=True)
else:
st.markdown(f"""<div id='{helper_element_id}'/>""", unsafe_allow_html=True)

st.markdown(
f"""
<style>
/* hide the dummy element */
div:has(> div.stMarkdown > div[data-testid="stMarkdownContainer"] > div#{helper_element_id}) {{
display: none;
}}
</style>
""",
unsafe_allow_html=True,
)
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ packages = [{include = "prompterator"}]
python = "^3.10"
pydantic = "^1.10.7"
openai = "^0.27.6"
streamlit = "^1.22.0"
streamlit-toggle-switch = "^1.0.2"
streamlit = "^1.28.0"
diff-match-patch = "^20230430"
jinja2 = "^3.1.2"

Expand Down
Loading