Skip to content

Commit d002f2a

Browse files
authored
fix: azure: fix path validation (#21)
Signed-off-by: Grant Linville <[email protected]>
1 parent 10c59d7 commit d002f2a

File tree

2 files changed

+62
-19
lines changed

2 files changed

+62
-19
lines changed

pkg/client/azure.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ type azureProvider struct {
4141
revisionsProvider *azureProvider
4242
}
4343

44-
func (a *azureProvider) validatePath(path string) error {
44+
func (a *azureProvider) validatePath(path string, allowTrailingSlash bool) error {
4545
if path == "" {
4646
return nil // empty path is valid in some contexts (e.g., Ls root)
4747
}
@@ -59,7 +59,7 @@ func (a *azureProvider) validatePath(path string) error {
5959
// Azure Blob Storage naming rules:
6060
// - Cannot start or end with '/'
6161
// - Cannot contain consecutive forward slashes
62-
if strings.HasSuffix(path, "/") {
62+
if !allowTrailingSlash && strings.HasSuffix(path, "/") {
6363
return fmt.Errorf("invalid path: cannot end with '/'")
6464
}
6565
if strings.Contains(path, "//") {
@@ -128,7 +128,7 @@ func (a *azureProvider) Rm(ctx context.Context, id string) error {
128128

129129
func (a *azureProvider) Ls(ctx context.Context, prefix string) ([]string, error) {
130130
prefix = strings.TrimPrefix(prefix, "/")
131-
if err := a.validatePath(prefix); err != nil {
131+
if err := a.validatePath(prefix, true); err != nil {
132132
return nil, err
133133
}
134134
if prefix != "" {
@@ -159,7 +159,7 @@ func (a *azureProvider) Ls(ctx context.Context, prefix string) ([]string, error)
159159

160160
func (a *azureProvider) DeleteFile(ctx context.Context, filePath string) error {
161161
filePath = strings.TrimPrefix(filePath, "/")
162-
if err := a.validatePath(filePath); err != nil {
162+
if err := a.validatePath(filePath, false); err != nil {
163163
return err
164164
}
165165
blobClient := a.client.ServiceClient().NewContainerClient(a.containerName).NewBlockBlobClient(fmt.Sprintf("%s/%s", a.dir, filePath))
@@ -195,7 +195,7 @@ func (a *azureProvider) DeleteFile(ctx context.Context, filePath string) error {
195195
func (a *azureProvider) OpenFile(ctx context.Context, filePath string, opt OpenOptions) (*File, error) {
196196
originalFilePath := filePath
197197
filePath = strings.TrimPrefix(filePath, "/")
198-
if err := a.validatePath(filePath); err != nil {
198+
if err := a.validatePath(filePath, false); err != nil {
199199
return nil, err
200200
}
201201
blobClient := a.client.ServiceClient().NewContainerClient(a.containerName).NewBlockBlobClient(fmt.Sprintf("%s/%s", a.dir, filePath))
@@ -227,7 +227,7 @@ func (a *azureProvider) OpenFile(ctx context.Context, filePath string, opt OpenO
227227

228228
func (a *azureProvider) WriteFile(ctx context.Context, fileName string, reader io.Reader, opt WriteOptions) error {
229229
fileName = strings.TrimPrefix(fileName, "/")
230-
if err := a.validatePath(fileName); err != nil {
230+
if err := a.validatePath(fileName, false); err != nil {
231231
return err
232232
}
233233
if a.revisionsProvider != nil && (opt.CreateRevision == nil || *opt.CreateRevision) {
@@ -274,7 +274,7 @@ func (a *azureProvider) WriteFile(ctx context.Context, fileName string, reader i
274274
func (a *azureProvider) StatFile(ctx context.Context, fileName string, opt StatOptions) (FileInfo, error) {
275275
originalFileName := fileName
276276
fileName = strings.TrimPrefix(fileName, "/")
277-
if err := a.validatePath(fileName); err != nil {
277+
if err := a.validatePath(fileName, false); err != nil {
278278
return FileInfo{}, err
279279
}
280280
blobClient := a.client.ServiceClient().NewContainerClient(a.containerName).NewBlockBlobClient(fmt.Sprintf("%s/%s", a.dir, fileName))
@@ -333,7 +333,7 @@ func (a *azureProvider) StatFile(ctx context.Context, fileName string, opt StatO
333333

334334
func (a *azureProvider) RemoveAllWithPrefix(ctx context.Context, prefix string) error {
335335
prefix = strings.TrimPrefix(prefix, "/")
336-
if err := a.validatePath(prefix); err != nil {
336+
if err := a.validatePath(prefix, true); err != nil {
337337
return err
338338
}
339339
if prefix != "" {
@@ -366,23 +366,23 @@ func (a *azureProvider) RemoveAllWithPrefix(ctx context.Context, prefix string)
366366

367367
func (a *azureProvider) ListRevisions(ctx context.Context, fileName string) ([]RevisionInfo, error) {
368368
fileName = strings.TrimPrefix(fileName, "/")
369-
if err := a.validatePath(fileName); err != nil {
369+
if err := a.validatePath(fileName, false); err != nil {
370370
return nil, err
371371
}
372372
return listRevisions(ctx, a.revisionsProvider, fmt.Sprintf("%s://%s/%s", AzureProvider, a.containerName, a.dir), fileName)
373373
}
374374

375375
func (a *azureProvider) GetRevision(ctx context.Context, fileName, revisionID string) (*File, error) {
376376
fileName = strings.TrimPrefix(fileName, "/")
377-
if err := a.validatePath(fileName); err != nil {
377+
if err := a.validatePath(fileName, false); err != nil {
378378
return nil, err
379379
}
380380
return getRevision(ctx, a.revisionsProvider, fileName, revisionID)
381381
}
382382

383383
func (a *azureProvider) DeleteRevision(ctx context.Context, fileName, revisionID string) error {
384384
fileName = strings.TrimPrefix(fileName, "/")
385-
if err := a.validatePath(fileName); err != nil {
385+
if err := a.validatePath(fileName, false); err != nil {
386386
return err
387387
}
388388
return deleteRevision(ctx, a.revisionsProvider, fileName, revisionID)

pkg/client/azure_test.go

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ func TestLsWithPrefixAzure(t *testing.T) {
270270
}
271271

272272
defer func() {
273-
err := azurePrv.RemoveAllWithPrefix(context.Background(), "testDir")
273+
err := azurePrv.RemoveAllWithPrefix(context.Background(), "testDir/")
274274
if err != nil {
275275
t.Errorf("unexpected error when deleting file %s: %v", "testDir", err)
276276
}
@@ -295,7 +295,7 @@ func TestLsWithPrefixAzure(t *testing.T) {
295295
}(fileName)
296296
}
297297

298-
contents, err := azurePrv.Ls(context.Background(), "testDir")
298+
contents, err := azurePrv.Ls(context.Background(), "testDir/")
299299
if err != nil {
300300
t.Fatalf("unexpected error when listing files: %v", err)
301301
}
@@ -342,7 +342,7 @@ func TestRemoveAllWithPrefixAzure(t *testing.T) {
342342
}(fileName)
343343
}
344344

345-
err := azurePrv.RemoveAllWithPrefix(context.Background(), "testDir")
345+
err := azurePrv.RemoveAllWithPrefix(context.Background(), "testDir/")
346346
if err != nil {
347347
t.Errorf("unexpected error when deleting all with prefix testDir: %v", err)
348348
}
@@ -695,12 +695,7 @@ func TestPathValidationAzure(t *testing.T) {
695695
{"traversal nested", "foo/../../test.txt", true, "must not contain '..'", nil},
696696
{"traversal with slash", "../test.txt/", true, "must not contain '..'", nil},
697697

698-
// Absolute path tests
699-
{"absolute path", "/test.txt", true, "must be relative", nil},
700-
{"absolute nested", "/foo/test.txt", true, "must be relative", nil},
701-
702698
// Azure naming rule tests
703-
{"trailing slash", "test/", true, "cannot end with '/'", nil},
704699
{"double slash", "foo//bar.txt", true, "cannot contain consecutive '/'", nil},
705700
{"invalid chars", "test*.txt", true, "contains invalid characters", nil},
706701
{"invalid chars nested", "foo/test*.txt", true, "contains invalid characters", nil},
@@ -725,6 +720,7 @@ func TestPathValidationAzure(t *testing.T) {
725720
}()
726721

727722
for _, tt := range tests {
723+
// Test operations that should not allow trailing slashes
728724
t.Run(fmt.Sprintf("WriteFile/%s", tt.name), func(t *testing.T) {
729725
err := azurePrv.WriteFile(context.Background(), tt.path, strings.NewReader("test"), WriteOptions{})
730726
assertPathError(t, err, tt.wantErr, tt.errMsg)
@@ -745,6 +741,7 @@ func TestPathValidationAzure(t *testing.T) {
745741
assertPathError(t, err, tt.wantErr, tt.errMsg)
746742
})
747743

744+
// Test operations that should allow trailing slashes
748745
t.Run(fmt.Sprintf("Ls/%s", tt.name), func(t *testing.T) {
749746
_, err := azurePrv.Ls(context.Background(), tt.path)
750747
assertPathError(t, err, tt.wantErr, tt.errMsg)
@@ -755,6 +752,7 @@ func TestPathValidationAzure(t *testing.T) {
755752
assertPathError(t, err, tt.wantErr, tt.errMsg)
756753
})
757754

755+
// Test revision operations that should not allow trailing slashes
758756
t.Run(fmt.Sprintf("ListRevisions/%s", tt.name), func(t *testing.T) {
759757
_, err := azurePrv.ListRevisions(context.Background(), tt.path)
760758
assertPathError(t, err, tt.wantErr, tt.errMsg)
@@ -772,6 +770,51 @@ func TestPathValidationAzure(t *testing.T) {
772770
assertPathError(t, err, tt.wantErr, tt.errMsg)
773771
})
774772
}
773+
774+
// Additional tests specifically for trailing slashes
775+
trailingSlashTests := []struct {
776+
name string
777+
path string
778+
wantErr bool
779+
errMsg string
780+
}{
781+
{"trailing slash in Ls", "test/", false, ""},
782+
{"trailing slash in RemoveAllWithPrefix", "test/", false, ""},
783+
{"trailing slash in WriteFile", "test/", true, "cannot end with '/'"},
784+
{"trailing slash in OpenFile", "test/", true, "cannot end with '/'"},
785+
{"trailing slash in StatFile", "test/", true, "cannot end with '/'"},
786+
{"trailing slash in DeleteFile", "test/", true, "cannot end with '/'"},
787+
{"trailing slash in ListRevisions", "test/", true, "cannot end with '/'"},
788+
{"trailing slash in GetRevision", "test/", true, "cannot end with '/'"},
789+
{"trailing slash in DeleteRevision", "test/", true, "cannot end with '/'"},
790+
}
791+
792+
for _, tt := range trailingSlashTests {
793+
t.Run(tt.name, func(t *testing.T) {
794+
var err error
795+
switch {
796+
case strings.Contains(tt.name, "Ls"):
797+
_, err = azurePrv.Ls(context.Background(), tt.path)
798+
case strings.Contains(tt.name, "RemoveAllWithPrefix"):
799+
err = azurePrv.RemoveAllWithPrefix(context.Background(), tt.path)
800+
case strings.Contains(tt.name, "WriteFile"):
801+
err = azurePrv.WriteFile(context.Background(), tt.path, strings.NewReader("test"), WriteOptions{})
802+
case strings.Contains(tt.name, "OpenFile"):
803+
_, err = azurePrv.OpenFile(context.Background(), tt.path, OpenOptions{})
804+
case strings.Contains(tt.name, "StatFile"):
805+
_, err = azurePrv.StatFile(context.Background(), tt.path, StatOptions{})
806+
case strings.Contains(tt.name, "DeleteFile"):
807+
err = azurePrv.DeleteFile(context.Background(), tt.path)
808+
case strings.Contains(tt.name, "ListRevisions"):
809+
_, err = azurePrv.ListRevisions(context.Background(), tt.path)
810+
case strings.Contains(tt.name, "GetRevision"):
811+
_, err = azurePrv.GetRevision(context.Background(), tt.path, "1")
812+
case strings.Contains(tt.name, "DeleteRevision"):
813+
err = azurePrv.DeleteRevision(context.Background(), tt.path, "1")
814+
}
815+
assertPathError(t, err, tt.wantErr, tt.errMsg)
816+
})
817+
}
775818
}
776819

777820
// Helper function to assert path validation errors

0 commit comments

Comments
 (0)