Skip to content

Commit

Permalink
Merge pull request #838 from cdesiniotis/enable-cdi-toolkit-container
Browse files Browse the repository at this point in the history
Enable CDI in the container runtime if enabled in the toolkit
  • Loading branch information
elezar authored Feb 3, 2025
2 parents d6c3129 + d8cd543 commit df4c87b
Show file tree
Hide file tree
Showing 14 changed files with 604 additions and 66 deletions.
10 changes: 8 additions & 2 deletions cmd/nvidia-ctk-installer/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ const (

// Options defines the shared options for the CLIs to configure containers runtimes.
type Options struct {
Config string
Socket string
Config string
Socket string
// EnabledCDI indicates whether CDI should be enabled.
EnableCDI bool
RuntimeName string
RuntimeDir string
SetAsDefault bool
Expand Down Expand Up @@ -111,6 +113,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error {
}
}

if o.EnableCDI {
cfg.EnableCDI()
}

return nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,51 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
}
}

func TestUpdateV1EnableCDI(t *testing.T) {
logger, _ := testlog.NewNullLogger()
const runtimeDir = "/test/runtime/dir"

testCases := []struct {
enableCDI bool
expectedEnableCDIValue interface{}
}{
{},
{
enableCDI: false,
expectedEnableCDIValue: nil,
},
{
enableCDI: true,
expectedEnableCDIValue: true,
},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("%v", tc.enableCDI), func(t *testing.T) {
o := &container.Options{
EnableCDI: tc.enableCDI,
RuntimeName: "nvidia",
RuntimeDir: runtimeDir,
}

cfg, err := toml.Empty.Load()
require.NoError(t, err)

v1 := &containerd.ConfigV1{
Logger: logger,
Tree: cfg,
RuntimeType: runtimeType,
}

err = o.UpdateConfig(v1)
require.NoError(t, err)

enableCDIValue := v1.GetPath([]string{"plugins", "cri", "containerd", "enable_cdi"})
require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue)
})
}
}

func TestRevertV1Config(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,53 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
}
}

func TestUpdateV2ConfigEnableCDI(t *testing.T) {
logger, _ := testlog.NewNullLogger()
const runtimeDir = "/test/runtime/dir"

testCases := []struct {
enableCDI bool
expectedEnableCDIValue interface{}
}{
{},
{
enableCDI: false,
expectedEnableCDIValue: nil,
},
{
enableCDI: true,
expectedEnableCDIValue: true,
},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("%v", tc.enableCDI), func(t *testing.T) {
o := &container.Options{
EnableCDI: tc.enableCDI,
RuntimeName: "nvidia",
RuntimeDir: runtimeDir,
SetAsDefault: false,
}

cfg, err := toml.LoadMap(map[string]interface{}{})
require.NoError(t, err)

v2 := &containerd.Config{
Logger: logger,
Tree: cfg,
RuntimeType: runtimeType,
CRIRuntimePluginName: "io.containerd.grpc.v1.cri",
}

err = o.UpdateConfig(v2)
require.NoError(t, err)

enableCDIValue := cfg.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "enable_cdi"})
require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue)
})
}
}

func TestRevertV2Config(t *testing.T) {
logger, _ := testlog.NewNullLogger()

Expand Down
13 changes: 12 additions & 1 deletion cmd/nvidia-ctk-installer/container/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/containerd"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/crio"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/docker"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/toolkit"
)

