diff --git a/flyteadmin/pkg/common/cloud.go b/flyteadmin/pkg/common/cloud.go index 93f3669a55..f982438594 100644 --- a/flyteadmin/pkg/common/cloud.go +++ b/flyteadmin/pkg/common/cloud.go @@ -7,6 +7,7 @@ type CloudProvider = string const ( AWS CloudProvider = "aws" GCP CloudProvider = "gcp" + Azure CloudProvider = "azure" Sandbox CloudProvider = "sandbox" Local CloudProvider = "local" None CloudProvider = "none" diff --git a/flyteadmin/pkg/data/factory.go b/flyteadmin/pkg/data/factory.go index b746d972fc..7d25e4e1fe 100644 --- a/flyteadmin/pkg/data/factory.go +++ b/flyteadmin/pkg/data/factory.go @@ -50,7 +50,11 @@ func GetRemoteDataHandler(cfg RemoteDataHandlerConfig) RemoteDataHandler { return &remoteDataHandler{ remoteURL: implementations.NewGCPRemoteURL(cfg.SigningPrincipal, signedURLDuration), } - + case common.Azure: + signedURLDuration := time.Minute * time.Duration(cfg.SignedURLDurationMinutes) + return &remoteDataHandler{ + remoteURL: implementations.NewAzureRemoteURL(*cfg.RemoteDataStoreClient, signedURLDuration), + } case common.Local: logger.Infof(context.TODO(), "setting up local signer ----- ") // Since minio = aws s3, we are creating the same client but using the config primitives from aws diff --git a/flyteadmin/pkg/data/implementations/azure_remote_url.go b/flyteadmin/pkg/data/implementations/azure_remote_url.go new file mode 100644 index 0000000000..a9649ef943 --- /dev/null +++ b/flyteadmin/pkg/data/implementations/azure_remote_url.go @@ -0,0 +1,46 @@ +package implementations + +import ( + "context" + "github.com/flyteorg/flyte/flyteadmin/pkg/data/interfaces" + "github.com/flyteorg/flyte/flyteadmin/pkg/errors" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flytestdlib/storage" + "github.com/flyteorg/stow" + "google.golang.org/grpc/codes" + "time" +) + +type AzureRemoteURL struct { + remoteDataStoreClient storage.DataStore + presignDuration time.Duration +} + +func (n *AzureRemoteURL) Get(ctx context.Context, uri string) (admin.UrlBlob, error) { + metadata, err := n.remoteDataStoreClient.Head(ctx, storage.DataReference(uri)) + if err != nil { + return admin.UrlBlob{}, errors.NewFlyteAdminErrorf(codes.Internal, + "failed to get metadata for uri: %s with err: %v", uri, err) + } + + signedUri, err := n.remoteDataStoreClient.CreateSignedURL(ctx, storage.DataReference(uri), storage.SignedURLProperties{ + Scope: stow.ClientMethodGet, + ExpiresIn: n.presignDuration, + }) + if err != nil { + return admin.UrlBlob{}, errors.NewFlyteAdminErrorf(codes.Internal, + "failed to get metadata for uri: %s with err: %v", uri, err) + } + + return admin.UrlBlob{ + Url: signedUri.URL.String(), + Bytes: metadata.Size(), + }, nil +} + +func NewAzureRemoteURL(remoteDataStoreClient storage.DataStore, presignDuration time.Duration) interfaces.RemoteURLInterface { + return &AzureRemoteURL{ + remoteDataStoreClient: remoteDataStoreClient, + presignDuration: presignDuration, + } +} diff --git a/flyteadmin/pkg/data/implementations/azure_remote_url_test.go b/flyteadmin/pkg/data/implementations/azure_remote_url_test.go new file mode 100644 index 0000000000..f48adc52e9 --- /dev/null +++ b/flyteadmin/pkg/data/implementations/azure_remote_url_test.go @@ -0,0 +1,38 @@ +package implementations + +import ( + "context" + commonMocks "github.com/flyteorg/flyte/flyteadmin/pkg/common/mocks" + "github.com/flyteorg/flyte/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "testing" +) + +type mockMetadata struct{} + +func (m mockMetadata) Exists() bool { + return true +} + +func (m mockMetadata) Size() int64 { + return 1 +} + +func (m mockMetadata) Etag() string { + return "etag" +} + +func TestAzureGet(t *testing.T) { + inputUri := "abfs//test/data" + mockStorage := commonMocks.GetMockStorageClient() + mockStorage.ComposedProtobufStore.(*commonMocks.TestDataStore).HeadCb = + func(ctx context.Context, reference storage.DataReference) (storage.Metadata, error) { + return mockMetadata{}, nil + } + remoteUrl := AzureRemoteURL{ + remoteDataStoreClient: *mockStorage, presignDuration: 1, + } + + result, _ := remoteUrl.Get(context.TODO(), inputUri) + assert.Contains(t, inputUri, result.Url) +}