From 1536f832177b65ebbc75f5dcf74eebe06a293f86 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 6 Nov 2023 18:33:23 +0100 Subject: [PATCH] Add page for randomized user testing --- api/main.py | 20 +++++++++ .../versions/5813982e9665_ratings_table.py | 33 ++++++++++++++ api/src/stampy_chat/db/models.py | 23 ++++++++++ web/src/components/settings.tsx | 10 ++--- web/src/hooks/useSettings.ts | 43 +++++++++++++++++-- web/src/styles/globals.css | 9 ++++ 6 files changed, 129 insertions(+), 9 deletions(-) create mode 100644 api/migrations/versions/5813982e9665_ratings_table.py diff --git a/api/main.py b/api/main.py index 51b396c..912778f 100644 --- a/api/main.py +++ b/api/main.py @@ -11,6 +11,8 @@ from stampy_chat.chat import run_query from stampy_chat.callbacks import stream_callback from stampy_chat.citations import get_top_k_blocks +from stampy_chat.db.session import make_session +from stampy_chat.db.models import Rating # ---------------------------------- web setup --------------------------------- @@ -96,5 +98,23 @@ def human(id): # ------------------------------------------------------------------------------ +@app.route('/ratings', methods=['POST']) +@cross_origin() +def ratings(): + session_id = request.json.get('sessionId') + settings = request.json.get('settings', {}) + score = request.json.get('score') + + if not session_id or score is None: + return Response('{"error": "missing params}', 400, mimetype='application/json') + + with make_session() as s: + s.add(Rating(session_id=session_id, score=score, settings=json.dumps(settings))) + s.commit() + + return jsonify({'status': 'ok'}) + + + if __name__ == '__main__': app.run(debug=True, port=FLASK_PORT) diff --git a/api/migrations/versions/5813982e9665_ratings_table.py b/api/migrations/versions/5813982e9665_ratings_table.py new file mode 100644 index 0000000..bf1d4f3 --- /dev/null +++ b/api/migrations/versions/5813982e9665_ratings_table.py @@ -0,0 +1,33 @@ +"""Ratings table + +Revision ID: 5813982e9665 +Revises: 78806d965229 +Create Date: 2023-11-06 17:31:47.814226 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql +from stampy_chat.db.models import UUID + +# revision identifiers, used by Alembic. +revision = '5813982e9665' +down_revision = '78806d965229' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + 'ratings', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('session_id', UUID(length=16), nullable=False), + sa.Column('score', sa.Integer(), nullable=False), + sa.Column('comment', mysql.LONGTEXT(), nullable=True), + sa.Column('settings', mysql.LONGTEXT(), nullable=False), + sa.Column('date_created', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + +def downgrade() -> None: + op.drop_table('rating') diff --git a/api/src/stampy_chat/db/models.py b/api/src/stampy_chat/db/models.py index 06bc435..8380115 100644 --- a/api/src/stampy_chat/db/models.py +++ b/api/src/stampy_chat/db/models.py @@ -90,3 +90,26 @@ def history(cls): def __repr__(self) -> str: return f"Interaction(session={self.session_id!r}, no={self.interaction_no!r}, query={self.query!r}, response={self.response!r})" + + +class Rating(Base): + __tablename__ = "ratings" + + id: Mapped[int] = mapped_column("id", primary_key=True) + + # The session_id is set once per session, so can be easily used to extract whole histories + session_id: Mapped[str] = mapped_column(UUID(), default=uuid.uuid4) + + # the user provided score + score: Mapped[int] = mapped_column(Integer) + + # An optional comment + comment: Mapped[Optional[str]] = mapped_column(LONGTEXT) + + # The settings object, serialized to JSON + settings: Mapped[str] = mapped_column(LONGTEXT) + + date_created: Mapped[datetime] = mapped_column(DateTime, default=func.now()) + + def __repr__(self) -> str: + return f"Rating(session={self.session_id!r}, score={self.score!r})" diff --git a/web/src/components/settings.tsx b/web/src/components/settings.tsx index 2b3ffc5..6ad74e5 100644 --- a/web/src/components/settings.tsx +++ b/web/src/components/settings.tsx @@ -37,18 +37,18 @@ export const ChatSettings = ({ onChange={(event: ChangeEvent) => { const value = (event.target as HTMLInputElement).value; const { maxNumTokens, topKBlocks } = - MODELS[value as keyof typeof MODELS]; + MODELS[value as keyof typeof MODELS] || {}; const prevNumTokens = - MODELS[settings.completions as keyof typeof MODELS].maxNumTokens; + MODELS[settings.completions as keyof typeof MODELS]?.maxNumTokens; const prevTopKBlocks = - MODELS[settings.completions as keyof typeof MODELS].topKBlocks; + MODELS[settings.completions as keyof typeof MODELS]?.topKBlocks; if (settings.maxNumTokens === prevNumTokens) { changeVal("maxNumTokens", maxNumTokens); } else { changeVal( "maxNumTokens", - Math.min(settings.maxNumTokens || 0, maxNumTokens) + Math.min(settings.maxNumTokens || 0, maxNumTokens || 0) ); } if (settings.topKBlocks === prevTopKBlocks) { @@ -86,7 +86,7 @@ export const ChatSettings = ({ field="maxNumTokens" label="Tokens" min="1" - max={MODELS[settings.completions as keyof typeof MODELS].maxNumTokens} + max={MODELS[settings.completions as keyof typeof MODELS]?.maxNumTokens} updater={updateNum("maxNumTokens")} /> + array[Math.floor(Math.random() * array.length)]; +const randomFloat = (min: number, max: number) => + Math.random() * (max - min) + min; +const randomInt = (min: number, max: number) => + Math.floor(randomFloat(min, max)); + /** Create a settings object in which all items in the `overrides` object will be parsed appropriately * * `parsers` should be an object mapping settings fields to functions that will return a valid setting. @@ -115,8 +126,8 @@ const SETTINGS_PARSERS = { mode: (v: string | undefined) => (v || "default") as Mode, completions: withDefault("gpt-3.5-turbo"), encoder: withDefault("cl100k_base"), - topKBlocks: withDefault(MODELS["gpt-3.5-turbo"].topKBlocks), // the number of blocks to use as citations - maxNumTokens: withDefault(MODELS["gpt-3.5-turbo"].maxNumTokens), + topKBlocks: withDefault(MODELS["gpt-3.5-turbo"]?.topKBlocks), // the number of blocks to use as citations + maxNumTokens: withDefault(MODELS["gpt-3.5-turbo"]?.maxNumTokens), tokensBuffer: withDefault(50), // the number of tokens to leave as a buffer when calculating remaining tokens maxHistory: withDefault(10), // the max number of previous items to use as history historyFraction: withDefault(0.25), // the (approximate) fraction of num_tokens to use for history text before truncating @@ -132,6 +143,27 @@ export const makeSettings = (overrides: LLMSettings) => SETTINGS_PARSERS ); +const randomSettings = () => { + const completions = randomElement(Object.keys(MODELS)); + const model = MODELS[completions] as Model; + const maxNumTokens = randomInt( + Math.floor(model.maxNumTokens * 0.3), + model.maxNumTokens + ); + const historyFraction = randomFloat(0.2, 0.8); + const contextFraction = randomFloat(0.2, 0.9 - historyFraction); + return makeSettings({ + completions, + maxNumTokens, + historyFraction, + contextFraction, + mode: randomElement(Object.keys(DEFAULT_PROMPTS.modes)) as Mode, + topKBlocks: randomInt(Math.floor(model.topKBlocks * 0.3), model.topKBlocks), + tokensBuffer: randomInt(10, 200), + maxHistory: randomInt(1, 20), + }); +}; + type ChatSettingsParams = { settings: LLMSettings; changeSetting: (path: string[], value: any) => void; @@ -173,10 +205,13 @@ export default function useSettings() { setLoaded(router.isReady); }, [router]); + const randomize = useCallback(() => updateSettings(randomSettings()), []); + return { settings, changeSetting, setMode, settingsLoaded, + randomize, }; } diff --git a/web/src/styles/globals.css b/web/src/styles/globals.css index 314f607..854e94e 100644 --- a/web/src/styles/globals.css +++ b/web/src/styles/globals.css @@ -56,3 +56,12 @@ ol { .glossary-link { @apply underline hover:no-underline; } + +.rate-container { + margin-top: 20px; +} + +.rate-button { + width: 2.5em; + margin: 0.5em; +}