Skip to content

Commit

Permalink
Add page for randomized user testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Nov 6, 2023
1 parent be895a4 commit 1536f83
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 9 deletions.
20 changes: 20 additions & 0 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from stampy_chat.chat import run_query
from stampy_chat.callbacks import stream_callback
from stampy_chat.citations import get_top_k_blocks
from stampy_chat.db.session import make_session
from stampy_chat.db.models import Rating


# ---------------------------------- web setup ---------------------------------
Expand Down Expand Up @@ -96,5 +98,23 @@ def human(id):

# ------------------------------------------------------------------------------

@app.route('/ratings', methods=['POST'])
@cross_origin()
def ratings():
session_id = request.json.get('sessionId')
settings = request.json.get('settings', {})
score = request.json.get('score')

if not session_id or score is None:
return Response('{"error": "missing params}', 400, mimetype='application/json')

with make_session() as s:
s.add(Rating(session_id=session_id, score=score, settings=json.dumps(settings)))
s.commit()

return jsonify({'status': 'ok'})



if __name__ == '__main__':
app.run(debug=True, port=FLASK_PORT)
33 changes: 33 additions & 0 deletions api/migrations/versions/5813982e9665_ratings_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Ratings table
Revision ID: 5813982e9665
Revises: 78806d965229
Create Date: 2023-11-06 17:31:47.814226
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
from stampy_chat.db.models import UUID

# revision identifiers, used by Alembic.
revision = '5813982e9665'
down_revision = '78806d965229'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
'ratings',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('session_id', UUID(length=16), nullable=False),
sa.Column('score', sa.Integer(), nullable=False),
sa.Column('comment', mysql.LONGTEXT(), nullable=True),
sa.Column('settings', mysql.LONGTEXT(), nullable=False),
sa.Column('date_created', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
)

def downgrade() -> None:
op.drop_table('rating')
23 changes: 23 additions & 0 deletions api/src/stampy_chat/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,26 @@ def history(cls):

def __repr__(self) -> str:
return f"Interaction(session={self.session_id!r}, no={self.interaction_no!r}, query={self.query!r}, response={self.response!r})"


class Rating(Base):
__tablename__ = "ratings"

id: Mapped[int] = mapped_column("id", primary_key=True)

# The session_id is set once per session, so can be easily used to extract whole histories
session_id: Mapped[str] = mapped_column(UUID(), default=uuid.uuid4)

# the user provided score
score: Mapped[int] = mapped_column(Integer)

# An optional comment
comment: Mapped[Optional[str]] = mapped_column(LONGTEXT)

# The settings object, serialized to JSON
settings: Mapped[str] = mapped_column(LONGTEXT)

date_created: Mapped[datetime] = mapped_column(DateTime, default=func.now())

