From 6946d0b384ab98aee22b608e8aacc6a64b9a9bb9 Mon Sep 17 00:00:00 2001 From: Max Lord Date: Thu, 9 May 2024 14:40:53 -0400 Subject: [PATCH] Resolving issues with loading dotprompt files (#92) - Adding plugin to samples which use prompt files - Checking existence of prompt folder before reading dir - Fixing bug in registry key --- js/plugins/dotprompt/src/registry.ts | 42 ++++++++++++---------- js/plugins/dotprompt/tests/prompt_test.ts | 16 +++++++-- js/samples/byo-evaluator/src/index.ts | 2 ++ js/samples/cat-eval/src/index.ts | 2 ++ js/samples/coffee-shop/src/index.ts | 8 +++-- js/samples/menu-example/basic/src/index.ts | 3 +- js/samples/prompt-file/src/index.ts | 4 +-- js/samples/rag/src/index.ts | 2 ++ 8 files changed, 53 insertions(+), 26 deletions(-) diff --git a/js/plugins/dotprompt/src/registry.ts b/js/plugins/dotprompt/src/registry.ts index b100416fa..7a574a69d 100644 --- a/js/plugins/dotprompt/src/registry.ts +++ b/js/plugins/dotprompt/src/registry.ts @@ -26,7 +26,7 @@ export function registryDefinitionKey(name: string, variant?: string) { } export function registryLookupKey(name: string, variant?: string) { - return `prompt/${registryDefinitionKey(name, variant)}`; + return `/prompt/${registryDefinitionKey(name, variant)}`; } export async function lookupPrompt( @@ -71,25 +71,29 @@ export async function loadPromptFolder( ): Promise { const promptsPath = resolve(dir); return new Promise((resolve, reject) => { - readdir( - promptsPath, - { - withFileTypes: true, - recursive: false, - }, - (err, dirEnts) => { - if (err) { - reject(err); - } else { - dirEnts.forEach(async (dirEnt) => { - if (dirEnt.isFile() && dirEnt.name.endsWith('.prompt')) { - loadPrompt(dirEnt.path, dirEnt.name); - } - }); - resolve(); + if (existsSync(promptsPath)) { + readdir( + promptsPath, + { + withFileTypes: true, + recursive: false, + }, + (err, dirEnts) => { + if (err) { + reject(err); + } else { + dirEnts.forEach(async (dirEnt) => { + if (dirEnt.isFile() && dirEnt.name.endsWith('.prompt')) { + loadPrompt(dirEnt.path, dirEnt.name); + } + }); + resolve(); + } } - } - ); + ); + } else { + resolve(); + } }); } diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index e720374da..f9a93cc23 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -18,12 +18,21 @@ import assert from 'node:assert'; import { describe, it } from 'node:test'; import { defineModel } from '@genkit-ai/ai/model'; -import z from 'zod'; - import { toJsonSchema, ValidationError } from '@genkit-ai/core/schema'; +import z from 'zod'; +import { registerPluginProvider } from '../../../core/src/registry.js'; import { defineDotprompt, Dotprompt, prompt } from '../src/index.js'; import { PromptMetadata } from '../src/metadata.js'; +function registerDotprompt() { + registerPluginProvider('dotprompt', { + name: 'dotprompt', + async initializer() { + return {}; + }, + }); +} + const echo = defineModel( { name: 'echo', supports: { tools: true } }, async (input) => ({ @@ -62,6 +71,7 @@ describe('Prompt', () => { }); it('rejects input not matching the schema', async () => { + registerDotprompt(); const invalidSchemaPrompt = defineDotprompt( { name: 'invalidInput', @@ -90,6 +100,7 @@ describe('Prompt', () => { }); it('rejects input not matching the schema', async () => { + registerDotprompt(); const invalidSchemaPrompt = defineDotprompt( { name: 'invalidInput', @@ -176,6 +187,7 @@ output: describe('definePrompt', () => { it('registers a prompt and its variant', async () => { + registerDotprompt(); defineDotprompt( { name: 'promptName', diff --git a/js/samples/byo-evaluator/src/index.ts b/js/samples/byo-evaluator/src/index.ts index 07822b5d0..b3061b870 100644 --- a/js/samples/byo-evaluator/src/index.ts +++ b/js/samples/byo-evaluator/src/index.ts @@ -16,6 +16,7 @@ import { EvaluatorAction } from '@genkit-ai/ai'; import { ModelReference } from '@genkit-ai/ai/model'; import { configureGenkit, genkitPlugin, PluginProvider } from '@genkit-ai/core'; +import { dotprompt } from '@genkit-ai/dotprompt'; import { firebase } from '@genkit-ai/firebase'; import { geminiPro, googleAI } from '@genkit-ai/googleai'; import * as z from 'zod'; @@ -42,6 +43,7 @@ import { configureGenkit({ plugins: [ + dotprompt(), firebase(), googleAI({ apiVersion: ['v1', 'v1beta'] }), byoEval({ diff --git a/js/samples/cat-eval/src/index.ts b/js/samples/cat-eval/src/index.ts index 0aed27d5a..5ac4f9ccd 100644 --- a/js/samples/cat-eval/src/index.ts +++ b/js/samples/cat-eval/src/index.ts @@ -16,6 +16,7 @@ import { configureGenkit } from '@genkit-ai/core'; import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; +import { dotprompt } from '@genkit-ai/dotprompt'; import { genkitEval, GenkitMetric } from '@genkit-ai/evaluator'; import { firebase } from '@genkit-ai/firebase'; import { geminiPro, googleAI } from '@genkit-ai/googleai'; @@ -46,6 +47,7 @@ export const PERMISSIVE_SAFETY_SETTINGS: any = { configureGenkit({ plugins: [ + dotprompt(), firebase(), googleAI(), genkitEval({ diff --git a/js/samples/coffee-shop/src/index.ts b/js/samples/coffee-shop/src/index.ts index 0d639bba0..e6735a72a 100644 --- a/js/samples/coffee-shop/src/index.ts +++ b/js/samples/coffee-shop/src/index.ts @@ -15,14 +15,18 @@ */ import { configureGenkit } from '@genkit-ai/core'; -import { defineDotprompt } from '@genkit-ai/dotprompt'; +import { defineDotprompt, dotprompt } from '@genkit-ai/dotprompt'; import { firebase } from '@genkit-ai/firebase'; import { defineFlow, runFlow } from '@genkit-ai/flow'; import googleAI, { geminiPro } from '@genkit-ai/googleai'; import * as z from 'zod'; configureGenkit({ - plugins: [googleAI({ apiVersion: ['v1', 'v1beta'] }), firebase()], + plugins: [ + googleAI({ apiVersion: ['v1', 'v1beta'] }), + firebase(), + dotprompt(), + ], enableTracingAndMetrics: true, flowStateStore: 'firebase', logLevel: 'debug', diff --git a/js/samples/menu-example/basic/src/index.ts b/js/samples/menu-example/basic/src/index.ts index 977b4aae8..8853d4c9d 100644 --- a/js/samples/menu-example/basic/src/index.ts +++ b/js/samples/menu-example/basic/src/index.ts @@ -18,12 +18,13 @@ // both. import { generate } from '@genkit-ai/ai'; import { configureGenkit } from '@genkit-ai/core'; +import { dotprompt } from '@genkit-ai/dotprompt'; import { defineFlow, startFlowsServer } from '@genkit-ai/flow'; import { geminiPro, googleAI } from '@genkit-ai/googleai'; import * as z from 'zod'; configureGenkit({ - plugins: [googleAI()], + plugins: [googleAI(), dotprompt()], logLevel: 'debug', enableTracingAndMetrics: true, }); diff --git a/js/samples/prompt-file/src/index.ts b/js/samples/prompt-file/src/index.ts index a501d3d81..bff74e480 100644 --- a/js/samples/prompt-file/src/index.ts +++ b/js/samples/prompt-file/src/index.ts @@ -15,13 +15,13 @@ */ import { configureGenkit } from '@genkit-ai/core'; -import { prompt } from '@genkit-ai/dotprompt'; +import { dotprompt, prompt } from '@genkit-ai/dotprompt'; import { defineFlow } from '@genkit-ai/flow'; import { googleAI } from '@genkit-ai/googleai'; import * as z from 'zod'; configureGenkit({ - plugins: [googleAI()], + plugins: [googleAI(), dotprompt()], enableTracingAndMetrics: true, logLevel: 'debug', }); diff --git a/js/samples/rag/src/index.ts b/js/samples/rag/src/index.ts index 2a96b5ac1..ffc5b9882 100644 --- a/js/samples/rag/src/index.ts +++ b/js/samples/rag/src/index.ts @@ -16,6 +16,7 @@ import { configureGenkit } from '@genkit-ai/core'; import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; +import { dotprompt } from '@genkit-ai/dotprompt'; import { genkitEval, GenkitMetric } from '@genkit-ai/evaluator'; import { firebase } from '@genkit-ai/firebase'; import { googleAI } from '@genkit-ai/googleai'; @@ -31,6 +32,7 @@ import { pinecone } from 'genkitx-pinecone'; export default configureGenkit({ plugins: [ + dotprompt(), firebase(), googleAI({ apiVersion: ['v1', 'v1beta'] }), genkitEval({