diff --git a/oci/client/client.go b/oci/client/client.go index c816855a..b3cd257a 100644 --- a/oci/client/client.go +++ b/oci/client/client.go @@ -21,15 +21,13 @@ import ( "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/v1/remote" -"github.com/hashicorp/go-retryablehttp" "github.com/fluxcd/pkg/oci" ) // Client holds the options for accessing remote OCI registries. type Client struct { - options []crane.Option - httpClient *retryablehttp.Client + options []crane.Option } // NewClient returns an OCI client configured with the given crane options. diff --git a/oci/client/download.go b/oci/client/download.go new file mode 100644 index 00000000..6f882f68 --- /dev/null +++ b/oci/client/download.go @@ -0,0 +1,359 @@ +/* +Copyright 2024 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "syscall" + "time" + + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/name" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" + "github.com/google/go-containerregistry/pkg/v1/remote/transport" + "github.com/hashicorp/go-retryablehttp" + "golang.org/x/sync/errgroup" +) + +const ( + minChunkSize = 100 * 1024 * 1024 // 100MB + maxChunkSize = 1 << 30 // 1GB + defaultNumberOfChunks = 50 +) + +var ( + // errRangeRequestNotSupported is returned when the registry does not support range requests. + errRangeRequestNotSupported = fmt.Errorf("range requests are not supported by the registry") + errCopyFailed = errors.New("copy failed") +) + +var ( + retries = 3 + defaultRetryBackoff = remote.Backoff{ + Duration: 1.0 * time.Second, + Factor: 3.0, + Jitter: 0.1, + Steps: retries, + } +) + +type downloadOption func(*downloadOptions) + +type downloadOptions struct { + transport http.RoundTripper + auth authn.Authenticator + keychain authn.Keychain + numberOfChunks int +} + +type blobManager struct { + name name.Reference + c *retryablehttp.Client + layer v1.Layer + path string + digest v1.Hash + size int64 + downloadOptions +} + +func withTransport(t http.RoundTripper) downloadOption { + return func(o *downloadOptions) { + o.transport = t + } +} + +func withAuth(auth authn.Authenticator) downloadOption { + return func(o *downloadOptions) { + o.auth = auth + } +} + +func withKeychain(k authn.Keychain) downloadOption { + return func(o *downloadOptions) { + o.keychain = k + } +} + +func withNumberOfChunks(n int) downloadOption { + return func(o *downloadOptions) { + o.numberOfChunks = n + } +} + +type chunk struct { + n int + offset int64 + size int64 + writeCounter +} + +func makeChunk(n int, offset, size int64) *chunk { + return &chunk{ + n: n, + offset: offset, + size: size, + writeCounter: writeCounter{}, + } +} + +// newDownloader returns a new blobManager with the given options. +func newDownloader(name name.Reference, path string, layer v1.Layer, opts ...downloadOption) *blobManager { + o := &downloadOptions{ + numberOfChunks: defaultNumberOfChunks, + keychain: authn.DefaultKeychain, + transport: remote.DefaultTransport.(*http.Transport).Clone(), + } + d := &blobManager{ + layer: layer, + name: name, + path: path, + downloadOptions: *o, + } + for _, opt := range opts { + opt(&d.downloadOptions) + } + + return d +} + +func (d *blobManager) download(ctx context.Context) error { + digest, err := d.layer.Digest() + if err != nil { + return fmt.Errorf("failed to get layer digest: %w", err) + } + d.digest = digest + + size, err := d.layer.Size() + if err != nil { + return fmt.Errorf("failed to get layer size: %w", err) + } + d.size = size + + if d.c == nil { + h, err := makeHttpClient(ctx, d.name.Context(), &d.downloadOptions) + if err != nil { + return fmt.Errorf("failed to create HTTP client: %w", err) + } + d.c = h + } + + ok, err := d.isRangeRequestEnabled(ctx) + if err != nil { + return fmt.Errorf("failed to check range request support: %w", err) + } + + if !ok { + return errRangeRequestNotSupported + } + + if err := d.downloadChunks(ctx); err != nil { + return fmt.Errorf("failed to download layer in chunks: %w", err) + } + + if err := d.verifyDigest(); err != nil { + return fmt.Errorf("failed to verify layer digest: %w", err) + } + + return nil +} + +func (d *blobManager) downloadChunks(ctx context.Context) error { + u := makeUrl(d.name, d.digest) + + file, err := os.OpenFile(d.path+".tmp", os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("failed to create layer file: %w", err) + } + defer file.Close() + + chunkSize := d.size / int64(d.numberOfChunks) + if chunkSize < minChunkSize { + chunkSize = minChunkSize + } else if chunkSize > maxChunkSize { + chunkSize = maxChunkSize + } + + var ( + chunks []*chunk + n int + ) + + for offset := int64(0); offset < d.size; offset += chunkSize { + if offset+chunkSize > d.size { + chunkSize = d.size - offset + } + chunk := makeChunk(n, offset, chunkSize) + chunks = append(chunks, chunk) + n++ + } + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(d.numberOfChunks) + for _, chunk := range chunks { + chunk := chunk + g.Go(func() error { + b := defaultRetryBackoff + for i := 0; i < retries; i++ { + w := io.NewOffsetWriter(file, chunk.offset) + err := chunk.download(ctx, d.c, w, u) + switch { + case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): + return err + case errors.Is(err, errCopyFailed): + time.Sleep(b.Step()) + continue + default: + return nil + } + } + return fmt.Errorf("failed to download chunk %d: %w", n, err) + }) + } + + err = g.Wait() + if err != nil { + return fmt.Errorf("failed to download layer in chunks: %w", err) + } + + if err := os.Rename(file.Name(), d.path); err != nil { + return err + } + + return nil + +} + +func (c *chunk) download(ctx context.Context, client *retryablehttp.Client, w io.Writer, u url.URL) error { + req, err := retryablehttp.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return err + } + + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", c.offset, c.offset+c.size-1)) + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return err + } + + if err := transport.CheckError(resp, http.StatusPartialContent); err != nil { + return err + } + + _, err = io.Copy(w, io.TeeReader(resp.Body, &c.writeCounter)) + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { + // TODO: if the download was interrupted, we can resume it + return fmt.Errorf("failed to download chunk %d: %w", c.n, err) + } + + return err +} + +func (d *blobManager) isRangeRequestEnabled(ctx context.Context) (bool, error) { + u := makeUrl(d.name, d.digest) + req, err := retryablehttp.NewRequest(http.MethodHead, u.String(), nil) + if err != nil { + return false, err + } + + resp, err := d.c.Do(req.WithContext(ctx)) + if err != nil { + return false, err + } + + if err := transport.CheckError(resp, http.StatusOK); err != nil { + return false, err + } + + if rangeUnit := resp.Header.Get("Accept-Ranges"); rangeUnit == "bytes" { + return true, nil + } + + return false, nil +} + +func (d *blobManager) verifyDigest() error { + f, err := os.Open(d.path) + if err != nil { + return fmt.Errorf("failed to open layer file: %w", err) + } + defer f.Close() + + h := sha256.New() + _, err = io.Copy(h, f) + if err != nil { + return fmt.Errorf("failed to hash layer: %w", err) + } + + newDigest := h.Sum(nil) + if d.digest.String() != fmt.Sprintf("sha256:%x", newDigest) { + return fmt.Errorf("layer digest does not match: %s != sha256:%x", d.digest.String(), newDigest) + } + return nil +} + +func makeUrl(name name.Reference, digest v1.Hash) url.URL { + return url.URL{ + Scheme: name.Context().Scheme(), + Host: name.Context().RegistryStr(), + Path: fmt.Sprintf("/v2/%s/blobs/%s", name.Context().RepositoryStr(), digest.String()), + } +} + +type resource interface { + Scheme() string + RegistryStr() string + Scope(string) string + + authn.Resource +} + +func makeHttpClient(ctx context.Context, target resource, o *downloadOptions) (*retryablehttp.Client, error) { + auth := o.auth + if o.keychain != nil { + kauth, err := o.keychain.Resolve(target) + if err != nil { + return nil, err + } + auth = kauth + } + + reg, ok := target.(name.Registry) + if !ok { + repo, ok := target.(name.Repository) + if !ok { + return nil, fmt.Errorf("unexpected resource: %T", target) + } + reg = repo.Registry + } + + tr, err := transport.NewWithContext(ctx, reg, auth, o.transport, []string{target.Scope(transport.PullScope)}) + if err != nil { + return nil, err + } + + h := retryablehttp.NewClient() + h.HTTPClient = &http.Client{Transport: tr} + return h, nil +} diff --git a/oci/client/pull.go b/oci/client/pull.go index b49892a7..5ed76515 100644 --- a/oci/client/pull.go +++ b/oci/client/pull.go @@ -20,31 +20,28 @@ import ( "bufio" "bytes" "context" + "errors" "fmt" "io" "net/http" - "net/url" "os" "github.com/fluxcd/pkg/tar" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/name" - "github.com/hashicorp/go-retryablehttp" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/remote" - "github.com/google/go-containerregistry/pkg/v1/remote/transport" - "golang.org/x/sync/errgroup" ) -const ( - // thresholdForConcurrentPull is the maximum size of a layer to be extracted in one go. - // If the layer is larger than this, it will be downloaded in chunks. - thresholdForConcurrentPull = 100 * 1024 * 1024 // 100MB - // maxConcurrentPulls is the maximum number of concurrent downloads. - maxConcurrentPulls = 10 -) +// const ( +// // thresholdForConcurrentPull is the maximum size of a layer to be extracted in one go. +// // If the layer is larger than this, it will be downloaded in chunks. +// thresholdForConcurrentPull = 100 * 1024 * 1024 // 100MB +// // maxConcurrentPulls is the maximum number of concurrent downloads. +// maxConcurrentPulls = 10 +// ) var ( // gzipMagicHeader are bytes found at the start of gzip files @@ -54,12 +51,11 @@ var ( // PullOptions contains options for pulling a layer. type PullOptions struct { - layerIndex int - layerType LayerType - transport http.RoundTripper - auth authn.Authenticator - keychain authn.Keychain - concurrency int + layerIndex int + layerType LayerType + transport http.RoundTripper + auth authn.Authenticator + keychain authn.Keychain } // PullOption is a function for configuring PullOptions. @@ -85,9 +81,15 @@ func WithTransport(t http.RoundTripper) PullOption { } } -func WithConcurrency(c int) PullOption { +func WithAuth(auth authn.Authenticator) PullOption { + return func(o *PullOptions) { + o.auth = auth + } +} + +func WithKeychain(k authn.Keychain) PullOption { return func(o *PullOptions) { - o.concurrency = c + o.keychain = k } } @@ -103,10 +105,6 @@ func (c *Client) Pull(ctx context.Context, urlString, outPath string, opts ...Pu opt(o) } - if o.concurrency == 0 || o.concurrency > maxConcurrentPulls { - o.concurrency = maxConcurrentPulls - } - if o.transport == nil { transport := remote.DefaultTransport.(*http.Transport).Clone() o.transport = transport @@ -117,14 +115,6 @@ func (c *Client) Pull(ctx context.Context, urlString, outPath string, opts ...Pu return nil, fmt.Errorf("invalid URL: %w", err) } - if c.httpClient == nil { - h, err := makeHttpClient(ctx, ref.Context(), *o) - if err != nil { - return nil, err - } - c.httpClient = h - } - img, err := crane.Pull(urlString, c.optionsWithContext(ctx)...) if err != nil { return nil, err @@ -162,126 +152,25 @@ func (c *Client) Pull(ctx context.Context, urlString, outPath string, opts ...Pu return nil, fmt.Errorf("failed to get layer size: %w", err) } - if size > thresholdForConcurrentPull { - digest, err := layers[o.layerIndex].Digest() - if err != nil { - return nil, fmt.Errorf("parsing digest failed: %w", err) - } - u := url.URL{ - Scheme: ref.Context().Scheme(), - Host: ref.Context().RegistryStr(), - Path: fmt.Sprintf("/v2/%s/blobs/%s", ref.Context().RepositoryStr(), digest.String()), + if size > minChunkSize { + manager := newDownloader(ref, outPath, layers[o.layerIndex], + withTransport(o.transport), withKeychain(o.keychain), withAuth(o.auth)) + err = manager.download(ctx) + if err != nil && !errors.Is(err, errRangeRequestNotSupported) { + return nil, fmt.Errorf("failed to download layer: %w", err) } - ok, err := c.IsRangeRequestEnabled(ctx, u) + } + + if size <= minChunkSize || errors.Is(err, errRangeRequestNotSupported) { + err = extractLayer(layers[o.layerIndex], outPath, o.layerType) if err != nil { - return nil, fmt.Errorf("failed to check range request support: %w", err) - } - if ok { - err = c.concurrentExtractLayer(ctx, u, layers[o.layerIndex], outPath, digest, size, o.concurrency) - if err != nil { - return nil, err - } - return meta, nil + return nil, err } } - err = extractLayer(layers[o.layerIndex], outPath, o.layerType) - if err != nil { - return nil, err - } return meta, nil } -// TO DO: handle authentication handle using keychain for authentication -func (c *Client) IsRangeRequestEnabled(ctx context.Context, u url.URL) (bool, error) { - req, err := retryablehttp.NewRequest(http.MethodHead, u.String(), nil) - if err != nil { - return false, err - } - - resp, err := c.httpClient.Do(req.WithContext(ctx)) - if err != nil { - return false, err - } - - if err := transport.CheckError(resp, http.StatusOK); err != nil { - return false, err - } - - if rangeUnit := resp.Header.Get("Accept-Ranges"); rangeUnit == "bytes" { - return true, nil - } - for k, v := range resp.Header { - fmt.Printf("Header: %s, Value: %s\n", k, v) - } - return false, nil -} - -func (c *Client) concurrentExtractLayer(ctx context.Context, u url.URL, layer v1.Layer, path string, digest v1.Hash, size int64, concurrency int) error { - chunkSize := size / int64(concurrency) - chunks := make([][]byte, concurrency+1) - diff := size % int64(concurrency) - - g, ctx := errgroup.WithContext(ctx) - for i := 0; i < concurrency; i++ { - i := i - g.Go(func() (err error) { - start, end := int64(i)*chunkSize, int64(i+1)*chunkSize - if i == concurrency-1 { - end += diff - } - req, err := retryablehttp.NewRequest(http.MethodGet, u.String(), nil) - if err != nil { - return fmt.Errorf("failed to create a new request: %w", err) - } - req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end-1)) - resp, err := c.httpClient.Do(req.WithContext(ctx)) - if err != nil { - return fmt.Errorf("failed to download archive: %w", err) - } - defer resp.Body.Close() - - if err := transport.CheckError(resp, http.StatusPartialContent); err != nil { - return fmt.Errorf("failed to download archive from %s (status: %s)", u.String(), resp.Status) - } - - c, err := io.ReadAll(io.LimitReader(resp.Body, end-start)) - if err != nil { - return fmt.Errorf("failed to read response body: %w", err) - } - chunks[i] = c - return nil - }) - } - err := g.Wait() - if err != nil { - return err - } - - content := bufio.NewReader(bytes.NewReader(bytes.Join(chunks, nil))) - d, s, err := v1.SHA256(content) - if err != nil { - return err - } - if d != digest { - return fmt.Errorf("digest mismatch: expected %s, got %s", digest, d) - } - if s != size { - return fmt.Errorf("size mismatch: expected %d, got %d", size, size) - } - - f, err := os.Create(path) - if err != nil { - return err - } - - _, err = io.Copy(f, content) - if err != nil { - return fmt.Errorf("error copying layer content: %s", err) - } - return nil -} - // extractLayer extracts the Layer to the path func extractLayer(layer v1.Layer, path string, layerType LayerType) error { var blob io.Reader @@ -341,40 +230,3 @@ func isGzipBlob(buf *bufio.Reader) (bool, error) { } return bytes.Equal(b, gzipMagicHeader), nil } - -type resource interface { - Scheme() string - RegistryStr() string - Scope(string) string - - authn.Resource -} - -func makeHttpClient(ctx context.Context, target resource, o PullOptions) (*retryablehttp.Client, error) { - auth := o.auth - if o.keychain != nil { - kauth, err := o.keychain.Resolve(target) - if err != nil { - return nil, err - } - auth = kauth - } - - reg, ok := target.(name.Registry) - if !ok { - repo, ok := target.(name.Repository) - if !ok { - return nil, fmt.Errorf("unexpected resource: %T", target) - } - reg = repo.Registry - } - - tr, err := transport.NewWithContext(ctx, reg, auth, o.transport, []string{target.Scope(transport.PullScope)}) - if err != nil { - return nil, err - } - - h := retryablehttp.NewClient() - h.HTTPClient = &http.Client{Transport: tr} - return h, nil -}