Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/cli/commands/tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/cmd/cli/desktop"

"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/google/go-containerregistry/pkg/name"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -38,7 +38,7 @@ func newTagCmd() *cobra.Command {

func tagModel(cmd *cobra.Command, desktopClient *desktop.Client, source, target string) error {
// Ensure tag is valid
tag, err := name.NewTag(target)
tag, err := name.NewTag(target, registry.GetDefaultRegistryOptions()...)
if err != nil {
return fmt.Errorf("invalid tag: %w", err)
}
Expand Down
16 changes: 9 additions & 7 deletions pkg/distribution/internal/store/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"path/filepath"

"github.com/google/go-containerregistry/pkg/name"

"github.com/docker/model-runner/pkg/distribution/registry"
)

// Index represents the index of all models in the store
Expand All @@ -16,7 +18,7 @@ type Index struct {
}

func (i Index) Tag(reference string, tag string) (Index, error) {
tagRef, err := name.NewTag(tag)
tagRef, err := name.NewTag(tag, registry.GetDefaultRegistryOptions()...)
if err != nil {
return Index{}, fmt.Errorf("invalid tag: %w", err)
}
Expand All @@ -39,7 +41,7 @@ func (i Index) Tag(reference string, tag string) (Index, error) {
}

func (i Index) UnTag(tag string) (name.Tag, Index, error) {
tagRef, err := name.NewTag(tag)
tagRef, err := name.NewTag(tag, registry.GetDefaultRegistryOptions()...)
if err != nil {
return name.Tag{}, Index{}, err
}
Expand Down Expand Up @@ -141,12 +143,12 @@ type IndexEntry struct {
}

func (e IndexEntry) HasTag(tag string) bool {
ref, err := name.NewTag(tag)
ref, err := name.NewTag(tag, registry.GetDefaultRegistryOptions()...)
if err != nil {
return false
}
for _, t := range e.Tags {
tr, err := name.ParseReference(t)
tr, err := name.ParseReference(t, registry.GetDefaultRegistryOptions()...)
if err != nil {
continue
}
Expand All @@ -159,7 +161,7 @@ func (e IndexEntry) HasTag(tag string) bool {

func (e IndexEntry) hasTag(tag name.Tag) bool {
for _, t := range e.Tags {
tr, err := name.ParseReference(t)
tr, err := name.ParseReference(t, registry.GetDefaultRegistryOptions()...)
if err != nil {
continue
}
Expand All @@ -174,7 +176,7 @@ func (e IndexEntry) MatchesReference(reference string) bool {
if e.ID == reference {
return true
}
ref, err := name.ParseReference(reference)
ref, err := name.ParseReference(reference, registry.GetDefaultRegistryOptions()...)
if err != nil {
return false
}
Expand All @@ -200,7 +202,7 @@ func (e IndexEntry) Tag(tag name.Tag) IndexEntry {
func (e IndexEntry) UnTag(tag name.Tag) IndexEntry {
var tags []string
for i, t := range e.Tags {
tr, err := name.ParseReference(t)
tr, err := name.ParseReference(t, registry.GetDefaultRegistryOptions()...)
if err != nil {
continue
}
Expand Down
34 changes: 29 additions & 5 deletions pkg/distribution/registry/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"

"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name"
Expand All @@ -22,9 +24,31 @@ const (
)

var (
DefaultTransport = remote.DefaultTransport
defaultRegistryOpts []name.Option
once sync.Once
DefaultTransport = remote.DefaultTransport
)

// GetDefaultRegistryOptions returns name.Option slice with custom default registry
// and insecure flag if the corresponding environment variables are set.
// Environment variables are read once at first call and cached for consistency.
// Returns a copy of the options to prevent race conditions from slice modifications.
// - DEFAULT_REGISTRY: Override the default registry (index.docker.io)
// - INSECURE_REGISTRY: Set to "true" to allow HTTP connections
func GetDefaultRegistryOptions() []name.Option {
once.Do(func() {
var opts []name.Option
if defaultReg := os.Getenv("DEFAULT_REGISTRY"); defaultReg != "" {
opts = append(opts, name.WithDefaultRegistry(defaultReg))
}
if os.Getenv("INSECURE_REGISTRY") == "true" {
opts = append(opts, name.Insecure)
}
defaultRegistryOpts = opts
})
return append([]name.Option(nil), defaultRegistryOpts...)
}

type Client struct {
transport http.RoundTripper
userAgent string
Expand Down Expand Up @@ -75,7 +99,7 @@ func NewClient(opts ...ClientOption) *Client {

func (c *Client) Model(ctx context.Context, reference string) (types.ModelArtifact, error) {
// Parse the reference
ref, err := name.ParseReference(reference)
ref, err := name.ParseReference(reference, GetDefaultRegistryOptions()...)
if err != nil {
return nil, NewReferenceError(reference, err)
}
Expand Down Expand Up @@ -115,7 +139,7 @@ func (c *Client) Model(ctx context.Context, reference string) (types.ModelArtifa

func (c *Client) BlobURL(reference string, digest v1.Hash) (string, error) {
// Parse the reference
ref, err := name.ParseReference(reference)
ref, err := name.ParseReference(reference, GetDefaultRegistryOptions()...)
if err != nil {
return "", NewReferenceError(reference, err)
}
Expand All @@ -129,7 +153,7 @@ func (c *Client) BlobURL(reference string, digest v1.Hash) (string, error) {

func (c *Client) BearerToken(ctx context.Context, reference string) (string, error) {
// Parse the reference
ref, err := name.ParseReference(reference)
ref, err := name.ParseReference(reference, GetDefaultRegistryOptions()...)
if err != nil {
return "", NewReferenceError(reference, err)
}
Expand Down Expand Up @@ -165,7 +189,7 @@ type Target struct {
}

func (c *Client) NewTarget(tag string) (*Target, error) {
ref, err := name.NewTag(tag)
ref, err := name.NewTag(tag, GetDefaultRegistryOptions()...)
if err != nil {
return nil, fmt.Errorf("invalid tag: %q: %w", tag, err)
}
Expand Down
133 changes: 133 additions & 0 deletions pkg/distribution/registry/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package registry

import (
"os"
"sync"
"testing"

"github.com/google/go-containerregistry/pkg/name"
)

func TestGetDefaultRegistryOptions_NoEnvVars(t *testing.T) {
// Reset the sync.Once for this test
resetOnceForTest()

// Ensure no env vars are set
os.Unsetenv("DEFAULT_REGISTRY")
os.Unsetenv("INSECURE_REGISTRY")

opts := GetDefaultRegistryOptions()

if len(opts) != 0 {
t.Errorf("Expected empty options slice, got %d options", len(opts))
}

// Verify that the default registry (index.docker.io) is used when no options are set
ref, err := name.ParseReference("myrepo/myimage:tag", opts...)
if err != nil {
t.Fatalf("Failed to parse reference: %v", err)
}

// When no DEFAULT_REGISTRY is set, the default should be index.docker.io
expectedRegistry := "index.docker.io"
if ref.Context().Registry.Name() != expectedRegistry {
t.Errorf("Expected default registry to be '%s', got '%s'", expectedRegistry, ref.Context().Registry.Name())
}

// Verify it uses HTTPS (secure by default)
if ref.Context().Registry.Scheme() != "https" {
t.Errorf("Expected scheme to be 'https', got '%s'", ref.Context().Registry.Scheme())
}
}

func TestGetDefaultRegistryOptions_OnlyDefaultRegistry(t *testing.T) {
// Reset the sync.Once for this test
resetOnceForTest()

t.Setenv("DEFAULT_REGISTRY", "custom.registry.io")
os.Unsetenv("INSECURE_REGISTRY")

opts := GetDefaultRegistryOptions()

if len(opts) != 1 {
t.Fatalf("Expected 1 option, got %d", len(opts))
}

// Verify the option sets the default registry by parsing a reference without explicit registry
ref, err := name.ParseReference("myrepo/myimage:tag", opts...)
if err != nil {
t.Fatalf("Failed to parse reference: %v", err)
}

if ref.Context().Registry.Name() != "custom.registry.io" {
t.Errorf("Expected registry to be 'custom.registry.io', got '%s'", ref.Context().Registry.Name())
}

// Verify it's not insecure (should use https)
if ref.Context().Registry.Scheme() != "https" {
t.Errorf("Expected scheme to be 'https', got '%s'", ref.Context().Registry.Scheme())
}
}

func TestGetDefaultRegistryOptions_OnlyInsecureRegistry(t *testing.T) {
// Reset the sync.Once for this test
resetOnceForTest()

os.Unsetenv("DEFAULT_REGISTRY")
t.Setenv("INSECURE_REGISTRY", "true")

opts := GetDefaultRegistryOptions()

if len(opts) != 1 {
t.Fatalf("Expected 1 option, got %d", len(opts))
}

// Verify the option makes the registry insecure by parsing a reference
ref, err := name.ParseReference("myregistry.io/myrepo/myimage:tag", opts...)
if err != nil {
t.Fatalf("Failed to parse reference: %v", err)
}

// Insecure registries should use http
if ref.Context().Registry.Scheme() != "http" {
t.Errorf("Expected scheme to be 'http', got '%s'", ref.Context().Registry.Scheme())
}
}

func TestGetDefaultRegistryOptions_BothEnvVars(t *testing.T) {
// Reset the sync.Once for this test
resetOnceForTest()

t.Setenv("DEFAULT_REGISTRY", "custom.registry.io")
t.Setenv("INSECURE_REGISTRY", "true")

opts := GetDefaultRegistryOptions()

if len(opts) != 2 {
t.Fatalf("Expected 2 options, got %d", len(opts))
}

// Verify both options are applied
ref, err := name.ParseReference("myrepo/myimage:tag", opts...)
if err != nil {
t.Fatalf("Failed to parse reference: %v", err)
}

// Check custom registry is used
if ref.Context().Registry.Name() != "custom.registry.io" {
t.Errorf("Expected registry to be 'custom.registry.io', got '%s'", ref.Context().Registry.Name())
}

// Check insecure is applied (http scheme)
if ref.Context().Registry.Scheme() != "http" {
t.Errorf("Expected scheme to be 'http', got '%s'", ref.Context().Registry.Scheme())
}
}

// Helper function to reset the sync.Once for testing
// Note: This is a workaround for testing. In production code, sync.Once ensures
// the initialization only happens once for the lifetime of the program.
func resetOnceForTest() {
once = sync.Once{}
defaultRegistryOpts = nil
}
3 changes: 2 additions & 1 deletion pkg/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"time"

"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/distribution/types"
"github.com/docker/model-runner/pkg/logging"
"github.com/google/go-containerregistry/pkg/authn"
Expand Down Expand Up @@ -86,7 +87,7 @@ func (t *Tracker) trackModel(model types.Model, userAgent, action string) {
}
ua := strings.Join(parts, " ")
for _, tag := range tags {
ref, err := name.ParseReference(tag)
ref, err := name.ParseReference(tag, registry.GetDefaultRegistryOptions()...)
if err != nil {
t.log.Errorf("Error parsing reference: %v\n", err)
return
Expand Down
Loading