-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: vercel kv graph checkpointer (#5948)
* - feat: vercel kv graph checkpointer * - fix: downgraded uuid pkg version - fix: langchain.config.js - fix: moved to '/langgraph/checkpointers' * - fix: moved to '/langgraph/checkpointers' - fix: imports from '@langchain/langgraph/web' * fix: reverted uuid to ^9.0.0 * fix: reverted uuid version * - fix: langgraph version deps - fix: uuid version deps * fix: fixed uuid v6 in unit tests * fix: lock issues * Switch to integration test, format, lint * Update build artifacts * - fix: save checkpoint atomically in redis * - fix: nit unit test * - fix: types * - fix: non-blocking key lookup optimization --------- Co-authored-by: jacoblee93 <[email protected]>
- Loading branch information
1 parent
c2d3472
commit bac6138
Showing
7 changed files
with
314 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
libs/langchain-community/src/langgraph/checkpointers/tests/checkpointer.int.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
/* eslint-disable no-process-env */ | ||
|
||
import { describe, test, expect } from "@jest/globals"; | ||
import { Checkpoint, CheckpointTuple } from "@langchain/langgraph"; | ||
import { VercelKVSaver } from "../vercel_kv.js"; | ||
|
||
const checkpoint1: Checkpoint = { | ||
v: 1, | ||
id: "1ef390c8-3ed9-6132-ffff-12d236274621", | ||
ts: "2024-04-19T17:19:07.952Z", | ||
channel_values: { | ||
someKey1: "someValue1", | ||
}, | ||
channel_versions: { | ||
someKey2: 1, | ||
}, | ||
versions_seen: { | ||
someKey3: { | ||
someKey4: 1, | ||
}, | ||
}, | ||
}; | ||
|
||
const checkpoint2: Checkpoint = { | ||
v: 1, | ||
id: "1ef390c8-3ed9-6133-8001-419c612dad04", | ||
ts: "2024-04-20T17:19:07.952Z", | ||
channel_values: { | ||
someKey1: "someValue2", | ||
}, | ||
channel_versions: { | ||
someKey2: 2, | ||
}, | ||
versions_seen: { | ||
someKey3: { | ||
someKey4: 2, | ||
}, | ||
}, | ||
}; | ||
|
||
describe("VercelKVSaver", () => { | ||
const vercelSaver = new VercelKVSaver({ | ||
url: process.env.VERCEL_KV_API_URL!, | ||
token: process.env.VERCEL_KV_API_TOKEN!, | ||
}); | ||
|
||
test("should save and retrieve checkpoints correctly", async () => { | ||
// save checkpoint | ||
const runnableConfig = await vercelSaver.put( | ||
{ configurable: { thread_id: "1" } }, | ||
checkpoint1, | ||
{ source: "update", step: -1, writes: null } | ||
); | ||
expect(runnableConfig).toEqual({ | ||
configurable: { | ||
thread_id: "1", | ||
checkpoint_id: checkpoint1.id, | ||
}, | ||
}); | ||
|
||
// get checkpoint tuple | ||
const checkpointTuple = await vercelSaver.getTuple({ | ||
configurable: { thread_id: "1" }, | ||
}); | ||
expect(checkpointTuple?.config).toEqual({ | ||
configurable: { | ||
thread_id: "1", | ||
checkpoint_id: checkpoint1.id, | ||
}, | ||
}); | ||
expect(checkpointTuple?.checkpoint).toEqual(checkpoint1); | ||
|
||
// save another checkpoint | ||
await vercelSaver.put( | ||
{ | ||
configurable: { | ||
thread_id: "1", | ||
}, | ||
}, | ||
checkpoint2, | ||
{ source: "update", step: -1, writes: null } | ||
); | ||
// list checkpoints | ||
const checkpointTupleGenerator = vercelSaver.list({ | ||
configurable: { thread_id: "1" }, | ||
}); | ||
|
||
const checkpointTuples: CheckpointTuple[] = []; | ||
|
||
for await (const checkpoint of checkpointTupleGenerator) { | ||
checkpointTuples.push(checkpoint); | ||
} | ||
expect(checkpointTuples.length).toBe(2); | ||
|
||
const checkpointTuple1 = checkpointTuples[0]; | ||
const checkpointTuple2 = checkpointTuples[1]; | ||
|
||
expect(checkpointTuple1.checkpoint.ts).toBe("2024-04-20T17:19:07.952Z"); | ||
expect(checkpointTuple2.checkpoint.ts).toBe("2024-04-19T17:19:07.952Z"); | ||
}); | ||
}); |
164 changes: 164 additions & 0 deletions
164
libs/langchain-community/src/langgraph/checkpointers/vercel_kv.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import { VercelKV, createClient } from "@vercel/kv"; | ||
|
||
import { RunnableConfig } from "@langchain/core/runnables"; | ||
import { | ||
BaseCheckpointSaver, | ||
Checkpoint, | ||
CheckpointMetadata, | ||
CheckpointTuple, | ||
SerializerProtocol, | ||
} from "@langchain/langgraph/web"; | ||
|
||
// snake_case is used to match Python implementation | ||
interface KVRow { | ||
checkpoint: string; | ||
metadata: string; | ||
} | ||
|
||
interface KVConfig { | ||
url: string; | ||
token: string; | ||
} | ||
|
||
export class VercelKVSaver extends BaseCheckpointSaver { | ||
private kv: VercelKV; | ||
|
||
constructor(config: KVConfig, serde?: SerializerProtocol<unknown>) { | ||
super(serde); | ||
this.kv = createClient(config); | ||
} | ||
|
||
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> { | ||
const thread_id = config.configurable?.thread_id; | ||
const checkpoint_id = config.configurable?.checkpoint_id; | ||
|
||
if (!thread_id) { | ||
return undefined; | ||
} | ||
|
||
const key = checkpoint_id | ||
? `${thread_id}:${checkpoint_id}` | ||
: `${thread_id}:last`; | ||
|
||
const row: KVRow | null = await this.kv.get(key); | ||
|
||
if (!row) { | ||
return undefined; | ||
} | ||
|
||
const [checkpoint, metadata] = await Promise.all([ | ||
this.serde.parse(row.checkpoint), | ||
this.serde.parse(row.metadata), | ||
]); | ||
|
||
return { | ||
checkpoint: checkpoint as Checkpoint, | ||
metadata: metadata as CheckpointMetadata, | ||
config: checkpoint_id | ||
? config | ||
: { | ||
configurable: { | ||
thread_id, | ||
checkpoint_id: (checkpoint as Checkpoint).id, | ||
}, | ||
}, | ||
}; | ||
} | ||
|
||
async *list( | ||
config: RunnableConfig, | ||
limit?: number, | ||
before?: RunnableConfig | ||
): AsyncGenerator<CheckpointTuple> { | ||
const thread_id: string = config.configurable?.thread_id; | ||
|
||
// LUA script to get keys excluding those starting with "last" | ||
const luaScript = ` | ||
local prefix = ARGV[1] | ||
local cursor = '0' | ||
local result = {} | ||
repeat | ||
local scanResult = redis.call('SCAN', cursor, 'MATCH', prefix .. '*', 'COUNT', 1000) | ||
cursor = scanResult[1] | ||
local keys = scanResult[2] | ||
for _, key in ipairs(keys) do | ||
if key:sub(-5) ~= ':last' then | ||
table.insert(result, key) | ||
end | ||
end | ||
until cursor == '0' | ||
return result | ||
`; | ||
|
||
// Execute the LUA script with the thread_id as an argument | ||
const keys: string[] = await this.kv.eval(luaScript, [], [thread_id]); | ||
|
||
const filteredKeys = keys.filter((key: string) => { | ||
const [, checkpoint_id] = key.split(":"); | ||
|
||
return !before || checkpoint_id < before?.configurable?.checkpoint_id; | ||
}); | ||
|
||
const sortedKeys = filteredKeys | ||
.sort((a: string, b: string) => b.localeCompare(a)) | ||
.slice(0, limit); | ||
|
||
const rows: (KVRow | null)[] = await this.kv.mget(...sortedKeys); | ||
for (const row of rows) { | ||
if (row) { | ||
const [checkpoint, metadata] = await Promise.all([ | ||
this.serde.parse(row.checkpoint), | ||
this.serde.parse(row.metadata), | ||
]); | ||
|
||
yield { | ||
config: { | ||
configurable: { | ||
thread_id, | ||
checkpoint_id: (checkpoint as Checkpoint).id, | ||
}, | ||
}, | ||
checkpoint: checkpoint as Checkpoint, | ||
metadata: metadata as CheckpointMetadata, | ||
}; | ||
} | ||
} | ||
} | ||
|
||
async put( | ||
config: RunnableConfig, | ||
checkpoint: Checkpoint, | ||
metadata: CheckpointMetadata | ||
): Promise<RunnableConfig> { | ||
const thread_id = config.configurable?.thread_id; | ||
|
||
if (!thread_id || !checkpoint.id) { | ||
throw new Error("Thread ID and Checkpoint ID must be defined"); | ||
} | ||
|
||
const row: KVRow = { | ||
checkpoint: this.serde.stringify(checkpoint), | ||
metadata: this.serde.stringify(metadata), | ||
}; | ||
|
||
// LUA script to set checkpoint data atomically" | ||
const luaScript = ` | ||
local thread_id = ARGV[1] | ||
local checkpoint_id = ARGV[2] | ||
local row = ARGV[3] | ||
redis.call('SET', thread_id .. ':' .. checkpoint_id, row) | ||
redis.call('SET', thread_id .. ':last', row) | ||
`; | ||
|
||
// Save the checkpoint and the last checkpoint | ||
await this.kv.eval(luaScript, [], [thread_id, checkpoint.id, row]); | ||
|
||
return { | ||
configurable: { | ||
thread_id, | ||
checkpoint_id: checkpoint.id, | ||
}, | ||
}; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.