Skip to content

Commit

Permalink
refactor models component (#1535)
Browse files Browse the repository at this point in the history
  • Loading branch information
lily-de authored Mar 7, 2025
1 parent 148737c commit 4f9c08a
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 91 deletions.
10 changes: 6 additions & 4 deletions ui/desktop/src/components/BottomMenu.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ export default function BottomMenu({
className="flex items-center cursor-pointer"
onClick={() => setIsModelMenuOpen(!isModelMenuOpen)}
>
<span>{envModelProvider || currentModel?.name || 'Select Model'}</span>
<span>
{envModelProvider || (currentModel?.alias ?? currentModel?.name) || 'Select Model'}
</span>
{isModelMenuOpen ? (
<ChevronDown className="w-4 h-4 ml-1" />
) : (
Expand All @@ -182,14 +184,14 @@ export default function BottomMenu({
<ModelRadioList
className="divide-y divide-borderSubtle"
renderItem={({ model, isSelected, onSelect }) => (
<label key={model.name} className="block cursor-pointer">
<label key={model.alias ?? model.name} className="block cursor-pointer">
<div
className="flex items-center justify-between p-2 text-textStandard hover:bg-bgSubtle transition-colors"
onClick={onSelect}
>
<div>
<p className="text-sm ">{model.name}</p>
<p className="text-xs text-textSubtle">{model.provider}</p>
<p className="text-sm ">{model.alias ?? model.name}</p>
<p className="text-xs text-textSubtle">{model.subtext ?? model.provider}</p>
</div>
<div className="relative">
<input
Expand Down
17 changes: 2 additions & 15 deletions ui/desktop/src/components/settings/SettingsView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -247,23 +247,10 @@ export default function SettingsView({
{/* Content Area */}
<div className="flex-1 py-8 pt-[20px]">
<div className="space-y-8">
{/*Models Section*/}
<section id="models">
<div className="flex justify-between items-center mb-6 border-b border-borderSubtle px-8">
<h2 className="text-xl font-medium text-textStandard">Models</h2>
<button
onClick={() => {
setView('moreModels');
}}
className="text-indigo-500 hover:text-indigo-600 text-sm"
>
Browse
</button>
</div>
<div className="px-8">
<RecentModelsRadio />
</div>
<RecentModelsRadio setView={setView} />
</section>

<section id="extensions">
<div className="flex justify-between items-center mb-6 border-b border-borderSubtle px-8">
<h2 className="text-xl font-semibold text-textStandard">Extensions</h2>
Expand Down
4 changes: 2 additions & 2 deletions ui/desktop/src/components/settings/models/AddModelInline.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Select from 'react-select';
import { Plus } from 'lucide-react';
import { createSelectedModel, useHandleModelSelection } from './utils';
import { useActiveKeys } from '../api_keys/ActiveKeysContext';
import { goose_models } from './hardcoded_stuff';
import { gooseModels } from './GooseModels';
import { createDarkSelectStyles, darkSelectTheme } from '../../ui/select-styles';

export function AddModelInline() {
Expand All @@ -31,7 +31,7 @@ export function AddModelInline() {
return;
}

const filtered = goose_models
const filtered = gooseModels
.filter(
(model) =>
model.provider.toLowerCase() === selectedProvider &&
Expand Down
30 changes: 30 additions & 0 deletions ui/desktop/src/components/settings/models/GooseModels.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { Model } from './ModelContext';

// TODO: move into backends / fetch dynamically
// this is used by ModelContext
export const gooseModels: Model[] = [
{ id: 1, name: 'gpt-4o-mini', provider: 'OpenAI' },
{ id: 2, name: 'gpt-4o', provider: 'OpenAI' },
{ id: 3, name: 'gpt-4-turbo', provider: 'OpenAI' },
{ id: 5, name: 'o1', provider: 'OpenAI' },
{ id: 7, name: 'claude-3-5-sonnet-latest', provider: 'Anthropic' },
{ id: 8, name: 'claude-3-5-haiku-latest', provider: 'Anthropic' },
{ id: 9, name: 'claude-3-opus-latest', provider: 'Anthropic' },
{ id: 10, name: 'gemini-1.5-pro', provider: 'Google' },
{ id: 11, name: 'gemini-1.5-flash', provider: 'Google' },
{ id: 12, name: 'gemini-2.0-flash', provider: 'Google' },
{ id: 13, name: 'gemini-2.0-flash-lite-preview-02-05', provider: 'Google' },
{ id: 14, name: 'gemini-2.0-flash-thinking-exp-01-21', provider: 'Google' },
{ id: 15, name: 'gemini-2.0-pro-exp-02-05', provider: 'Google' },
{ id: 16, name: 'llama-3.3-70b-versatile', provider: 'Groq' },
{ id: 17, name: 'qwen2.5', provider: 'Ollama' },
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
{ id: 20, name: 'claude-3-7-sonnet@20250219', provider: 'GCP Vertex AI' },
{ id: 21, name: 'claude-3-5-sonnet-v2@20241022', provider: 'GCP Vertex AI' },
{ id: 22, name: 'claude-3-5-sonnet@20240620', provider: 'GCP Vertex AI' },
{ id: 23, name: 'claude-3-5-haiku@20241022', provider: 'GCP Vertex AI' },
{ id: 24, name: 'gemini-2.0-pro-exp-02-05', provider: 'GCP Vertex AI' },
{ id: 25, name: 'gemini-2.0-flash-001', provider: 'GCP Vertex AI' },
{ id: 26, name: 'gemini-1.5-pro-002', provider: 'GCP Vertex AI' },
];
6 changes: 4 additions & 2 deletions ui/desktop/src/components/settings/models/ModelContext.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import React, { createContext, useContext, useState, ReactNode } from 'react';
import { GOOSE_MODEL, GOOSE_PROVIDER } from '../../../env_vars';
import { goose_models } from './hardcoded_stuff'; // Assuming hardcoded models are here
import { gooseModels } from './GooseModels'; // Assuming hardcoded models are here

// TODO: API keys
export interface Model {
id?: number; // Make `id` optional to allow user-defined models
name: string;
provider: string;
lastUsed?: string;
alias?: string; // optional model display name
subtext?: string; // goes below model name if not the provider
}

interface ModelContextValue {
Expand All @@ -31,7 +33,7 @@ export const ModelProvider = ({ children }: { children: ReactNode }) => {

const switchModel = (model: Model) => {
const newModel = model.id
? goose_models.find((m) => m.id === model.id) || model
? gooseModels.find((m) => m.id === model.id) || model
: { id: Date.now(), ...model }; // Assign unique ID for user-defined models
updateModel(newModel);
};
Expand Down
28 changes: 26 additions & 2 deletions ui/desktop/src/components/settings/models/ModelRadioList.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import React, { useState, useEffect } from 'react';
import { Model } from './ModelContext';
import { useRecentModels } from './RecentModels';
import { useModel } from './ModelContext';
import { useHandleModelSelection } from './utils';
import { useRecentModels } from './RecentModels';
import type { View } from '@/src/App';
import { SettingsViewOptions } from '@/src/components/settings/SettingsView';

export interface Model {
id?: number; // Make `id` optional to allow user-defined models
name: string;
provider: string;
lastUsed?: string;
}

interface ModelRadioListProps {
renderItem: (props: {
Expand All @@ -13,6 +21,22 @@ interface ModelRadioListProps {
className?: string;
}

export function SeeMoreModelsButtons({ setView }: { setView: (view: View) => void }) {
return (
<div className="flex justify-between items-center mb-6 border-b border-borderSubtle px-8">
<h2 className="text-xl font-medium text-textStandard">Models</h2>
<button
onClick={() => {
setView('moreModels');
}}
className="text-indigo-500 hover:text-indigo-600 text-sm"
>
Browse
</button>
</div>
);
}

export function ModelRadioList({ renderItem, className = '' }: ModelRadioListProps) {
const { recentModels } = useRecentModels();
const { currentModel } = useModel();
Expand Down
5 changes: 3 additions & 2 deletions ui/desktop/src/components/settings/models/ProviderButtons.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import React, { useState, useEffect } from 'react';
import { Button } from '../../ui/button';
import { Switch } from '../../ui/switch';
import { useActiveKeys } from '../api_keys/ActiveKeysContext';
import { model_docs_link, goose_models } from './hardcoded_stuff';
import { model_docs_link } from './hardcoded_stuff';
import { gooseModels } from './GooseModels';
import { useModel } from './ModelContext';
import { useHandleModelSelection } from './utils';

Expand Down Expand Up @@ -31,7 +32,7 @@ export function ProviderButtons() {

// Filter models by provider
const providerModels = selectedProvider
? goose_models.filter((model) => model.provider === selectedProvider)
? gooseModels.filter((model) => model.provider === selectedProvider)
: [];

return (
Expand Down
63 changes: 34 additions & 29 deletions ui/desktop/src/components/settings/models/RecentModels.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import React, { useState, useEffect } from 'react';
import { Clock } from 'lucide-react';
import { Model } from './ModelContext';
import { useHandleModelSelection } from './utils';
import { ModelRadioList, SeeMoreModelsButtons } from './ModelRadioList';
import { useModel } from './ModelContext';
import { ModelRadioList } from './ModelRadioList';
import { useHandleModelSelection } from './utils';
import type { View } from '../../../App';

const MAX_RECENT_MODELS = 3;

Expand Down Expand Up @@ -129,37 +130,41 @@ export function RecentModels() {
);
}

export function RecentModelsRadio() {
export function RecentModelsRadio({ setView }: { setView: (view: View) => void }) {
return (
<div className="space-y-2">
<h2 className="text-md font-medium text-textStandard">Recently used</h2>
<ModelRadioList
renderItem={({ model, isSelected, onSelect }) => (
<label key={model.name} className="flex items-center py-2 cursor-pointer">
<div className="relative mr-4">
<input
type="radio"
name="recentModels"
value={model.name}
checked={isSelected}
onChange={onSelect}
className="peer sr-only"
/>
<div
className="h-4 w-4 rounded-full border border-gray-400 dark:border-gray-500
<div>
<SeeMoreModelsButtons setView={setView} />
<div className="px-8">
<div className="space-y-2">
<ModelRadioList
renderItem={({ model, isSelected, onSelect }) => (
<label key={model.name} className="flex items-center py-2 cursor-pointer">
<div className="relative mr-4">
<input
type="radio"
name="recentModels"
value={model.name}
checked={isSelected}
onChange={onSelect}
className="peer sr-only"
/>
<div
className="h-4 w-4 rounded-full border border-gray-400 dark:border-gray-500
peer-checked:border-[6px] peer-checked:border-black dark:peer-checked:border-white
peer-checked:bg-white dark:peer-checked:bg-black
transition-all duration-200 ease-in-out"
></div>
</div>

<div className="">
<p className="text-sm text-textStandard">{model.name}</p>
<p className="text-xs text-textSubtle">{model.provider}</p>
</div>
</label>
)}
/>
></div>
</div>

<div className="">
<p className="text-sm text-textStandard">{model.alias ?? model.name}</p>
<p className="text-xs text-textSubtle">{model.subtext ?? model.provider}</p>
</div>
</label>
)}
/>
</div>
</div>
</div>
);
}
4 changes: 2 additions & 2 deletions ui/desktop/src/components/settings/models/Search.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import React, { useState, useEffect, useRef } from 'react';
import { Search } from 'lucide-react';
import { Switch } from '../../ui/switch';
import { goose_models } from './hardcoded_stuff';
import { gooseModels } from './GooseModels';
import { useModel } from './ModelContext';
import { useHandleModelSelection } from './utils';
import { useActiveKeys } from '../api_keys/ActiveKeysContext';
Expand All @@ -22,7 +22,7 @@ export function SearchBar() {
// results set will only include models that have a configured provider
const { activeKeys } = useActiveKeys(); // Access active keys from context

const model_options = goose_models.filter((model) => activeKeys.includes(model.provider));
const model_options = gooseModels.filter((model) => activeKeys.includes(model.provider));

const filteredModels = model_options
.filter((model) => model.name.toLowerCase().includes(search.toLowerCase()))
Expand Down
28 changes: 0 additions & 28 deletions ui/desktop/src/components/settings/models/hardcoded_stuff.tsx
Original file line number Diff line number Diff line change
@@ -1,33 +1,5 @@
import { Model } from './ModelContext';

// TODO: move into backends / fetch dynamically
export const goose_models: Model[] = [
{ id: 1, name: 'gpt-4o-mini', provider: 'OpenAI' },
{ id: 2, name: 'gpt-4o', provider: 'OpenAI' },
{ id: 3, name: 'gpt-4-turbo', provider: 'OpenAI' },
{ id: 5, name: 'o1', provider: 'OpenAI' },
{ id: 7, name: 'claude-3-5-sonnet-latest', provider: 'Anthropic' },
{ id: 8, name: 'claude-3-5-haiku-latest', provider: 'Anthropic' },
{ id: 9, name: 'claude-3-opus-latest', provider: 'Anthropic' },
{ id: 10, name: 'gemini-1.5-pro', provider: 'Google' },
{ id: 11, name: 'gemini-1.5-flash', provider: 'Google' },
{ id: 12, name: 'gemini-2.0-flash', provider: 'Google' },
{ id: 13, name: 'gemini-2.0-flash-lite-preview-02-05', provider: 'Google' },
{ id: 14, name: 'gemini-2.0-flash-thinking-exp-01-21', provider: 'Google' },
{ id: 15, name: 'gemini-2.0-pro-exp-02-05', provider: 'Google' },
{ id: 16, name: 'llama-3.3-70b-versatile', provider: 'Groq' },
{ id: 17, name: 'qwen2.5', provider: 'Ollama' },
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
{ id: 20, name: 'claude-3-7-sonnet@20250219', provider: 'GCP Vertex AI' },
{ id: 21, name: 'claude-3-5-sonnet-v2@20241022', provider: 'GCP Vertex AI' },
{ id: 22, name: 'claude-3-5-sonnet@20240620', provider: 'GCP Vertex AI' },
{ id: 23, name: 'claude-3-5-haiku@20241022', provider: 'GCP Vertex AI' },
{ id: 24, name: 'gemini-2.0-pro-exp-02-05', provider: 'GCP Vertex AI' },
{ id: 25, name: 'gemini-2.0-flash-001', provider: 'GCP Vertex AI' },
{ id: 26, name: 'gemini-1.5-pro-002', provider: 'GCP Vertex AI' },
];

export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];

export const anthropic_models = [
Expand Down
2 changes: 1 addition & 1 deletion ui/desktop/src/components/settings/models/toasts.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export function ToastSuccessModelSwitch(model: Model) {
return toast.success(
<div>
<strong>Model Changed</strong>
<div>Switched to {model.name}</div>
<div>Switched to {model.alias ?? model.name}</div>
</div>,
{
position: 'top-right',
Expand Down
8 changes: 4 additions & 4 deletions ui/desktop/src/components/settings/models/utils.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { useModel } from './ModelContext'; // Import the useModel hook
import { Model } from './ModelContext';
import { useMemo } from 'react';
import { goose_models } from './hardcoded_stuff';
import { gooseModels } from './GooseModels';
import { ToastFailureGeneral, ToastSuccessModelSwitch } from './toasts';
import { initializeSystem } from '../../../utils/providerUtils';
import { useRecentModels } from './RecentModels';
Expand Down Expand Up @@ -43,7 +43,7 @@ export function useHandleModelSelection() {
}

export function createSelectedModel(selectedProvider, modelName) {
let selectedModel = goose_models.find(
let selectedModel = gooseModels.find(
(model) =>
model.provider.toLowerCase() === selectedProvider &&
model.name.toLowerCase() === modelName.toLowerCase()
Expand All @@ -52,7 +52,7 @@ export function createSelectedModel(selectedProvider, modelName) {
if (!selectedModel) {
// Normalize the casing for the provider using the first matching model
const normalizedProvider =
goose_models.find((model) => model.provider.toLowerCase() === selectedProvider)?.provider ||
gooseModels.find((model) => model.provider.toLowerCase() === selectedProvider)?.provider ||
selectedProvider;

// Construct a model object
Expand All @@ -67,7 +67,7 @@ export function createSelectedModel(selectedProvider, modelName) {

export function useFilteredModels(search: string, activeKeys: string[]) {
const filteredModels = useMemo(() => {
const modelOptions = goose_models.filter((model) => activeKeys.includes(model.provider));
const modelOptions = gooseModels.filter((model) => activeKeys.includes(model.provider));

if (!search) {
return modelOptions; // Return all models if no search term
Expand Down

0 comments on commit 4f9c08a

Please sign in to comment.