Skip to content

Commit

Permalink
feat(agents): granite 3.1 support (#257)
Browse files Browse the repository at this point in the history
Signed-off-by: MICHAEL DESMOND <[email protected]>
Signed-off-by: Graham White <[email protected]>
Co-authored-by: Graham White <[email protected]>
  • Loading branch information
michael-desmond and grahamwhiteuk authored Dec 18, 2024
1 parent 747a052 commit 56045d6
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 75 deletions.
12 changes: 6 additions & 6 deletions examples/agents/granite/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@

The [IBM Granite](https://www.ibm.com/granite) family of models can be used as the underlying LLM within Bee Agents. Granite™ is IBM's family of open, performant, and trusted AI models tailored for business and optimized to scale your AI applications.

This guide and the associated examples will help you get started with creating Bee Agents using Granite.
This guide and the associated examples will help you get started with creating Bee Agents using Granite 3.1.

## 📦 Prerequisites

### LLM Services

IBM Granite is supported by [watsonx.ai](https://www.ibm.com/products/watsonx-ai) and [Ollama](https://ollama.com/). Watsonx.ai will allow you to run models in the cloud. Ollama will allow you to download and run models locally.
IBM Granite 3.1 is supported by [watsonx.ai](https://www.ibm.com/products/watsonx-ai) and [Ollama](https://ollama.com/). Watsonx.ai will allow you to run models in the cloud. Ollama will allow you to download and run models locally.

> [!TIP]
> Better performance will be achieved by using larger Granite models.
> [!NOTE]
> If you work for IBM there are additional options to run IBM Granite models with VLLM or RITS.
> If you work for IBM there are additional options to run IBM Granite 3.1 models with VLLM or RITS.
#### Ollama

There are guides available for running Granite with Ollama on [Linux](https://www.ibm.com/granite/docs/run/granite-on-linux/granite/), [Mac](https://www.ibm.com/granite/docs/run/granite-on-mac/granite/) or [Windows](https://www.ibm.com/granite/docs/run/granite-on-windows/granite/).
There are guides available for running Granite 3.1 with Ollama on [Linux](https://www.ibm.com/granite/docs/run/granite-on-linux/granite/), [Mac](https://www.ibm.com/granite/docs/run/granite-on-mac/granite/) or [Windows](https://www.ibm.com/granite/docs/run/granite-on-windows/granite/).

#### Watsonx

Expand Down Expand Up @@ -88,10 +88,10 @@ In this example the wikipedia tool interface is extended so that the agent can s

This example uses Ollama exclusively.

To get started you will need to pull `granite3-dense:8b` and `nomic-embed-text` (to perform text embedding). If you are unfamiliar with using Ollama then check out instructions for getting up and running at the the [Ollama Github repo](https://github.com/ollama/ollama).
To get started you will need to pull `granite3.1-dense:8b` and `nomic-embed-text` (to perform text embedding). If you are unfamiliar with using Ollama then check out instructions for getting up and running at the the [Ollama Github repo](https://github.com/ollama/ollama).

```shell
ollama pull granite3-dense:8b
ollama pull granite3.1-dense:8b
ollama pull nomic-embed-text
ollama serve
```
Expand Down
6 changes: 3 additions & 3 deletions examples/agents/granite/granite_bee.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function getChatLLM(provider?: Provider): ChatLLM<ChatLLMOutput> {
const LLMFactories: Record<Provider, () => ChatLLM<ChatLLMOutput>> = {
[Providers.OLLAMA]: () =>
new OllamaChatLLM({
modelId: getEnv("OLLAMA_MODEL") || "granite3-dense:8b",
modelId: getEnv("OLLAMA_MODEL") || "granite3.1-dense:8b",
parameters: {
temperature: 0,
repeat_penalty: 1,
Expand All @@ -45,7 +45,7 @@ function getChatLLM(provider?: Provider): ChatLLM<ChatLLMOutput> {
projectId: getEnv("WATSONX_PROJECT_ID"),
region: getEnv("WATSONX_REGION"),
}),
[Providers.IBMVLLM]: () => IBMVllmChatLLM.fromPreset(IBMVllmModel.GRANITE_3_0_8B_INSTRUCT),
[Providers.IBMVLLM]: () => IBMVllmChatLLM.fromPreset(IBMVllmModel.GRANITE_3_1_8B_INSTRUCT),
[Providers.IBMRITS]: () =>
new OpenAIChatLLM({
client: new OpenAI({
Expand All @@ -55,7 +55,7 @@ function getChatLLM(provider?: Provider): ChatLLM<ChatLLMOutput> {
RITS_API_KEY: process.env.IBM_RITS_API_KEY,
},
}),
modelId: getEnv("IBM_RITS_MODEL") || "ibm-granite/granite-3.0-8b-instruct",
modelId: getEnv("IBM_RITS_MODEL") || "ibm-granite/granite-3.1-8b-instruct",
parameters: {
temperature: 0,
max_tokens: 2048,
Expand Down
4 changes: 2 additions & 2 deletions examples/agents/granite/granite_wiki_bee.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function wikipediaRetrivalTool(passageSize: number, overlap: number, maxResults:

// Agent LLM
const llm = new OllamaChatLLM({
modelId: "granite3-dense:8b",
modelId: "granite3.1-dense:8b",
parameters: {
temperature: 0,
num_ctx: 4096,
Expand All @@ -87,7 +87,7 @@ const llm = new OllamaChatLLM({
const agent = new BeeAgent({
llm,
memory: new TokenMemory({ llm }),
tools: [wikipediaRetrivalTool(200, 50, 3)],
tools: [wikipediaRetrivalTool(400, 50, 3)],
});

const reader = createConsoleReader();
Expand Down
22 changes: 2 additions & 20 deletions src/adapters/ibm-vllm/chatPreset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ export const IBMVllmModel = {
LLAMA_3_1_405B_INSTRUCT_FP8: "meta-llama/llama-3-1-405b-instruct-fp8",
LLAMA_3_1_70B_INSTRUCT: "meta-llama/llama-3-1-70b-instruct",
LLAMA_3_1_8B_INSTRUCT: "meta-llama/llama-3-1-8b-instruct",
GRANITE_3_0_8B_INSTRUCT: "ibm-granite/granite-3-0-8b-instruct",
GRANITE_3_1_8B_INSTRUCT: "ibm-granite/granite-3-1-8b-instruct",
} as const;
export type IBMVllmModel = (typeof IBMVllmModel)[keyof typeof IBMVllmModel];
Expand Down Expand Up @@ -119,26 +118,8 @@ export const IBMVllmChatLLMPreset = {
},
};
},
[IBMVllmModel.GRANITE_3_0_8B_INSTRUCT]: (): IBMVllmChatLLMPreset => {
const { template, parameters, messagesToPrompt } = LLMChatTemplates.get("granite3Instruct");
return {
base: {
modelId: IBMVllmModel.GRANITE_3_0_8B_INSTRUCT,
parameters: {
method: "GREEDY",
stopping: {
stop_sequences: [...parameters.stop_sequence],
include_stop_sequence: false,
},
},
},
chat: {
messagesToPrompt: messagesToPrompt(template),
},
};
},
[IBMVllmModel.GRANITE_3_1_8B_INSTRUCT]: (): IBMVllmChatLLMPreset => {
const { template, parameters, messagesToPrompt } = LLMChatTemplates.get("granite3Instruct");
const { template, parameters, messagesToPrompt } = LLMChatTemplates.get("granite3.1-Instruct");
return {
base: {
modelId: IBMVllmModel.GRANITE_3_1_8B_INSTRUCT,
Expand All @@ -147,6 +128,7 @@ export const IBMVllmChatLLMPreset = {
stopping: {
stop_sequences: [...parameters.stop_sequence],
include_stop_sequence: false,
max_new_tokens: 2048,
},
},
},
Expand Down
14 changes: 7 additions & 7 deletions src/adapters/shared/llmChatTemplates.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,21 @@ const llama3: LLMChatTemplate = {
},
};

const granite3Instruct: LLMChatTemplate = {
const granite31Instruct: LLMChatTemplate = {
template: new PromptTemplate({
schema: templateSchemaFactory([
"system",
"user",
"assistant",
"available_tools",
"tools",
"tool_call",
"tool_response",
] as const),
template: `{{#messages}}{{#system}}<|start_of_role|>system<|end_of_role|>
{{system}}<|end_of_text|>
{{ end }}{{/system}}{{#available_tools}}<|start_of_role|>available_tools<|end_of_role|>
{{available_tools}}<|end_of_text|>
{{ end }}{{/available_tools}}{{#user}}<|start_of_role|>user<|end_of_role|>
{{ end }}{{/system}}{{#tools}}<|start_of_role|>tools<|end_of_role|>
{{tools}}<|end_of_text|>
{{ end }}{{/tools}}{{#user}}<|start_of_role|>user<|end_of_role|>
{{user}}<|end_of_text|>
{{ end }}{{/user}}{{#assistant}}<|start_of_role|>assistant<|end_of_role|>
{{assistant}}<|end_of_text|>
Expand All @@ -142,7 +142,7 @@ const granite3Instruct: LLMChatTemplate = {
`,
}),
messagesToPrompt: messagesToPromptFactory({
available_tools: "available_tools",
tools: "tools",
tool_response: "tool_response",
tool_call: "tool_call",
}),
Expand All @@ -156,7 +156,7 @@ export class LLMChatTemplates {
"llama3.3": llama33,
"llama3.1": llama31,
"llama3": llama3,
"granite3Instruct": granite3Instruct,
"granite3.1-Instruct": granite31Instruct,
};

static register(model: string, template: LLMChatTemplate, override = false) {
Expand Down
23 changes: 2 additions & 21 deletions src/adapters/watsonx/chatPreset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ export const WatsonXChatLLMPreset = {
};
},
"ibm/granite-3-8b-instruct": (): WatsonXChatLLMPreset => {
const { template, parameters, messagesToPrompt } = LLMChatTemplates.get("granite3Instruct");
const { template, parameters, messagesToPrompt } = LLMChatTemplates.get("granite3.1-Instruct");
return {
base: {
parameters: {
decoding_method: "greedy",
max_new_tokens: 512,
max_new_tokens: 2048,
include_stop_sequence: false,
stop_sequences: [...parameters.stop_sequence],
},
Expand All @@ -61,25 +61,6 @@ export const WatsonXChatLLMPreset = {
"ibm/granite-3-2b-instruct"() {
return WatsonXChatLLMPreset["ibm/granite-3-8b-instruct"]();
},
"ibm/granite-3-1-8b-instruct": (): WatsonXChatLLMPreset => {
const { template, parameters, messagesToPrompt } = LLMChatTemplates.get("granite3Instruct");
return {
base: {
parameters: {
decoding_method: "greedy",
max_new_tokens: 512,
include_stop_sequence: false,
stop_sequences: [...parameters.stop_sequence],
},
},
chat: {
messagesToPrompt: messagesToPrompt(template),
},
};
},
"ibm/granite-3-1-2b-instruct"() {
return WatsonXChatLLMPreset["ibm/granite-3-8b-instruct"]();
},
"meta-llama/llama-3-1-70b-instruct": (): WatsonXChatLLMPreset => {
const { template, messagesToPrompt, parameters } = LLMChatTemplates.get("llama3.1");

Expand Down
32 changes: 17 additions & 15 deletions src/agents/bee/runners/granite/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,26 @@ export const GraniteBeeSystemPrompt = BeeSystemPrompt.fork((config) => ({
}).format(date);
},
},
template: `# Setting
You are an AI assistant.
template: `You are an AI assistant.
When the user sends a message figure out a solution and provide a final answer.
{{#tools.length}}
You have access to a set of available tools that can be used to retrieve information and perform actions.
You have access to a set of tools that can be used to retrieve information and perform actions.
Pay close attention to the tool description to determine if a tool is useful in a particular context.
{{/tools.length}}
# Communication structure:
- Line starting 'Message: ' The user's question or instruction. This is provided by the user, the assistant does not produce this.
- Line starting 'Thought: ' The assistant's response always starts with a thought, this is free text where the assistant thinks about the user's message and describes in detail what it should do next.
# Communication structure
You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by 'Tool Name' and then 'Tool Input', or 'Thought' followed by 'Final Answer'
Line starting 'Thought: ' The assistant's response always starts with a thought, this is a single line where the assistant thinks about the user's message and describes in detail what it should do next.
{{#tools.length}}
- In a 'Thought', the assistant should determine if a Tool Call is necessary to get more information or perform an action, or if the available information is sufficient to provide the Final Answer.
- If a tool needs to be called and is available, the assistant will produce a tool call:
- Line starting 'Tool Name: ' name of the tool that you want to use.
- Line starting 'Tool Input: ' JSON formatted tool arguments adhering to the selected tool parameters schema i.e. {"arg1":"value1", "arg2":"value2"}.
- Line starting 'Thought: ', followed by free text where the assistant thinks about the all the information it has available, and what it should do next (e.g. try the same tool with a different input, try a different tool, or proceed with answering the original user question).
In a 'Thought: ', the assistant should determine if a Tool Call is necessary to get more information or perform an action, or if the available information is sufficient to provide the Final Answer.
If a tool needs to be called and is available, the assistant will produce a tool call:
Line starting 'Tool Name: ' name of the tool that you want to use.
Line starting 'Tool Input: ' JSON formatted tool arguments adhering to the selected tool parameters schema i.e. {"arg1":"value1", "arg2":"value2"}.
After a 'Tool Input: ' the next message will contain a tool response. The next output should be a 'Thought: ' where the assistant thinks about the all the information it has available, and what it should do next (e.g. try the same tool with a different input, try a different tool, or proceed with answering the original user question).
{{/tools.length}}
- Once enough information is available to provide the Final Answer, the last line in the message needs to be:
- Line starting 'Final Answer: ' followed by a answer to the original message.
Once enough information is available to provide the Final Answer, the last line in the message needs to be:
Line starting 'Final Answer: ' followed by a concise and clear answer to the original message.
# Best practices
- Use markdown syntax for formatting code snippets, links, JSON, tables, images, files.
Expand All @@ -81,8 +81,10 @@ The current date and time is: {{formatDate}}
You do not need a tool to get the current Date and Time. Use the information available here.
{{/tools.length}}
{{#instructions}}
# Additional instructions
{{instructions}}
{{.}}
{{/instructions}}
`,
}));

Expand All @@ -94,7 +96,7 @@ You communicate only in instruction lines. Valid instruction lines are 'Thought'

export const GraniteBeeUserPrompt = BeeUserPrompt.fork((config) => ({
...config,
template: `Message: {{input}}`,
template: `{{input}}`,
}));

export const GraniteBeeToolNotFoundPrompt = BeeToolNotFoundPrompt.fork((config) => ({
Expand Down
2 changes: 1 addition & 1 deletion src/agents/bee/runners/granite/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ export class GraniteRunner extends DefaultRunner {
const index = memory.messages.findIndex((msg) => msg.role === Role.SYSTEM) + 1;
await memory.add(
BaseMessage.of({
role: "available_tools",
role: "tools",
text: JSON.stringify(
(await this.renderers.system.variables.tools()).map((tool) => ({
name: tool.name,
Expand Down

0 comments on commit 56045d6

Please sign in to comment.