diff --git a/src/OrasProject.Oras/Content/MemoryStore.cs b/src/OrasProject.Oras/Content/MemoryStore.cs index 9a9f0f0..528765c 100644 --- a/src/OrasProject.Oras/Content/MemoryStore.cs +++ b/src/OrasProject.Oras/Content/MemoryStore.cs @@ -11,16 +11,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System; using OrasProject.Oras.Exceptions; using OrasProject.Oras.Oci; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Threading; using System.Threading.Tasks; +using OrasProject.Oras.Registry; namespace OrasProject.Oras.Content; -public class MemoryStore : ITarget, IPredecessorFindable +public class MemoryStore : ITarget, IPredecessorFindable, IMounter { private readonly MemoryStorage _storage = new(); private readonly MemoryTagStore _tagResolver = new(); @@ -94,4 +97,16 @@ public async Task TagAsync(Descriptor descriptor, string reference, Cancellation /// public async Task> GetPredecessorsAsync(Descriptor node, CancellationToken cancellationToken = default) => await _graph.GetPredecessorsAsync(node, cancellationToken).ConfigureAwait(false); + + public async Task MountAsync(Descriptor descriptor, string contentReference, Func>? getContents, CancellationToken cancellationToken) + { + var taggedDescriptor = await _tagResolver.ResolveAsync(contentReference, cancellationToken).ConfigureAwait(false); + var successors = await _storage.GetSuccessorsAsync(taggedDescriptor, cancellationToken); + + if (descriptor != taggedDescriptor && !successors.Contains(descriptor)) + { + await _storage.PushAsync(descriptor, await getContents(cancellationToken), cancellationToken).ConfigureAwait(false); + await _graph.IndexAsync(_storage, descriptor, cancellationToken).ConfigureAwait(false); + } + } } diff --git a/src/OrasProject.Oras/Extensions.cs b/src/OrasProject.Oras/Extensions.cs index 6b048ac..595948b 100644 --- a/src/OrasProject.Oras/Extensions.cs +++ b/src/OrasProject.Oras/Extensions.cs @@ -11,14 +11,47 @@ // See the License for the specific language governing permissions and // limitations under the License. -using OrasProject.Oras.Oci; using System; +using System.IO; using System.Threading; using System.Threading.Tasks; +using OrasProject.Oras.Oci; +using OrasProject.Oras.Registry; using static OrasProject.Oras.Content.Extensions; namespace OrasProject.Oras; +public struct CopyOptions +{ + // public int Concurrency { get; set; } + + public event Action OnPreCopy; + public event Action OnPostCopy; + public event Action OnCopySkipped; + public event Action OnMounted; + + public Func MountFrom { get; set; } + + internal void PreCopy(Descriptor descriptor) + { + OnPreCopy?.Invoke(descriptor); + } + + internal void PostCopy(Descriptor descriptor) + { + OnPostCopy?.Invoke(descriptor); + } + + internal void CopySkipped(Descriptor descriptor) + { + OnCopySkipped?.Invoke(descriptor); + } + + internal void Mounted(Descriptor descriptor, string sourceRepository) + { + OnMounted?.Invoke(descriptor, sourceRepository); + } +} public static class Extensions { @@ -36,38 +69,89 @@ public static class Extensions /// /// /// - public static async Task CopyAsync(this ITarget src, string srcRef, ITarget dst, string dstRef, CancellationToken cancellationToken = default) + public static async Task CopyAsync(this ITarget src, string srcRef, ITarget dst, string dstRef, CancellationToken cancellationToken = default, CopyOptions? copyOptions = default) { if (string.IsNullOrEmpty(dstRef)) { dstRef = srcRef; } var root = await src.ResolveAsync(srcRef, cancellationToken).ConfigureAwait(false); - await src.CopyGraphAsync(dst, root, cancellationToken).ConfigureAwait(false); + await src.CopyGraphAsync(dst, root, cancellationToken, copyOptions).ConfigureAwait(false); await dst.TagAsync(root, dstRef, cancellationToken).ConfigureAwait(false); return root; } - public static async Task CopyGraphAsync(this ITarget src, ITarget dst, Descriptor node, CancellationToken cancellationToken) + public static async Task CopyGraphAsync(this ITarget src, ITarget dst, Descriptor node, CancellationToken cancellationToken, CopyOptions? copyOptions = default) { // check if node exists in target if (await dst.ExistsAsync(node, cancellationToken).ConfigureAwait(false)) { + copyOptions?.CopySkipped(node); return; } // retrieve successors var successors = await src.GetSuccessorsAsync(node, cancellationToken).ConfigureAwait(false); - // obtain data stream - var dataStream = await src.FetchAsync(node, cancellationToken).ConfigureAwait(false); + // check if the node has successors - if (successors != null) + foreach (var childNode in successors) + { + await src.CopyGraphAsync(dst, childNode, cancellationToken, copyOptions).ConfigureAwait(false); + } + + var sourceRepositories = copyOptions?.MountFrom(node) ?? []; + if (dst is IMounter mounter && sourceRepositories.Length > 0) { - foreach (var childNode in successors) + for (var i = 0; i < sourceRepositories.Length; i++) { - await src.CopyGraphAsync(dst, childNode, cancellationToken).ConfigureAwait(false); + var sourceRepository = sourceRepositories[i]; + var mountFailed = false; + + async Task GetContents(CancellationToken token) + { + // the invocation of getContent indicates that mounting has failed + mountFailed = true; + + if (i < sourceRepositories.Length - 1) + { + // If this is not the last one, skip this source and try next one + // We want to return an error that we will test for from mounter.Mount() + throw new SkipSourceException(); + } + + // this is the last iteration so we need to actually get the content and do the copy + // but first call the PreCopy function + copyOptions?.PreCopy(node); + return await src.FetchAsync(node, token).ConfigureAwait(false); + } + + try + { + await mounter.MountAsync(node, sourceRepository, GetContents, cancellationToken).ConfigureAwait(false); + } + catch (SkipSourceException) + { + } + + if (!mountFailed) + { + copyOptions?.Mounted(node, sourceRepository); + return; + } } } - await dst.PushAsync(node, dataStream, cancellationToken).ConfigureAwait(false); + else + { + // alternatively we just copy it + copyOptions?.PreCopy(node); + var dataStream = await src.FetchAsync(node, cancellationToken).ConfigureAwait(false); + await dst.PushAsync(node, dataStream, cancellationToken).ConfigureAwait(false); + } + + // we copied it + copyOptions?.PostCopy(node); } + + private class SkipSourceException : Exception {} } + diff --git a/src/OrasProject.Oras/Registry/IMounter.cs b/src/OrasProject.Oras/Registry/IMounter.cs new file mode 100644 index 0000000..8c645dc --- /dev/null +++ b/src/OrasProject.Oras/Registry/IMounter.cs @@ -0,0 +1,24 @@ +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using OrasProject.Oras.Oci; + +namespace OrasProject.Oras.Registry; + +/// +/// Mounter allows cross-repository blob mounts. +/// +public interface IMounter +{ + /// + /// Mount makes the blob with the given descriptor in fromRepo + /// available in the repository signified by the receiver. + /// + /// + /// + /// + /// + /// + Task MountAsync(Descriptor descriptor, string contentReference, Func>? getContents, CancellationToken cancellationToken); +} diff --git a/src/OrasProject.Oras/Registry/IRepository.cs b/src/OrasProject.Oras/Registry/IRepository.cs index b163e2f..41682c5 100644 --- a/src/OrasProject.Oras/Registry/IRepository.cs +++ b/src/OrasProject.Oras/Registry/IRepository.cs @@ -27,7 +27,7 @@ namespace OrasProject.Oras.Registry; /// Furthermore, this interface also provides the ability to enforce the /// separation of the blob and the manifests CASs. /// -public interface IRepository : ITarget, IReferenceFetchable, IReferencePushable, IDeletable, ITagListable +public interface IRepository : ITarget, IReferenceFetchable, IReferencePushable, IDeletable, ITagListable, IMounter { /// /// Blobs provides access to the blob CAS only, which contains config blobs,layers, and other generic blobs. diff --git a/src/OrasProject.Oras/Registry/Remote/BlobStore.cs b/src/OrasProject.Oras/Registry/Remote/BlobStore.cs index 52b0783..791acb7 100644 --- a/src/OrasProject.Oras/Registry/Remote/BlobStore.cs +++ b/src/OrasProject.Oras/Registry/Remote/BlobStore.cs @@ -25,7 +25,7 @@ namespace OrasProject.Oras.Registry.Remote; -public class BlobStore(Repository repository) : IBlobStore +public class BlobStore(Repository repository) : IBlobStore, IMounter { public Repository Repository { get; init; } = repository; @@ -148,25 +148,7 @@ public async Task PushAsync(Descriptor expected, Stream content, CancellationTok url = location.IsAbsoluteUri ? location : new Uri(url, location); } - // monolithic upload - // add digest key to query string with expected digest value - var req = new HttpRequestMessage(HttpMethod.Put, new UriBuilder(url) - { - Query = $"{url.Query}&digest={HttpUtility.UrlEncode(expected.Digest)}" - }.Uri); - req.Content = new StreamContent(content); - req.Content.Headers.ContentLength = expected.Size; - - // the expected media type is ignored as in the API doc. - req.Content.Headers.ContentType = new MediaTypeHeaderValue(MediaTypeNames.Application.Octet); - - using (var response = await Repository.Options.HttpClient.SendAsync(req, cancellationToken).ConfigureAwait(false)) - { - if (response.StatusCode != HttpStatusCode.Created) - { - throw await response.ParseErrorResponseAsync(cancellationToken).ConfigureAwait(false); - } - } + await InternalPushAsync(url, expected, content, cancellationToken); } /// @@ -198,4 +180,98 @@ public async Task ResolveAsync(string reference, CancellationToken c /// public async Task DeleteAsync(Descriptor target, CancellationToken cancellationToken = default) => await Repository.DeleteAsync(target, false, cancellationToken).ConfigureAwait(false); + + /// + /// Mounts the given descriptor from contentReference into the blob store. + /// + /// + /// + /// + /// + /// + /// + public async Task MountAsync(Descriptor descriptor, string contentReference, + Func>? getContents, CancellationToken cancellationToken) + { + var url = new UriFactory(Repository.Options).BuildRepositoryBlobUpload(); + var mountReq = new HttpRequestMessage(HttpMethod.Post, new UriBuilder(url) + { + Query = + $"{url.Query}&mount={HttpUtility.UrlEncode(descriptor.Digest)}&from={HttpUtility.UrlEncode(contentReference)}" + }.Uri); + + using (var response = await Repository.Options.HttpClient.SendAsync(mountReq, cancellationToken) + .ConfigureAwait(false)) + { + switch (response.StatusCode) + { + case HttpStatusCode.Created: + // 201, layer has been mounted + return; + case HttpStatusCode.Accepted: + { + // 202, mounting failed. upload session has begun + var location = response.Headers.Location ?? + throw new HttpRequestException("missing location header"); + url = location.IsAbsoluteUri ? location : new Uri(url, location); + break; + } + default: + throw await response.ParseErrorResponseAsync(cancellationToken).ConfigureAwait(false); + } + } + + // From the [spec]: + // + // "If a registry does not support cross-repository mounting + // or is unable to mount the requested blob, + // it SHOULD return a 202. + // This indicates that the upload session has begun + // and that the client MAY proceed with the upload." + // + // So we need to get the content from somewhere in order to + // push it. If the caller has provided a getContent function, we + // can use that, otherwise pull the content from the source repository. + // + // [spec]: https://github.com/opencontainers/distribution-spec/blob/v1.1.0/spec.md#mounting-a-blob-from-another-repository + + Stream contents; + if (getContents != null) + { + contents = await getContents(cancellationToken).ConfigureAwait(false); + } + else + { + var referenceOptions = repository.Options with + { + Reference = Reference.Parse(contentReference), + }; + contents = await new Repository(referenceOptions).FetchAsync(descriptor, cancellationToken); + } + + await InternalPushAsync(url, descriptor, contents, cancellationToken).ConfigureAwait(false); + } + + private async Task InternalPushAsync(Uri url, Descriptor descriptor, Stream content, + CancellationToken cancellationToken) + { + // monolithic upload + // add digest key to query string with descriptor digest value + var req = new HttpRequestMessage(HttpMethod.Put, new UriBuilder(url) + { + Query = $"{url.Query}&digest={HttpUtility.UrlEncode(descriptor.Digest)}" + }.Uri); + req.Content = new StreamContent(content); + req.Content.Headers.ContentLength = descriptor.Size; + + // the descriptor media type is ignored as in the API doc. + req.Content.Headers.ContentType = new MediaTypeHeaderValue(MediaTypeNames.Application.Octet); + + using var response = + await Repository.Options.HttpClient.SendAsync(req, cancellationToken).ConfigureAwait(false); + if (response.StatusCode != HttpStatusCode.Created) + { + throw await response.ParseErrorResponseAsync(cancellationToken).ConfigureAwait(false); + } + } } diff --git a/src/OrasProject.Oras/Registry/Remote/Repository.cs b/src/OrasProject.Oras/Registry/Remote/Repository.cs index 62d73bc..49d9328 100644 --- a/src/OrasProject.Oras/Registry/Remote/Repository.cs +++ b/src/OrasProject.Oras/Registry/Remote/Repository.cs @@ -331,4 +331,22 @@ internal Reference ParseReferenceFromContentReference(string reference) /// /// private IBlobStore BlobStore(Descriptor desc) => IsManifest(desc) ? Manifests : Blobs; + + /// + /// Mount makes the blob with the given digest in fromRepo + /// available in the repository signified by the receiver. + /// + /// This avoids the need to pull content down from fromRepo only to push it to r. + /// + /// If the registry does not implement mounting, getContent will be used to get the + /// content to push. If getContent is null, the content will be pulled from the source + /// repository. + /// + /// + /// + /// + /// + /// + public Task MountAsync(Descriptor descriptor, string contentReference, Func>? getContents, CancellationToken cancellationToken) + => ((IMounter)Blobs).MountAsync(descriptor,contentReference, getContents, cancellationToken); } diff --git a/tests/OrasProject.Oras.Tests/CopyTest.cs b/tests/OrasProject.Oras.Tests/CopyTest.cs index 4f26873..3959b62 100644 --- a/tests/OrasProject.Oras.Tests/CopyTest.cs +++ b/tests/OrasProject.Oras.Tests/CopyTest.cs @@ -142,4 +142,96 @@ public async Task CanCopyBetweenMemoryTargets() } } + + [Fact] + public async Task CanCopyBetweenMemoryTargetsMountingFromDestination() + { + var sourceTarget = new MemoryStore(); + var cancellationToken = new CancellationToken(); + var blobs = new List(); + var descs = new List(); + var appendBlob = (string mediaType, byte[] blob) => + { + blobs.Add(blob); + var desc = new Descriptor + { + MediaType = mediaType, + Digest = Digest.ComputeSHA256(blob), + Size = blob.Length + }; + descs.Add(desc); + }; + var generateManifest = (Descriptor config, List layers) => + { + var manifest = new Manifest + { + Config = config, + Layers = layers + }; + var manifestBytes = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(manifest)); + appendBlob(MediaType.ImageManifest, manifestBytes); + }; + var getBytes = (string data) => Encoding.UTF8.GetBytes(data); + appendBlob(MediaType.ImageConfig, getBytes("config")); // blob 0 + appendBlob(MediaType.ImageLayer, getBytes("foo")); // blob 1 + appendBlob(MediaType.ImageLayer, getBytes("bar")); // blob 2 + generateManifest(descs[0], descs.GetRange(1, 2)); // blob 3 + + appendBlob(MediaType.ImageConfig, getBytes("config2")); // blob 4 + appendBlob(MediaType.ImageLayer, getBytes("bar2")); // blob 5 + generateManifest(descs[4], [descs[1], descs[5]]); // blob 6 + + for (var i = 0; i < blobs.Count; i++) + { + await sourceTarget.PushAsync(descs[i], new MemoryStream(blobs[i]), cancellationToken); + } + + var root = descs[3]; + var reference = "foobar"; + await sourceTarget.TagAsync(root, reference, cancellationToken); + + var root2 = descs[6]; + var reference2 = "other/foobar"; + await sourceTarget.TagAsync(root2, reference2, cancellationToken); + + var destinationTarget = new MemoryStore(); + var gotDesc = await sourceTarget.CopyAsync(reference, destinationTarget, "", cancellationToken); + Assert.Equal(gotDesc, root); + Assert.Equal(await destinationTarget.ResolveAsync(reference, cancellationToken), root); + + for (var i = 0; i < 3; i++) + { + Assert.True(await destinationTarget.ExistsAsync(descs[i], cancellationToken)); + var fetchContent = await destinationTarget.FetchAsync(descs[i], cancellationToken); + var memoryStream = new MemoryStream(); + await fetchContent.CopyToAsync(memoryStream, cancellationToken); + var bytes = memoryStream.ToArray(); + Assert.Equal(blobs[i], bytes); + } + + var copyOpts = new CopyOptions() + { + MountFrom = d => [reference] + }; + var mounted = false; + copyOpts.OnMounted += (d, s) => + { + mounted = true; + }; + var gotDesc2 = await sourceTarget.CopyAsync(reference2, destinationTarget, reference2, cancellationToken, copyOpts); + + Assert.Equal(gotDesc2, root2); + Assert.Equal(await destinationTarget.ResolveAsync(reference2, cancellationToken), root2); + Assert.True(mounted); + + for (var i = 4; i < descs.Count; i++) + { + Assert.True(await destinationTarget.ExistsAsync(descs[i], cancellationToken)); + var fetchContent = await destinationTarget.FetchAsync(descs[i], cancellationToken); + var memoryStream = new MemoryStream(); + await fetchContent.CopyToAsync(memoryStream, cancellationToken); + var bytes = memoryStream.ToArray(); + Assert.Equal(blobs[i], bytes); + } + } }