diff --git a/internal/commands/relationship.go b/internal/commands/relationship.go index 2a12dec..323bc0e 100644 --- a/internal/commands/relationship.go +++ b/internal/commands/relationship.go @@ -8,6 +8,7 @@ import ( "io" "os" "strings" + "time" "unicode" "github.com/authzed/zed/internal/client" @@ -21,6 +22,7 @@ import ( "github.com/spf13/cobra" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" ) func RegisterRelationshipCmd(rootCmd *cobra.Command) *cobra.Command { @@ -29,11 +31,13 @@ func RegisterRelationshipCmd(rootCmd *cobra.Command) *cobra.Command { relationshipCmd.AddCommand(createCmd) createCmd.Flags().Bool("json", false, "output as JSON") createCmd.Flags().String("caveat", "", `the caveat for the relationship, with format: 'caveat_name:{"some":"context"}'`) + createCmd.Flags().String("expiration-time", "", `the expiration time of the relationship in RFC 3339 format`) createCmd.Flags().IntP("batch-size", "b", 100, "batch size when writing streams of relationships from stdin") relationshipCmd.AddCommand(touchCmd) touchCmd.Flags().Bool("json", false, "output as JSON") touchCmd.Flags().String("caveat", "", `the caveat for the relationship, with format: 'caveat_name:{"some":"context"}'`) + touchCmd.Flags().String("expiration-time", "", `the expiration time for the relationship in RFC 3339 format`) touchCmd.Flags().IntP("batch-size", "b", 100, "batch size when writing streams of relationships from stdin") relationshipCmd.AddCommand(deleteCmd) @@ -494,6 +498,10 @@ func writeRelationshipCmdFunc(operation v1.RelationshipUpdate_Operation, input * if err := handleCaveatFlag(cmd, rel); err != nil { return err } + + if err := handleExpirationFlag(cmd, rel); err != nil { + return err + } } updateBatch = append(updateBatch, &v1.RelationshipUpdate{ @@ -536,3 +544,17 @@ func handleCaveatFlag(cmd *cobra.Command, rel *v1.Relationship) error { } return nil } + +func handleExpirationFlag(cmd *cobra.Command, rel *v1.Relationship) error { + expirationTime := cobrautil.MustGetString(cmd, "expiration-time") + + if expirationTime != "" { + t, err := time.Parse(time.RFC3339, expirationTime) + if err != nil { + return fmt.Errorf("could not parse RFC 3339 timestamp: %w", err) + } + rel.OptionalExpiresAt = timestamppb.New(t) + } + + return nil +} diff --git a/internal/commands/relationship_test.go b/internal/commands/relationship_test.go index d7bad30..8136b6c 100644 --- a/internal/commands/relationship_test.go +++ b/internal/commands/relationship_test.go @@ -276,6 +276,7 @@ func TestWriteRelationshipCmdFuncFromTTY(t *testing.T) { cmd.Flags().Int("batch-size", 100, "") cmd.Flags().Bool("json", true, "") cmd.Flags().String("caveat", `cav:{"letters": ["a", "b", "c"]}`, "") + cmd.Flags().String("expiration-time", "", "") err = f(cmd, []string{"resource:1", "view", "user:1"}) require.NoError(t, err) @@ -322,6 +323,7 @@ func TestWriteRelationshipCmdFuncArgsTakePrecedence(t *testing.T) { cmd.Flags().Int("batch-size", 100, "") cmd.Flags().Bool("json", true, "") cmd.Flags().String("caveat", "", "") + cmd.Flags().String("expiration-time", "", "") err := f(cmd, []string{"resource:1", "viewer", "user:1"}) require.NoError(t, err) @@ -365,6 +367,7 @@ func TestWriteRelationshipCmdFuncFromStdin(t *testing.T) { cmd.Flags().Int("batch-size", 100, "") cmd.Flags().Bool("json", true, "") cmd.Flags().String("caveat", "", "") + cmd.Flags().String("expiration-time", "", "") err := f(cmd, nil) require.NoError(t, err) @@ -414,6 +417,7 @@ func TestWriteRelationshipCmdFuncFromStdinBatch(t *testing.T) { cmd.Flags().Int("batch-size", 1, "") cmd.Flags().Bool("json", true, "") cmd.Flags().String("caveat", "", "") + cmd.Flags().String("expiration-time", "", "") err := f(cmd, nil) require.NoError(t, err) @@ -454,11 +458,112 @@ func TestWriteRelationshipCmdFuncFromFailsWithCaveatArg(t *testing.T) { cmd.Flags().Int("batch-size", 1, "") cmd.Flags().Bool("json", true, "") cmd.Flags().String("caveat", `cav:{"letters": ["a", "b", "c"]}`, "") + cmd.Flags().String("expiration-time", "", "") err := f(cmd, nil) require.ErrorContains(t, err, "cannot specify a caveat in both the relationship and the --caveat flag") } +func TestWriteRelationshipCmdFuncWithExpirationTime(t *testing.T) { + mock := func(*cobra.Command) (client.Client, error) { + return &mockClient{t: t, expectedWrites: []*v1.WriteRelationshipsRequest{ + { + Updates: []*v1.RelationshipUpdate{ + { + Operation: v1.RelationshipUpdate_OPERATION_TOUCH, + Relationship: tuple.MustParseV1Rel(`resource:1#viewer@user:1[expiration:2025-01-27T20:04:05Z]`), + }, + }, + }, + { + Updates: []*v1.RelationshipUpdate{ + { + Operation: v1.RelationshipUpdate_OPERATION_TOUCH, + Relationship: tuple.MustParseV1Rel(`resource:1#viewer@user:2[expiration:2025-01-27T20:04:05Z]`), + }, + }, + }, + }}, nil + } + + fi := fileFromStrings(t, []string{ + `resource:1 viewer user:1`, + `resource:1 viewer user:2`, + }) + defer func() { + require.NoError(t, fi.Close()) + }() + t.Cleanup(func() { + _ = os.Remove(fi.Name()) + }) + + originalClient := client.NewClient + client.NewClient = mock + defer func() { + client.NewClient = originalClient + }() + + f := writeRelationshipCmdFunc(v1.RelationshipUpdate_OPERATION_TOUCH, fi) + cmd := &cobra.Command{} + cmd.Flags().Int("batch-size", 1, "") + cmd.Flags().Bool("json", true, "") + cmd.Flags().String("caveat", "", "") + cmd.Flags().String("expiration-time", "2025-01-27T20:04:05Z", "") + + err := f(cmd, nil) + require.NoError(t, err) +} + +func TestWriteRelationshipCmdFuncFromStdinBatchWithExpirationTime(t *testing.T) { + mock := func(*cobra.Command) (client.Client, error) { + return &mockClient{t: t, expectedWrites: []*v1.WriteRelationshipsRequest{ + { + Updates: []*v1.RelationshipUpdate{ + { + Operation: v1.RelationshipUpdate_OPERATION_TOUCH, + Relationship: tuple.MustParseV1Rel(`resource:1#viewer@user:1[expiration:2025-01-27T20:04:05Z]`), + }, + }, + }, + { + Updates: []*v1.RelationshipUpdate{ + { + Operation: v1.RelationshipUpdate_OPERATION_TOUCH, + Relationship: tuple.MustParseV1Rel(`resource:1#viewer@user:2[expiration:2025-01-27T20:04:05Z]`), + }, + }, + }, + }}, nil + } + + fi := fileFromStrings(t, []string{ + `resource:1 viewer user:1`, + `resource:1 viewer user:2`, + }) + defer func() { + require.NoError(t, fi.Close()) + }() + t.Cleanup(func() { + _ = os.Remove(fi.Name()) + }) + + originalClient := client.NewClient + client.NewClient = mock + defer func() { + client.NewClient = originalClient + }() + + f := writeRelationshipCmdFunc(v1.RelationshipUpdate_OPERATION_TOUCH, fi) + cmd := &cobra.Command{} + cmd.Flags().Int("batch-size", 1, "") + cmd.Flags().Bool("json", true, "") + cmd.Flags().String("caveat", "", "") + cmd.Flags().String("expiration-time", "2025-01-27T20:04:05Z", "") + + err := f(cmd, nil) + require.NoError(t, err) +} + func fileFromStrings(t *testing.T, strings []string) *os.File { t.Helper()