Skip to content

Commit

Permalink
Support context for advertiser key
Browse files Browse the repository at this point in the history
  • Loading branch information
amanjpro committed Nov 1, 2024
1 parent c35ba3e commit 40de157
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 42 deletions.
16 changes: 10 additions & 6 deletions pkg/cmd/cli/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import (
)

type CliContext struct {
ctx context.Context
config *Config
ctx context.Context
config *Config
keyContext string
}

type (
Expand All @@ -29,15 +30,18 @@ type (

Version VersionCmd `cmd:"" help:"Print utility version"`

CleanroomCmd CleanroomCmd `cmd:"" name:"cleanroom" help:"Commands for interacting with Optable PAIR clean rooms."`
KeyCmd KeyCmd `cmd:"" name:"key" help:"Commands for managing advertiser clean room private keys."`
CleanroomCmd CleanroomCmd `cmd:"" name:"cleanroom" help:"Commands for interacting with Optable PAIR clean rooms."`
AdvertiserKeyPath string `cmd:"" short:"k" name:"keypath" help:"The path to the advertiser clean room's private key to use for the operation. If not provided, the key saved in the configuration file will be used."`
KeyCmd KeyCmd `cmd:"" name:"key" help:"Commands for managing advertiser clean room private keys."`
Context string `short:"c" help:"Context name to use" default:"default"`
}
)

