From aef89bad7072586fdb1a80095ace27e2515c2e2e Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Thu, 4 Jan 2024 14:07:25 +0800 Subject: [PATCH] feat: add drag/drop model preference feature --- app/package.json | 1 + app/pnpm-lock.yaml | 9 +++ app/src/api/broadcast.ts | 11 +-- app/src/api/types.ts | 1 + app/src/assets/pages/home.less | 26 ++++-- app/src/components/FileProvider.tsx | 7 +- app/src/components/home/ModelFinder.tsx | 8 +- app/src/components/home/ModelMarket.tsx | 102 ++++++++++++++++++++---- app/src/conf.ts | 80 ++++++++++++++----- app/src/main.tsx | 5 +- app/src/store/chat.ts | 5 +- app/src/utils/storage.ts | 27 +++++++ 12 files changed, 217 insertions(+), 65 deletions(-) create mode 100644 app/src/utils/storage.ts diff --git a/app/package.json b/app/package.json index 4370d296..509ac210 100644 --- a/app/package.json +++ b/app/package.json @@ -62,6 +62,7 @@ "@tauri-apps/cli": "^1.5.6", "@types/node": "^20.5.9", "@types/react": "^18.2.15", + "@types/react-beautiful-dnd": "^13.1.8", "@types/react-dom": "^18.2.7", "@types/react-syntax-highlighter": "^15.5.7", "@typescript-eslint/eslint-plugin": "^6.0.0", diff --git a/app/pnpm-lock.yaml b/app/pnpm-lock.yaml index 7b09c870..82aaaa24 100644 --- a/app/pnpm-lock.yaml +++ b/app/pnpm-lock.yaml @@ -151,6 +151,9 @@ devDependencies: '@types/react': specifier: ^18.2.15 version: 18.2.33 + '@types/react-beautiful-dnd': + specifier: ^13.1.8 + version: 13.1.8 '@types/react-dom': specifier: ^18.2.7 version: 18.2.14 @@ -2052,6 +2055,12 @@ packages: /@types/prop-types@15.7.9: resolution: {integrity: sha512-n1yyPsugYNSmHgxDFjicaI2+gCNjsBck8UX9kuofAKlc0h1bL+20oSF72KeNaW2DUlesbEVCFgyV2dPGTiY42g==} + /@types/react-beautiful-dnd@13.1.8: + resolution: {integrity: sha512-E3TyFsro9pQuK4r8S/OL6G99eq7p8v29sX0PM7oT8Z+PJfZvSQTx4zTQbUJ+QZXioAF0e7TGBEcA1XhYhCweyQ==} + dependencies: + '@types/react': 18.2.33 + dev: true + /@types/react-dom@18.2.14: resolution: {integrity: sha512-V835xgdSVmyQmI1KLV2BEIUgqEuinxp9O4G6g3FqO/SqLac049E53aysv0oEFD2kHfejeKU+ZqL2bcFWj9gLAQ==} dependencies: diff --git a/app/src/api/broadcast.ts b/app/src/api/broadcast.ts index 66e66c63..06aebcfa 100644 --- a/app/src/api/broadcast.ts +++ b/app/src/api/broadcast.ts @@ -23,14 +23,15 @@ export type CreateBroadcastResponse = { export async function getRawBroadcast(): Promise { try { const data = await axios.get("/broadcast/view"); - return data.data as Broadcast; + if (data.data) return data.data as Broadcast; } catch (e) { console.warn(e); - return { - content: "", - index: 0, - }; } + + return { + content: "", + index: 0, + }; } export async function getBroadcast(): Promise { diff --git a/app/src/api/types.ts b/app/src/api/types.ts index b39160dd..3e06e501 100644 --- a/app/src/api/types.ts +++ b/app/src/api/types.ts @@ -15,6 +15,7 @@ export type Model = { name: string; free: boolean; auth: boolean; + high_context: boolean; tag?: string[]; }; diff --git a/app/src/assets/pages/home.less b/app/src/assets/pages/home.less index 55b0c567..7ca9f85e 100644 --- a/app/src/assets/pages/home.less +++ b/app/src/assets/pages/home.less @@ -18,6 +18,13 @@ overflow: auto; scrollbar-width: thin; + .market-wrapper { + width: 100%; + height: 100%; + max-width: 768px; + margin: 0 auto; + } + @media (max-width: 768px) { padding: 1rem; } @@ -71,8 +78,7 @@ .model-list { display: flex; - flex-direction: row; - flex-wrap: wrap; + flex-direction: column; width: 100%; } @@ -88,14 +94,11 @@ border: 1px solid hsl(var(--border-hover)); border-radius: var(--radius); transition: 0.25s; + transition-property: border-color, padding, background, box-shadow; cursor: grab; animation: fadein 0.25s forwards ease-in-out; opacity: 0; - width: 100%; - - @media (min-width: 960px) { - width: calc(50% - 1rem); - } + width: calc(100% - 1rem); @keyframes fadein { from { opacity: 0; transform: translateY(2.5rem); } @@ -109,6 +112,10 @@ &:before { width: 100%; } + + .grip-icon { + opacity: 1; + } } &.active { @@ -132,6 +139,11 @@ border-radius: var(--radius); } + .grip-icon { + opacity: 0.6; + transition: 0.25s; + } + .model-avatar { border-radius: var(--radius); width: 3rem; diff --git a/app/src/components/FileProvider.tsx b/app/src/components/FileProvider.tsx index 0f23a30a..2913a3fc 100644 --- a/app/src/components/FileProvider.tsx +++ b/app/src/components/FileProvider.tsx @@ -24,7 +24,7 @@ import { useDraggableInput } from "@/utils/dom.ts"; import { FileObject, FileArray, blobParser } from "@/api/file.ts"; import { Button } from "@/components/ui/button.tsx"; import { useSelector } from "react-redux"; -import { largeContextModels } from "@/conf.ts"; +import { isHighContextModel } from "@/conf.ts"; import { selectModel } from "@/store/chat.ts"; import { ChatAction } from "@/components/home/assemblies/ChatAction.tsx"; @@ -45,10 +45,7 @@ function FileProvider({ value, onChange }: FileProviderProps) { console.debug( `[file] new file was added (filename: ${file.name}, size: ${file.size}, prompt: ${file.content.length})`, ); - if ( - file.content.length > MaxPromptSize && - !largeContextModels.includes(model) - ) { + if (file.content.length > MaxPromptSize && isHighContextModel(model)) { file.content = file.content.slice(0, MaxPromptSize); toast({ title: t("file.max-length"), diff --git a/app/src/components/home/ModelFinder.tsx b/app/src/components/home/ModelFinder.tsx index ddc79261..0c84f5ef 100644 --- a/app/src/components/home/ModelFinder.tsx +++ b/app/src/components/home/ModelFinder.tsx @@ -1,5 +1,5 @@ import SelectGroup, { SelectItemProps } from "@/components/SelectGroup.tsx"; -import { expensiveModels, supportModels } from "@/conf.ts"; +import { supportModels } from "@/conf.ts"; import { getPlanModels, openMarket, @@ -35,12 +35,6 @@ function filterModel(model: Model, level: number) { value: model.name, badge: { variant: "gold", name: "plus" }, } as SelectItemProps; - } else if (expensiveModels.includes(model.id)) { - return { - name: model.id, - value: model.name, - badge: { variant: "gold", name: "expensive" }, - } as SelectItemProps; } return { diff --git a/app/src/components/home/ModelMarket.tsx b/app/src/components/home/ModelMarket.tsx index c786a21c..d481efe2 100644 --- a/app/src/components/home/ModelMarket.tsx +++ b/app/src/components/home/ModelMarket.tsx @@ -3,6 +3,7 @@ import { Input } from "@/components/ui/input.tsx"; import { ChevronLeft, ChevronRight, + GripVertical, Link, Plus, Search, @@ -31,6 +32,13 @@ import { selectAuthenticated } from "@/store/auth.ts"; import { useToast } from "@/components/ui/use-toast.ts"; import { docsEndpoint } from "@/utils/env.ts"; import { goAuth } from "@/utils/app.ts"; +import { + DragDropContext, + Droppable, + Draggable, + DropResult, +} from "react-beautiful-dnd"; +import { savePreferenceModels } from "@/utils/storage.ts"; type SearchBarProps = { value: string; @@ -58,13 +66,23 @@ function SearchBar({ value, onChange }: SearchBarProps) { ); } -type ModelProps = { +type ModelProps = React.DetailedHTMLProps< + React.HTMLAttributes, + HTMLDivElement +> & { model: Model; className?: string; style?: React.CSSProperties; + forwardRef?: React.Ref; }; -function ModelItem({ model, className, style }: ModelProps) { +function ModelItem({ + model, + className, + style, + forwardRef, + ...props +}: ModelProps) { const { t } = useTranslation(); const dispatch = useDispatch(); const { toast } = useToast(); @@ -94,6 +112,8 @@ function ModelItem({ model, className, style }: ModelProps) {
{ dispatch(addModelList(model.id)); @@ -113,6 +133,7 @@ function ModelItem({ model, className, style }: ModelProps) { dispatch(closeMarket()); }} > + {model.name}

{model.name}

@@ -163,7 +184,7 @@ function MarketPlace({ search }: MarketPlaceProps) { const { t } = useTranslation(); const select = useSelector(selectModel); - const arr = useMemo(() => { + const models = useMemo(() => { if (search.length === 0) return supportModels; // fuzzy search const raw = splitList(search.toLowerCase(), [" ", ",", ";", "-"]); @@ -181,18 +202,63 @@ function MarketPlace({ search }: MarketPlaceProps) { tag_translated.includes(item), ); }); - }, [search]); + }, [supportModels, search]); + + const queryIndex = (id: number) => { + const model = models[id]; + if (!model) return -1; + + return supportModels.findIndex((item) => item.id === model.id); + }; + + const onDragEnd = (result: DropResult) => { + const { destination, source } = result; + if ( + !destination || + destination.index === source.index || + destination.index === -1 + ) + return; + + const from = queryIndex(source.index); + const to = queryIndex(destination.index); + if (from === -1 || to === -1) return; + + const list = [...supportModels]; + const [removed] = list.splice(from, 1); + list.splice(to, 0, removed); + + supportModels.splice(0, supportModels.length, ...list); + savePreferenceModels(supportModels); + }; return ( -
- {arr.map((model, index) => ( - - ))} -
+ + + {(provided) => ( +
+ {models.map((model, index) => ( + + {(provided) => ( + + )} + + ))} + {provided.placeholder} +
+ )} +
+
); } @@ -237,10 +303,12 @@ function ModelMarket() { return (
- - - - +
+ + + + +
); } diff --git a/app/src/conf.ts b/app/src/conf.ts index af95b238..f8029273 100644 --- a/app/src/conf.ts +++ b/app/src/conf.ts @@ -12,20 +12,22 @@ import { getMemory } from "@/utils/memory.ts"; import { Compass, Image, Newspaper } from "lucide-react"; import React from "react"; import { syncSiteInfo } from "@/admin/api/info.ts"; +import { loadPreferenceModels } from "@/utils/storage.ts"; -export const version = "3.8.0"; +export const version = "3.8.0-rc"; export const dev: boolean = getDev(); export const deploy: boolean = true; export let rest_api: string = getRestApi(deploy); export let ws_api: string = getWebsocketApi(deploy); export const tokenField = getTokenField(deploy); -export const supportModels: Model[] = [ +export let supportModels: Model[] = loadPreferenceModels([ // openai models { id: "gpt-3.5-turbo-0613", name: "GPT-3.5", free: true, auth: false, + high_context: false, tag: ["free", "official"], }, { @@ -33,6 +35,7 @@ export const supportModels: Model[] = [ name: "GPT-3.5-16k", free: true, auth: true, + high_context: true, tag: ["free", "official", "high-context"], }, { @@ -40,6 +43,7 @@ export const supportModels: Model[] = [ name: "GPT-3.5 1106", free: true, auth: true, + high_context: true, tag: ["free", "official"], }, { @@ -47,6 +51,7 @@ export const supportModels: Model[] = [ name: "GPT-3.5 Fast", free: false, auth: true, + high_context: false, tag: ["official"], }, { @@ -54,6 +59,7 @@ export const supportModels: Model[] = [ name: "GPT-3.5 16K Fast", free: false, auth: true, + high_context: true, tag: ["official"], }, { @@ -61,6 +67,7 @@ export const supportModels: Model[] = [ name: "GPT-4", free: false, auth: true, + high_context: true, tag: ["official", "high-quality"], }, { @@ -68,6 +75,7 @@ export const supportModels: Model[] = [ name: "GPT-4 Turbo 128k", free: false, auth: true, + high_context: true, tag: ["official", "high-context", "unstable"], }, { @@ -75,6 +83,7 @@ export const supportModels: Model[] = [ name: "GPT-4 Vision 128k", free: false, auth: true, + high_context: true, tag: ["official", "high-context", "multi-modal", "unstable"], }, { @@ -82,6 +91,7 @@ export const supportModels: Model[] = [ name: "GPT-4 Vision", free: false, auth: true, + high_context: true, tag: ["official", "unstable", "multi-modal"], }, { @@ -89,6 +99,7 @@ export const supportModels: Model[] = [ name: "GPT-4 DALLE", free: false, auth: true, + high_context: true, tag: ["official", "unstable", "image-generation"], }, @@ -97,6 +108,7 @@ export const supportModels: Model[] = [ name: "Azure GPT-3.5", free: false, auth: true, + high_context: false, tag: ["official"], }, { @@ -104,6 +116,7 @@ export const supportModels: Model[] = [ name: "Azure GPT-3.5 16K", free: false, auth: true, + high_context: true, tag: ["official"], }, { @@ -111,6 +124,7 @@ export const supportModels: Model[] = [ name: "Azure GPT-4", free: false, auth: true, + high_context: true, tag: ["official", "high-quality"], }, { @@ -118,6 +132,7 @@ export const supportModels: Model[] = [ name: "Azure GPT-4 Turbo 128k", free: false, auth: true, + high_context: true, tag: ["official", "high-context", "unstable"], }, { @@ -125,6 +140,7 @@ export const supportModels: Model[] = [ name: "Azure GPT-4 Vision 128k", free: false, auth: true, + high_context: true, tag: ["official", "high-context", "multi-modal"], }, { @@ -132,6 +148,7 @@ export const supportModels: Model[] = [ name: "Azure GPT-4 32k", free: false, auth: true, + high_context: true, tag: ["official", "multi-modal"], }, @@ -141,6 +158,7 @@ export const supportModels: Model[] = [ name: "讯飞星火 V3", free: false, auth: true, + high_context: false, tag: ["official", "high-quality"], }, { @@ -148,6 +166,7 @@ export const supportModels: Model[] = [ name: "讯飞星火 V2", free: false, auth: true, + high_context: false, tag: ["official"], }, { @@ -155,6 +174,7 @@ export const supportModels: Model[] = [ name: "讯飞星火 V1.5", free: false, auth: true, + high_context: false, tag: ["official"], }, @@ -164,6 +184,7 @@ export const supportModels: Model[] = [ name: "通义千问 Plus Net", free: false, auth: true, + high_context: false, tag: ["official", "high-quality", "web"], }, { @@ -171,6 +192,7 @@ export const supportModels: Model[] = [ name: "通义千问 Plus", free: false, auth: true, + high_context: false, tag: ["official", "high-quality"], }, { @@ -178,6 +200,7 @@ export const supportModels: Model[] = [ name: "通义千问 Turbo Net", free: false, auth: true, + high_context: false, tag: ["official", "web"], }, { @@ -185,6 +208,7 @@ export const supportModels: Model[] = [ name: "通义千问 Turbo", free: false, auth: true, + high_context: false, tag: ["official"], }, @@ -194,6 +218,7 @@ export const supportModels: Model[] = [ name: "腾讯混元 Pro", free: false, auth: true, + high_context: false, tag: ["official"], }, @@ -203,6 +228,7 @@ export const supportModels: Model[] = [ name: "ChatGLM Turbo", free: false, auth: true, + high_context: true, tag: ["official", "open-source", "high-context"], }, @@ -212,6 +238,7 @@ export const supportModels: Model[] = [ name: "百川 Baichuan 53B", free: false, auth: true, + high_context: false, tag: ["official", "open-source"], }, @@ -221,6 +248,7 @@ export const supportModels: Model[] = [ name: "抖音豆包 Skylark", free: false, auth: true, + high_context: false, tag: ["official"], }, @@ -230,6 +258,7 @@ export const supportModels: Model[] = [ name: "360 智脑", free: false, auth: true, + high_context: false, tag: ["official"], }, @@ -238,6 +267,7 @@ export const supportModels: Model[] = [ name: "Claude", free: true, auth: true, + high_context: true, tag: ["free", "unstable"], }, { @@ -245,6 +275,7 @@ export const supportModels: Model[] = [ name: "Claude 100k", free: false, auth: true, + high_context: true, tag: ["official", "high-context"], }, { @@ -252,6 +283,7 @@ export const supportModels: Model[] = [ name: "Claude 200k", free: false, auth: true, + high_context: true, tag: ["official", "high-context"], }, @@ -261,6 +293,7 @@ export const supportModels: Model[] = [ name: "LLaMa-2 70B", free: false, auth: true, + high_context: false, tag: ["open-source", "unstable"], }, { @@ -268,6 +301,7 @@ export const supportModels: Model[] = [ name: "LLaMa-2 13B", free: false, auth: true, + high_context: false, tag: ["open-source", "unstable"], }, { @@ -275,6 +309,7 @@ export const supportModels: Model[] = [ name: "LLaMa-2 7B", free: false, auth: true, + high_context: false, tag: ["open-source", "unstable"], }, @@ -283,6 +318,7 @@ export const supportModels: Model[] = [ name: "Code LLaMa 34B", free: false, auth: true, + high_context: false, tag: ["open-source", "unstable"], }, { @@ -290,6 +326,7 @@ export const supportModels: Model[] = [ name: "Code LLaMa 13B", free: false, auth: true, + high_context: false, tag: ["open-source", "unstable"], }, { @@ -297,6 +334,7 @@ export const supportModels: Model[] = [ name: "Code LLaMa 7B", free: false, auth: true, + high_context: false, tag: ["open-source", "unstable"], }, @@ -306,6 +344,7 @@ export const supportModels: Model[] = [ name: "New Bing", free: true, auth: true, + high_context: true, tag: ["free", "unstable", "web"], }, @@ -315,6 +354,7 @@ export const supportModels: Model[] = [ name: "Google PaLM2", free: true, auth: true, + high_context: false, tag: ["free", "english-model"], }, @@ -324,6 +364,7 @@ export const supportModels: Model[] = [ name: "Gemini Pro", free: true, auth: true, + high_context: true, tag: ["free", "official"], }, { @@ -331,6 +372,7 @@ export const supportModels: Model[] = [ name: "Gemini Pro Vision", free: true, auth: true, + high_context: true, tag: ["free", "official", "multi-modal"], }, @@ -340,6 +382,7 @@ export const supportModels: Model[] = [ name: "Midjourney", free: false, auth: true, + high_context: false, tag: ["official", "image-generation"], }, { @@ -347,6 +390,7 @@ export const supportModels: Model[] = [ name: "Midjourney Fast", free: false, auth: true, + high_context: false, tag: ["official", "fast", "image-generation"], }, { @@ -354,6 +398,7 @@ export const supportModels: Model[] = [ name: "Midjourney Turbo", free: false, auth: true, + high_context: false, tag: ["official", "fast", "image-generation"], }, { @@ -361,6 +406,7 @@ export const supportModels: Model[] = [ name: "Stable Diffusion XL", free: false, auth: true, + high_context: false, tag: ["open-source", "unstable", "image-generation"], }, { @@ -368,6 +414,7 @@ export const supportModels: Model[] = [ name: "DALLE 2", free: true, auth: true, + high_context: false, tag: ["free", "official", "image-generation"], }, { @@ -375,6 +422,7 @@ export const supportModels: Model[] = [ name: "DALLE 3", free: false, auth: true, + high_context: false, tag: ["official", "image-generation"], }, @@ -383,9 +431,10 @@ export const supportModels: Model[] = [ name: "GPT-4-32k", free: false, auth: true, + high_context: true, tag: ["official", "high-quality", "high-price"], }, -]; +]); export const defaultModels = [ "gpt-3.5-turbo-0613", @@ -423,20 +472,6 @@ export const defaultModels = [ export let allModels: string[] = supportModels.map((model) => model.id); -export const largeContextModels = [ - "gpt-3.5-turbo-16k-0613", - "gpt-4-1106-preview", - "gpt-4-vision-preview", - "gpt-4-all", - "gpt-4-32k-0613", - "claude-1", - "claude-1-100k", - "claude-2", - "claude-2.1", - "claude-2-100k", - "zhipu-chatglm-turbo", -]; - export const planModels: PlanModel[] = [ { id: "gpt-4-0613", level: 1 }, { id: "gpt-4-1106-preview", level: 1 }, @@ -450,8 +485,6 @@ export const planModels: PlanModel[] = [ { id: "midjourney-fast", level: 1 }, ]; -export const expensiveModels = ["gpt-4-32k-0613"]; - export const modelAvatars: Record = { "gpt-3.5-turbo-0613": "gpt35turbo.png", "gpt-3.5-turbo-16k-0613": "gpt35turbo16k.webp", @@ -516,6 +549,15 @@ export const subscriptionUsage: SubscriptionUsage = { "claude-100k": { name: "Claude 100k", icon: React.createElement(Newspaper) }, }; +export function getModelFromId(id: string): Model | undefined { + return supportModels.find((model) => model.id === id); +} + +export function isHighContextModel(id: string): boolean { + const model = getModelFromId(id); + return !!model && model.high_context; +} + export function login() { location.href = `${deeptrainEndpoint}/login?app=${ dev ? "dev" : deeptrainAppName diff --git a/app/src/main.tsx b/app/src/main.tsx index 8e8c01ac..b6428c20 100644 --- a/app/src/main.tsx +++ b/app/src/main.tsx @@ -1,4 +1,3 @@ -import React from "react"; import ReactDOM from "react-dom/client"; import App from "./App.tsx"; import "./conf.ts"; @@ -9,8 +8,8 @@ import "./conf.ts"; import ReloadPrompt from "./components/ReloadService.tsx"; ReactDOM.createRoot(document.getElementById("root")!).render( - + <> - , + , ); diff --git a/app/src/store/chat.ts b/app/src/store/chat.ts index ac7ee065..b569ac8e 100644 --- a/app/src/store/chat.ts +++ b/app/src/store/chat.ts @@ -105,8 +105,9 @@ const chatSlice = createSlice({ state.messages[state.messages.length - 1] = action.payload as Message; }, setModelList: (state, action) => { - setMemory("model_list", action.payload as string); - state.model_list = action.payload as string[]; + const models = action.payload as string[]; + state.model_list = models.filter((item) => inModel(item)); + setMemory("model_list", models.join(",")); }, addModelList: (state, action) => { const model = action.payload as string; diff --git a/app/src/utils/storage.ts b/app/src/utils/storage.ts new file mode 100644 index 00000000..120a91d1 --- /dev/null +++ b/app/src/utils/storage.ts @@ -0,0 +1,27 @@ +import { getMemory, setMemory } from "@/utils/memory.ts"; +import { Model } from "@/api/types.ts"; + +export function savePreferenceModels(models: Model[]): void { + setMemory("model_preference", models.map((item) => item.id).join(",")); +} + +export function getPreferenceModels(): string[] { + const memory = getMemory("model_preference"); + return memory.length ? memory.split(",") : []; +} + +export function loadPreferenceModels(models: Model[]): Model[] { + // sort by preference + const preference = getPreferenceModels(); + + return models.sort((a, b) => { + const aIndex = preference.indexOf(a.id); + const bIndex = preference.indexOf(b.id); + + if (aIndex === -1 && bIndex === -1) return 0; + if (aIndex === -1) return 1; + if (bIndex === -1) return -1; + + return aIndex - bIndex; + }); +}