Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
feature: implemented parallel inference for llama-rs, implemented nai…
Browse files Browse the repository at this point in the history
…ve sequential async inference for llama-cpp and rwkv-cpp (#52)

* feat: support parallel inference for llama-rs, support sequential async for llama-cpp and rwkv-cpp
  • Loading branch information
hlhr202 authored May 9, 2023
1 parent a311873 commit e82222d
Show file tree
Hide file tree
Showing 32 changed files with 558 additions and 1,415 deletions.
4 changes: 2 additions & 2 deletions packages/cli/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { existsSync } from "fs";

const convertType = ["q4_0", "q4_1", "f16", "f32"] as const;

type ConvertType = typeof convertType[number];
type ConvertType = (typeof convertType)[number];

interface CLIInferenceArguments extends LLamaInferenceArguments, LLamaConfig {
logger?: boolean;
Expand Down Expand Up @@ -75,7 +75,7 @@ class InferenceCommand implements yargs.CommandModule {
if (logger) {
LLama.enableLogger();
}
const llama = LLama.create({ path: absolutePath, numCtxTokens });
const llama = await LLama.create({ path: absolutePath, numCtxTokens });
llama.inference(rest, (result) => {
switch (result.type) {
case InferenceResultType.Data:
Expand Down
2 changes: 1 addition & 1 deletion packages/core/__test__/index.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test(
async () => {
LLama.enableLogger();

const llama = LLama.create({
const llama = await LLama.create({
path: process.env.model?.toString()!,
numCtxTokens: 128,
});
Expand Down
68 changes: 36 additions & 32 deletions packages/core/example/cachesession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,50 @@ const saveSession = path.resolve(process.cwd(), "./tmp/session.bin");

LLama.enableLogger();

const llama = LLama.create({
path: model,
numCtxTokens: 128,
});
const run = async () => {
const llama = await LLama.create({
path: model,
numCtxTokens: 128,
});

const template = `how are you`;
const template = `how are you`;

const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.
const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
${template}
### Response:`;

llama.inference(
{
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
feedPrompt: true,
feedPromptOnly: true,
saveSession,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
llama.inference(
{
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
feedPrompt: true,
feedPromptOnly: true,
saveSession,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
}
}
}
}
);
);
};

run();
74 changes: 35 additions & 39 deletions packages/core/example/embedding.ts
Original file line number Diff line number Diff line change
@@ -1,51 +1,47 @@
import { EmbeddingResultType, LLama } from "../index";
import { LLama } from "../index";
import path from "path";
import fs from "fs";

const model = path.resolve(process.cwd(), "../../ggml-alpaca-7b-q4.bin");

LLama.enableLogger();

const llama = LLama.create({
path: model,
numCtxTokens: 128,
});

const getWordEmbeddings = (prompt: string, file: string) => {
llama.getWordEmbeddings(
{
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
},
(response) => {
switch (response.type) {
case EmbeddingResultType.Data: {
fs.writeFileSync(
path.resolve(process.cwd(), file),
JSON.stringify(response.data)
);
break;
}
case EmbeddingResultType.Error: {
console.log(response);
break;
}
}
}
const getWordEmbeddings = async (
llama: LLama,
prompt: string,
file: string
) => {
const response = await llama.getWordEmbeddings({
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
});

fs.writeFileSync(
path.resolve(process.cwd(), file),
JSON.stringify(response)
);
};

const dog1 = `My favourite animal is the dog`;
getWordEmbeddings(dog1, "./example/semantic-compare/dog1.json");
const run = async () => {
const llama = await LLama.create({
path: model,
numCtxTokens: 128,
});

const dog1 = `My favourite animal is the dog`;
getWordEmbeddings(llama, dog1, "./example/semantic-compare/dog1.json");

const dog2 = `I have just adopted a cute dog`;
getWordEmbeddings(dog2, "./example/semantic-compare/dog2.json");
const dog2 = `I have just adopted a cute dog`;
getWordEmbeddings(llama, dog2, "./example/semantic-compare/dog2.json");

const cat1 = `My favourite animal is the cat`;
getWordEmbeddings(llama, cat1, "./example/semantic-compare/cat1.json");
};

const cat1 = `My favourite animal is the cat`;
getWordEmbeddings(cat1, "./example/semantic-compare/cat1.json");
run();
64 changes: 33 additions & 31 deletions packages/core/example/inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,47 @@ const model = path.resolve(process.cwd(), "../../ggml-alpaca-7b-q4.bin");

LLama.enableLogger();

const llama = LLama.create({
path: model,
numCtxTokens: 128,
});
const run = async () => {
const llama = await LLama.create({
path: model,
numCtxTokens: 128,
});

const template = `how are you`;
const template = `how are you`;

const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.
const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
${template}
### Response:`;

llama.inference(
{
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
feedPrompt: true,
// persistSession,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
llama.inference(
{
prompt,
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
feedPrompt: true,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
}
}
}
}
);
);
};
run();
60 changes: 32 additions & 28 deletions packages/core/example/loadsession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,38 @@ const loadSession = path.resolve(process.cwd(), "./tmp/session.bin");

LLama.enableLogger();

const llama = LLama.create({
path: model,
numCtxTokens: 128,
});
const run = async () => {
const llama = await LLama.create({
path: model,
numCtxTokens: 128,
});

llama.inference(
{
prompt: "",
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
loadSession,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
llama.inference(
{
prompt: "",
numPredict: 128,
temp: 0.2,
topP: 1,
topK: 40,
repeatPenalty: 1,
repeatLastN: 64,
seed: 0,
loadSession,
},
(response) => {
switch (response.type) {
case InferenceResultType.Data: {
process.stdout.write(response.data?.token ?? "");
break;
}
case InferenceResultType.End:
case InferenceResultType.Error: {
console.log(response);
break;
}
}
}
}
);
);
};

run();
21 changes: 12 additions & 9 deletions packages/core/example/tokenize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ const model = path.resolve(process.cwd(), "../../ggml-alpaca-7b-q4.bin");

LLama.enableLogger();

const llama = LLama.create({
path: model,
numCtxTokens: 128,
});
const run = async () => {
const llama = await LLama.create({
path: model,
numCtxTokens: 128,
});

const prompt = "My favourite animal is the cat";
const prompt = "My favourite animal is the cat";

llama.tokenize(prompt, (response) => {
console.log(response);
console.log(response.data.length); // 7
});
const tokens = await llama.tokenize(prompt);

console.log(tokens);
};

run();
Loading

0 comments on commit e82222d

Please sign in to comment.