def __repr__(self) -> str:
return f"Rating(session={self.session_id!r}, score={self.score!r})"
10 changes: 5 additions & 5 deletions web/src/components/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ export const ChatSettings = ({
onChange={(event: ChangeEvent) => {
const value = (event.target as HTMLInputElement).value;
const { maxNumTokens, topKBlocks } =
MODELS[value as keyof typeof MODELS];
MODELS[value as keyof typeof MODELS] || {};
const prevNumTokens =
MODELS[settings.completions as keyof typeof MODELS].maxNumTokens;
MODELS[settings.completions as keyof typeof MODELS]?.maxNumTokens;
const prevTopKBlocks =
MODELS[settings.completions as keyof typeof MODELS].topKBlocks;
MODELS[settings.completions as keyof typeof MODELS]?.topKBlocks;

if (settings.maxNumTokens === prevNumTokens) {
changeVal("maxNumTokens", maxNumTokens);
} else {
changeVal(
"maxNumTokens",
Math.min(settings.maxNumTokens || 0, maxNumTokens)
Math.min(settings.maxNumTokens || 0, maxNumTokens || 0)
);
}
if (settings.topKBlocks === prevTopKBlocks) {
Expand Down Expand Up @@ -86,7 +86,7 @@ export const ChatSettings = ({
field="maxNumTokens"
label="Tokens"
min="1"
max={MODELS[settings.completions as keyof typeof MODELS].maxNumTokens}
max={MODELS[settings.completions as keyof typeof MODELS]?.maxNumTokens}
updater={updateNum("maxNumTokens")}
/>
<NumberInput
Expand Down
43 changes: 39 additions & 4 deletions web/src/hooks/useSettings.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { useRouter } from "next/router";
import { useState, useEffect } from "react";
import { useState, useEffect, useCallback } from "react";

import type { CurrentSearch, Mode, Entry, LLMSettings } from "../types";

Expand Down Expand Up @@ -42,7 +42,11 @@ const DEFAULT_PROMPTS = {
"rather than just giving a formal definition.\n\n",
},
};
export const MODELS = {
interface Model {
maxNumTokens: number;
topKBlocks: number;
}
export const MODELS: { [key: string]: Model } = {
"gpt-3.5-turbo": { maxNumTokens: 4095, topKBlocks: 10 },
"gpt-3.5-turbo-16k": { maxNumTokens: 16385, topKBlocks: 30 },
"gpt-4": { maxNumTokens: 8192, topKBlocks: 20 },
Expand Down Expand Up @@ -74,6 +78,13 @@ export const updateIn = (
return obj;
};

const randomElement = (array: any[]) =>
array[Math.floor(Math.random() * array.length)];
const randomFloat = (min: number, max: number) =>
Math.random() * (max - min) + min;
const randomInt = (min: number, max: number) =>
Math.floor(randomFloat(min, max));

/** Create a settings object in which all items in the `overrides` object will be parsed appropriately
*
* `parsers` should be an object mapping settings fields to functions that will return a valid setting.
Expand Down Expand Up @@ -115,8 +126,8 @@ const SETTINGS_PARSERS = {
mode: (v: string | undefined) => (v || "default") as Mode,
completions: withDefault("gpt-3.5-turbo"),
encoder: withDefault("cl100k_base"),
topKBlocks: withDefault(MODELS["gpt-3.5-turbo"].topKBlocks), // the number of blocks to use as citations
maxNumTokens: withDefault(MODELS["gpt-3.5-turbo"].maxNumTokens),
topKBlocks: withDefault(MODELS["gpt-3.5-turbo"]?.topKBlocks), // the number of blocks to use as citations
maxNumTokens: withDefault(MODELS["gpt-3.5-turbo"]?.maxNumTokens),
tokensBuffer: withDefault(50), // the number of tokens to leave as a buffer when calculating remaining tokens
maxHistory: withDefault(10), // the max number of previous items to use as history
historyFraction: withDefault(0.25), // the (approximate) fraction of num_tokens to use for history text before truncating
Expand All @@ -132,6 +143,27 @@ export const makeSettings = (overrides: LLMSettings) =>
SETTINGS_PARSERS
);

const randomSettings = () => {
const completions = randomElement(Object.keys(MODELS));
const model = MODELS[completions] as Model;
const maxNumTokens = randomInt(
Math.floor(model.maxNumTokens * 0.3),
model.maxNumTokens
);
const historyFraction = randomFloat(0.2, 0.8);
const contextFraction = randomFloat(0.2, 0.9 - historyFraction);
return makeSettings({
completions,
maxNumTokens,
historyFraction,
contextFraction,
mode: randomElement(Object.keys(DEFAULT_PROMPTS.modes)) as Mode,
topKBlocks: randomInt(Math.floor(model.topKBlocks * 0.3), model.topKBlocks),
tokensBuffer: randomInt(10, 200),
maxHistory: randomInt(1, 20),
});
};

type ChatSettingsParams = {
settings: LLMSettings;
changeSetting: (path: string[], value: any) => void;
Expand Down Expand Up @@ -173,10 +205,13 @@ export default function useSettings() {
setLoaded(router.isReady);
}, [router]);

const randomize = useCallback(() => updateSettings(randomSettings()), []);

return {
settings,
changeSetting,
setMode,
settingsLoaded,
randomize,
};
}
9 changes: 9 additions & 0 deletions web/src/styles/globals.css
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ ol {
.glossary-link {
@apply underline hover:no-underline;
}

.rate-container {
margin-top: 20px;
}

.rate-button {
width: 2.5em;
margin: 0.5em;
}

0 comments on commit 1536f83

Please sign in to comment.