func (c *Cli) NewContext(conf *Config) (*CliContext, error) {
cliCtx := &CliContext{
ctx: NewLogger("opair", c.Verbose).WithContext(context.Background()),
config: conf,
ctx: NewLogger("opair", c.Verbose).WithContext(context.Background()),
config: conf,
keyContext: c.Context,
}

return cliCtx, nil
Expand Down
55 changes: 34 additions & 21 deletions pkg/cmd/cli/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func ensureKeyConfigPath(configPath string) error {
return nil
}

func LoadKeyConfig(configPath string) (*Config, error) {
func loadAllKeyConfigs(configPath string) (map[string]keys.KeyConfig, error) {
if err := ensureKeyConfigPath(configPath); err != nil {
return nil, err
}
Expand All @@ -43,44 +43,57 @@ func LoadKeyConfig(configPath string) (*Config, error) {
}
defer file.Close()

var config keys.KeyConfig
if err := json.NewDecoder(file).Decode(&config); err != nil {
var configs map[string]keys.KeyConfig
if err := json.NewDecoder(file).Decode(&configs); err != nil {
if errors.Is(err, io.EOF) {
return &Config{configPath: configPath}, nil
return nil, err
} else {
return nil, fmt.Errorf("json.Decode: %w", err)
}
}
return configs, nil
}

return &Config{
configPath: configPath,
keyConfig: &config,
}, nil
func LoadKeyConfig(context, configPath string, strict bool) (*Config, error) {
configs, err := loadAllKeyConfigs(configPath)
if errors.Is(err, io.EOF) {
return &Config{configPath: configPath}, nil
} else if err != nil {
return nil, err
}
if config, ok := configs[context]; ok {
return &Config{
configPath: configPath,
keyConfig: &config,
}, nil
}
if !strict {
return &Config{configPath: configPath}, nil
}
return nil, errors.New("no key configuration found for the specified context")
}

func (c *CliContext) SaveConfig() error {
func (c *CliContext) SaveConfig(context string) error {
configs, err := loadAllKeyConfigs(c.config.configPath)
if errors.Is(err, io.EOF) {
configs = make(map[string]keys.KeyConfig)
} else if err != nil {
return err
}
file, err := os.OpenFile(c.config.configPath, os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return fmt.Errorf("os.OpenFile: %w", err)
}
defer file.Close()

if err := json.NewEncoder(file).Encode(c.config.keyConfig); err != nil {
configs[context] = *c.config.keyConfig
if err := json.NewEncoder(file).Encode(configs); err != nil {
return fmt.Errorf("json.Encode: %w", err)
}

return nil
}

func ReadKeyConfig(providedKeyPath string, defaultConfig *keys.KeyConfig) (string, error) {
if providedKeyPath == "" {
advertiserKey := defaultConfig.Key
if advertiserKey == "" {
return "", errors.New("advertiser key is required, please either provide one or generate one.")
}
return advertiserKey, nil
}
config, err := LoadKeyConfig(providedKeyPath)
func ReadKeyConfig(context string, config *Config) (string, error) {
config, err := LoadKeyConfig(context, config.configPath, true)
if err != nil {
return "", err
}
Expand Down
9 changes: 4 additions & 5 deletions pkg/cmd/cli/decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ import (

type (
DecryptCmd struct {
Input string `arg:"" help:"The input file containing the already matched triple encrypted PAIR IDs to be decrypted. If given a directory, all files in the directory will be processed."`
AdvertiserKeyPath string `cmd:"" short:"k" name:"keypath" help:"The path to the advertiser clean room's private key to use for the operation. If not provided, the key saved in the configuration file will be used."`
Output string `cmd:"" short:"o" help:"The output file to write the resulting publisher decrypted PAIR IDs to. Defaults to stdout."`
NumThreads int `cmd:"" short:"n" help:"The number of threads to use for the operation. Defaults to the number of the available cores on the machine."`
Input string `arg:"" help:"The input file containing the already matched triple encrypted PAIR IDs to be decrypted. If given a directory, all files in the directory will be processed."`
Output string `cmd:"" short:"o" help:"The output file to write the resulting publisher decrypted PAIR IDs to. Defaults to stdout."`
NumThreads int `cmd:"" short:"n" help:"The number of threads to use for the operation. Defaults to the number of the available cores on the machine."`
}
)

Expand All @@ -35,7 +34,7 @@ func (c *DecryptCmd) Run(cli *CliContext) error {
if c.NumThreads <= 0 {
c.NumThreads = defaultThreadCount
}
advertiserKey, err := ReadKeyConfig(c.AdvertiserKeyPath, cli.config.keyConfig)
advertiserKey, err := ReadKeyConfig(cli.keyContext, cli.config)
if err != nil {
return fmt.Errorf("ReadKeyConfig: %w", err)
}
Expand Down
10 changes: 10 additions & 0 deletions pkg/cmd/cli/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"optable-pair-cli/pkg/internal"
"time"

"google.golang.org/protobuf/encoding/protojson"
)
Expand Down Expand Up @@ -36,6 +37,15 @@ func (c *GetCmd) Run(cli *CliContext) error {
return err
}

config := cleanroom.GetConfig().GetPair()
shouldTokenRefresh := config.GcsToken == nil || config.GcsToken.ExpireTime.AsTime().Before(time.Now())
if shouldTokenRefresh {
cleanroom, err = client.RefreshToken(ctx)
if err != nil {
return err
}
}

marshaler := protojson.MarshalOptions{
Multiline: true,
UseProtoNames: true,
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/cli/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (c *CreateCmd) Run(cli *CliContext) error {
// overwrite the key config
conf = key
cli.config.keyConfig = conf
cli.SaveConfig()
cli.SaveConfig(cli.keyContext)

fmt.Println("The following key has been generated and saved to: ", cli.config.configPath)
} else {
Expand Down
3 changes: 1 addition & 2 deletions pkg/cmd/cli/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ type (
AdvertiserInput string `cmd:"" short:"a" help:"If given a file path, it will read from the file. If not provided, it will read from the GCS path specified from the token."`
PublisherInput string `cmd:"" short:"p" help:"If given a file path, it will read from the file. If not provided, it will read from the GCS path specified from the token."`
OutputDir string `cmd:"" short:"o" help:"The output directory path to write the decrypted and matched double encrypted PAIR IDs. Each thread will write one single file in the given directory path. If none are provided, all matched and decrypted PAIR IDs will be written to stdout."`
AdvertiserKeyPath string `cmd:"" short:"k" help:"The path to the advertiser private key to use for the operation. If not provided, the key saved in the configuration file will be used."`
NumThreads int `cmd:"" short:"n" help:"The number of threads to use for the operation. Defaults to the number of the available cores on the machine."`
PublisherPAIRIDs string `cmd:"" short:"s" name:"publisher-pair-ids" help:"Use the publisher's PAIR IDs from a path."`
}
Expand All @@ -30,7 +29,7 @@ and output the list of decrypted and matched PAIR IDs.
func (c *MatchCmd) Run(cli *CliContext) error {
ctx := cli.Context()

advertiserKey, err := ReadKeyConfig(c.AdvertiserKeyPath, cli.config.keyConfig)
advertiserKey, err := ReadKeyConfig(cli.keyContext, cli.config)
if err != nil {
return fmt.Errorf("ReadKeyConfig: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/cli/participate.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type (
func (c *ParticipateCmd) Run(cli *CliContext) error {
ctx := cli.Context()

advertiserKey, err := ReadKeyConfig(c.AdvertiserKeyPath, cli.config.keyConfig)
advertiserKey, err := ReadKeyConfig(cli.keyContext, cli.config)
if err != nil {
return fmt.Errorf("ReadKeyConfig: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/cli/re-encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type (
func (c *ReEncryptCmd) Run(cli *CliContext) error {
ctx := cli.Context()

advertiserKey, err := ReadKeyConfig(c.AdvertiserKeyPath, cli.config.keyConfig)
advertiserKey, err := ReadKeyConfig(cli.keyContext, cli.config)
if err != nil {
return fmt.Errorf("ReadKeyConfig: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ last successful step.
func (c *RunCmd) Run(cli *CliContext) error {
ctx := cli.Context()

advertiserKey, err := ReadKeyConfig(c.AdvertiserKeyPath, cli.config.keyConfig)
advertiserKey, err := ReadKeyConfig(cli.keyContext, cli.config)
if err != nil {
return fmt.Errorf("ReadKeyConfig: %w", err)
}
Expand Down
12 changes: 8 additions & 4 deletions pkg/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@ func main() {
},
)

configPath, err := xdg.ConfigFile(keyConfigPath)
if err != nil {
kongCtx.FatalIfErrorf(err)
configPath := c.AdvertiserKeyPath
if configPath == "" {
var err error
configPath, err = xdg.ConfigFile(keyConfigPath)
if err != nil {
kongCtx.FatalIfErrorf(err)
}
}

conf, err := cli.LoadKeyConfig(configPath)
conf, err := cli.LoadKeyConfig(c.Context, configPath, false)
if err != nil {
kongCtx.FatalIfErrorf(err)
}
Expand Down

0 comments on commit 40de157

Please sign in to comment.