Skip to content

Commit

Permalink
Merge pull request #3476 from yuwenma/scifi-llm-prep
Browse files Browse the repository at this point in the history
feat: write prompt result to a file
  • Loading branch information
google-oss-prow[bot] authored Jan 14, 2025
2 parents 8dafa4e + 7cb9819 commit 20a42d3
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 10 deletions.
3 changes: 2 additions & 1 deletion dev/tools/controllerbuilder/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ require (
github.com/fatih/color v1.17.0
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5
golang.org/x/oauth2 v0.23.0
golang.org/x/tools v0.24.0
google.golang.org/api v0.203.0
google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9
google.golang.org/protobuf v1.35.1
k8s.io/apimachinery v0.27.11
k8s.io/client-go v0.27.11
k8s.io/klog/v2 v2.130.1
)

Expand Down Expand Up @@ -46,7 +48,6 @@ require (
golang.org/x/crypto v0.28.0 // indirect
golang.org/x/mod v0.20.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/text v0.19.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions dev/tools/controllerbuilder/go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 26 additions & 1 deletion dev/tools/controllerbuilder/pkg/commands/exportcsv/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
package exportcsv

import (
"bytes"
"context"
"fmt"
"io"
"os"
"strings"

kccio "github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/io"
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/options"
"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/toolbot"
"k8s.io/klog/v2"
Expand All @@ -33,12 +36,14 @@ type PromptOptions struct {

ProtoDir string
SrcDir string
Output string
}

// BindFlags binds the flags to the command.
func (o *PromptOptions) BindFlags(cmd *cobra.Command) {
cmd.Flags().StringVar(&o.SrcDir, "src-dir", o.SrcDir, "base directory for source code")
cmd.Flags().StringVar(&o.ProtoDir, "proto-dir", o.ProtoDir, "base directory for checkout of proto API definitions")
cmd.Flags().StringVar(&o.Output, "output", o.Output, "the directory to store the prompt outcome")
}

// BuildPromptCommand builds the `prompt` command.
Expand Down Expand Up @@ -75,6 +80,7 @@ func RunPrompt(ctx context.Context, o *PromptOptions) error {
if o.ProtoDir == "" {
return fmt.Errorf("--proto-dir is required")
}

extractor := &toolbot.ExtractToolMarkers{}
addProtoDefinition, err := toolbot.NewEnhanceWithProtoDefinition(o.ProtoDir)
if err != nil {
Expand Down Expand Up @@ -109,9 +115,28 @@ func RunPrompt(ctx context.Context, o *PromptOptions) error {

log.Info("built data point", "dataPoint", dataPoint)

if err := x.RunGemini(ctx, dataPoint, os.Stdout); err != nil {
out := &bytes.Buffer{}
if err := x.RunGemini(ctx, dataPoint, out); err != nil {
return fmt.Errorf("running LLM inference: %w", err)

}

if o.Output == "" {
fmt.Println(out)
return nil
}

if tmpF, err := kccio.WriteToCache(ctx, o.Output, out.String(), fileNamePattern(dataPoint)); err != nil {
return err
} else {
fmt.Println(tmpF)
}
return nil
}

func fileNamePattern(dataPoint *toolbot.DataPoint) string {
for k, _ := range dataPoint.Input {
return strings.Replace(k, " ", "-", -1)
}
return ""
}
16 changes: 8 additions & 8 deletions dev/tools/controllerbuilder/pkg/llm/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"strings"

"cloud.google.com/go/vertexai/genai"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
"k8s.io/klog/v2"
)
Expand All @@ -35,15 +36,14 @@ func BuildGeminiClient(ctx context.Context) (*genai.Client, error) {

if s := os.Getenv("GEMINI_API_KEY"); s != "" {
opts = append(opts, option.WithAPIKey(s))
} else {
// Some account can not use GEMINI_API_KEY but requires stricter access control via OAuth.
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/generative-language", "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("finding default credentials: %w", err)
}
opts = append(opts, option.WithCredentials(creds))
}
// else {
// creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/generative-language", "https://www.googleapis.com/auth/cloud-platform")
// if err != nil {
// return nil, fmt.Errorf("finding default credentials: %w", err)
// }
// opts = append(opts, option.WithCredentials(creds))
// }

projectID := ""
location := ""

Expand Down
1 change: 1 addition & 0 deletions dev/tools/controllerbuilder/pkg/toolbot/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ func (x *CSVExporter) RunGemini(ctx context.Context, input *DataPoint, out io.Wr
for _, part := range content.Parts {
if text, ok := part.(genai.Text); ok {
klog.Infof("TEXT: %+v", text)
out.Write([]byte(text + "\n"))
} else {
klog.Infof("UNKNOWN: %T %+v", part, part)
}
Expand Down

0 comments on commit 20a42d3

Please sign in to comment.