diff --git a/docs/core_docs/docs/integrations/chat/cerebras.ipynb b/docs/core_docs/docs/integrations/chat/cerebras.ipynb new file mode 100644 index 000000000000..2d49afcd1b79 --- /dev/null +++ b/docs/core_docs/docs/integrations/chat/cerebras.ipynb @@ -0,0 +1,341 @@ +{ + "cells": [ + { + "cell_type": "raw", + "id": "afaf8039", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "---\n", + "sidebar_label: Cerebras\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "e49f1e0d", + "metadata": {}, + "source": [ + "# ChatCerebras\n", + "\n", + "[Cerebras](https://cerebras.ai/) is a model provider that serves open source models with an emphasis on speed. The Cerebras CS-3 system, powered by the the Wafer-Scale Engine-3 (WSE-3), represents a new class of AI supercomputer that sets the standard for generative AI training and inference with unparalleled performance and scalability.\n", + "\n", + "With Cerebras as your inference provider, you can:\n", + "\n", + "- Achieve unprecedented speed for AI inference workloads\n", + "- Build commercially with high throughput\n", + "- Effortlessly scale your AI workloads with our seamless clustering technology\n", + "\n", + "Our CS-3 systems can be quickly and easily clustered to create the largest AI supercomputers in the world, making it simple to place and run the largest models. Leading corporations, research institutions, and governments are already using Cerebras solutions to develop proprietary models and train popular open-source models.\n", + "\n", + "Want to experience the power of Cerebras? Check out our [website](https://cerebras.ai/) for more resources and explore options for accessing our technology through the Cerebras Cloud or on-premise deployments!\n", + "\n", + "For more information about Cerebras Cloud, visit [cloud.cerebras.ai](https://cloud.cerebras.ai/). Our API reference is available at [inference-docs.cerebras.ai](https://inference-docs.cerebras.ai).\n", + "\n", + "## Overview\n", + "\n", + "### Integration details\n", + "\n", + "| Class | Package | Local | Serializable | [PY support](https://python.langchain.com/docs/integrations/chat/cerebras) | Package downloads | Package latest |\n", + "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", + "| [ChatCerebras](https://api.js.langchain.com/classes/langchain_cerebras.ChatCerebras.html) | [`@langchain/cerebras`](https://www.npmjs.com/package/@langchain/cerebras) | ❌ | ❌ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/cerebras?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/cerebras?style=flat-square&label=%20&) |\n", + "\n", + "### Model features\n", + "\n", + "See the links in the table headers below for guides on how to use specific features.\n", + "\n", + "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", + "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", + "| ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | \n", + "\n", + "## Setup\n", + "\n", + "To access ChatCerebras models you'll need to create a Cerebras account, get an API key, and install the `@langchain/cerebras` integration package.\n", + "\n", + "### Credentials\n", + "\n", + "Get an API Key from [cloud.cerebras.ai](https://cloud.cerebras.ai) and add it to your environment variables:\n", + "\n", + "```bash\n", + "export CEREBRAS_API_KEY=\"your-api-key\"\n", + "```\n", + "\n", + "If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:\n", + "\n", + "```bash\n", + "# export LANGCHAIN_TRACING_V2=\"true\"\n", + "# export LANGCHAIN_API_KEY=\"your-api-key\"\n", + "```\n", + "\n", + "### Installation\n", + "\n", + "The LangChain ChatCerebras integration lives in the `@langchain/cerebras` package:\n", + "\n", + "```{=mdx}\n", + "\n", + "import IntegrationInstallTooltip from \"@mdx_components/integration_install_tooltip.mdx\";\n", + "import Npm2Yarn from \"@theme/Npm2Yarn\";\n", + "\n", + "\n", + "\n", + "\n", + " @langchain/cerebras @langchain/core\n", + "\n", + "\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "a38cde65-254d-4219-a441-068766c0d4b5", + "metadata": {}, + "source": [ + "## Instantiation\n", + "\n", + "Now we can instantiate our model object and generate chat completions:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae", + "metadata": {}, + "outputs": [], + "source": [ + "import { ChatCerebras } from \"@langchain/cerebras\" \n", + "\n", + "const llm = new ChatCerebras({\n", + " model: \"llama-3.3-70b\",\n", + " temperature: 0,\n", + " maxTokens: undefined,\n", + " maxRetries: 2,\n", + " // other params...\n", + "})" + ] + }, + { + "cell_type": "markdown", + "id": "2b4f3e15", + "metadata": {}, + "source": [ + "## Invocation" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "62e0dbc3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "AIMessage {\n", + " \"id\": \"run-17c7d62d-67ac-4677-b33a-18298fc85e35\",\n", + " \"content\": \"J'adore la programmation.\",\n", + " \"additional_kwargs\": {},\n", + " \"response_metadata\": {\n", + " \"id\": \"chatcmpl-2d1e2de5-4239-46fb-af2a-6200d89d7dde\",\n", + " \"created\": 1735785598,\n", + " \"model\": \"llama-3.3-70b\",\n", + " \"system_fingerprint\": \"fp_2e2a2a083c\",\n", + " \"object\": \"chat.completion\",\n", + " \"time_info\": {\n", + " \"queue_time\": 0.00009063,\n", + " \"prompt_time\": 0.002163031,\n", + " \"completion_time\": 0.012339628,\n", + " \"total_time\": 0.01640915870666504,\n", + " \"created\": 1735785598\n", + " }\n", + " },\n", + " \"tool_calls\": [],\n", + " \"invalid_tool_calls\": [],\n", + " \"usage_metadata\": {\n", + " \"input_tokens\": 55,\n", + " \"output_tokens\": 9,\n", + " \"total_tokens\": 64\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "const aiMsg = await llm.invoke([\n", + " {\n", + " role: \"system\",\n", + " content: \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n", + " },\n", + " { role: \"user\", content: \"I love programming.\" },\n", + "])\n", + "aiMsg" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d86145b3-bfef-46e8-b227-4dda5c9c2705", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "J'adore la programmation.\n" + ] + } + ], + "source": [ + "console.log(aiMsg.content)" + ] + }, + { + "cell_type": "markdown", + "id": "ce0414fe", + "metadata": {}, + "source": [ + "## Json invocation" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3f0a7a2a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{ aiInvokeMsgContent: '{\"result\":4}', aiBindMsg: '{\"result\":4}' }\n" + ] + } + ], + "source": [ + "const messages = [\n", + " {\n", + " role: \"system\",\n", + " content: \"You are a math tutor that handles math exercises and makes output in json in format { result: number }.\",\n", + " },\n", + " { role: \"user\", content: \"2 + 2\" },\n", + "];\n", + "\n", + "const aiInvokeMsg = await llm.invoke(messages, { response_format: { type: \"json_object\" } });\n", + "\n", + "// if you want not to pass response_format in every invoke, you can bind it to the instance\n", + "const llmWithResponseFormat = llm.bind({ response_format: { type: \"json_object\" } });\n", + "const aiBindMsg = await llmWithResponseFormat.invoke(messages);\n", + "\n", + "// they are the same\n", + "console.log({ aiInvokeMsgContent: aiInvokeMsg.content, aiBindMsg: aiBindMsg.content });" + ] + }, + { + "cell_type": "markdown", + "id": "18e2bfc0-7e78-4528-a73f-499ac150dca8", + "metadata": {}, + "source": [ + "## Chaining\n", + "\n", + "We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "AIMessage {\n", + " \"id\": \"run-5c8a9f25-0f57-499b-9c2b-87bd07135feb\",\n", + " \"content\": \"Ich liebe das Programmieren.\",\n", + " \"additional_kwargs\": {},\n", + " \"response_metadata\": {\n", + " \"id\": \"chatcmpl-abd1e9eb-b873-492e-9e30-0d13dfc3a145\",\n", + " \"created\": 1735785607,\n", + " \"model\": \"llama-3.3-70b\",\n", + " \"system_fingerprint\": \"fp_2e2a2a083c\",\n", + " \"object\": \"chat.completion\",\n", + " \"time_info\": {\n", + " \"queue_time\": 0.00009499,\n", + " \"prompt_time\": 0.002095266,\n", + " \"completion_time\": 0.008807576,\n", + " \"total_time\": 0.012718439102172852,\n", + " \"created\": 1735785607\n", + " }\n", + " },\n", + " \"tool_calls\": [],\n", + " \"invalid_tool_calls\": [],\n", + " \"usage_metadata\": {\n", + " \"input_tokens\": 50,\n", + " \"output_tokens\": 7,\n", + " \"total_tokens\": 57\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "import { ChatPromptTemplate } from \"@langchain/core/prompts\"\n", + "\n", + "const prompt = ChatPromptTemplate.fromMessages(\n", + " [\n", + " [\n", + " \"system\",\n", + " \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n", + " ],\n", + " [\"human\", \"{input}\"],\n", + " ]\n", + ")\n", + "\n", + "const chain = prompt.pipe(llm);\n", + "await chain.invoke(\n", + " {\n", + " input_language: \"English\",\n", + " output_language: \"German\",\n", + " input: \"I love programming.\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3", + "metadata": {}, + "source": [ + "## API reference\n", + "\n", + "For detailed documentation of all ChatCerebras features and configurations head to the API reference: https://api.js.langchain.com/classes/langchain_cerebras.ChatCerebras.html" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "TypeScript", + "language": "typescript", + "name": "tslab" + }, + "language_info": { + "codemirror_mode": { + "mode": "typescript", + "name": "javascript", + "typescript": true + }, + "file_extension": ".ts", + "mimetype": "text/typescript", + "name": "typescript", + "version": "3.7.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/package.json b/langchain/package.json index 6329b37aac9f..47d3ccdde02f 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -413,6 +413,7 @@ "@jest/globals": "^29.5.0", "@langchain/anthropic": "*", "@langchain/aws": "*", + "@langchain/cerebras": "*", "@langchain/cohere": "*", "@langchain/core": "workspace:*", "@langchain/google-genai": "*", @@ -460,6 +461,7 @@ "peerDependencies": { "@langchain/anthropic": "*", "@langchain/aws": "*", + "@langchain/cerebras": "*", "@langchain/cohere": "*", "@langchain/core": ">=0.2.21 <0.4.0", "@langchain/google-genai": "*", @@ -480,6 +482,9 @@ "@langchain/aws": { "optional": true }, + "@langchain/cerebras": { + "optional": true + }, "@langchain/cohere": { "optional": true }, diff --git a/langchain/src/chat_models/universal.ts b/langchain/src/chat_models/universal.ts index a61b1a5b4cb6..13311c4cfea5 100644 --- a/langchain/src/chat_models/universal.ts +++ b/langchain/src/chat_models/universal.ts @@ -115,6 +115,10 @@ async function _initChatModelHelper( const { ChatGroq } = await import("@langchain/groq"); return new ChatGroq({ model, ...passedParams }); } + case "cerebras": { + const { ChatCerebras } = await import("@langchain/cerebras"); + return new ChatCerebras({ model, ...passedParams }); + } case "bedrock": { const { ChatBedrockConverse } = await import("@langchain/aws"); return new ChatBedrockConverse({ model, ...passedParams }); @@ -598,6 +602,7 @@ export async function initChatModel< * - mistralai (@langchain/mistralai) * - groq (@langchain/groq) * - ollama (@langchain/ollama) + * - cerebras (@langchain/cerebras) * @param {string[] | "any"} [fields.configurableFields] - Which model parameters are configurable: * - undefined: No configurable fields. * - "any": All fields are configurable. (See Security Note in description) diff --git a/libs/langchain-cerebras/.eslintrc.cjs b/libs/langchain-cerebras/.eslintrc.cjs new file mode 100644 index 000000000000..e3033ac0160c --- /dev/null +++ b/libs/langchain-cerebras/.eslintrc.cjs @@ -0,0 +1,74 @@ +module.exports = { + extends: [ + "airbnb-base", + "eslint:recommended", + "prettier", + "plugin:@typescript-eslint/recommended", + ], + parserOptions: { + ecmaVersion: 12, + parser: "@typescript-eslint/parser", + project: "./tsconfig.json", + sourceType: "module", + }, + plugins: ["@typescript-eslint", "no-instanceof"], + ignorePatterns: [ + ".eslintrc.cjs", + "scripts", + "node_modules", + "dist", + "dist-cjs", + "*.js", + "*.cjs", + "*.d.ts", + ], + rules: { + "no-process-env": 2, + "no-instanceof/no-instanceof": 2, + "@typescript-eslint/explicit-module-boundary-types": 0, + "@typescript-eslint/no-empty-function": 0, + "@typescript-eslint/no-shadow": 0, + "@typescript-eslint/no-empty-interface": 0, + "@typescript-eslint/no-use-before-define": ["error", "nofunc"], + "@typescript-eslint/no-unused-vars": ["warn", { args: "none" }], + "@typescript-eslint/no-floating-promises": "error", + "@typescript-eslint/no-misused-promises": "error", + camelcase: 0, + "class-methods-use-this": 0, + "import/extensions": [2, "ignorePackages"], + "import/no-extraneous-dependencies": [ + "error", + { devDependencies: ["**/*.test.ts"] }, + ], + "import/no-unresolved": 0, + "import/prefer-default-export": 0, + "keyword-spacing": "error", + "max-classes-per-file": 0, + "max-len": 0, + "no-await-in-loop": 0, + "no-bitwise": 0, + "no-console": 0, + "no-restricted-syntax": 0, + "no-shadow": 0, + "no-continue": 0, + "no-void": 0, + "no-underscore-dangle": 0, + "no-use-before-define": 0, + "no-useless-constructor": 0, + "no-return-await": 0, + "consistent-return": 0, + "no-else-return": 0, + "func-names": 0, + "no-lonely-if": 0, + "prefer-rest-params": 0, + "new-cap": ["error", { properties: false, capIsNew: false }], + }, + overrides: [ + { + files: ["**/*.test.ts"], + rules: { + "@typescript-eslint/no-unused-vars": "off", + }, + }, + ], +}; diff --git a/libs/langchain-cerebras/.gitignore b/libs/langchain-cerebras/.gitignore new file mode 100644 index 000000000000..c10034e2f1be --- /dev/null +++ b/libs/langchain-cerebras/.gitignore @@ -0,0 +1,7 @@ +index.cjs +index.js +index.d.ts +index.d.cts +node_modules +dist +.yarn diff --git a/libs/langchain-cerebras/.prettierrc b/libs/langchain-cerebras/.prettierrc new file mode 100644 index 000000000000..ba08ff04f677 --- /dev/null +++ b/libs/langchain-cerebras/.prettierrc @@ -0,0 +1,19 @@ +{ + "$schema": "https://json.schemastore.org/prettierrc", + "printWidth": 80, + "tabWidth": 2, + "useTabs": false, + "semi": true, + "singleQuote": false, + "quoteProps": "as-needed", + "jsxSingleQuote": false, + "trailingComma": "es5", + "bracketSpacing": true, + "arrowParens": "always", + "requirePragma": false, + "insertPragma": false, + "proseWrap": "preserve", + "htmlWhitespaceSensitivity": "css", + "vueIndentScriptAndStyle": false, + "endOfLine": "lf" +} diff --git a/libs/langchain-cerebras/.release-it.json b/libs/langchain-cerebras/.release-it.json new file mode 100644 index 000000000000..522ee6abf705 --- /dev/null +++ b/libs/langchain-cerebras/.release-it.json @@ -0,0 +1,10 @@ +{ + "github": { + "release": true, + "autoGenerate": true, + "tokenRef": "GITHUB_TOKEN_RELEASE" + }, + "npm": { + "versionArgs": ["--workspaces-update=false"] + } +} diff --git a/libs/langchain-cerebras/LICENSE b/libs/langchain-cerebras/LICENSE new file mode 100644 index 000000000000..ea0e5809eda3 --- /dev/null +++ b/libs/langchain-cerebras/LICENSE @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) 2025 LangChain + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/libs/langchain-cerebras/README.md b/libs/langchain-cerebras/README.md new file mode 100644 index 000000000000..ad3e9e72ae5f --- /dev/null +++ b/libs/langchain-cerebras/README.md @@ -0,0 +1,76 @@ +# @langchain/cerebras + +This package contains the LangChain.js integrations for Cerebras via the `@cerebras/cerebras_cloud_sdk` package. + +## Installation + +```bash npm2yarn +npm install @langchain/cerebras @langchain/core +``` + +## Chat models + +This package adds support for Cerebras chat model inference. + +Set the necessary environment variable (or pass it in via the constructor): + +```bash +export CEREBRAS_API_KEY= +``` + +```typescript +import { ChatCerebras } from "@langchain/cerebras"; +import { HumanMessage } from "@langchain/core/messages"; + +const model = new ChatCerebras({ + apiKey: process.env.CEREBRAS_API_KEY, // Default value. +}); + +const message = new HumanMessage("What color is the sky?"); + +const res = await model.invoke([message]); +``` + +## Development + +To develop the `@langchain/cerebras` package, you'll need to follow these instructions: + +### Install dependencies + +```bash +yarn install +``` + +### Build the package + +```bash +yarn build +``` + +Or from the repo root: + +```bash +yarn build --filter=@langchain/cerebras +``` + +### Run tests + +Test files should live within a `tests/` file in the `src/` folder. Unit tests should end in `.test.ts` and integration tests should +end in `.int.test.ts`: + +```bash +$ yarn test +$ yarn test:int +``` + +### Lint & Format + +Run the linter & formatter to ensure your code is up to standard: + +```bash +yarn lint && yarn format +``` + +### Adding new entrypoints + +If you add a new file to be exported, either import & re-export from `src/index.ts`, or add it to the `entrypoints` field in the `config` variable located inside `langchain.config.js` and run `yarn build` to generate the new entrypoint. diff --git a/libs/langchain-cerebras/jest.config.cjs b/libs/langchain-cerebras/jest.config.cjs new file mode 100644 index 000000000000..994826496bc5 --- /dev/null +++ b/libs/langchain-cerebras/jest.config.cjs @@ -0,0 +1,21 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: "ts-jest/presets/default-esm", + testEnvironment: "./jest.env.cjs", + modulePathIgnorePatterns: ["dist/", "docs/"], + moduleNameMapper: { + "^(\\.{1,2}/.*)\\.js$": "$1", + }, + transform: { + "^.+\\.tsx?$": ["@swc/jest"], + }, + transformIgnorePatterns: [ + "/node_modules/", + "\\.pnp\\.[^\\/]+$", + "./scripts/jest-setup-after-env.js", + ], + setupFiles: ["dotenv/config"], + testTimeout: 20_000, + passWithNoTests: true, + collectCoverageFrom: ["src/**/*.ts"], +}; diff --git a/libs/langchain-cerebras/jest.env.cjs b/libs/langchain-cerebras/jest.env.cjs new file mode 100644 index 000000000000..2ccedccb8672 --- /dev/null +++ b/libs/langchain-cerebras/jest.env.cjs @@ -0,0 +1,12 @@ +const { TestEnvironment } = require("jest-environment-node"); + +class AdjustedTestEnvironmentToSupportFloat32Array extends TestEnvironment { + constructor(config, context) { + // Make `instanceof Float32Array` return true in tests + // to avoid https://github.com/xenova/transformers.js/issues/57 and https://github.com/jestjs/jest/issues/2549 + super(config, context); + this.global.Float32Array = Float32Array; + } +} + +module.exports = AdjustedTestEnvironmentToSupportFloat32Array; diff --git a/libs/langchain-cerebras/langchain.config.js b/libs/langchain-cerebras/langchain.config.js new file mode 100644 index 000000000000..46b1a2b31264 --- /dev/null +++ b/libs/langchain-cerebras/langchain.config.js @@ -0,0 +1,22 @@ +import { resolve, dirname } from "node:path"; +import { fileURLToPath } from "node:url"; + +/** + * @param {string} relativePath + * @returns {string} + */ +function abs(relativePath) { + return resolve(dirname(fileURLToPath(import.meta.url)), relativePath); +} + +export const config = { + internals: [/node\:/, /@langchain\/core\//], + entrypoints: { + index: "index", + }, + requiresOptionalDependency: [], + tsConfigPath: resolve("./tsconfig.json"), + cjsSource: "./dist-cjs", + cjsDestination: "./dist", + abs, +}; diff --git a/libs/langchain-cerebras/package.json b/libs/langchain-cerebras/package.json new file mode 100644 index 000000000000..cfc67bc4af56 --- /dev/null +++ b/libs/langchain-cerebras/package.json @@ -0,0 +1,92 @@ +{ + "name": "@langchain/cerebras", + "version": "0.0.0", + "description": "Cerebras integration for LangChain.js", + "type": "module", + "engines": { + "node": ">=18" + }, + "main": "./index.js", + "types": "./index.d.ts", + "repository": { + "type": "git", + "url": "git@github.com:langchain-ai/langchainjs.git" + }, + "homepage": "https://github.com/langchain-ai/langchainjs/tree/main/libs/langchain-cerebras/", + "scripts": { + "build": "yarn turbo:command build:internal --filter=@langchain/cerebras", + "build:internal": "yarn lc_build --create-entrypoints --pre --tree-shaking", + "lint:eslint": "NODE_OPTIONS=--max-old-space-size=4096 eslint --cache --ext .ts,.js src/", + "lint:dpdm": "dpdm --exit-code circular:1 --no-warning --no-tree src/*.ts src/**/*.ts", + "lint": "yarn lint:eslint && yarn lint:dpdm", + "lint:fix": "yarn lint:eslint --fix && yarn lint:dpdm", + "clean": "rm -rf .turbo dist/", + "prepack": "yarn build", + "test": "NODE_OPTIONS=--experimental-vm-modules jest --testPathIgnorePatterns=\\.int\\.test.ts --testTimeout 30000 --maxWorkers=50%", + "test:watch": "NODE_OPTIONS=--experimental-vm-modules jest --watch --testPathIgnorePatterns=\\.int\\.test.ts", + "test:single": "NODE_OPTIONS=--experimental-vm-modules yarn run jest --config jest.config.cjs --testTimeout 100000", + "test:int": "NODE_OPTIONS=--experimental-vm-modules jest --testPathPattern=\\.int\\.test.ts --testTimeout 100000 --maxWorkers=50%", + "format": "prettier --config .prettierrc --write \"src\"", + "format:check": "prettier --config .prettierrc --check \"src\"" + }, + "author": "LangChain", + "license": "MIT", + "dependencies": { + "@cerebras/cerebras_cloud_sdk": "^1.15.0", + "uuid": "^10.0.0", + "zod": "^3.22.4", + "zod-to-json-schema": "^3.22.3" + }, + "peerDependencies": { + "@langchain/core": ">=0.3.0 <0.4.0" + }, + "devDependencies": { + "@jest/globals": "^29.5.0", + "@langchain/core": "workspace:*", + "@langchain/scripts": ">=0.1.0 <0.2.0", + "@langchain/standard-tests": "0.0.0", + "@swc/core": "^1.3.90", + "@swc/jest": "^0.2.29", + "@tsconfig/recommended": "^1.0.3", + "@types/uuid": "^10", + "@typescript-eslint/eslint-plugin": "^6.12.0", + "@typescript-eslint/parser": "^6.12.0", + "dotenv": "^16.3.1", + "dpdm": "^3.12.0", + "eslint": "^8.33.0", + "eslint-config-airbnb-base": "^15.0.0", + "eslint-config-prettier": "^8.6.0", + "eslint-plugin-import": "^2.27.5", + "eslint-plugin-no-instanceof": "^1.0.1", + "eslint-plugin-prettier": "^4.2.1", + "jest": "^29.5.0", + "jest-environment-node": "^29.6.4", + "prettier": "^2.8.3", + "release-it": "^15.10.1", + "rollup": "^4.5.2", + "ts-jest": "^29.1.0", + "typescript": "<5.2.0" + }, + "publishConfig": { + "access": "public" + }, + "exports": { + ".": { + "types": { + "import": "./index.d.ts", + "require": "./index.d.cts", + "default": "./index.d.ts" + }, + "import": "./index.js", + "require": "./index.cjs" + }, + "./package.json": "./package.json" + }, + "files": [ + "dist/", + "index.cjs", + "index.js", + "index.d.ts", + "index.d.cts" + ] +} diff --git a/libs/langchain-cerebras/scripts/jest-setup-after-env.js b/libs/langchain-cerebras/scripts/jest-setup-after-env.js new file mode 100644 index 000000000000..7323083d0ea5 --- /dev/null +++ b/libs/langchain-cerebras/scripts/jest-setup-after-env.js @@ -0,0 +1,9 @@ +import { awaitAllCallbacks } from "@langchain/core/callbacks/promises"; +import { afterAll, jest } from "@jest/globals"; + +afterAll(awaitAllCallbacks); + +// Allow console.log to be disabled in tests +if (process.env.DISABLE_CONSOLE_LOGS === "true") { + console.log = jest.fn(); +} diff --git a/libs/langchain-cerebras/src/chat_models.ts b/libs/langchain-cerebras/src/chat_models.ts new file mode 100644 index 000000000000..821234edd134 --- /dev/null +++ b/libs/langchain-cerebras/src/chat_models.ts @@ -0,0 +1,826 @@ +import Cerebras from "@cerebras/cerebras_cloud_sdk"; + +import { + AIMessage, + AIMessageChunk, + UsageMetadata, + type BaseMessage, +} from "@langchain/core/messages"; +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { + BaseChatModel, + BaseChatModelCallOptions, + type BaseChatModelParams, + BindToolsInput, + LangSmithParams, + ToolChoice, +} from "@langchain/core/language_models/chat_models"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs"; +import { + Runnable, + RunnableLambda, + RunnablePassthrough, + RunnableSequence, +} from "@langchain/core/runnables"; +import { + BaseLanguageModelInput, + StructuredOutputMethodOptions, + ToolDefinition, +} from "@langchain/core/language_models/base"; +import { convertToOpenAITool } from "@langchain/core/utils/function_calling"; +import { concat } from "@langchain/core/utils/stream"; +import { isZodSchema } from "@langchain/core/utils/types"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { z } from "zod"; + +import { + convertToCerebrasMessageParams, + formatToCerebrasToolChoice, +} from "./utils.js"; + +/** + * Input to chat model class. + */ +export interface ChatCerebrasInput extends BaseChatModelParams { + model: string; + apiKey?: string; + streaming?: boolean; + maxTokens?: number; + maxCompletionTokens?: number; + temperature?: number; + topP?: number; + seed?: number; + timeout?: number; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + fetch?: (...args: any) => any; +} + +export interface ChatCerebrasCallOptions + extends BaseChatModelCallOptions, + Pick { + tools?: BindToolsInput[]; + tool_choice?: ToolChoice; + user?: string; + response_format?: Cerebras.ChatCompletionCreateParams["response_format"]; +} + +/** + * Cerebras chat model integration. + * + * Setup: + * Install `@langchain/cerebras` and set an environment variable named `CEREBRAS_API_KEY`. + * + * ```bash + * npm install @langchain/cerebras + * export CEREBRAS_API_KEY="your-api-key" + * ``` + * + * ## [Constructor args](https://api.js.langchain.com/classes/langchain_cerebras.ChatCerebras.html#constructor) + * + * ## [Runtime args](https://api.js.langchain.com/interfaces/langchain_cerebras.ChatCerebrasCallOptions.html) + * + * Runtime args can be passed as the second argument to any of the base runnable methods `.invoke`. `.stream`, `.batch`, etc. + * They can also be passed via `.bind`, or the second arg in `.bindTools`, like shown in the examples below: + * + * ```typescript + * // When calling `.bind`, call options should be passed via the first argument + * const llmWithArgsBound = llm.bind({ + * stop: ["\n"], + * tools: [...], + * }); + * + * // When calling `.bindTools`, call options should be passed via the second argument + * const llmWithTools = llm.bindTools( + * [...], + * { + * tool_choice: "auto", + * } + * ); + * ``` + * + * ## Examples + * + *
+ * Instantiate + * + * ```typescript + * import { ChatCerebras } from '@langchain/cerebras'; + * + * const llm = new ChatCerebras({ + * model: "llama-3.3-70b", + * temperature: 0, + * // other params... + * }); + * ``` + *
+ * + *
+ * + *
+ * Invoking + * + * ```typescript + * const input = `Translate "I love programming" into French.`; + * + * // Models also accept a list of chat messages or a formatted prompt + * const result = await llm.invoke(input); + * console.log(result); + * ``` + * + * ```txt + * AIMessage { + * "id": "run-9281952d-d4c5-424c-9c18-c6ad62dd6684", + * "content": "J'adore la programmation.", + * "additional_kwargs": {}, + * "response_metadata": { + * "id": "chatcmpl-bb411272-aac5-44a5-b793-ae70bd94fd3d", + * "created": 1735784442, + * "model": "llama-3.3-70b", + * "system_fingerprint": "fp_2e2a2a083c", + * "object": "chat.completion", + * "time_info": { + * "queue_time": 0.000096069, + * "prompt_time": 0.002166527, + * "completion_time": 0.012331633, + * "total_time": 0.01629185676574707, + * "created": 1735784442 + * } + * }, + * "tool_calls": [], + * "invalid_tool_calls": [], + * "usage_metadata": { + * "input_tokens": 55, + * "output_tokens": 9, + * "total_tokens": 64 + * } + * } + * ``` + *
+ * + *
+ * + *
+ * Streaming Chunks + * + * ```typescript + * for await (const chunk of await llm.stream(input)) { + * console.log(chunk); + * } + * ``` + * + * ```txt + * AIMessageChunk { + * "id": "run-1756a5b2-2ce0-47a9-81e0-2195bf893bd4", + * "content": "", + * "additional_kwargs": {}, + * "response_metadata": { + * "created": 1735785346, + * "object": "chat.completion.chunk" + * }, + * "tool_calls": [], + * "tool_call_chunks": [], + * "invalid_tool_calls": [], + * "usage_metadata": {} + * } + * AIMessageChunk { + * "id": "run-1756a5b2-2ce0-47a9-81e0-2195bf893bd4", + * "content": "J", + * "additional_kwargs": {}, + * "response_metadata": { + * "created": 1735785346, + * "object": "chat.completion.chunk" + * }, + * "tool_calls": [], + * "tool_call_chunks": [], + * "invalid_tool_calls": [], + * "usage_metadata": {} + * } + * AIMessageChunk { + * "id": "run-1756a5b2-2ce0-47a9-81e0-2195bf893bd4", + * "content": "'", + * "additional_kwargs": {}, + * "response_metadata": { + * "created": 1735785346, + * "object": "chat.completion.chunk" + * }, + * "tool_calls": [], + * "tool_call_chunks": [], + * "invalid_tool_calls": [], + * "usage_metadata": {} + * } + * AIMessageChunk { + * "id": "run-1756a5b2-2ce0-47a9-81e0-2195bf893bd4", + * "content": "ad", + * "additional_kwargs": {}, + * "response_metadata": { + * "created": 1735785346, + * "object": "chat.completion.chunk" + * }, + * "tool_calls": [], + * "tool_call_chunks": [], + * "invalid_tool_calls": [], + * "usage_metadata": {} + * } + * AIMessageChunk { + * "id": "run-1756a5b2-2ce0-47a9-81e0-2195bf893bd4", + * "content": "ore", + * "additional_kwargs": {}, + * "response_metadata": { + * "created": 1735785346, + * "object": "chat.completion.chunk" + * }, + * "tool_calls": [], + * "tool_call_chunks": [], + * "invalid_tool_calls": [], + * "usage_metadata": {} + * } + * AIMessageChunk { + * "id": "run-1756a5b2-2ce0-47a9-81e0-2195bf893bd4", + * "content": " la", + * "additional_kwargs": {}, + * "response_metadata": { + * "created": 1735785346, + * "object": "chat.completion.chunk" + * }, + * "tool_calls": [], + * "tool_call_chunks": [], + * "invalid_tool_calls": [], + * "usage_metadata": {} + * } + * ... + * AIMessageChunk { + * "id": "run-1756a5b2-2ce0-47a9-81e0-2195bf893bd4", + * "content": "", + * "additional_kwargs": {}, + * "response_metadata": { + * "finish_reason": "stop", + * "id": "chatcmpl-15c80082-4475-423c-b140-7b0a556311ca", + * "system_fingerprint": "fp_2e2a2a083c", + * "model": "llama-3.3-70b", + * "created": 1735785346, + * "object": "chat.completion.chunk", + * "time_info": { + * "queue_time": 0.000100589, + * "prompt_time": 0.002167348, + * "completion_time": 0.012320277, + * "total_time": 0.0169985294342041, + * "created": 1735785346 + * } + * }, + * "tool_calls": [], + * "tool_call_chunks": [], + * "invalid_tool_calls": [], + * "usage_metadata": { + * "input_tokens": 55, + * "output_tokens": 9, + * "total_tokens": 64 + * } + * } + * ``` + *
+ * + *
+ * + *
+ * Aggregate Streamed Chunks + * + * ```typescript + * import { AIMessageChunk } from '@langchain/core/messages'; + * import { concat } from '@langchain/core/utils/stream'; + * + * const stream = await llm.stream(input); + * let full: AIMessageChunk | undefined; + * for await (const chunk of stream) { + * full = !full ? chunk : concat(full, chunk); + * } + * console.log(full); + * ``` + * + * ```txt + * AIMessageChunk { + * "content": "J'adore la programmation.", + * "additional_kwargs": {}, + * "tool_calls": [], + * "tool_call_chunks": [], + * "invalid_tool_calls": [] + * } + * ``` + *
+ * + *
+ * + *
+ * Bind tools + * + * ```typescript + * import { z } from 'zod'; + * + * const llmForToolCalling = new ChatCerebras({ + * model: "llama-3.3-70b", + * temperature: 0, + * // other params... + * }); + * + * const GetWeather = { + * name: "GetWeather", + * description: "Get the current weather in a given location", + * schema: z.object({ + * location: z.string().describe("The city and state, e.g. San Francisco, CA") + * }), + * } + * + * const GetPopulation = { + * name: "GetPopulation", + * description: "Get the current population in a given location", + * schema: z.object({ + * location: z.string().describe("The city and state, e.g. San Francisco, CA") + * }), + * } + * + * const llmWithTools = llmForToolCalling.bindTools([GetWeather, GetPopulation]); + * const aiMsg = await llmWithTools.invoke( + * "Which city is hotter today and which is bigger: LA or NY?" + * ); + * console.log(aiMsg.tool_calls); + * ``` + * + * ```txt + * [ + * { + * name: 'GetWeather', + * args: { location: 'Los Angeles, CA' }, + * type: 'tool_call', + * id: 'call_cd34' + * }, + * { + * name: 'GetWeather', + * args: { location: 'New York, NY' }, + * type: 'tool_call', + * id: 'call_68rf' + * }, + * { + * name: 'GetPopulation', + * args: { location: 'Los Angeles, CA' }, + * type: 'tool_call', + * id: 'call_f81z' + * }, + * { + * name: 'GetPopulation', + * args: { location: 'New York, NY' }, + * type: 'tool_call', + * id: 'call_8byt' + * } + * ] + * ``` + *
+ * + *
+ * + *
+ * Structured Output + * + * ```typescript + * import { z } from 'zod'; + * + * const Joke = z.object({ + * setup: z.string().describe("The setup of the joke"), + * punchline: z.string().describe("The punchline to the joke"), + * rating: z.number().optional().describe("How funny the joke is, from 1 to 10") + * }).describe('Joke to tell user.'); + * + * const structuredLlm = llmForToolCalling.withStructuredOutput(Joke, { name: "Joke" }); + * const jokeResult = await structuredLlm.invoke("Tell me a joke about cats"); + * console.log(jokeResult); + * ``` + * + * ```txt + * { + * setup: "Why don't cats play poker in the wild?", + * punchline: 'Because there are too many cheetahs.' + * } + * ``` + *
+ * + *
+ */ +export class ChatCerebras + extends BaseChatModel + implements ChatCerebrasInput +{ + static lc_name() { + return "ChatCerebras"; + } + + lc_serializable = true; + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "CEREBRAS_API_KEY", + }; + } + + get lc_aliases(): { [key: string]: string } | undefined { + return { + apiKey: "CEREBRAS_API_KEY", + }; + } + + getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { + const params = this.invocationParams(options); + return { + ls_provider: "cerebras", + ls_model_name: this.model, + ls_model_type: "chat", + ls_temperature: params.temperature ?? undefined, + ls_max_tokens: params.max_completion_tokens ?? undefined, + ls_stop: options.stop, + }; + } + + client: Cerebras; + + model: string; + + maxCompletionTokens?: number; + + temperature?: number; + + topP?: number; + + seed?: number; + + streaming?: boolean; + + constructor(fields: ChatCerebrasInput) { + super(fields ?? {}); + this.model = fields.model; + this.maxCompletionTokens = fields.maxCompletionTokens; + this.temperature = fields.temperature; + this.topP = fields.topP; + this.seed = fields.seed; + this.streaming = fields.streaming; + this.client = new Cerebras({ + apiKey: fields.apiKey ?? getEnvironmentVariable("CEREBRAS_API_KEY"), + timeout: fields.timeout, + // Rely on built-in async caller + maxRetries: 0, + fetch: fields.fetch, + }); + } + + // Replace + _llmType() { + return "cerebras"; + } + + override bindTools( + tools: BindToolsInput[], + kwargs?: Partial + ): Runnable { + return this.bind({ + tools: tools.map((tool) => convertToOpenAITool(tool)), + ...kwargs, + }); + } + + /** + * A method that returns the parameters for an Ollama API call. It + * includes model and options parameters. + * @param options Optional parsed call options. + * @returns An object containing the parameters for an Ollama API call. + */ + override invocationParams( + options?: this["ParsedCallOptions"] + ): Omit { + return { + model: this.model, + max_completion_tokens: this.maxCompletionTokens, + temperature: this.temperature, + top_p: this.topP, + seed: this.seed, + stop: options?.stop, + response_format: options?.response_format, + user: options?.user, + tools: options?.tools?.length + ? options.tools.map( + (tool) => + convertToOpenAITool( + tool + ) as Cerebras.ChatCompletionCreateParams.Tool + ) + : undefined, + tool_choice: formatToCerebrasToolChoice(options?.tool_choice), + }; + } + + async _generate( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + // Handle streaming + if (this.streaming) { + let finalChunk: AIMessageChunk | undefined; + for await (const chunk of this._streamResponseChunks( + messages, + options, + runManager + )) { + if (!finalChunk) { + finalChunk = chunk.message; + } else { + finalChunk = concat(finalChunk, chunk.message); + } + } + + // Convert from AIMessageChunk to AIMessage since `generate` expects AIMessage. + const nonChunkMessage = new AIMessage({ + id: finalChunk?.id, + content: finalChunk?.content ?? "", + tool_calls: finalChunk?.tool_calls, + response_metadata: finalChunk?.response_metadata, + usage_metadata: finalChunk?.usage_metadata, + }); + return { + generations: [ + { + text: + typeof nonChunkMessage.content === "string" + ? nonChunkMessage.content + : "", + message: nonChunkMessage, + }, + ], + }; + } + + const res = await this.caller.call(async () => { + const res = await this.client.chat.completions.create( + { + ...this.invocationParams(options), + messages: convertToCerebrasMessageParams(messages), + stream: false, + }, + { + headers: options.headers, + httpAgent: options.httpAgent, + } + ); + return res; + }); + + const { choices, usage, ...rest } = res; + // TODO: Remove casts when underlying types are fixed + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const choice = (choices as any)[0]; + const content = choice?.message?.content ?? ""; + const usageMetadata: UsageMetadata = { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + input_tokens: (usage as any)?.prompt_tokens, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + output_tokens: (usage as any)?.completion_tokens, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + total_tokens: (usage as any)?.total_tokens, + }; + + return { + generations: [ + { + text: content, + message: new AIMessage({ + content, + tool_calls: choice?.message?.tool_calls?.map( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (toolCall: any) => ({ + id: toolCall.id, + name: toolCall.function?.name, + args: JSON.parse(toolCall.function?.arguments), + index: toolCall.index, + type: "tool_call", + }) + ), + usage_metadata: usageMetadata, + response_metadata: rest, + }), + }, + ], + }; + } + + /** + * Implement to support streaming. + * Should yield chunks iteratively. + */ + async *_streamResponseChunks( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const stream = await this.caller.call(async () => { + const res = await this.client.chat.completions.create( + { + ...this.invocationParams(options), + messages: convertToCerebrasMessageParams(messages), + stream: true, + }, + { + headers: options.headers, + httpAgent: options.httpAgent, + } + ); + return res; + }); + for await (const chunk of stream) { + const { choices, system_fingerprint, model, id, object, usage, ...rest } = + chunk; + // TODO: Remove casts when underlying types are fixed + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const choice = (choices as any)[0]; + const content = choice?.delta?.content ?? ""; + let usageMetadata: UsageMetadata | undefined; + if (usage !== undefined) { + usageMetadata = { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + input_tokens: (usage as any).prompt_tokens, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + output_tokens: (usage as any).completion_tokens, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + total_tokens: (usage as any).total_tokens, + }; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const generationInfo: Record = {}; + if (choice.finish_reason != null) { + generationInfo.finish_reason = choice.finish_reason; + // Only include system fingerprint and related in the last chunk for now + // to avoid concatenation issues + generationInfo.id = id; + generationInfo.system_fingerprint = system_fingerprint; + generationInfo.model = model; + generationInfo.object = object; + } + const generationChunk = new ChatGenerationChunk({ + text: content, + message: new AIMessageChunk({ + content, + tool_call_chunks: choice?.delta.tool_calls?.map( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (toolCallChunk: any) => ({ + id: toolCallChunk.id, + name: toolCallChunk.function?.name, + args: toolCallChunk.function?.arguments, + index: toolCallChunk.index, + type: "tool_call_chunk", + }) + ), + usage_metadata: usageMetadata, + response_metadata: rest, + }), + generationInfo, + }); + yield generationChunk; + await runManager?.handleLLMNewToken( + content, + undefined, + undefined, + undefined, + undefined, + { chunk: generationChunk } + ); + } + } + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): Runnable; + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): Runnable; + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): + | Runnable + | Runnable< + BaseLanguageModelInput, + { + raw: BaseMessage; + parsed: RunOutput; + } + > { + if (config?.strict) { + throw new Error( + `"strict" mode is not supported for this model by default.` + ); + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const schema: z.ZodType | Record = outputSchema; + const name = config?.name; + const description = schema.description ?? "A function available to call."; + const method = config?.method; + const includeRaw = config?.includeRaw; + if (method === "jsonMode") { + throw new Error( + `Cerebras withStructuredOutput implementation only supports "functionCalling" as a method.` + ); + } + let functionName = name ?? "extract"; + let tools: ToolDefinition[]; + if (isZodSchema(schema)) { + tools = [ + { + type: "function", + function: { + name: functionName, + description, + parameters: zodToJsonSchema(schema), + }, + }, + ]; + } else { + if ("name" in schema) { + functionName = schema.name; + } + tools = [ + { + type: "function", + function: { + name: functionName, + description, + parameters: schema, + }, + }, + ]; + } + + const llm = this.bindTools(tools, { + tool_choice: tools[0].function.name, + }); + const outputParser = RunnableLambda.from( + (input: AIMessageChunk): RunOutput => { + if (!input.tool_calls || input.tool_calls.length === 0) { + throw new Error("No tool calls found in the response."); + } + const toolCall = input.tool_calls.find( + (tc) => tc.name === functionName + ); + if (!toolCall) { + throw new Error(`No tool call found with name ${functionName}.`); + } + return toolCall.args as RunOutput; + } + ); + + if (!includeRaw) { + return llm.pipe(outputParser).withConfig({ + runName: "ChatCerebrasStructuredOutput", + }) as Runnable; + } + + const parserAssign = RunnablePassthrough.assign({ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + parsed: (input: any, config) => outputParser.invoke(input.raw, config), + }); + const parserNone = RunnablePassthrough.assign({ + parsed: () => null, + }); + const parsedWithFallback = parserAssign.withFallbacks({ + fallbacks: [parserNone], + }); + return RunnableSequence.from< + BaseLanguageModelInput, + { raw: BaseMessage; parsed: RunOutput } + >([ + { + raw: llm, + }, + parsedWithFallback, + ]).withConfig({ + runName: "ChatCerebrasStructuredOutput", + }); + } +} diff --git a/libs/langchain-cerebras/src/index.ts b/libs/langchain-cerebras/src/index.ts new file mode 100644 index 000000000000..38c7cea7f478 --- /dev/null +++ b/libs/langchain-cerebras/src/index.ts @@ -0,0 +1 @@ +export * from "./chat_models.js"; diff --git a/libs/langchain-cerebras/src/tests/chat_models.int.test.ts b/libs/langchain-cerebras/src/tests/chat_models.int.test.ts new file mode 100644 index 000000000000..4e6e04fdbe70 --- /dev/null +++ b/libs/langchain-cerebras/src/tests/chat_models.int.test.ts @@ -0,0 +1,284 @@ +import { test } from "@jest/globals"; +import { + AIMessage, + AIMessageChunk, + HumanMessage, + ToolMessage, +} from "@langchain/core/messages"; +import { tool } from "@langchain/core/tools"; +import { z } from "zod"; +import { concat } from "@langchain/core/utils/stream"; +import { ChatCerebras } from "../chat_models.js"; + +test("invoke", async () => { + const chat = new ChatCerebras({ + model: "llama3.1-8b", + maxRetries: 2, + }); + const message = new HumanMessage("What color is the sky?"); + const res = await chat.invoke([message]); + // console.log({ res }); + expect(res.content.length).toBeGreaterThan(10); +}); + +test("invoke with stop sequence", async () => { + const chat = new ChatCerebras({ + model: "llama3.1-8b", + maxRetries: 2, + }); + const message = new HumanMessage("Count to ten."); + const res = await chat.bind({ stop: ["5", "five"] }).invoke([message]); + // console.log({ res }); + expect((res.content as string).toLowerCase()).not.toContain("6"); + expect((res.content as string).toLowerCase()).not.toContain("six"); +}); + +test("invoke with streaming: true", async () => { + const chat = new ChatCerebras({ + model: "llama3.1-8b", + maxRetries: 2, + streaming: true, + }); + const message = new HumanMessage("What color is the sky?"); + const res = await chat.invoke([message]); + console.log({ res }); + expect(res.content.length).toBeGreaterThan(10); +}); + +test("invoke should respect passed headers", async () => { + const chat = new ChatCerebras({ + model: "llama3.1-8b", + maxRetries: 2, + }); + const message = new HumanMessage("Count to ten."); + await expect(async () => { + await chat.invoke([message], { + headers: { Authorization: "badbadbad" }, + }); + }).rejects.toThrowError(); +}); + +test("stream should respect passed headers", async () => { + const chat = new ChatCerebras({ + model: "llama3.1-8b", + maxRetries: 2, + }); + const message = new HumanMessage("Count to ten."); + await expect(async () => { + await chat.stream([message], { + headers: { Authorization: "badbadbad" }, + }); + }).rejects.toThrowError(); +}); + +test("generate", async () => { + const chat = new ChatCerebras({ + model: "llama3.1-8b", + maxRetries: 2, + }); + const message = new HumanMessage("Hello!"); + const res = await chat.generate([[message]]); + // console.log(JSON.stringify(res, null, 2)); + expect(res.generations[0][0].text.length).toBeGreaterThan(10); +}); + +test("streaming", async () => { + const chat = new ChatCerebras({ + model: "llama3.1-8b", + maxRetries: 2, + }); + const message = new HumanMessage("What color is the sky?"); + const stream = await chat.stream([message]); + let iters = 0; + let finalRes = ""; + for await (const chunk of stream) { + iters += 1; + finalRes += chunk.content; + } + // console.log({ finalRes, iters }); + expect(iters).toBeGreaterThan(1); +}); + +test("invoke with bound tools", async () => { + const chat = new ChatCerebras({ + model: "llama3.1-8b", + maxRetries: 2, + }); + const message = new HumanMessage("What is the current weather in Hawaii?"); + const res = await chat + .bind({ + tools: [ + { + type: "function", + function: { + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and state, e.g. San Francisco, CA", + }, + unit: { type: "string", enum: ["celsius", "fahrenheit"] }, + }, + required: ["location"], + }, + }, + }, + ], + tool_choice: "auto", + }) + .invoke([message]); + expect(typeof res.tool_calls?.[0].args).toEqual("object"); +}); + +test("stream with bound tools, yielding a single chunk", async () => { + const chat = new ChatCerebras({ + model: "llama3.1-8b", + maxRetries: 2, + }); + const message = new HumanMessage("What is the current weather in Hawaii?"); + const stream = await chat + .bind({ + tools: [ + { + type: "function", + function: { + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and state, e.g. San Francisco, CA", + }, + unit: { type: "string", enum: ["celsius", "fahrenheit"] }, + }, + required: ["location"], + }, + }, + }, + ], + tool_choice: "auto", + }) + .stream([message]); + // @eslint-disable-next-line/@typescript-eslint/ban-ts-comment + // @ts-expect-error unused var + for await (const chunk of stream) { + // console.log(JSON.stringify(chunk)); + } +}); + +test("Few shotting with tool calls", async () => { + const chat = new ChatCerebras({ + model: "llama3.1-8b", + temperature: 0, + }).bind({ + tools: [ + { + type: "function", + function: { + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and state, e.g. San Francisco, CA", + }, + unit: { type: "string", enum: ["celsius", "fahrenheit"] }, + }, + required: ["location"], + }, + }, + }, + ], + tool_choice: "auto", + }); + const res = await chat.invoke([ + new HumanMessage("What is the weather in SF?"), + new AIMessage({ + content: "", + tool_calls: [ + { + id: "12345", + name: "get_current_weather", + args: { + location: "SF", + }, + }, + ], + }), + new ToolMessage({ + tool_call_id: "12345", + content: "It is currently 24 degrees with hail in SF.", + }), + new AIMessage("It is currently 24 degrees in SF with hail in SF."), + new HumanMessage("What did you say the weather was?"), + ]); + // console.log(res); + expect(res.content).toContain("24"); +}); + +test("Cerebras can stream tool calls", async () => { + const model = new ChatCerebras({ + model: "llama3.1-8b", + temperature: 0, + }); + + const weatherTool = tool((_) => "The temperature is 24 degrees with hail.", { + name: "get_current_weather", + schema: z.object({ + location: z + .string() + .describe("The location to get the current weather for."), + }), + description: "Get the current weather in a given location.", + }); + + const modelWithTools = model.bindTools([weatherTool]); + + const stream = await modelWithTools.stream( + "What is the weather in San Francisco?" + ); + + let finalMessage: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalMessage = !finalMessage ? chunk : concat(finalMessage, chunk); + } + + expect(finalMessage).toBeDefined(); + if (!finalMessage) return; + + expect(finalMessage.tool_calls?.[0]).toBeDefined(); + if (!finalMessage.tool_calls?.[0]) return; + + expect(finalMessage.tool_calls?.[0].name).toBe("get_current_weather"); + expect(finalMessage.tool_calls?.[0].args).toHaveProperty("location"); + expect(finalMessage.tool_calls?.[0].id).toBeDefined(); +}); + +test("json mode", async () => { + const llm = new ChatCerebras({ + model: "llama3.1-8b", + temperature: 0, + }); + + const messages = [ + { + role: "system", + content: + "You are a math tutor that handles math exercises and makes output in json in format { result: number }.", + }, + { role: "user", content: "2 + 2" }, + ]; + + const res = await llm.invoke(messages, { + response_format: { type: "json_object" }, + }); + + expect(JSON.parse(res.content as string)).toEqual({ result: 4 }); +}); diff --git a/libs/langchain-cerebras/src/tests/chat_models.standard.int.test.ts b/libs/langchain-cerebras/src/tests/chat_models.standard.int.test.ts new file mode 100644 index 000000000000..dfa0719fa5ec --- /dev/null +++ b/libs/langchain-cerebras/src/tests/chat_models.standard.int.test.ts @@ -0,0 +1,39 @@ +/* eslint-disable no-process-env */ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { test, expect } from "@jest/globals"; +import { ChatModelIntegrationTests } from "@langchain/standard-tests"; +import { AIMessageChunk } from "@langchain/core/messages"; +import { ChatCerebras, ChatCerebrasCallOptions } from "../chat_models.js"; + +class ChatCerebrasStandardIntegrationTests extends ChatModelIntegrationTests< + ChatCerebrasCallOptions, + AIMessageChunk +> { + constructor() { + if (!process.env.CEREBRAS_API_KEY) { + throw new Error( + "Can not run Cerebras integration tests because CEREBRAS_API_KEY_API_KEY is not set" + ); + } + super({ + Cls: ChatCerebras as any, + chatModelHasToolCalling: true, + chatModelHasStructuredOutput: true, + constructorArgs: { + model: "llama3.1-8b", + maxRetries: 1, + temperature: 0, + }, + }); + } +} + +const testClass = new ChatCerebrasStandardIntegrationTests(); + +test("ChatCerebrasStandardIntegrationTests", async () => { + console.warn = (..._args: unknown[]) => { + // no-op + }; + const testResults = await testClass.runTests(); + expect(testResults).toBe(true); +}); diff --git a/libs/langchain-cerebras/src/tests/chat_models.standard.test.ts b/libs/langchain-cerebras/src/tests/chat_models.standard.test.ts new file mode 100644 index 000000000000..c44cca3f1c2c --- /dev/null +++ b/libs/langchain-cerebras/src/tests/chat_models.standard.test.ts @@ -0,0 +1,44 @@ +/* eslint-disable no-process-env */ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { test, expect } from "@jest/globals"; +import { ChatModelUnitTests } from "@langchain/standard-tests"; +import { AIMessageChunk } from "@langchain/core/messages"; +import { ChatCerebras, ChatCerebrasCallOptions } from "../chat_models.js"; + +class ChatCerebrasStandardUnitTests extends ChatModelUnitTests< + ChatCerebrasCallOptions, + AIMessageChunk +> { + constructor() { + super({ + Cls: ChatCerebras as any, + chatModelHasToolCalling: true, + chatModelHasStructuredOutput: true, + constructorArgs: { + model: "llama3.1-8b", + maxRetries: 1, + temperature: 0, + }, + }); + // This must be set so method like `.bindTools` or `.withStructuredOutput` + // which we call after instantiating the model will work. + // (constructor will throw if API key is not set) + process.env.CEREBRAS_API_KEY = "test"; + } + + testChatModelInitApiKey() { + // Unset the API key env var here so this test can properly check + // the API key class arg. + process.env.CEREBRAS_API_KEY = ""; + super.testChatModelInitApiKey(); + // Re-set the API key env var here so other tests can run properly. + process.env.CEREBRAS_API_KEY = "test"; + } +} + +const testClass = new ChatCerebrasStandardUnitTests(); + +test("ChatCerebrasStandardUnitTests", () => { + const testResults = testClass.runTests(); + expect(testResults).toBe(true); +}); diff --git a/libs/langchain-cerebras/src/utils.ts b/libs/langchain-cerebras/src/utils.ts new file mode 100644 index 000000000000..c1849f208cab --- /dev/null +++ b/libs/langchain-cerebras/src/utils.ts @@ -0,0 +1,226 @@ +import { + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + MessageContentText, + SystemMessage, + ToolMessage, + UsageMetadata, +} from "@langchain/core/messages"; +import { v4 as uuidv4 } from "uuid"; +import Cerebras from "@cerebras/cerebras_cloud_sdk"; +import { ToolChoice } from "@langchain/core/language_models/chat_models"; + +export type CerebrasMessageParam = + | Cerebras.ChatCompletionCreateParams.AssistantMessageRequest + | Cerebras.ChatCompletionCreateParams.SystemMessageRequest + | Cerebras.ChatCompletionCreateParams.ToolMessageRequest + | Cerebras.ChatCompletionCreateParams.UserMessageRequest; +export type CerebrasToolCall = + Cerebras.ChatCompletion.ChatCompletionResponse.Choice.Message.ToolCall; + +export function convertCerebrasMessagesToLangChain( + messages: Cerebras.ChatCompletion.ChatCompletionResponse.Choice.Message, + extra?: { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + responseMetadata?: Record; + usageMetadata?: UsageMetadata; + } +): AIMessageChunk { + return new AIMessageChunk({ + content: messages.content ?? "", + tool_call_chunks: messages.tool_calls?.map((tc) => ({ + name: tc.function.name, + args: JSON.stringify(tc.function.arguments), + type: "tool_call_chunk", + index: 0, + id: uuidv4(), + })), + response_metadata: extra?.responseMetadata, + usage_metadata: extra?.usageMetadata, + }); +} + +function extractBase64FromDataUrl(dataUrl: string): string { + const match = dataUrl.match(/^data:.*?;base64,(.*)$/); + return match ? match[1] : ""; +} + +function convertAIMessageToCerebras( + messages: AIMessage +): CerebrasMessageParam[] { + if (typeof messages.content === "string") { + return [ + { + role: "assistant", + content: messages.content, + }, + ]; + } + + const textFields = messages.content.filter( + (c) => c.type === "text" && typeof c.text === "string" + ); + const textMessages: CerebrasMessageParam[] = ( + textFields as MessageContentText[] + ).map((c) => ({ + role: "assistant", + content: c.text, + })); + let toolCallMsgs: CerebrasMessageParam | undefined; + + if ( + messages.content.find((c) => c.type === "tool_use") && + messages.tool_calls?.length + ) { + // `tool_use` content types are accepted if the message has tool calls + const toolCalls: CerebrasToolCall[] | undefined = messages.tool_calls?.map( + (tc) => ({ + id: tc.id!, + type: "function", + function: { + name: tc.name, + arguments: JSON.stringify(tc.args), + }, + }) + ); + + if (toolCalls) { + toolCallMsgs = { + role: "assistant", + tool_calls: toolCalls, + content: "", + }; + } + } else if ( + messages.content.find((c) => c.type === "tool_use") && + !messages.tool_calls?.length + ) { + throw new Error( + "'tool_use' content type is not supported without tool calls." + ); + } + + return [...textMessages, ...(toolCallMsgs ? [toolCallMsgs] : [])]; +} + +function convertHumanGenericMessagesToCerebras( + message: HumanMessage +): CerebrasMessageParam[] { + if (typeof message.content === "string") { + return [ + { + role: "user", + content: message.content, + }, + ]; + } + return message.content.map((c) => { + if (c.type === "text") { + return { + role: "user", + content: c.text, + }; + } else if (c.type === "image_url") { + if (typeof c.image_url === "string") { + return { + role: "user", + content: "", + images: [extractBase64FromDataUrl(c.image_url)], + }; + } else if (c.image_url.url && typeof c.image_url.url === "string") { + return { + role: "user", + content: "", + images: [extractBase64FromDataUrl(c.image_url.url)], + }; + } + } + throw new Error(`Unsupported content type: ${c.type}`); + }); +} + +function convertSystemMessageToCerebras( + message: SystemMessage +): CerebrasMessageParam[] { + if (typeof message.content === "string") { + return [ + { + role: "system", + content: message.content, + }, + ]; + } else if ( + message.content.every( + (c) => c.type === "text" && typeof c.text === "string" + ) + ) { + return (message.content as MessageContentText[]).map((c) => ({ + role: "system", + content: c.text, + })); + } else { + throw new Error( + `Unsupported content type(s): ${message.content + .map((c) => c.type) + .join(", ")}` + ); + } +} + +function convertToolMessageToCerebras( + message: ToolMessage +): CerebrasMessageParam[] { + if (typeof message.content !== "string") { + throw new Error("Non string tool message content is not supported"); + } + return [ + { + role: "tool", + content: message.content, + tool_call_id: message.tool_call_id, + }, + ]; +} + +export function convertToCerebrasMessageParams( + messages: BaseMessage[] +): CerebrasMessageParam[] { + return messages.flatMap((msg) => { + if (["human", "generic"].includes(msg._getType())) { + return convertHumanGenericMessagesToCerebras(msg); + } else if (msg._getType() === "ai") { + return convertAIMessageToCerebras(msg); + } else if (msg._getType() === "system") { + return convertSystemMessageToCerebras(msg); + } else if (msg._getType() === "tool") { + return convertToolMessageToCerebras(msg as ToolMessage); + } else { + throw new Error(`Unsupported message type: ${msg._getType()}`); + } + }); +} + +export function formatToCerebrasToolChoice( + toolChoice?: ToolChoice +): Cerebras.ChatCompletionCreateParams["tool_choice"] { + if (!toolChoice) { + return undefined; + } else if (toolChoice === "any" || toolChoice === "required") { + return "required"; + } else if (toolChoice === "auto") { + return "auto"; + } else if (toolChoice === "none") { + return "none"; + } else if (typeof toolChoice === "string") { + return { + type: "function", + function: { + name: toolChoice, + }, + }; + } else { + return toolChoice as Cerebras.ChatCompletionCreateParams.Tool; + } +} diff --git a/libs/langchain-cerebras/tsconfig.cjs.json b/libs/langchain-cerebras/tsconfig.cjs.json new file mode 100644 index 000000000000..3b7026ea406c --- /dev/null +++ b/libs/langchain-cerebras/tsconfig.cjs.json @@ -0,0 +1,8 @@ +{ + "extends": "./tsconfig.json", + "compilerOptions": { + "module": "commonjs", + "declaration": false + }, + "exclude": ["node_modules", "dist", "docs", "**/tests"] +} diff --git a/libs/langchain-cerebras/tsconfig.json b/libs/langchain-cerebras/tsconfig.json new file mode 100644 index 000000000000..bc85d83b6229 --- /dev/null +++ b/libs/langchain-cerebras/tsconfig.json @@ -0,0 +1,23 @@ +{ + "extends": "@tsconfig/recommended", + "compilerOptions": { + "outDir": "../dist", + "rootDir": "./src", + "target": "ES2021", + "lib": ["ES2021", "ES2022.Object", "DOM"], + "module": "ES2020", + "moduleResolution": "nodenext", + "esModuleInterop": true, + "declaration": true, + "noImplicitReturns": true, + "noFallthroughCasesInSwitch": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "useDefineForClassFields": true, + "strictPropertyInitialization": false, + "allowJs": true, + "strict": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "docs"] +} diff --git a/libs/langchain-cerebras/turbo.json b/libs/langchain-cerebras/turbo.json new file mode 100644 index 000000000000..d024cee15c81 --- /dev/null +++ b/libs/langchain-cerebras/turbo.json @@ -0,0 +1,11 @@ +{ + "extends": ["//"], + "pipeline": { + "build": { + "outputs": ["**/dist/**"] + }, + "build:internal": { + "dependsOn": ["^build:internal"] + } + } +} diff --git a/yarn.lock b/yarn.lock index fa91e3d3f1ea..6fb66edb4fb8 100644 --- a/yarn.lock +++ b/yarn.lock @@ -8691,6 +8691,21 @@ __metadata: languageName: node linkType: hard +"@cerebras/cerebras_cloud_sdk@npm:^1.15.0": + version: 1.15.0 + resolution: "@cerebras/cerebras_cloud_sdk@npm:1.15.0" + dependencies: + "@types/node": ^18.11.18 + "@types/node-fetch": ^2.6.4 + abort-controller: ^3.0.0 + agentkeepalive: ^4.2.1 + form-data-encoder: 1.7.2 + formdata-node: ^4.3.2 + node-fetch: ^2.6.7 + checksum: 192b6c473fa2cf7d7b77b1523e935c21f0d8084365ec795e02a0dcb608e85a4b742c1584338311cad816c46b2f50772ab7b7902a2e5d5fd12800eae085db9b06 + languageName: node + linkType: hard + "@cfworker/json-schema@npm:^4.0.2": version: 4.0.2 resolution: "@cfworker/json-schema@npm:4.0.2" @@ -11729,6 +11744,44 @@ __metadata: languageName: unknown linkType: soft +"@langchain/cerebras@*, @langchain/cerebras@workspace:libs/langchain-cerebras": + version: 0.0.0-use.local + resolution: "@langchain/cerebras@workspace:libs/langchain-cerebras" + dependencies: + "@cerebras/cerebras_cloud_sdk": ^1.15.0 + "@jest/globals": ^29.5.0 + "@langchain/core": "workspace:*" + "@langchain/scripts": ">=0.1.0 <0.2.0" + "@langchain/standard-tests": 0.0.0 + "@swc/core": ^1.3.90 + "@swc/jest": ^0.2.29 + "@tsconfig/recommended": ^1.0.3 + "@types/uuid": ^10 + "@typescript-eslint/eslint-plugin": ^6.12.0 + "@typescript-eslint/parser": ^6.12.0 + dotenv: ^16.3.1 + dpdm: ^3.12.0 + eslint: ^8.33.0 + eslint-config-airbnb-base: ^15.0.0 + eslint-config-prettier: ^8.6.0 + eslint-plugin-import: ^2.27.5 + eslint-plugin-no-instanceof: ^1.0.1 + eslint-plugin-prettier: ^4.2.1 + jest: ^29.5.0 + jest-environment-node: ^29.6.4 + prettier: ^2.8.3 + release-it: ^15.10.1 + rollup: ^4.5.2 + ts-jest: ^29.1.0 + typescript: <5.2.0 + uuid: ^10.0.0 + zod: ^3.22.4 + zod-to-json-schema: ^3.22.3 + peerDependencies: + "@langchain/core": ">=0.3.0 <0.4.0" + languageName: unknown + linkType: soft + "@langchain/cloudflare@workspace:*, @langchain/cloudflare@workspace:libs/langchain-cloudflare": version: 0.0.0-use.local resolution: "@langchain/cloudflare@workspace:libs/langchain-cloudflare" @@ -20119,7 +20172,7 @@ __metadata: languageName: node linkType: hard -"@types/uuid@npm:^10.0.0": +"@types/uuid@npm:^10, @types/uuid@npm:^10.0.0": version: 10.0.0 resolution: "@types/uuid@npm:10.0.0" checksum: e3958f8b0fe551c86c14431f5940c3470127293280830684154b91dc7eb3514aeb79fe3216968833cf79d4d1c67f580f054b5be2cd562bebf4f728913e73e944 @@ -33365,6 +33418,7 @@ __metadata: "@jest/globals": ^29.5.0 "@langchain/anthropic": "*" "@langchain/aws": "*" + "@langchain/cerebras": "*" "@langchain/cohere": "*" "@langchain/core": "workspace:*" "@langchain/google-genai": "*" @@ -33423,6 +33477,7 @@ __metadata: peerDependencies: "@langchain/anthropic": "*" "@langchain/aws": "*" + "@langchain/cerebras": "*" "@langchain/cohere": "*" "@langchain/core": ">=0.2.21 <0.4.0" "@langchain/google-genai": "*" @@ -33440,6 +33495,8 @@ __metadata: optional: true "@langchain/aws": optional: true + "@langchain/cerebras": + optional: true "@langchain/cohere": optional: true "@langchain/google-genai":