diff --git a/go/core/schema.go b/go/core/schema.go new file mode 100644 index 0000000000..09c9789533 --- /dev/null +++ b/go/core/schema.go @@ -0,0 +1,130 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +// Package core provides core functionality for the genkit framework. +package core + +import ( + "fmt" + "sync" +) + +// Schema represents a schema definition that can be of any type. +type Schema any + +var ( + schemasMu sync.RWMutex + schemas = make(map[string]any) + schemaLookups []func(string) any + // Keep track of schemas to register with Dotprompt + pendingSchemas = make(map[string]Schema) +) + +// RegisterSchema registers a schema with the given name. +// This is intended to be called by higher-level packages like ai. +// It validates that the name is not empty and the schema is not nil, +// then registers the schema in the core schemas map. +// Returns the schema for convenience in chaining operations. +func RegisterSchema(name string, schema any) (Schema, error) { + if name == "" { + return nil, fmt.Errorf("core.RegisterSchema: schema name cannot be empty") + } + + if schema == nil { + return nil, fmt.Errorf("core.RegisterSchema: schema definition cannot be nil") + } + + schemasMu.Lock() + defer schemasMu.Unlock() + + if _, exists := schemas[name]; exists { + return nil, fmt.Errorf("core.RegisterSchema: schema with name %q already exists", name) + } + + schemas[name] = schema + pendingSchemas[name] = schema + + return schema, nil +} + +// LookupSchema looks up a schema by name. +// It first checks the local registry, and if not found, +// it calls each registered lookup function until one returns a non-nil result. +func LookupSchema(name string) any { + schemasMu.RLock() + defer schemasMu.RUnlock() + + // First check local registry + if schema, ok := schemas[name]; ok { + return schema + } + + // Then try lookup functions + for _, lookup := range schemaLookups { + if schema := lookup(name); schema != nil { + return schema + } + } + + return nil +} + +// RegisterSchemaLookup registers a function that can look up schemas by name. +// This allows different packages to provide schemas while maintaining a +// unified lookup mechanism. +func RegisterSchemaLookup(lookup func(string) any) { + schemasMu.Lock() + defer schemasMu.Unlock() + + schemaLookups = append(schemaLookups, lookup) +} + +// Schemas returns a copy of all registered schemas. +func Schemas() map[string]any { + schemasMu.RLock() + defer schemasMu.RUnlock() + + result := make(map[string]any, len(schemas)) + for name, schema := range schemas { + result[name] = schema + } + + return result +} + +// PendingSchemas returns a copy of pending schemas that need to be +// registered with Dotprompt. +func PendingSchemas() map[string]Schema { + schemasMu.RLock() + defer schemasMu.RUnlock() + + result := make(map[string]Schema, len(pendingSchemas)) + for name, schema := range pendingSchemas { + result[name] = schema + } + + return result +} + +// ClearPendingSchemas clears the pending schemas map. +// This is called after the schemas have been registered with Dotprompt. +func ClearPendingSchemas() { + schemasMu.Lock() + defer schemasMu.Unlock() + + schemas = make(map[string]any) + pendingSchemas = make(map[string]Schema) + schemaLookups = nil +} diff --git a/go/core/schema_test.go b/go/core/schema_test.go new file mode 100644 index 0000000000..932ea775c5 --- /dev/null +++ b/go/core/schema_test.go @@ -0,0 +1,282 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package core + +import ( + "fmt" + "reflect" + "sync" + "testing" +) + +// clearSchemasForTest removes all registered schemas. +// This is exclusively for testing purposes. +func clearSchemasForTest() { + schemasMu.Lock() + defer schemasMu.Unlock() + + schemas = make(map[string]any) + pendingSchemas = make(map[string]Schema) + schemaLookups = nil +} + +// TestRegisterSchema tests schema registration functionality +func TestRegisterSchema(t *testing.T) { + clearSchemasForTest() + t.Cleanup(clearSchemasForTest) + + t.Run("RegisterValidSchema", func(t *testing.T) { + schema := map[string]interface{}{"type": "object"} + result, err := RegisterSchema("test", schema) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if result == nil { + t.Fatal("Expected RegisterSchema to return the schema, got nil") + } + + retrieved := LookupSchema("test") + if retrieved == nil { + t.Fatal("Failed to retrieve registered schema") + } + + if !reflect.DeepEqual(retrieved, schema) { + t.Fatalf("Retrieved schema doesn't match registered schema. Got %v, want %v", retrieved, schema) + } + }) + + t.Run("RegisterDuplicateName", func(t *testing.T) { + clearSchemasForTest() + _, err := RegisterSchema("duplicate", "first") + if err != nil { + t.Fatalf("Unexpected error registering first schema: %v", err) + } + + _, err = RegisterSchema("duplicate", "second") + if err == nil { + t.Fatal("Expected error when registering duplicate schema name, but no error occurred") + } + expectedErrMsg := `core.RegisterSchema: schema with name "duplicate" already exists` + if err.Error() != expectedErrMsg { + t.Fatalf("Expected error message %q, got %q", expectedErrMsg, err.Error()) + } + }) + + t.Run("RegisterEmptyName", func(t *testing.T) { + _, err := RegisterSchema("", "schema") + if err == nil { + t.Fatal("Expected error when registering schema with empty name, but no error occurred") + } + expectedErrMsg := "core.RegisterSchema: schema name cannot be empty" + if err.Error() != expectedErrMsg { + t.Fatalf("Expected error message %q, got %q", expectedErrMsg, err.Error()) + } + }) + + t.Run("RegisterNilSchema", func(t *testing.T) { + _, err := RegisterSchema("nil_schema", nil) + if err == nil { + t.Fatal("Expected error when registering nil schema, but no error occurred") + } + expectedErrMsg := "core.RegisterSchema: schema definition cannot be nil" + if err.Error() != expectedErrMsg { + t.Fatalf("Expected error message %q, got %q", expectedErrMsg, err.Error()) + } + }) +} + +// TestLookupSchema tests schema lookup functionality +func TestLookupSchema(t *testing.T) { + clearSchemasForTest() + t.Cleanup(clearSchemasForTest) + + t.Run("LookupExistingSchema", func(t *testing.T) { + expectedSchema := "test_schema" + _, err := RegisterSchema("existing", expectedSchema) + if err != nil { + t.Fatalf("Failed to register schema: %v", err) + } + + result := LookupSchema("existing") + if result != expectedSchema { + t.Fatalf("Expected schema %v, got %v", expectedSchema, result) + } + }) + + t.Run("LookupNonExistentSchema", func(t *testing.T) { + result := LookupSchema("nonexistent") + if result != nil { + t.Fatalf("Expected nil for non-existent schema, got %v", result) + } + }) + + t.Run("LookupViaCustomFunction", func(t *testing.T) { + expectedSchema := "custom_schema" + RegisterSchemaLookup(func(name string) any { + if name == "custom" { + return expectedSchema + } + return nil + }) + + result := LookupSchema("custom") + if result != expectedSchema { + t.Fatalf("Expected schema %v from custom lookup, got %v", expectedSchema, result) + } + }) + + t.Run("PreferLocalRegistryOverLookup", func(t *testing.T) { + localSchema := "local_schema" + _, err := RegisterSchema("preference_test", localSchema) + if err != nil { + t.Fatalf("Failed to register schema: %v", err) + } + + lookupSchema := "lookup_schema" + RegisterSchemaLookup(func(name string) any { + if name == "preference_test" { + return lookupSchema + } + return nil + }) + + result := LookupSchema("preference_test") + if result != localSchema { + t.Fatalf("Expected local schema %v to be preferred, got %v", localSchema, result) + } + }) +} + +// TestPendingSchemas tests handling of pending schemas +func TestPendingSchemas(t *testing.T) { + clearSchemasForTest() + t.Cleanup(clearSchemasForTest) + + t.Run("GetPendingSchemas", func(t *testing.T) { + _, err := RegisterSchema("pending1", "test1") + if err != nil { + t.Fatalf("Failed to register first schema: %v", err) + } + + _, err = RegisterSchema("pending2", "test2") + if err != nil { + t.Fatalf("Failed to register second schema: %v", err) + } + + pending := PendingSchemas() + if len(pending) != 2 { + t.Fatalf("Expected 2 pending schemas, got %d", len(pending)) + } + + if pending["pending1"] != "test1" || pending["pending2"] != "test2" { + t.Fatal("Pending schemas don't match expected values") + } + }) + + t.Run("ClearPendingSchemas", func(t *testing.T) { + _, err := RegisterSchema("pending3", "test3") + if err != nil { + t.Fatalf("Failed to register schema: %v", err) + } + + ClearPendingSchemas() + + pending := PendingSchemas() + if len(pending) != 0 { + t.Fatalf("Expected 0 pending schemas after clearing, got %d", len(pending)) + } + }) +} + +// TestSchemas tests the Schemas function that returns all registered schemas +func TestSchemas(t *testing.T) { + clearSchemasForTest() + t.Cleanup(clearSchemasForTest) + + _, err := RegisterSchema("schema1", "value1") + if err != nil { + t.Fatalf("Failed to register first schema: %v", err) + } + + _, err = RegisterSchema("schema2", "value2") + if err != nil { + t.Fatalf("Failed to register second schema: %v", err) + } + + schemasMap := Schemas() + if len(schemasMap) != 2 { + t.Fatalf("Expected 2 schemas, got %d", len(schemasMap)) + } + + if schemasMap["schema1"] != "value1" || schemasMap["schema2"] != "value2" { + t.Fatal("Retrieved schemas don't match expected values") + } + + schemasMap["schema3"] = "value3" + + internalSchemas := Schemas() + if len(internalSchemas) != 2 { + t.Fatalf("Expected internal schemas count to remain 2, got %d", len(internalSchemas)) + } + + if _, exists := internalSchemas["schema3"]; exists { + t.Fatal("Modifying returned schemas map should not affect internal state") + } +} + +// TestConcurrentAccess tests thread safety of schema operations +func TestConcurrentAccess(t *testing.T) { + clearSchemasForTest() + t.Cleanup(clearSchemasForTest) + + const numGoroutines = 10 + const schemasPerGoroutine = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(routineID int) { + defer wg.Done() + + for j := 0; j < schemasPerGoroutine; j++ { + name := fmt.Sprintf("schema_r%d_s%d", routineID, j) + _, err := RegisterSchema(name, j) + if err != nil { + t.Errorf("Unexpected error registering schema %s: %v", name, err) + } + } + + for j := 0; j < schemasPerGoroutine; j++ { + name := fmt.Sprintf("schema_r%d_s%d", routineID, j) + value := LookupSchema(name) + if value != j { + t.Errorf("Expected schema value %d for %s, got %v", j, name, value) + } + } + }(i) + } + + wg.Wait() + + schemasMap := Schemas() + expectedCount := numGoroutines * schemasPerGoroutine + if len(schemasMap) != expectedCount { + t.Fatalf("Expected %d total schemas, got %d", expectedCount, len(schemasMap)) + } +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 2473b0dd68..7f858c1627 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -210,6 +210,11 @@ func Init(ctx context.Context, opts ...GenkitOption) (*Genkit, error) { g := &Genkit{reg: r} + // Register schemas with Dotprompt before loading plugins or prompt files + if err := registerPendingSchemas(r); err != nil { + return nil, fmt.Errorf("genkit.Init: error registering schemas: %w", err) + } + for _, plugin := range gOpts.Plugins { if err := plugin.Init(ctx, g); err != nil { return nil, fmt.Errorf("genkit.Init: plugin %T initialization failed: %w", plugin, err) @@ -245,6 +250,20 @@ func Init(ctx context.Context, opts ...GenkitOption) (*Genkit, error) { return g, nil } +// Internal function called during Init to register pending schemas +func registerPendingSchemas(reg *registry.Registry) error { + pendingSchemas := core.PendingSchemas() + + for name, schema := range pendingSchemas { + if err := reg.RegisterSchemaWithDotprompt(name, schema); err != nil { + return fmt.Errorf("failed to register schema %s: %w", name, err) + } + } + + core.ClearPendingSchemas() + return nil +} + // DefineFlow defines a non-streaming flow, registers it as a [core.Action] of type Flow, // and returns a [core.Flow] runner. // The provided function `fn` takes an input of type `In` and returns an output of type `Out`. diff --git a/go/genkit/schema.go b/go/genkit/schema.go new file mode 100644 index 0000000000..8ddf9d687f --- /dev/null +++ b/go/genkit/schema.go @@ -0,0 +1,114 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package genkit + +import ( + "fmt" + "log/slog" + "sync" + + "github.com/firebase/genkit/go/core" + "github.com/google/dotprompt/go/dotprompt" + "github.com/invopop/jsonschema" +) + +// Schema is an alias for core.Schema to maintain compatibility with existing type definitions +type Schema = core.Schema + +// schemasMu and pendingSchemas are maintained for backward compatibility +var ( + schemasMu sync.RWMutex + pendingSchemas = make(map[string]Schema) +) + +// DefineSchema registers a schema that can be referenced by name in genkit. +// This allows schemas to be defined once and used across the AI generation pipeline. +// +// Example usage: +// +// type Person struct { +// Name string `json:"name"` +// Age int `json:"age"` +// } +// +// personSchema := genkit.DefineSchema("Person", Person{}) +func DefineSchema(name string, schema Schema) (Schema, error) { + if name == "" { + return nil, fmt.Errorf("genkit.DefineSchema: schema name cannot be empty") + } + + if schema == nil { + return nil, fmt.Errorf("genkit.DefineSchema: schema cannot be nil") + } + + core.RegisterSchema(name, schema) + + schemasMu.Lock() + defer schemasMu.Unlock() + pendingSchemas[name] = schema + + return schema, nil +} + +// LookupSchema retrieves a registered schema by name. +// It returns nil and false if no schema exists with that name. +func LookupSchema(name string) (Schema, bool) { + schema := core.LookupSchema(name) + return schema, schema != nil +} + +// FindSchema retrieves a registered schema by name. +// It returns an error if no schema exists with that name. +func FindSchema(name string) (Schema, error) { + schema, exists := LookupSchema(name) + if !exists { + return nil, fmt.Errorf("genkit: schema '%s' not found", name) + } + return schema, nil +} + +// registerSchemaResolver registers a schema resolver with Dotprompt to handle schema lookups +func registerSchemaResolver(dp *dotprompt.Dotprompt) { + // Create a schema resolver that can look up schemas from the Genkit registry + schemaResolver := func(name string) any { + schema, exists := LookupSchema(name) + if !exists { + slog.Error("schema not found in registry", "name", name) + return nil + } + + reflector := jsonschema.Reflector{} + jsonSchema := reflector.Reflect(schema) + return jsonSchema + } + + dp.RegisterExternalSchemaLookup(schemaResolver) +} + +// RegisterGlobalSchemaResolver exports the schema lookup capabilities for use in other packages +func RegisterGlobalSchemaResolver(dp *dotprompt.Dotprompt) { + dp.RegisterExternalSchemaLookup(func(name string) any { + schema, exists := LookupSchema(name) + if !exists { + return nil + } + + reflector := jsonschema.Reflector{} + jsonSchema := reflector.Reflect(schema) + return jsonSchema + }) +} diff --git a/go/genkit/schema_test.go b/go/genkit/schema_test.go new file mode 100644 index 0000000000..15fa4cc514 --- /dev/null +++ b/go/genkit/schema_test.go @@ -0,0 +1,94 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package genkit + +import ( + "testing" +) + +type TestStruct struct { + Name string + Age int +} + +func TestDefineAndLookupSchema(t *testing.T) { + schemaName := "TestStruct" + testSchema := TestStruct{Name: "Alice", Age: 30} + + // Define the schema + schema, err := DefineSchema(schemaName, testSchema) + if err != nil { + t.Fatalf("Unexpected error defining schema: %v", err) + } + + // Lookup the schema + schema, found := LookupSchema(schemaName) + if !found { + t.Fatalf("Expected schema '%s' to be found", schemaName) + } + + // Assert the type + typedSchema, ok := schema.(TestStruct) + if !ok { + t.Fatalf("Expected schema to be of type TestStruct") + } + + if typedSchema.Name != "Alice" || typedSchema.Age != 30 { + t.Errorf("Unexpected schema contents: %+v", typedSchema) + } +} + +func TestSchemaSuccess(t *testing.T) { + schemaName := "GetStruct" + testSchema := TestStruct{Name: "Bob", Age: 25} + + _, err := DefineSchema(schemaName, testSchema) + if err != nil { + t.Fatalf("Unexpected error defining schema: %v", err) + } + + schema, err := FindSchema(schemaName) + if err != nil { + t.Fatalf("Expected schema '%s' to be retrieved without error", schemaName) + } + + typedSchema := schema.(TestStruct) + if typedSchema.Name != "Bob" || typedSchema.Age != 25 { + t.Errorf("Unexpected schema contents: %+v", typedSchema) + } +} + +func TestSchemaNotFound(t *testing.T) { + _, err := FindSchema("NonExistentSchema") + if err == nil { + t.Fatal("Expected error when retrieving a non-existent schema") + } +} + +func TestDefineSchemaEmptyName(t *testing.T) { + _, err := DefineSchema("", TestStruct{}) + if err == nil { + t.Fatal("Expected error for empty schema name") + } +} + +func TestDefineSchemaNil(t *testing.T) { + var nilSchema Schema + _, err := DefineSchema("NilSchema", nilSchema) + if err == nil { + t.Fatal("Expected error for nil schema") + } +} diff --git a/go/go.mod b/go/go.mod index 2df94ef6b5..ae1d8a1824 100644 --- a/go/go.mod +++ b/go/go.mod @@ -23,6 +23,7 @@ require ( github.com/pgvector/pgvector-go v0.3.0 github.com/weaviate/weaviate v1.26.0-rc.1 github.com/weaviate/weaviate-go-client/v4 v4.15.0 + github.com/wk8/go-ordered-map/v2 v2.1.8 github.com/xeipuuv/gojsonschema v1.2.0 go.opentelemetry.io/otel v1.29.0 go.opentelemetry.io/otel/metric v1.29.0 @@ -52,6 +53,7 @@ require ( github.com/PuerkitoBio/purell v1.1.1 // indirect github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aymerick/raymond v2.0.2+incompatible // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -76,11 +78,9 @@ require ( github.com/gorilla/websocket v1.5.3 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/mailru/easyjson v0.9.0 // indirect - github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect go.mongodb.org/mongo-driver v1.14.0 // indirect @@ -101,3 +101,5 @@ require ( google.golang.org/grpc v1.66.2 // indirect google.golang.org/protobuf v1.34.2 // indirect ) + +replace github.com/google/dotprompt/go => github.com/google/dotprompt/go v0.0.0-20250422204256-6029fef7a2fd diff --git a/go/go.sum b/go/go.sum index 48450c251f..e6efdac1d5 100644 --- a/go/go.sum +++ b/go/go.sum @@ -43,6 +43,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aymerick/raymond v2.0.2+incompatible h1:VEp3GpgdAnv9B2GFyTvqgcKvY+mfKMjPOA3SbKLtnU0= +github.com/aymerick/raymond v2.0.2+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/blues/jsonata-go v1.5.4 h1:XCsXaVVMrt4lcpKeJw6mNJHqQpWU751cnHdCFUq3xd8= @@ -145,8 +147,8 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/dotprompt/go v0.0.0-20250415074656-072d95deb01d h1:ChKKjq8F7GcNKViCCB/vRoU6joR7IDsZgu1I4wg6RjQ= -github.com/google/dotprompt/go v0.0.0-20250415074656-072d95deb01d/go.mod h1:dnIk+MSMnipm9uZyPIgptq7I39aDxyjBiaev/OG0W0Y= +github.com/google/dotprompt/go v0.0.0-20250422204256-6029fef7a2fd h1:LmVYfpTt3dDDYoBqziibAZf2lfMOcOf5MfkFDyoDrPg= +github.com/google/dotprompt/go v0.0.0-20250422204256-6029fef7a2fd/go.mod h1:wVZXOPYuasZIfPu6UQvYxODdVUR2nIligI4SWs47GVs= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -213,8 +215,6 @@ github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4 github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= -github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a h1:v2cBA3xWKv2cIOVhnzX/gNgkNXqiHfUgJtA3r61Hf7A= -github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a/go.mod h1:Y6ghKH+ZijXn5d9E7qGGZBmjitx7iitZdQiIW97EpTU= github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= diff --git a/go/internal/registry/registry.go b/go/internal/registry/registry.go index 8cf0ef80c8..2ba4622039 100644 --- a/go/internal/registry/registry.go +++ b/go/internal/registry/registry.go @@ -36,6 +36,7 @@ import ( const ( DefaultModelKey = "genkit/defaultModel" PromptDirKey = "genkit/promptDir" + SchemaType = "schema" ) type Registry struct { diff --git a/go/internal/registry/schema.go b/go/internal/registry/schema.go new file mode 100644 index 0000000000..a24da9fde7 --- /dev/null +++ b/go/internal/registry/schema.go @@ -0,0 +1,220 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import ( + "fmt" + "reflect" + "strings" + + "github.com/google/dotprompt/go/dotprompt" + "github.com/invopop/jsonschema" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +// DefineSchema registers a Go struct as a schema with the given name. +func (r *Registry) DefineSchema(name string, structType any) error { + jsonSchema, err := convertStructToJsonSchema(structType) + if err != nil { + return err + } + + if r.Dotprompt == nil { + r.Dotprompt = dotprompt.NewDotprompt(&dotprompt.DotpromptOptions{ + Schemas: map[string]*jsonschema.Schema{}, + }) + } + + r.Dotprompt.DefineSchema(name, jsonSchema) + r.RegisterValue(SchemaType+"/"+name, structType) + fmt.Printf("Registered schema '%s' with registry and Dotprompt\n", name) + return nil +} + +// RegisterSchemaWithDotprompt registers a schema with the Dotprompt instance +// This is used during Init to register schemas that were defined before the registry was created. +func (r *Registry) RegisterSchemaWithDotprompt(name string, schema any) error { + if r.Dotprompt == nil { + r.Dotprompt = dotprompt.NewDotprompt(&dotprompt.DotpromptOptions{ + Schemas: map[string]*jsonschema.Schema{}, + }) + } + + jsonSchema, err := convertStructToJsonSchema(schema) + if err != nil { + return err + } + + r.Dotprompt.DefineSchema(name, jsonSchema) + r.RegisterValue(SchemaType+"/"+name, schema) + + // Set up schema lookup if not already done + r.setupSchemaLookupFunction() + + return nil +} + +// setupSchemaLookupFunction registers the external schema lookup function with Dotprompt +// This function bridges between Dotprompt's schema resolution and the registry's values +func (r *Registry) setupSchemaLookupFunction() { + if r.Dotprompt == nil { + return + } + + r.Dotprompt.RegisterExternalSchemaLookup(func(schemaName string) any { + schemaValue := r.LookupValue(SchemaType + "/" + schemaName) + if schemaValue != nil { + return schemaValue + } + return nil + }) +} + +// convertStructToJsonSchema converts a Go struct to a JSON schema +func convertStructToJsonSchema(structType any) (*jsonschema.Schema, error) { + t := reflect.TypeOf(structType) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("expected struct type, got %s", t.Kind()) + } + + schema := &jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema](), + Required: []string{}, + } + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + if field.PkgPath != "" { + continue + } + + jsonTag := field.Tag.Get("json") + parts := strings.Split(jsonTag, ",") + propName := parts[0] + if propName == "" { + propName = field.Name + } + + if propName == "-" { + continue + } + + isRequired := true + for _, opt := range parts[1:] { + if opt == "omitempty" { + isRequired = false + break + } + } + + if isRequired { + schema.Required = append(schema.Required, propName) + } + + description := field.Tag.Get("description") + + fieldSchema := fieldToSchema(field.Type, description) + schema.Properties.Set(propName, fieldSchema) + } + + return schema, nil +} + +// fieldToSchema converts a field type to a JSON Schema. +func fieldToSchema(t reflect.Type, description string) *jsonschema.Schema { + schema := &jsonschema.Schema{} + + if description != "" { + schema.Description = description + } + + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.String: + schema.Type = "string" + case reflect.Bool: + schema.Type = "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + schema.Type = "integer" + case reflect.Float32, reflect.Float64: + schema.Type = "number" + case reflect.Slice, reflect.Array: + schema.Type = "array" + itemSchema := fieldToSchema(t.Elem(), "") + schema.Items = itemSchema + case reflect.Map: + schema.Type = "object" + if t.Key().Kind() == reflect.String { + valueSchema := fieldToSchema(t.Elem(), "") + schema.AdditionalProperties = valueSchema + } + case reflect.Struct: + schema.Type = "object" + schema.Properties = orderedmap.New[string, *jsonschema.Schema]() + schema.Required = []string{} + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + if field.PkgPath != "" { + continue + } + + jsonTag := field.Tag.Get("json") + parts := strings.Split(jsonTag, ",") + propName := parts[0] + if propName == "" { + propName = field.Name + } + + if propName == "-" { + continue + } + + isRequired := true + for _, opt := range parts[1:] { + if opt == "omitempty" { + isRequired = false + break + } + } + + if isRequired { + schema.Required = append(schema.Required, propName) + } + + fieldDescription := field.Tag.Get("description") + + fieldSchema := fieldToSchema(field.Type, fieldDescription) + schema.Properties.Set(propName, fieldSchema) + } + default: + schema.Type = "string" + } + + return schema +} diff --git a/go/samples/schema/main.go b/go/samples/schema/main.go new file mode 100644 index 0000000000..78525ed19c --- /dev/null +++ b/go/samples/schema/main.go @@ -0,0 +1,284 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +/* +Product Generator using Genkit and Dotprompt + +This application demonstrates a structured product generation system that uses: +- Genkit: A framework for managing AI model interactions and prompts +- Dotprompt: A library for working with structured prompts and JSON schemas +- JSON Schema: For defining the structure of generated product data + +The program: +1. Defines a ProductSchema struct for structured product data +2. Creates a mock AI model plugin that returns predefined product data +3. Generates and saves JSON schema files in a prompts directory +4. Creates a prompt template that takes a theme as input and outputs a product +5. Initializes Dotprompt with schema resolution capabilities +6. Executes the prompt with an "eco-friendly" theme +7. Parses the structured response and displays the generated product + +The mock implementation simulates what would happen with a real AI model +by returning different products based on detected themes in the input. +This provides a testable framework for structured AI outputs conforming +to the defined schema. +*/ + +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/google/dotprompt/go/dotprompt" + "github.com/invopop/jsonschema" +) + +// ProductSchema defines our product output structure +// This schema will be used for structured outputs from AI models +type ProductSchema struct { + Name string `json:"name"` + Description string `json:"description"` + Price float64 `json:"price"` + Category string `json:"category"` + InStock bool `json:"inStock"` +} + +// MockPlugin implements the genkit.Plugin interface +// It provides a custom model implementation for testing purposes +type MockPlugin struct{} + +// Name returns the unique identifier for the plugin +func (p *MockPlugin) Name() string { + return "mock" +} + +// Init initializes the plugin with the Genkit instance +// It registers a mock model that returns predefined product data +func (p *MockPlugin) Init(ctx context.Context, g *genkit.Genkit) error { + genkit.DefineModel(g, "mock", "product-model", + &ai.ModelInfo{ + Label: "Mock Product Model", + Supports: &ai.ModelSupports{}, + }, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + product := ProductSchema{ + Name: "Eco-Friendly Bamboo Cutting Board", + Description: "A sustainable cutting board made from 100% bamboo. Features a juice groove and handle.", + Price: 29.99, + Category: "Kitchen Accessories", + InStock: true, + } + + jsonBytes, err := json.Marshal(product) + if err != nil { + return nil, err + } + + resp := &ai.ModelResponse{ + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewTextPart(string(jsonBytes))}, + }, + FinishReason: ai.FinishReasonStop, + } + + return resp, nil + }) + + return nil +} + +func main() { + ctx := context.Background() + + cwd, _ := os.Getwd() + promptDir := filepath.Join(cwd, "prompts") + + if _, err := os.Stat(promptDir); os.IsNotExist(err) { + if err := os.MkdirAll(promptDir, 0755); err != nil { + log.Fatalf("Failed to create prompt directory: %v", err) + } + } + + schemaFilePath := filepath.Join(promptDir, "_schema_ProductSchema.partial.prompt") + + reflector := jsonschema.Reflector{} + schema := reflector.Reflect(ProductSchema{}) + + // Structure the schema according to what Dotprompt expects + schemaWrapper := struct { + Schema string `json:"$schema"` + Ref string `json:"$ref"` + Definitions map[string]*jsonschema.Schema `json:"$defs"` + }{ + Schema: "https://json-schema.org/draft/2020-12/schema", + Ref: "#/$defs/ProductSchema", + Definitions: map[string]*jsonschema.Schema{ + "ProductSchema": schema, + }, + } + + schemaJSON, err := json.MarshalIndent(schemaWrapper, "", " ") + if err != nil { + log.Fatalf("Failed to marshal schema: %v", err) + } + + if err := os.WriteFile(schemaFilePath, schemaJSON, 0644); err != nil { + log.Fatalf("Failed to write schema file: %v", err) + } + + // Create prompt file with schema reference + promptFilePath := filepath.Join(promptDir, "product_generator.prompt") + promptContent := "---\n" + + "input:\n" + + " schema:\n" + + " theme: string\n" + + "output:\n" + + " schema: ProductSchema\n" + + "---\n" + + "Generate a product that fits the {{theme}} theme.\n" + + "Make sure to provide a detailed description and appropriate pricing." + + if err := os.WriteFile(promptFilePath, []byte(promptContent), 0644); err != nil { + log.Fatalf("Failed to write prompt file: %v", err) + } + + // Testing with dotprompt directly + dp := dotprompt.NewDotprompt(&dotprompt.DotpromptOptions{ + Schemas: map[string]*jsonschema.Schema{}, + }) + + // Register external schema lookup function + dp.RegisterExternalSchemaLookup(func(schemaName string) any { + if schemaName == "ProductSchema" { + return schema + } + return nil + }) + + metadata := map[string]any{ + "output": map[string]any{ + "schema": "ProductSchema", + }, + } + + if err = dp.ResolveSchemaReferences(metadata); err != nil { + log.Fatalf("Schema resolution failed: %v", err) + } + + // Define our schema with Genkit + genkit.DefineSchema("ProductSchema", ProductSchema{}) + + // Initialize Genkit with our prompt directory + g, err := genkit.Init(ctx, + genkit.WithPromptDir(promptDir), + genkit.WithDefaultModel("mock/default-model")) + if err != nil { + log.Fatalf("Failed to initialize Genkit: %v", err) + } + + // Define a mock model to respond to prompts + genkit.DefineModel(g, "mock", "default-model", + &ai.ModelInfo{ + Label: "Mock Default Model", + Supports: &ai.ModelSupports{}, + }, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Extract theme from the request to customize the response + theme := "generic" + if len(req.Messages) > 0 { + lastMsg := req.Messages[len(req.Messages)-1] + if lastMsg.Role == ai.RoleUser { + for _, part := range lastMsg.Content { + if part.IsText() && strings.Contains(part.Text, "eco-friendly") { + theme = "eco-friendly" + } + } + } + } + + // Generate appropriate product based on theme + var product ProductSchema + if theme == "eco-friendly" { + product = ProductSchema{ + Name: "Eco-Friendly Bamboo Cutting Board", + Description: "A sustainable cutting board made from 100% bamboo. Features a juice groove and handle.", + Price: 29.99, + Category: "Kitchen Accessories", + InStock: true, + } + } else { + product = ProductSchema{ + Name: "Classic Stainless Steel Water Bottle", + Description: "Durable 24oz water bottle with vacuum insulation. Keeps drinks cold for 24 hours.", + Price: 24.99, + Category: "Drinkware", + InStock: true, + } + } + + jsonBytes, err := json.Marshal(product) + if err != nil { + return nil, err + } + + resp := &ai.ModelResponse{ + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewTextPart(string(jsonBytes))}, + }, + FinishReason: ai.FinishReasonStop, + } + + return resp, nil + }) + + // Look up and execute the prompt + productPrompt := genkit.LookupPrompt(g, "product_generator") + if productPrompt == nil { + log.Fatalf("Prompt 'product_generator' not found") + } + + input := map[string]any{ + "theme": "eco-friendly kitchen gadgets", + } + + resp, err := productPrompt.Execute(ctx, ai.WithInput(input)) + if err != nil { + log.Fatalf("Failed to execute prompt: %v", err) + } + + // Parse the structured response into our Go struct + var product ProductSchema + if err := resp.Output(&product); err != nil { + log.Fatalf("Failed to parse response: %v", err) + } + + fmt.Println("\nGenerated Product:") + fmt.Printf("Name: %s\n", product.Name) + fmt.Printf("Description: %s\n", product.Description) + fmt.Printf("Price: $%.2f\n", product.Price) + fmt.Printf("Category: %s\n", product.Category) + fmt.Printf("In Stock: %v\n", product.InStock) +} diff --git a/go/samples/schema/prompts/_schema_ProductSchema.partial.prompt b/go/samples/schema/prompts/_schema_ProductSchema.partial.prompt new file mode 100644 index 0000000000..bd3a4cdc31 --- /dev/null +++ b/go/samples/schema/prompts/_schema_ProductSchema.partial.prompt @@ -0,0 +1,40 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/ProductSchema", + "$defs": { + "ProductSchema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/ProductSchema", + "$defs": { + "ProductSchema": { + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "price": { + "type": "number" + }, + "category": { + "type": "string" + }, + "inStock": { + "type": "boolean" + } + }, + "additionalProperties": false, + "type": "object", + "required": [ + "name", + "description", + "price", + "category", + "inStock" + ] + } + } + } + } +} \ No newline at end of file diff --git a/go/samples/schema/prompts/product_generator.prompt b/go/samples/schema/prompts/product_generator.prompt new file mode 100644 index 0000000000..2c1a127ff4 --- /dev/null +++ b/go/samples/schema/prompts/product_generator.prompt @@ -0,0 +1,9 @@ +--- +input: + schema: + theme: string +output: + schema: ProductSchema +--- +Generate a product that fits the {{theme}} theme. +Make sure to provide a detailed description and appropriate pricing. \ No newline at end of file