diff --git a/copy/blob.go b/copy/blob.go index f45b97f56c..40a857d332 100644 --- a/copy/blob.go +++ b/copy/blob.go @@ -27,13 +27,15 @@ func (ic *imageCopier) copyBlobFromStream(ctx context.Context, srcReader io.Read info: srcInfo, } + canRewriteDigest := srcInfo.Incremental + // === Process input through digestingReader to validate against the expected digest. // Be paranoid; in case PutBlob somehow managed to ignore an error from digestingReader, // use a separate validation failure indicator. // Note that for this check we don't use the stronger "validationSucceeded" indicator, because // dest.PutBlob may detect that the layer already exists, in which case we don't // read stream to the end, and validation does not happen. - digestingReader, err := newDigestingReader(stream.reader, srcInfo.Digest) + digestingReader, err := newDigestingReader(stream.reader, srcInfo.Digest, canRewriteDigest) if err != nil { return types.BlobInfo{}, fmt.Errorf("preparing to verify blob %s: %w", srcInfo.Digest, err) } @@ -128,8 +130,13 @@ func (ic *imageCopier) copyBlobFromStream(ctx context.Context, srcReader io.Read } } - if digestingReader.validationFailed { // Coverage: This should never happen. - return types.BlobInfo{}, fmt.Errorf("Internal error writing blob %s, digest verification failed but was ignored", srcInfo.Digest) + if digestingReader.validationFailed { + if !canRewriteDigest { // Coverage: This should never happen. + return types.BlobInfo{}, fmt.Errorf("Internal error writing blob %s, digest verification failed but was ignored", srcInfo.Digest) + } + uploadedInfo.Digest = digestingReader.digester.Digest() + logrus.Warningf("Digest verification failed for blob %s, using computed %s", srcInfo.Digest, uploadedInfo.Digest) + return uploadedInfo, nil } if stream.info.Digest != "" && uploadedInfo.Digest != stream.info.Digest { return types.BlobInfo{}, fmt.Errorf("Internal error writing blob %s, blob with digest %s saved with digest %s", srcInfo.Digest, stream.info.Digest, uploadedInfo.Digest) diff --git a/copy/digesting_reader.go b/copy/digesting_reader.go index 901d10826f..24d5410504 100644 --- a/copy/digesting_reader.go +++ b/copy/digesting_reader.go @@ -15,12 +15,13 @@ type digestingReader struct { expectedDigest digest.Digest validationFailed bool validationSucceeded bool + canRewriteDigest bool } // newDigestingReader returns an io.Reader implementation with contents of source, which will eventually return a non-EOF error // or set validationSucceeded/validationFailed to true if the source stream does/does not match expectedDigest. // (neither is set if EOF is never reached). -func newDigestingReader(source io.Reader, expectedDigest digest.Digest) (*digestingReader, error) { +func newDigestingReader(source io.Reader, expectedDigest digest.Digest, canRewriteDigest bool) (*digestingReader, error) { var digester digest.Digester if err := expectedDigest.Validate(); err != nil { return nil, fmt.Errorf("Invalid digest specification %s", expectedDigest) @@ -37,6 +38,7 @@ func newDigestingReader(source io.Reader, expectedDigest digest.Digest) (*digest hash: digester.Hash(), expectedDigest: expectedDigest, validationFailed: false, + canRewriteDigest: canRewriteDigest, }, nil } @@ -54,9 +56,12 @@ func (d *digestingReader) Read(p []byte) (int, error) { actualDigest := d.digester.Digest() if actualDigest != d.expectedDigest { d.validationFailed = true - return 0, fmt.Errorf("Digest did not match, expected %s, got %s", d.expectedDigest, actualDigest) + if !d.canRewriteDigest { + return 0, fmt.Errorf("Digest did not match, expected %s, got %s", d.expectedDigest, actualDigest) + } + } else { + d.validationSucceeded = true } - d.validationSucceeded = true } return n, err } diff --git a/copy/digesting_reader_test.go b/copy/digesting_reader_test.go index 2e17437ae3..1ce42ae887 100644 --- a/copy/digesting_reader_test.go +++ b/copy/digesting_reader_test.go @@ -21,7 +21,7 @@ func TestNewDigestingReader(t *testing.T) { "sha256:0", // Invalid hex value "sha256:01", // Invalid length of hex value } { - _, err := newDigestingReader(source, input) + _, err := newDigestingReader(source, input, false) assert.Error(t, err, input.String()) } } @@ -37,41 +37,56 @@ func TestDigestingReaderRead(t *testing.T) { } // Valid input for _, c := range cases { - source := bytes.NewReader(c.input) - reader, err := newDigestingReader(source, c.digest) - require.NoError(t, err, c.digest.String()) - dest := bytes.Buffer{} - n, err := io.Copy(&dest, reader) - assert.NoError(t, err, c.digest.String()) - assert.Equal(t, int64(len(c.input)), n, c.digest.String()) - assert.Equal(t, c.input, dest.Bytes(), c.digest.String()) - assert.False(t, reader.validationFailed, c.digest.String()) - assert.True(t, reader.validationSucceeded, c.digest.String()) + for _, incremental := range []bool{false, true} { + source := bytes.NewReader(c.input) + reader, err := newDigestingReader(source, c.digest, incremental) + require.NoError(t, err, c.digest.String()) + dest := bytes.Buffer{} + n, err := io.Copy(&dest, reader) + assert.NoError(t, err, c.digest.String()) + assert.Equal(t, int64(len(c.input)), n, c.digest.String()) + assert.Equal(t, c.input, dest.Bytes(), c.digest.String()) + assert.False(t, reader.validationFailed, c.digest.String()) + assert.True(t, reader.validationSucceeded, c.digest.String()) + } } // Modified input for _, c := range cases { source := bytes.NewReader(bytes.Join([][]byte{c.input, []byte("x")}, nil)) - reader, err := newDigestingReader(source, c.digest) + reader, err := newDigestingReader(source, c.digest, false) require.NoError(t, err, c.digest.String()) dest := bytes.Buffer{} _, err = io.Copy(&dest, reader) assert.Error(t, err, c.digest.String()) assert.True(t, reader.validationFailed, c.digest.String()) assert.False(t, reader.validationSucceeded, c.digest.String()) + + // try with an incremental source + source = bytes.NewReader(bytes.Join([][]byte{c.input, []byte("x")}, nil)) + reader, err = newDigestingReader(source, c.digest, true) + require.NoError(t, err, c.digest.String()) + dest = bytes.Buffer{} + _, err = io.Copy(&dest, reader) + assert.NoError(t, err, c.digest.String()) + assert.True(t, reader.validationFailed, c.digest.String()) + assert.False(t, reader.validationSucceeded, c.digest.String()) + assert.NotEqual(t, c.digest.String(), reader.digester.Digest(), c.digest.String()) } // Truncated input for _, c := range cases { - source := bytes.NewReader(c.input) - reader, err := newDigestingReader(source, c.digest) - require.NoError(t, err, c.digest.String()) - if len(c.input) != 0 { - dest := bytes.Buffer{} - truncatedLen := int64(len(c.input) - 1) - n, err := io.CopyN(&dest, reader, truncatedLen) - assert.NoError(t, err, c.digest.String()) - assert.Equal(t, truncatedLen, n, c.digest.String()) + for _, incremental := range []bool{false, true} { + source := bytes.NewReader(c.input) + reader, err := newDigestingReader(source, c.digest, incremental) + require.NoError(t, err, c.digest.String()) + if len(c.input) != 0 { + dest := bytes.Buffer{} + truncatedLen := int64(len(c.input) - 1) + n, err := io.CopyN(&dest, reader, truncatedLen) + assert.NoError(t, err, c.digest.String()) + assert.Equal(t, truncatedLen, n, c.digest.String()) + } + assert.False(t, reader.validationFailed, c.digest.String()) + assert.False(t, reader.validationSucceeded, c.digest.String()) } - assert.False(t, reader.validationFailed, c.digest.String()) - assert.False(t, reader.validationSucceeded, c.digest.String()) } } diff --git a/copy/single.go b/copy/single.go index 9afdea73dc..17ea46c50c 100644 --- a/copy/single.go +++ b/copy/single.go @@ -695,7 +695,7 @@ func (ic *imageCopier) copyLayer(ctx context.Context, srcInfo types.BlobInfo, to } defer srcStream.Close() - blobInfo, diffIDChan, err := ic.copyLayerFromStream(ctx, srcStream, types.BlobInfo{Digest: srcInfo.Digest, Size: srcBlobSize, MediaType: srcInfo.MediaType, Annotations: srcInfo.Annotations}, diffIDIsNeeded, toEncrypt, bar, layerIndex, emptyLayer) + blobInfo, diffIDChan, err := ic.copyLayerFromStream(ctx, srcStream, types.BlobInfo{Digest: srcInfo.Digest, Incremental: srcInfo.Incremental, Size: srcBlobSize, MediaType: srcInfo.MediaType, Annotations: srcInfo.Annotations}, diffIDIsNeeded, toEncrypt, bar, layerIndex, emptyLayer) if err != nil { return types.BlobInfo{}, "", err }