Skip to content

Commit

Permalink
Add support for gemini authenticated models endpoint (#2868)
Browse files Browse the repository at this point in the history
* Add support for gemini authenticated models endpoint
add customModels entry
add un-authed fallback to default listing
separate models by expiermental status
resolves #2866

* add back improved logic for apiVersion decision making
  • Loading branch information
timothycarambat authored Dec 17, 2024
1 parent 71cd5e5 commit b082c8e
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 110 deletions.
130 changes: 85 additions & 45 deletions frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import System from "@/models/system";
import { useEffect, useState } from "react";

export default function GeminiLLMOptions({ settings }) {
const [inputValue, setInputValue] = useState(settings?.GeminiLLMApiKey);
const [geminiApiKey, setGeminiApiKey] = useState(settings?.GeminiLLMApiKey);

return (
<div className="w-full flex flex-col">
<div className="w-full flex items-center gap-[36px] mt-1.5">
Expand All @@ -15,56 +21,14 @@ export default function GeminiLLMOptions({ settings }) {
required={true}
autoComplete="off"
spellCheck={false}
onChange={(e) => setInputValue(e.target.value)}
onBlur={() => setGeminiApiKey(inputValue)}
/>
</div>

{!settings?.credentialsOnly && (
<>
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
defaultValue={settings?.GeminiLLMModelPref || "gemini-pro"}
required={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<optgroup label="Stable Models">
{[
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</optgroup>
<optgroup label="Experimental Models">
{[
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</optgroup>
</select>
</div>
<GeminiModelSelection apiKey={geminiApiKey} settings={settings} />
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Safety Setting
Expand All @@ -91,3 +55,79 @@ export default function GeminiLLMOptions({ settings }) {
</div>
);
}

function GeminiModelSelection({ apiKey, settings }) {
const [groupedModels, setGroupedModels] = useState({});
const [loading, setLoading] = useState(true);

useEffect(() => {
async function findCustomModels() {
setLoading(true);
const { models } = await System.customModels("gemini", apiKey);

if (models?.length > 0) {
const modelsByOrganization = models.reduce((acc, model) => {
acc[model.experimental ? "Experimental" : "Stable"] =
acc[model.experimental ? "Experimental" : "Stable"] || [];
acc[model.experimental ? "Experimental" : "Stable"].push(model);
return acc;
}, {});
setGroupedModels(modelsByOrganization);
}
setLoading(false);
}
findCustomModels();
}, [apiKey]);

if (loading) {
return (
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
disabled={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<option disabled={true} selected={true}>
-- loading available models --
</option>
</select>
</div>
);
}

return (
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
required={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
{Object.keys(groupedModels)
.sort((a, b) => {
if (a === "Stable") return -1;
if (b === "Stable") return 1;
return a.localeCompare(b);
})
.map((organization) => (
<optgroup key={organization} label={organization}>
{groupedModels[organization].map((model) => (
<option
key={model.id}
value={model.id}
selected={settings?.GeminiLLMModelPref === model.id}
>
{model.name}
</option>
))}
</optgroup>
))}
</select>
</div>
);
}
46 changes: 46 additions & 0 deletions server/utils/AiProviders/gemini/defaultModals.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
const { MODEL_MAP } = require("../modelMap");

const stableModels = [
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
];

const experimentalModels = [
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
];

// There are some models that are only available in the v1beta API
// and some models that are only available in the v1 API
// generally, v1beta models have `exp` in the name, but not always
// so we check for both against a static list as well.
const v1BetaModels = ["gemini-1.5-pro-latest", "gemini-1.5-flash-latest"];

const defaultGeminiModels = [
...stableModels.map((model) => ({
id: model,
name: model,
contextWindow: MODEL_MAP.gemini[model],
experimental: false,
})),
...experimentalModels.map((model) => ({
id: model,
name: model,
contextWindow: MODEL_MAP.gemini[model],
experimental: true,
})),
];

module.exports = {
defaultGeminiModels,
v1BetaModels,
};
115 changes: 72 additions & 43 deletions server/utils/AiProviders/gemini/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const {
clientAbortedHandler,
} = require("../../helpers/chat/responses");
const { MODEL_MAP } = require("../modelMap");
const { defaultGeminiModels, v1BetaModels } = require("./defaultModals");

class GeminiLLM {
constructor(embedder = null, modelPreference = null) {
Expand All @@ -21,22 +22,17 @@ class GeminiLLM {
this.gemini = genAI.getGenerativeModel(
{ model: this.model },
{
// Gemini-1.5-pro-* and Gemini-1.5-flash are only available on the v1beta API.
apiVersion: [
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
].includes(this.model)
? "v1beta"
: "v1",
apiVersion:
/**
* There are some models that are only available in the v1beta API
* and some models that are only available in the v1 API
* generally, v1beta models have `exp` in the name, but not always
* so we check for both against a static list as well.
* @see {v1BetaModels}
*/
this.model.includes("exp") || v1BetaModels.includes(this.model)
? "v1beta"
: "v1",
}
);
this.limits = {
Expand All @@ -48,6 +44,11 @@ class GeminiLLM {
this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7; // not used for Gemini
this.safetyThreshold = this.#fetchSafetyThreshold();
this.#log(`Initialized with model: ${this.model}`);
}

#log(text, ...args) {
console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args);
}

#appendContext(contextTexts = []) {
Expand Down Expand Up @@ -109,25 +110,63 @@ class GeminiLLM {
return MODEL_MAP.gemini[this.model] ?? 30_720;
}

isValidChatCompletionModel(modelName = "") {
const validModels = [
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
];
return validModels.includes(modelName);
/**
* Fetches Gemini models from the Google Generative AI API
* @param {string} apiKey - The API key to use for the request
* @param {number} limit - The maximum number of models to fetch
* @param {string} pageToken - The page token to use for pagination
* @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models
*/
static async fetchModels(apiKey, limit = 1_000, pageToken = null) {
const url = new URL(
"https://generativelanguage.googleapis.com/v1beta/models"
);
url.searchParams.set("pageSize", limit);
url.searchParams.set("key", apiKey);
if (pageToken) url.searchParams.set("pageToken", pageToken);

return fetch(url.toString(), {
method: "GET",
headers: { "Content-Type": "application/json" },
})
.then((res) => res.json())
.then((data) => {
if (data.error) throw new Error(data.error.message);
return data.models ?? [];
})
.then((models) =>
models
.filter(
(model) => !model.displayName.toLowerCase().includes("tuning")
)
.filter((model) =>
model.supportedGenerationMethods.includes("generateContent")
) // Only generateContent is supported
.map((model) => {
return {
id: model.name.split("/").pop(),
name: model.displayName,
contextWindow: model.inputTokenLimit,
experimental: model.name.includes("exp"),
};
})
)
.catch((e) => {
console.error(`Gemini:getGeminiModels`, e.message);
return defaultGeminiModels;
});
}

/**
* Checks if a model is valid for chat completion (unused)
* @deprecated
* @param {string} modelName - The name of the model to check
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the model is valid
*/
async isValidChatCompletionModel(modelName = "") {
const models = await this.fetchModels(true);
return models.some((model) => model.id === modelName);
}
/**
* Generates appropriate content array for a message + attachments.
* @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
Expand Down Expand Up @@ -218,11 +257,6 @@ class GeminiLLM {
}

async getChatCompletion(messages = [], _opts = {}) {
if (!this.isValidChatCompletionModel(this.model))
throw new Error(
`Gemini chat: ${this.model} is not valid for chat completion!`
);

const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT"
)?.content;
Expand Down Expand Up @@ -256,11 +290,6 @@ class GeminiLLM {
}

async streamGetChatCompletion(messages = [], _opts = {}) {
if (!this.isValidChatCompletionModel(this.model))
throw new Error(
`Gemini chat: ${this.model} is not valid for chat completion!`
);

const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT"
)?.content;
Expand Down
15 changes: 15 additions & 0 deletions server/utils/helpers/customModels.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs");
const { fetchNovitaModels } = require("../AiProviders/novita");
const { parseLMStudioBasePath } = require("../AiProviders/lmStudio");
const { parseNvidiaNimBasePath } = require("../AiProviders/nvidiaNim");
const { GeminiLLM } = require("../AiProviders/gemini");

const SUPPORT_CUSTOM_MODELS = [
"openai",
Expand All @@ -28,6 +29,7 @@ const SUPPORT_CUSTOM_MODELS = [
"apipie",
"novita",
"xai",
"gemini",
];

async function getCustomModels(provider = "", apiKey = null, basePath = null) {
Expand Down Expand Up @@ -73,6 +75,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) {
return await getXAIModels(apiKey);
case "nvidia-nim":
return await getNvidiaNimModels(basePath);
case "gemini":
return await getGeminiModels(apiKey);
default:
return { models: [], error: "Invalid provider for custom models" };
}
Expand Down Expand Up @@ -572,6 +576,17 @@ async function getNvidiaNimModels(basePath = null) {
}
}

async function getGeminiModels(_apiKey = null) {
const apiKey =
_apiKey === true
? process.env.GEMINI_API_KEY
: _apiKey || process.env.GEMINI_API_KEY || null;
const models = await GeminiLLM.fetchModels(apiKey);
// Api Key was successful so lets save it for future uses
if (models.length > 0 && !!apiKey) process.env.GEMINI_API_KEY = apiKey;
return { models, error: null };
}

module.exports = {
getCustomModels,
};
Loading

0 comments on commit b082c8e

Please sign in to comment.