Skip to content

Commit

Permalink
fix(models): interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
ido-pluto committed May 10, 2024
1 parent c96435c commit 5bb7e24
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 54 deletions.
41 changes: 37 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Make sure you have [Node.js](https://nodejs.org/en/) (**download current**) inst
```bash
npm install -g catai

catai install llama3-8b-openhermes-dpo-q3_k_s
catai install meta-llama-3-8b-q4_k_m
catai up
```

Expand Down Expand Up @@ -118,14 +118,47 @@ const data = await response.text();

For more information, please read the [API guide](https://github.com/withcatai/catai/blob/main/docs/api.md)

## Development API + Node-llama-cpp@beta integration
## Development API

You can also use the development API to interact with the model.

```ts
import {createChat, downloadModel, initCatAILlama, LlamaJsonSchemaGrammar} from "catai";

// skip downloading the model if you already have it
await downloadModel("meta-llama-3-8b-q4_k_m");

const llama = await initCatAILlama();
const chat = await createChat({
model: "meta-llama-3-8b-q4_k_m"
});

const fullResponse = await chat.prompt("Give me array of random numbers (10 numbers)", {
grammar: new LlamaJsonSchemaGrammar(llama, {
type: "array",
items: {
type: "number",
minimum: 0,
maximum: 100
},
}),
topP: 0.8,
temperature: 0.8,
});

console.log(fullResponse); // [10, 2, 3, 4, 6, 9, 8, 1, 7, 5]
```

(For the full list of model, run `catai models`)

### Node-llama-cpp@beta low level integration

You can use the model with [node-llama-cpp@beta](https://github.com/withcatai/node-llama-cpp/pull/105)

CatAI enables you to easily manage the models and chat with them.

```ts
import {downloadModel, getModelPath} from 'catai';
import {downloadModel, getModelPath, initCatAILlama, LlamaChatSession} from 'catai';

// download the model, skip if you already have the model
await downloadModel(
Expand All @@ -136,7 +169,7 @@ await downloadModel(
// get the model path with catai
const modelPath = getModelPath("llama3");

const llama = await getLlama();
const llama = await initCatAILlama();
const model = await llama.loadModel({
modelPath
});
Expand Down
88 changes: 66 additions & 22 deletions models.json
Original file line number Diff line number Diff line change
Expand Up @@ -528,70 +528,114 @@
},
"version": 1
},
"alphallama3-8b-q3_k_s": {
"meta-llama-3-8b-q4_k_m": {
"download": {
"files": {
"model": "Alphallama3-8B.Q3_K_S.gguf"
"model": "Meta-Llama-3-8B.Q4_K_M.gguf"
},
"repo": "https://huggingface.co/mradermacher/Alphallama3-8B-GGUF",
"commit": "738ab183a3e2ce92b96c9273e5d78960387ad939",
"repo": "https://huggingface.co/QuantFactory/Meta-Llama-3-8B-GGUF-v2",
"commit": "7b15b4f184a48c035fbc5ac2876e5617b64ea885",
"branch": "main"
},
"hardwareCompatibility": {
"ramGB": 5.3,
"ramGB": 5.6,
"cpuCors": 3,
"compressions": "q3_k_s"
"compressions": "q4_k_m"
},
"compatibleCatAIVersionRange": [
"3.1.2"
"3.2.0"
],
"settings": {
"bind": "node-llama-cpp-v2"
},
"version": 1
},
"llama3-8b-dpo-uncensored-q4_k_s": {
"llama-3-8b-lexi-uncensored-q4_k_m": {
"download": {
"files": {
"model": "Llama3-8B-DPO-uncensored.Q4_K_S.gguf"
"model": "Llama-3-8B-Lexi-Uncensored.Q4_K_M.gguf"
},
"repo": "https://huggingface.co/mradermacher/Llama3-8B-DPO-uncensored-GGUF",
"commit": "af5654c362a9967e2f704658f8aad7429cfcffb7",
"repo": "https://huggingface.co/QuantFactory/Llama-3-8B-Lexi-Uncensored-GGUF",
"commit": "5caac86e58458f70d7ff02ad2b7d99a850d61d4b",
"branch": "main"
},
"hardwareCompatibility": {
"ramGB": 5.3,
"ramGB": 5.6,
"cpuCors": 3,
"compressions": "q4_k_s"
"compressions": "q4_k_m"
},
"compatibleCatAIVersionRange": [
"3.1.2"
"3.2.0"
],
"settings": {
"bind": "node-llama-cpp-v2"
},
"version": 1
},
"llama3-8b-openhermes-dpo-q3_k_s": {
"power-wizardlm-2-13b-q5_k_m": {
"download": {
"files": {
"model": "Llama3-8B-OpenHermes-DPO.Q3_K_S.gguf"
"model": "Power-WizardLM-2-13b.Q5_K_M.gguf"
},
"repo": "https://huggingface.co/mradermacher/Llama3-8B-OpenHermes-DPO-GGUF",
"commit": "c0edd26cf8259267807d02ad8903faac593b099d",
"repo": "https://huggingface.co/mradermacher/Power-WizardLM-2-13b-GGUF",
"commit": "15ecbe0d095df08b49017db3b223433cd89153fc",
"branch": "main"
},
"hardwareCompatibility": {
"ramGB": 5.3,
"ramGB": 9.8,
"cpuCors": 5,
"compressions": "q5_k_m"
},
"compatibleCatAIVersionRange": [
"3.2.0"
],
"settings": {
"bind": "node-llama-cpp-v2"
},
"version": 1
},
"power-llama-3-13b-q4_k_m": {
"download": {
"files": {
"model": "Power-Llama-3-13b.Q4_K_M.gguf"
},
"repo": "https://huggingface.co/mradermacher/Power-Llama-3-13b-GGUF",
"commit": "0a61b3cce433745691cb73c5609c249b9b9848e9",
"branch": "main"
},
"hardwareCompatibility": {
"ramGB": 8.4,
"cpuCors": 4,
"compressions": "q4_k_m"
},
"compatibleCatAIVersionRange": [
"3.2.0"
],
"settings": {
"bind": "node-llama-cpp-v2"
},
"version": 1
},
"arrowpro-7b-robinhood-q4_k_m": {
"download": {
"files": {
"model": "ArrowPro-7B-RobinHood.Q4_K_M.gguf\\"
},
"repo": "https://huggingface.co/mradermacher/ArrowPro-7B-RobinHood-GGUF",
"commit": "54be3527006ac83c14d74d25b2573f81285077bc",
"branch": "main"
},
"hardwareCompatibility": {
"ramGB": 4.6,
"cpuCors": 3,
"compressions": "q3_k_s"
"compressions": "q4_k_m"
},
"compatibleCatAIVersionRange": [
"3.1.2"
"3.2.0"
],
"settings": {
"bind": "node-llama-cpp-v2"
},
"version": 1
}
}
}
1 change: 1 addition & 0 deletions server/scripts/new-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ const fileCompressionParametersToSize = {
},
'q4_k_m': {
7: 4.1,
8: 5.1,
13: 7.9,
30: 19.6,
34: 20.2,
Expand Down
4 changes: 4 additions & 0 deletions server/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import createChat, {getModelPath} from './manage-models/bind-class/bind-class.js
import CatAIDB from './storage/app-db.js';
import ENV_CONFIG from './storage/config.js';
import {CatAIError} from './errors/CatAIError.js';
import {initCatAILlama} from './manage-models/bind-class/binds/node-llama-cpp/node-llama-cpp-v2/node-llama-cpp-v2.js';

export * from 'node-llama-cpp';

const downloadModel = FetchModels.simpleDownload;

Expand All @@ -15,5 +18,6 @@ export {
CatAIDB,
getModelPath,
downloadModel,
initCatAILlama,
ENV_CONFIG as CATAI_ENV_CONFIG,
};
4 changes: 3 additions & 1 deletion server/src/manage-models/bind-class/bind-class.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import {ModelNotInstalledError} from './errors/ModelNotInstalledError.js';
import {NoActiveModelError} from './errors/NoActiveModelError.js';
import {NoModelBindError} from './errors/NoModelBindError.js';
import {BindNotFoundError} from './errors/BindNotFoundError.js';
import {ChatContext} from './chat-context.js';
import type {LLamaChatPromptOptions} from 'node-llama-cpp';

export const ALL_BINDS = [NodeLlamaCppV2];
const cachedBinds: { [key: string]: InstanceType<typeof BaseBindClass> } = {};
Expand Down Expand Up @@ -37,7 +39,7 @@ export function getCacheBindClass(modelDetails: ModelSettings<any> = findLocalMo
}

const lockContext = {};
export default async function createChat(options?: CreateChatOptions) {
export default async function createChat(options?: CreateChatOptions): Promise<ChatContext<LLamaChatPromptOptions>> {
return await withLock(lockContext, "createChat", async () => {
const modelDetails = findLocalModel(options?.model);
const cachedBindClass = getCacheBindClass(modelDetails);
Expand Down
6 changes: 3 additions & 3 deletions server/src/manage-models/bind-class/binds/base-bind-class.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import {ModelSettings} from '../../../storage/app-db.js';
import {ChatContext} from '../chat-context.js';
import {NodeLlamaCppOptions} from "./node-llama-cpp/node-llama-cpp-v2/node-llama-cpp-v2.js";
import {NodeLlamaCppOptions} from './node-llama-cpp/node-llama-cpp-v2/node-llama-cpp-v2.js';

export type CreateChatOptions = NodeLlamaCppOptions & {
model: string
}

export default abstract class BaseBindClass<T> {
export default abstract class BaseBindClass<Settings> {
public static shortName?: string;
public static description?: string;

public constructor(public modelSettings: ModelSettings<T>) {
public constructor(public modelSettings: ModelSettings<Settings>) {
}

public abstract initialize(): Promise<void> | void;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import type {LLamaChatPromptOptions, LlamaChatSession, Token} from 'node-llama-cpp';
import {ChatContext} from '../../../chat-context.js';
import objectAssignDeep from "object-assign-deep";

export default class NodeLlamaCppChat extends ChatContext {
export default class NodeLlamaCppChat extends ChatContext<LLamaChatPromptOptions> {

constructor(protected _promptSettings: Partial<LLamaChatPromptOptions>, private _session: LlamaChatSession) {
super();
}

public async prompt(prompt: string, onTokenText?: (token: string) => void, overrideSettings?: Partial<LLamaChatPromptOptions>): Promise<string | null> {
public async prompt(prompt: string, onTokenText?: ((token: string) => void) | Partial<LLamaChatPromptOptions>, overrideSettings?: Partial<LLamaChatPromptOptions>): Promise<string | null> {
if (typeof onTokenText !== 'function') {
overrideSettings = onTokenText as Partial<LLamaChatPromptOptions>;
onTokenText = undefined;
}

this.emit('abort', 'Aborted by new prompt');
const abort = new AbortController();
const closeCallback = () => {
Expand All @@ -19,7 +23,7 @@ export default class NodeLlamaCppChat extends ChatContext {

let response = null;
try {
const allSettings: LLamaChatPromptOptions = objectAssignDeep({}, this._promptSettings, overrideSettings);
const allSettings: LLamaChatPromptOptions = Object.assign({}, this._promptSettings, overrideSettings);
response = await this._session.prompt(prompt, {
...allSettings,
signal: abort.signal,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
import type {
LLamaChatPromptOptions,
LlamaChatSessionOptions,
LlamaContextOptions,
LlamaModel,
LlamaModelOptions
} from 'node-llama-cpp';
import {getLlama, Llama, LLamaChatPromptOptions, LlamaChatSession, LlamaChatSessionOptions, LlamaContextOptions, LlamaModel, LlamaModelOptions, LlamaOptions} from 'node-llama-cpp';
import NodeLlamaCppChat from './node-llama-cpp-chat.js';
import BaseBindClass from '../../base-bind-class.js';
import objectAssignDeep from "object-assign-deep";

export type NodeLlamaCppOptions =
Omit<LlamaContextOptions, 'model'> &
Omit<LlamaModelOptions, 'modelPath'> &
Omit<LlamaChatSessionOptions, 'contextSequence'> &
LLamaChatPromptOptions;


let cachedLlama: Llama | null = null;

export async function initCatAILlama(options?: LlamaOptions) {
return cachedLlama = await getLlama(options);
}

export default class NodeLlamaCppV2 extends BaseBindClass<NodeLlamaCppOptions> {
public static override shortName = 'node-llama-cpp-v2';
public static override description = 'node-llama-cpp v2, that support GGUF model, and advanced feature such as output format, max tokens and much more';
private _model?: LlamaModel;
private _package?: typeof import('node-llama-cpp');

async createChat(overrideSettings?: NodeLlamaCppOptions) {
if (!this._model || !this._package)
if (!this._model)
throw new Error('Model not initialized');

const settings= objectAssignDeep({}, this.modelSettings.settings, overrideSettings);
const settings = Object.assign({}, this.modelSettings.settings, overrideSettings);
const context = await this._model.createContext({
...settings
});

const session = new this._package.LlamaChatSession({
const session = new LlamaChatSession({
contextSequence: context.getSequence(),
...settings
});
Expand All @@ -39,10 +38,7 @@ export default class NodeLlamaCppV2 extends BaseBindClass<NodeLlamaCppOptions> {
}

async initialize(): Promise<void> {
const {getLlama, ...others} = await import('node-llama-cpp');
this._package = others as any;

const llama = await getLlama();
const llama = cachedLlama ?? await initCatAILlama();
this._model = await llama.loadModel({
modelPath: this.modelSettings.downloadedFiles.model,
...this.modelSettings.settings
Expand Down
7 changes: 3 additions & 4 deletions server/src/manage-models/bind-class/chat-context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ export interface ChatContextEvents {
emit(event: 'modelResponseEnd'): boolean;
}

export abstract class ChatContext extends EventEmitter implements ChatContextEvents {
export abstract class ChatContext<Settings = any> extends EventEmitter implements ChatContextEvents {

/**
* Prompt the model and stream the response
* @param prompt
* @param onToken
*/
abstract prompt(prompt: string, onToken?: (token: string) => void): Promise<string | null>;
abstract prompt(prompt: string, overrideSettings?: Partial<Settings>): Promise<string | null>;
abstract prompt(prompt: string, onToken?: (token: string) => void, overrideSettings?: Partial<Settings>): Promise<string | null>;

/**
* Abort the model response
Expand Down

0 comments on commit 5bb7e24

Please sign in to comment.