const (
Expand Down Expand Up @@ -66,6 +67,12 @@ func Flags(opts *Options) []cli.Flag {
Destination: &opts.RestartMode,
EnvVars: []string{"RUNTIME_RESTART_MODE"},
},
&cli.BoolFlag{
Name: "enable-cdi-in-runtime",
Usage: "Enable CDI in the configured runt ime",
Destination: &opts.EnableCDI,
EnvVars: []string{"RUNTIME_ENABLE_CDI"},
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting the runtime using systemd",
Expand Down Expand Up @@ -98,10 +105,14 @@ func Flags(opts *Options) []cli.Flag {
}

// ValidateOptions checks whether the specified options are valid
func ValidateOptions(opts *Options, runtime string, toolkitRoot string) error {
func ValidateOptions(c *cli.Context, opts *Options, runtime string, toolkitRoot string, to *toolkit.Options) error {
// We set this option here to ensure that it is available in future calls.
opts.RuntimeDir = toolkitRoot

if !c.IsSet("enable-cdi-in-runtime") {
opts.EnableCDI = to.CDI.Enabled
}

// Apply the runtime-specific config changes.
switch runtime {
case containerd.Name:
Expand Down
41 changes: 23 additions & 18 deletions cmd/nvidia-ctk-installer/container/toolkit/toolkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ const (
configFilename = "config.toml"
)

type cdiOptions struct {
Enabled bool
outputDir string
kind string
vendor string
class string
}

type Options struct {
DriverRoot string
DevRoot string
Expand All @@ -63,11 +71,8 @@ type Options struct {

ContainerCLIDebug string

cdiEnabled bool
cdiOutputDir string
cdiKind string
cdiVendor string
cdiClass string
// CDI stores the CDI options for the toolkit.
CDI cdiOptions

createDeviceNodes cli.StringSlice

Expand Down Expand Up @@ -170,21 +175,21 @@ func Flags(opts *Options) []cli.Flag {
Name: "cdi-enabled",
Aliases: []string{"enable-cdi"},
Usage: "enable the generation of a CDI specification",
Destination: &opts.cdiEnabled,
Destination: &opts.CDI.Enabled,
EnvVars: []string{"CDI_ENABLED", "ENABLE_CDI"},
},
&cli.StringFlag{
Name: "cdi-output-dir",
Usage: "the directory where the CDI output files are to be written. If this is set to '', no CDI specification is generated.",
Value: "/var/run/cdi",
Destination: &opts.cdiOutputDir,
Destination: &opts.CDI.outputDir,
EnvVars: []string{"CDI_OUTPUT_DIR"},
},
&cli.StringFlag{
Name: "cdi-kind",
Usage: "the vendor string to use for the generated CDI specification",
Value: "management.nvidia.com/gpu",
Destination: &opts.cdiKind,
Destination: &opts.CDI.kind,
EnvVars: []string{"CDI_KIND"},
},
&cli.BoolFlag{
Expand Down Expand Up @@ -240,19 +245,19 @@ func (t *Installer) ValidateOptions(opts *Options) error {
return fmt.Errorf("invalid --toolkit-root option: %v", t.toolkitRoot)
}

vendor, class := parser.ParseQualifier(opts.cdiKind)
vendor, class := parser.ParseQualifier(opts.CDI.kind)
if err := parser.ValidateVendorName(vendor); err != nil {
return fmt.Errorf("invalid CDI vendor name: %v", err)
}
if err := parser.ValidateClassName(class); err != nil {
return fmt.Errorf("invalid CDI class name: %v", err)
}
opts.cdiVendor = vendor
opts.cdiClass = class
opts.CDI.vendor = vendor
opts.CDI.class = class

if opts.cdiEnabled && opts.cdiOutputDir == "" {
if opts.CDI.Enabled && opts.CDI.outputDir == "" {
t.logger.Warning("Skipping CDI spec generation (no output directory specified)")
opts.cdiEnabled = false
opts.CDI.Enabled = false
}

isDisabled := false
Expand All @@ -265,7 +270,7 @@ func (t *Installer) ValidateOptions(opts *Options) error {
break
}
}
if !opts.cdiEnabled && !isDisabled {
if !opts.CDI.Enabled && !isDisabled {
t.logger.Info("disabling device node creation since --cdi-enabled=false")
isDisabled = true
}
Expand Down Expand Up @@ -698,7 +703,7 @@ func (t *Installer) createDeviceNodes(opts *Options) error {

// generateCDISpec generates a CDI spec for use in management containers
func (t *Installer) generateCDISpec(opts *Options, nvidiaCDIHookPath string) error {
if !opts.cdiEnabled {
if !opts.CDI.Enabled {
return nil
}
t.logger.Info("Generating CDI spec for management containers")
Expand All @@ -708,8 +713,8 @@ func (t *Installer) generateCDISpec(opts *Options, nvidiaCDIHookPath string) err
nvcdi.WithDriverRoot(opts.DriverRootCtrPath),
nvcdi.WithDevRoot(opts.DevRootCtrPath),
nvcdi.WithNVIDIACDIHookPath(nvidiaCDIHookPath),
nvcdi.WithVendor(opts.cdiVendor),
nvcdi.WithClass(opts.cdiClass),
nvcdi.WithVendor(opts.CDI.vendor),
nvcdi.WithClass(opts.CDI.class),
)
if err != nil {
return fmt.Errorf("failed to create CDI library for management containers: %v", err)
Expand All @@ -734,7 +739,7 @@ func (t *Installer) generateCDISpec(opts *Options, nvidiaCDIHookPath string) err
if err != nil {
return fmt.Errorf("failed to generate CDI name for management containers: %v", err)
}
err = spec.Save(filepath.Join(opts.cdiOutputDir, name))
err = spec.Save(filepath.Join(opts.CDI.outputDir, name))
if err != nil {
return fmt.Errorf("failed to save CDI spec for management containers: %v", err)
}
Expand Down
8 changes: 5 additions & 3 deletions cmd/nvidia-ctk-installer/container/toolkit/toolkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ kind: example.com/class
options := Options{
DriverRoot: "/host/driver/root",
DriverRootCtrPath: filepath.Join(moduleRoot, "testdata", "lookup", tc.hostRoot),
cdiEnabled: tc.cdiEnabled,
cdiOutputDir: cdiOutputDir,
cdiKind: "example.com/class",
CDI: cdiOptions{
Enabled: tc.cdiEnabled,
outputDir: cdiOutputDir,
kind: "example.com/class",
},
}

ti := NewInstaller(
Expand Down
13 changes: 11 additions & 2 deletions cmd/nvidia-ctk-installer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type options struct {
runtimeArgs string
root string
pidFile string
sourceRoot string

toolkitOptions toolkit.Options
runtimeOptions runtime.Options
Expand Down Expand Up @@ -141,6 +142,13 @@ func (a app) build() *cli.App {
Destination: &options.root,
EnvVars: []string{"ROOT"},
},
&cli.StringFlag{
Name: "source-root",
Value: "/",
Usage: "The folder where the required toolkit artifacts can be found",
Destination: &options.sourceRoot,
EnvVars: []string{"SOURCE_ROOT"},
},
&cli.StringFlag{
Name: "pid-file",
Value: defaultPidFile,
Expand All @@ -159,12 +167,13 @@ func (a app) build() *cli.App {
func (a *app) Before(c *cli.Context, o *options) error {
a.toolkit = toolkit.NewInstaller(
toolkit.WithLogger(a.logger),
toolkit.WithSourceRoot(o.sourceRoot),
toolkit.WithToolkitRoot(o.toolkitRoot()),
)
return a.validateFlags(c, o)
}

func (a *app) validateFlags(_ *cli.Context, o *options) error {
func (a *app) validateFlags(c *cli.Context, o *options) error {
if o.root == "" {
return fmt.Errorf("the install root must be specified")
}
Expand All @@ -178,7 +187,7 @@ func (a *app) validateFlags(_ *cli.Context, o *options) error {
if err := a.toolkit.ValidateOptions(&o.toolkitOptions); err != nil {
return err
}
if err := runtime.ValidateOptions(&o.runtimeOptions, o.runtime, o.toolkitRoot()); err != nil {
if err := runtime.ValidateOptions(c, &o.runtimeOptions, o.runtime, o.toolkitRoot(), &o.toolkitOptions); err != nil {
return err
}
return nil
Expand Down
Loading

0 comments on commit df4c87b

Please sign in to comment.