From 7cb98192d3ba83922b9296edbe2150410853d9e8 Mon Sep 17 00:00:00 2001 From: Yuwen Ma Date: Sun, 12 Jan 2025 18:28:59 +0000 Subject: [PATCH] feat: write prompt result to a file --- dev/tools/controllerbuilder/go.mod | 3 ++- dev/tools/controllerbuilder/go.sum | 2 ++ .../pkg/commands/exportcsv/prompt.go | 27 ++++++++++++++++++- dev/tools/controllerbuilder/pkg/llm/gemini.go | 16 +++++------ .../controllerbuilder/pkg/toolbot/csv.go | 1 + 5 files changed, 39 insertions(+), 10 deletions(-) diff --git a/dev/tools/controllerbuilder/go.mod b/dev/tools/controllerbuilder/go.mod index e21904e2fe..69c7d4b021 100644 --- a/dev/tools/controllerbuilder/go.mod +++ b/dev/tools/controllerbuilder/go.mod @@ -10,11 +10,13 @@ require ( github.com/fatih/color v1.17.0 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 + golang.org/x/oauth2 v0.23.0 golang.org/x/tools v0.24.0 google.golang.org/api v0.203.0 google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 google.golang.org/protobuf v1.35.1 k8s.io/apimachinery v0.27.11 + k8s.io/client-go v0.27.11 k8s.io/klog/v2 v2.130.1 ) @@ -46,7 +48,6 @@ require ( golang.org/x/crypto v0.28.0 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/net v0.30.0 // indirect - golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/text v0.19.0 // indirect diff --git a/dev/tools/controllerbuilder/go.sum b/dev/tools/controllerbuilder/go.sum index a62c0d02c1..3c5d031ca6 100644 --- a/dev/tools/controllerbuilder/go.sum +++ b/dev/tools/controllerbuilder/go.sum @@ -211,5 +211,7 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= k8s.io/apimachinery v0.27.11 h1:ivrKMN7JgdtKhay14S5UQlvilV3z6W+wjiSQTzyr5zc= k8s.io/apimachinery v0.27.11/go.mod h1:IHu2ovJ60RqxyPSLmTel7KDLdOCRbpOxwtUBmwBnT/E= +k8s.io/client-go v0.27.11 h1:SZChXsDaN6lB5IYywCpvQs/ZUa5vK2NHkpEwUhoK3fQ= +k8s.io/client-go v0.27.11/go.mod h1:Rg3Yeuk9sX87gpVunVn3AsvMkGZfXuutTDC/jigBNUo= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= diff --git a/dev/tools/controllerbuilder/pkg/commands/exportcsv/prompt.go b/dev/tools/controllerbuilder/pkg/commands/exportcsv/prompt.go index 927d4efd87..97ecfe47e2 100644 --- a/dev/tools/controllerbuilder/pkg/commands/exportcsv/prompt.go +++ b/dev/tools/controllerbuilder/pkg/commands/exportcsv/prompt.go @@ -15,11 +15,14 @@ package exportcsv import ( + "bytes" "context" "fmt" "io" "os" + "strings" + kccio "github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/io" "github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/options" "github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/toolbot" "k8s.io/klog/v2" @@ -33,12 +36,14 @@ type PromptOptions struct { ProtoDir string SrcDir string + Output string } // BindFlags binds the flags to the command. func (o *PromptOptions) BindFlags(cmd *cobra.Command) { cmd.Flags().StringVar(&o.SrcDir, "src-dir", o.SrcDir, "base directory for source code") cmd.Flags().StringVar(&o.ProtoDir, "proto-dir", o.ProtoDir, "base directory for checkout of proto API definitions") + cmd.Flags().StringVar(&o.Output, "output", o.Output, "the directory to store the prompt outcome") } // BuildPromptCommand builds the `prompt` command. @@ -75,6 +80,7 @@ func RunPrompt(ctx context.Context, o *PromptOptions) error { if o.ProtoDir == "" { return fmt.Errorf("--proto-dir is required") } + extractor := &toolbot.ExtractToolMarkers{} addProtoDefinition, err := toolbot.NewEnhanceWithProtoDefinition(o.ProtoDir) if err != nil { @@ -109,9 +115,28 @@ func RunPrompt(ctx context.Context, o *PromptOptions) error { log.Info("built data point", "dataPoint", dataPoint) - if err := x.RunGemini(ctx, dataPoint, os.Stdout); err != nil { + out := &bytes.Buffer{} + if err := x.RunGemini(ctx, dataPoint, out); err != nil { return fmt.Errorf("running LLM inference: %w", err) } + + if o.Output == "" { + fmt.Println(out) + return nil + } + + if tmpF, err := kccio.WriteToCache(ctx, o.Output, out.String(), fileNamePattern(dataPoint)); err != nil { + return err + } else { + fmt.Println(tmpF) + } return nil } + +func fileNamePattern(dataPoint *toolbot.DataPoint) string { + for k, _ := range dataPoint.Input { + return strings.Replace(k, " ", "-", -1) + } + return "" +} diff --git a/dev/tools/controllerbuilder/pkg/llm/gemini.go b/dev/tools/controllerbuilder/pkg/llm/gemini.go index bffafcc042..ba13fc420f 100644 --- a/dev/tools/controllerbuilder/pkg/llm/gemini.go +++ b/dev/tools/controllerbuilder/pkg/llm/gemini.go @@ -23,6 +23,7 @@ import ( "strings" "cloud.google.com/go/vertexai/genai" + "golang.org/x/oauth2/google" "google.golang.org/api/option" "k8s.io/klog/v2" ) @@ -35,15 +36,14 @@ func BuildGeminiClient(ctx context.Context) (*genai.Client, error) { if s := os.Getenv("GEMINI_API_KEY"); s != "" { opts = append(opts, option.WithAPIKey(s)) + } else { + // Some account can not use GEMINI_API_KEY but requires stricter access control via OAuth. + creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/generative-language", "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + return nil, fmt.Errorf("finding default credentials: %w", err) + } + opts = append(opts, option.WithCredentials(creds)) } - // else { - // creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/generative-language", "https://www.googleapis.com/auth/cloud-platform") - // if err != nil { - // return nil, fmt.Errorf("finding default credentials: %w", err) - // } - // opts = append(opts, option.WithCredentials(creds)) - // } - projectID := "" location := "" diff --git a/dev/tools/controllerbuilder/pkg/toolbot/csv.go b/dev/tools/controllerbuilder/pkg/toolbot/csv.go index b87d190785..fa1c3a591a 100644 --- a/dev/tools/controllerbuilder/pkg/toolbot/csv.go +++ b/dev/tools/controllerbuilder/pkg/toolbot/csv.go @@ -243,6 +243,7 @@ func (x *CSVExporter) RunGemini(ctx context.Context, input *DataPoint, out io.Wr for _, part := range content.Parts { if text, ok := part.(genai.Text); ok { klog.Infof("TEXT: %+v", text) + out.Write([]byte(text + "\n")) } else { klog.Infof("UNKNOWN: %T %+v", part, part) }