Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
feat: impl extensions for langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
hlhr202 committed Apr 17, 2023
1 parent 23cfb48 commit 25aa9be
Show file tree
Hide file tree
Showing 8 changed files with 1,091 additions and 89 deletions.
1,086 changes: 1,008 additions & 78 deletions package-lock.json

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@
"@tensorflow/tfjs-node": "^4.2.0",
"@types/node": "^18.15.5",
"@types/semver": "^7.3.13",
"axios": "^1.3.5",
"axios": "*",
"glob": "^9.3.4",
"rimraf": "^4.4.1",
"semver": "^7.3.8",
"tsup": "^6.7.0",
"tsx": "^3.12.6",
"typescript": "^5.0.4",
"vitest": "^0.29.8"
"vitest": "^0.29.8",
"langchain": "^0.0.56"
},
"dependencies": {
"@llama-node/cli": "0.0.27"
Expand All @@ -68,4 +69,4 @@
"@llama-node/core": "0.0.27",
"@llama-node/llama-cpp": "0.0.27"
}
}
}
29 changes: 29 additions & 0 deletions src/extensions/langchain.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import { AsyncCaller } from "langchain/dist/util/async_caller";
import { Embeddings, type EmbeddingsParams } from "langchain/embeddings/base";
import type { LLama } from "..";

export class LLamaEmbeddings implements Embeddings {
caller: AsyncCaller;
llm: LLama;

constructor(params: EmbeddingsParams, llm: LLama) {
if ((params.maxConcurrency ?? 1) > 1) {
console.warn(
"maxConcurrency > 1 not officially supported for llama-node, use at your own risk"
);
}
this.caller = new AsyncCaller(params);
this.llm = llm;
}

embedDocuments(documents: string[]): Promise<number[][]> {
const promises = documents.map((doc) =>
this.llm.getDefaultEmbeddings(doc)
);
return Promise.all(promises);
}

embedQuery(document: string): Promise<number[]> {
return this.llm.getDefaultEmbeddings(document);
}
}
19 changes: 14 additions & 5 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ import { CompletionCallback } from "./llm";
import type { LLM } from "./llm";

export class LLama<
Instance,
LoadConfig,
LLMInferenceArguments,
LLMEmbeddingArguments,
TokenizeArguments
Instance = any,
LoadConfig = any,
LLMInferenceArguments = any,
LLMEmbeddingArguments = any,
TokenizeArguments = any
> {
llm: LLM<
Instance,
Expand Down Expand Up @@ -48,6 +48,15 @@ export class LLama<
}
}

async getDefaultEmbeddings(text: string): Promise<number[]> {
if (!this.llm.getDefaultEmbedding) {
console.warn("getDefaultEmbedding not implemented for current LLM");
return [];
} else {
return this.llm.getDefaultEmbedding(text);
}
}

async tokenize(content: TokenizeArguments): Promise<number[]> {
if (!this.llm.tokenize) {
console.warn("tokenize not implemented for current LLM");
Expand Down
4 changes: 3 additions & 1 deletion src/llm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export interface LLM<
LoadConfig,
LLMInferenceArguments,
LLMEmbeddingArguments,
LLMTokenizeArguments,
LLMTokenizeArguments
> {
readonly instance: Instance;

Expand All @@ -20,5 +20,7 @@ export interface LLM<

getEmbedding?(params: LLMEmbeddingArguments): Promise<number[]>;

getDefaultEmbedding?(text: string): Promise<number[]>;

tokenize?(content: LLMTokenizeArguments): Promise<number[]>;
}
20 changes: 19 additions & 1 deletion src/llm/llama-cpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ export interface TokenizeArguments {

export class LLamaCpp
implements
LLM<LLama, LoadConfig, LlamaInvocation, LlamaInvocation, TokenizeArguments>
LLM<
LLama,
LoadConfig,
LlamaInvocation,
LlamaInvocation,
TokenizeArguments
>
{
instance!: LLama;

Expand Down Expand Up @@ -82,6 +88,18 @@ export class LLamaCpp
});
}

async getDefaultEmbedding(text: string): Promise<number[]> {
return this.getEmbedding({
nThreads: 4,
nTokPredict: 1024,
topK: 40,
topP: 0.1,
temp: 0.1,
repeatPenalty: 1,
prompt: text,
});
}

async tokenize(params: TokenizeArguments): Promise<number[]> {
return new Promise<number[]>((res, rej) => {
this.instance.tokenize(params.content, params.nCtx, (response) => {
Expand Down
12 changes: 12 additions & 0 deletions src/llm/llama-rs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ export class LLamaRS
});
}

async getDefaultEmbedding(text: string): Promise<number[]> {
return this.getEmbedding({
nThreads: 4,
numPredict: 1024,
topK: 40,
topP: 0.1,
temp: 0.1,
repeatPenalty: 1,
prompt: text,
});
}

async tokenize(params: string): Promise<number[]> {
return new Promise<number[]>((res) => {
this.instance.tokenize(params, (response) => {
Expand Down
3 changes: 2 additions & 1 deletion tsup.config.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { defineConfig } from "tsup";

export default defineConfig({
entry: ["src/index.ts", "src/llm/*.ts"],
entry: ["src/index.ts", "src/llm/*.ts", "src/extensions/*.ts"],
external: ["langchain"],
target: ["es2015"],
format: ["cjs", "esm"],
dts: true,
Expand Down

0 comments on commit 25aa9be

Please sign in to comment.