diff --git a/Dockerfile b/Dockerfile index efa7a6f..c72fd90 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,9 +23,10 @@ COPY ./ ./ #RUN CGO_ENABLED=0 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} go build -a -o manager cmd/maintenance-manager/main.go RUN --mount=type=cache,target=/go/pkg/mod/ GO_GCFLAGS=${GCFLAGS} make build-manager -# Use distroless as minimal base image to package the manager binary -# Refer to https://github.com/GoogleContainerTools/distroless for more details -FROM gcr.io/distroless/static:nonroot +FROM quay.io/centos/centos:stream9 + +RUN yum -y install mstflint && yum clean all + WORKDIR / COPY --from=builder /workspace/build/manager . COPY bindata /bindata diff --git a/cmd/manager/main.go b/cmd/manager/main.go index b68966f..7f3b645 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -37,6 +37,7 @@ import ( configurationnetv1alpha1 "github.com/Mellanox/nic-configuration-operator/api/v1alpha1" "github.com/Mellanox/nic-configuration-operator/internal/controller" + "github.com/Mellanox/nic-configuration-operator/pkg/firmware" "github.com/Mellanox/nic-configuration-operator/pkg/ncolog" "github.com/Mellanox/nic-configuration-operator/pkg/version" //+kubebuilder:scaffold:imports @@ -145,6 +146,15 @@ func main() { setupLog.Error(err, "unable to create controller", "controller", "NicConfigurationTemplate") os.Exit(1) } + + if err = (&controller.NicFirmwareSourceReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + FirmwareProvisioner: firmware.NewFirmwareProvisioner(), + }).SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create controller", "controller", "NicFirmwareSource") + os.Exit(1) + } //+kubebuilder:scaffold:builder if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil { diff --git a/internal/controller/nicfirmwaresource_controller.go b/internal/controller/nicfirmwaresource_controller.go new file mode 100644 index 0000000..81bf3ab --- /dev/null +++ b/internal/controller/nicfirmwaresource_controller.go @@ -0,0 +1,139 @@ +/* +2025 NVIDIA CORPORATION & AFFILIATES +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + "context" + + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + + "github.com/Mellanox/nic-configuration-operator/api/v1alpha1" + "github.com/Mellanox/nic-configuration-operator/pkg/consts" + "github.com/Mellanox/nic-configuration-operator/pkg/firmware" +) + +// NicFirmwareSourceReconciler reconciles a NicDevice object +type NicFirmwareSourceReconciler struct { + client.Client + Scheme *runtime.Scheme + + FirmwareProvisioner firmware.FirmwareProvisioner +} + +// Reconcile reconciles the NicFirmwareSource object +func (r *NicFirmwareSourceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + // Fetch the HostDeviceNetwork instance + instance := &v1alpha1.NicFirmwareSource{} + err := r.Get(ctx, req.NamespacedName, instance) + + // TODO use finalizers to clean up cache storage after CR deletion + + if err != nil { + if errors.IsNotFound(err) { + // Request object not found, could have been deleted after reconcile request. + // Owned objects are automatically garbage collected. + // Return and don't requeue + return reconcile.Result{}, nil + } + // Error reading the object - requeue the request. + return reconcile.Result{}, err + } + + cacheName := instance.Name + + urlsToProcess, err := r.FirmwareProvisioner.VerifyCachedBinaries(cacheName, instance.Spec.BinUrlSource) + if err != nil { + if err = r.updateStatus(ctx, instance, consts.FirmwareSourceCacheVerificationFailedStatus, err, nil); err != nil { + return reconcile.Result{}, err + } + return reconcile.Result{}, err + } + if len(urlsToProcess) == 0 { + if err = r.updateStatus(ctx, instance, consts.FirmwareSourceProcessingStatus, nil, nil); err != nil { + return reconcile.Result{}, err + } + + return r.ValidateCache(ctx, instance) + } + + if err = r.updateStatus(ctx, instance, consts.FirmwareSourceDownloadingStatus, nil, nil); err != nil { + return reconcile.Result{}, err + } + + err = r.FirmwareProvisioner.DownloadAndUnzipFirmwareArchives(cacheName, urlsToProcess, true) + if err != nil { + if err = r.updateStatus(ctx, instance, consts.FirmwareSourceDownloadFailedStatus, err, nil); err != nil { + return reconcile.Result{}, err + } + return reconcile.Result{}, err + } + + if err = r.updateStatus(ctx, instance, consts.FirmwareSourceProcessingStatus, nil, nil); err != nil { + return reconcile.Result{}, err + } + + err = r.FirmwareProvisioner.AddFirmwareBinariesToCacheByMetadata(cacheName) + if err != nil { + if err = r.updateStatus(ctx, instance, consts.FirmwareSourceProcessingFailedStatus, err, nil); err != nil { + return reconcile.Result{}, err + } + return reconcile.Result{}, err + } + + return r.ValidateCache(ctx, instance) +} + +func (r *NicFirmwareSourceReconciler) ValidateCache(ctx context.Context, instance *v1alpha1.NicFirmwareSource) (reconcile.Result, error) { + versions, err := r.FirmwareProvisioner.ValidateCache(instance.Name) + if err != nil { + if err = r.updateStatus(ctx, instance, consts.FirmwareSourceProcessingFailedStatus, err, nil); err != nil { + return reconcile.Result{}, err + } + return reconcile.Result{}, err + } + + if err = r.updateStatus(ctx, instance, consts.FirmwareSourceSuccessStatus, nil, versions); err != nil { + return reconcile.Result{}, err + } + + return ctrl.Result{}, nil +} + +func (r *NicFirmwareSourceReconciler) updateStatus(ctx context.Context, obj *v1alpha1.NicFirmwareSource, status string, err error, versions map[string][]string) error { + obj.Status.State = status + if err != nil { + obj.Status.Reason = err.Error() + } else { + obj.Status.Reason = "" + } + + obj.Status.Versions = versions + return r.Status().Update(ctx, obj) +} + +// SetupWithManager sets up the controller with the Manager. +func (r *NicFirmwareSourceReconciler) SetupWithManager(mgr ctrl.Manager) error { + controller := ctrl.NewControllerManagedBy(mgr). + For(&v1alpha1.NicFirmwareSource{}) + + return controller. + Named("nicFirmwareSourceReconciler"). + Complete(r) +} diff --git a/internal/controller/nicfirmwaresource_controller_test.go b/internal/controller/nicfirmwaresource_controller_test.go new file mode 100644 index 0000000..68f9d30 --- /dev/null +++ b/internal/controller/nicfirmwaresource_controller_test.go @@ -0,0 +1,316 @@ +/* +2025 NVIDIA CORPORATION & AFFILIATES +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + "context" + "errors" + "sync" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/stretchr/testify/mock" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/utils/ptr" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/config" + "sigs.k8s.io/controller-runtime/pkg/manager" + metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + + "github.com/Mellanox/nic-configuration-operator/api/v1alpha1" + "github.com/Mellanox/nic-configuration-operator/pkg/consts" + "github.com/Mellanox/nic-configuration-operator/pkg/firmware/mocks" +) + +const ( + crName = "nic-fw-source" + crNamespace = "default" +) + +var _ = Describe("NicConfigurationTemplate Controller", func() { + var ( + mgr manager.Manager + k8sClient client.Client + reconciler *NicFirmwareSourceReconciler + ctx context.Context + cancel context.CancelFunc + firmwareProvisioner mocks.FirmwareProvisioner + + err error + ) + + getCR := func(name, namespace string) (*v1alpha1.NicFirmwareSource, error) { + key := types.NamespacedName{Name: name, Namespace: namespace} + cr := &v1alpha1.NicFirmwareSource{} + err := k8sClient.Get(context.Background(), key, cr) + return cr, err + } + + createCR := func(name, namespace string) { + By("creating NicFirmwareSource CR") + cr := &v1alpha1.NicFirmwareSource{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + Spec: v1alpha1.NicFirmwareSourceSpec{ + BinUrlSource: []string{"https://firmware.example.com/fwA.zip"}, + }, + } + Expect(k8sClient.Create(context.Background(), cr)).To(Succeed()) + } + + ValidateCRStatusAndReason := func(name, namespace, status, reason string) { + Eventually(func(g Gomega) []string { + cr, err := getCR(name, namespace) + g.Expect(err).NotTo(HaveOccurred()) + return []string{cr.Status.State, cr.Status.Reason} + }).Should(Equal([]string{status, reason})) + } + + ValidateCRReportedVersions := func(name, namespace string, versions map[string][]string) { + Eventually(func(g Gomega) map[string][]string { + cr, err := getCR(name, namespace) + g.Expect(err).NotTo(HaveOccurred()) + return cr.Status.Versions + }).Should(Equal(versions)) + } + + BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) + mgr, err = ctrl.NewManager(cfg, ctrl.Options{ + Scheme: scheme.Scheme, + Metrics: metricsserver.Options{BindAddress: "0"}, + Controller: config.Controller{SkipNameValidation: ptr.To(true)}, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(mgr).NotTo(BeNil()) + + k8sClient = mgr.GetClient() + + firmwareProvisioner = mocks.FirmwareProvisioner{} + + reconciler = &NicFirmwareSourceReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + FirmwareProvisioner: &firmwareProvisioner, + } + + Expect(reconciler.SetupWithManager(mgr)).To(Succeed()) + + testMgrCtx, cancel := context.WithCancel(ctx) + By("start manager") + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + defer GinkgoRecover() + By("Start controller manager") + err := mgr.Start(testMgrCtx) + Expect(err).ToNot(HaveOccurred()) + }() + + DeferCleanup(func() { + By("Shut down controller manager") + cancel() + wg.Wait() + }) + }) + + AfterEach(func() { + Expect(k8sClient.DeleteAllOf(ctx, &v1alpha1.NicFirmwareSource{}, client.InNamespace(crNamespace))).To(Succeed()) + + cancel() + }) + + It("should set the success status if the firmware provisioner did not return any errors", func() { + versionsMap := map[string][]string{"1.2.3": {"psid1"}} + + firmwareProvisioner.On("VerifyCachedBinaries", crName, mock.AnythingOfType("[]string")). + Return([]string{"http://firmware.example.com/fwA.zip"}, nil). + Once() + + firmwareProvisioner.On("DownloadAndUnzipFirmwareArchives", crName, []string{"http://firmware.example.com/fwA.zip"}, true). + Return(nil). + Once() + + firmwareProvisioner.On("AddFirmwareBinariesToCacheByMetadata", crName). + Return(nil). + Once() + + firmwareProvisioner.On("ValidateCache", crName). + Return(versionsMap, nil). + Once() + + createCR(crName, crNamespace) + + ValidateCRStatusAndReason(crName, crNamespace, consts.FirmwareSourceSuccessStatus, "") + ValidateCRReportedVersions(crName, crNamespace, versionsMap) + }) + + It("should set the CacheVerification status if the firmware provisioner failed to verify the existing cache", func() { + errMsg := "failed to verify cache" + firmwareProvisioner.On("VerifyCachedBinaries", crName, mock.AnythingOfType("[]string")). + Return([]string(nil), errors.New(errMsg)). + Once() + + firmwareProvisioner.AssertNotCalled(GinkgoT(), "DownloadAndUnzipFirmwareArchives", crName, []string{}, true) + firmwareProvisioner.AssertNotCalled(GinkgoT(), "AddFirmwareBinariesToCacheByMetadata", crName) + firmwareProvisioner.AssertNotCalled(GinkgoT(), "ValidateCache", crName) + + createCR(crName, crNamespace) + + ValidateCRStatusAndReason(crName, crNamespace, consts.FirmwareSourceCacheVerificationFailedStatus, errMsg) + }) + + It("should set the success status if no urls need to be processed after cache verification", func() { + versionsMap := map[string][]string{"1.2.3": {"psid1"}} + + firmwareProvisioner.On("VerifyCachedBinaries", crName, mock.AnythingOfType("[]string")). + Return([]string{}, nil). + Once() + + firmwareProvisioner.AssertNotCalled(GinkgoT(), "DownloadAndUnzipFirmwareArchives", crName, []string{}, true) + firmwareProvisioner.AssertNotCalled(GinkgoT(), "AddFirmwareBinariesToCacheByMetadata", crName) + + firmwareProvisioner.On("ValidateCache", crName). + Return(versionsMap, nil). + Once() + + createCR(crName, crNamespace) + + ValidateCRStatusAndReason(crName, crNamespace, consts.FirmwareSourceSuccessStatus, "") + ValidateCRReportedVersions(crName, crNamespace, versionsMap) + }) + + It("should set the Downloading status when the firmware provisioner is downloading the firmware binaries", func() { + versionsMap := map[string][]string{"1.2.3": {"psid4"}} + + firmwareProvisioner.On("VerifyCachedBinaries", crName, mock.AnythingOfType("[]string")). + Return([]string{"http://firmware.example.com/fwA.zip"}, nil). + Once() + + // Simulate a slow download so we can observe the intermediate status + firmwareProvisioner.On("DownloadAndUnzipFirmwareArchives", crName, []string{"http://firmware.example.com/fwA.zip"}, true). + Run(func(args mock.Arguments) { + time.Sleep(1 * time.Second) + }). + Return(nil). + Once() + + firmwareProvisioner.On("AddFirmwareBinariesToCacheByMetadata", crName). + Return(nil). + Once() + + firmwareProvisioner.On("ValidateCache", crName). + Return(versionsMap, nil). + Once() + + createCR(crName, crNamespace) + + // We want to see "Downloading" before it completes + Eventually(func(g Gomega) string { + cr, err := getCR(crName, crNamespace) + g.Expect(err).NotTo(HaveOccurred()) + return cr.Status.State + }, 500*time.Millisecond, 100*time.Millisecond).Should(Equal(consts.FirmwareSourceDownloadingStatus)) + + // Eventually it should succeed with an empty reason + ValidateCRStatusAndReason(crName, crNamespace, consts.FirmwareSourceDownloadingStatus, "") + ValidateCRStatusAndReason(crName, crNamespace, consts.FirmwareSourceSuccessStatus, "") + ValidateCRReportedVersions(crName, crNamespace, versionsMap) + }) + + It("should set the DownloadFailed status if the firmware provisioner failed to download the binaries", func() { + errMsg := "failed to download" + firmwareProvisioner.On("VerifyCachedBinaries", crName, mock.AnythingOfType("[]string")). + Return([]string{"http://firmware.example.com/fwA.zip"}, nil). + Once() + + firmwareProvisioner.On("DownloadAndUnzipFirmwareArchives", crName, []string{"http://firmware.example.com/fwA.zip"}, true). + Return(errors.New(errMsg)). + Once() + + firmwareProvisioner.AssertNotCalled(GinkgoT(), "AddFirmwareBinariesToCacheByMetadata", crName) + firmwareProvisioner.AssertNotCalled(GinkgoT(), "ValidateCache", crName) + + createCR(crName, crNamespace) + + ValidateCRStatusAndReason(crName, crNamespace, consts.FirmwareSourceDownloadFailedStatus, errMsg) + }) + + It("should set the Processing status if the firmware provisioner is organizing the firmware binaries", func() { + versionsMap := map[string][]string{"15.5.3": {"psid4"}} + + firmwareProvisioner.On("VerifyCachedBinaries", crName, mock.AnythingOfType("[]string")). + Return([]string{"http://firmware.example.com/fwA.zip"}, nil). + Once() + + firmwareProvisioner.On("DownloadAndUnzipFirmwareArchives", crName, []string{"http://firmware.example.com/fwA.zip"}, true). + Return(nil). + Once() + + // Simulate a short delay for organizing + firmwareProvisioner.On("AddFirmwareBinariesToCacheByMetadata", crName). + Run(func(args mock.Arguments) { + time.Sleep(1 * time.Second) // Enough delay to observe "Processing" + }). + Return(nil). + Once() + + firmwareProvisioner.On("ValidateCache", crName). + Return(versionsMap, nil). + Once() + + createCR(crName, crNamespace) + + // After the download, the code sets it to "Processing" + Eventually(func(g Gomega) string { + cr, err := getCR(crName, crNamespace) + g.Expect(err).NotTo(HaveOccurred()) + return cr.Status.State + }, 500*time.Millisecond, 100*time.Millisecond).Should(Equal(consts.FirmwareSourceProcessingStatus)) + + // Eventually it should succeed with an empty reason + ValidateCRStatusAndReason(crName, crNamespace, consts.FirmwareSourceSuccessStatus, "") + ValidateCRReportedVersions(crName, crNamespace, versionsMap) + }) + + It("should set the ProcessingFailed status if the firmware provisioner failed to organize the binaries", func() { + errMsg := "failed to organize" + firmwareProvisioner.On("VerifyCachedBinaries", crName, mock.AnythingOfType("[]string")). + Return([]string{"http://firmware.example.com/fwA.zip"}, nil). + Once() + + firmwareProvisioner.On("DownloadAndUnzipFirmwareArchives", crName, []string{"http://firmware.example.com/fwA.zip"}, true). + Return(nil). + Once() + + firmwareProvisioner.On("AddFirmwareBinariesToCacheByMetadata", crName). + Return(errors.New(errMsg)). + Once() + + firmwareProvisioner.AssertNotCalled(GinkgoT(), "ValidateCache", crName) + + createCR(crName, crNamespace) + + ValidateCRStatusAndReason(crName, crNamespace, consts.FirmwareSourceProcessingFailedStatus, errMsg) + }) +}) diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index 6599bb9..a632f1b 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -86,4 +86,15 @@ const ( ConfigDaemonManifestsPath = "./bindata/manifests/daemon" OperatorConfigMapName = "nic-configuration-operator-config" + + NicFirmwareStorage = "/nic-firmware" + NicFirmwareBinariesFolder = "firmware-binaries" + NicFirmwareBinaryFileExtension = ".bin" + + FirmwareSourceDownloadingStatus = "Downloading" + FirmwareSourceDownloadFailedStatus = "DownloadFailed" + FirmwareSourceProcessingStatus = "Processing" + FirmwareSourceProcessingFailedStatus = "ProcessingFailed" + FirmwareSourceSuccessStatus = "Success" + FirmwareSourceCacheVerificationFailedStatus = "CacheVerificationFailed" ) diff --git a/pkg/firmware/mocks/FirmwareProvisioner.go b/pkg/firmware/mocks/FirmwareProvisioner.go new file mode 100644 index 0000000..ac0d62a --- /dev/null +++ b/pkg/firmware/mocks/FirmwareProvisioner.go @@ -0,0 +1,120 @@ +// Code generated by mockery v2.46.3. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// FirmwareProvisioner is an autogenerated mock type for the FirmwareProvisioner type +type FirmwareProvisioner struct { + mock.Mock +} + +// AddFirmwareBinariesToCacheByMetadata provides a mock function with given fields: cacheName +func (_m *FirmwareProvisioner) AddFirmwareBinariesToCacheByMetadata(cacheName string) error { + ret := _m.Called(cacheName) + + if len(ret) == 0 { + panic("no return value specified for AddFirmwareBinariesToCacheByMetadata") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(cacheName) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DownloadAndUnzipFirmwareArchives provides a mock function with given fields: cacheName, urls, cleanupArchives +func (_m *FirmwareProvisioner) DownloadAndUnzipFirmwareArchives(cacheName string, urls []string, cleanupArchives bool) error { + ret := _m.Called(cacheName, urls, cleanupArchives) + + if len(ret) == 0 { + panic("no return value specified for DownloadAndUnzipFirmwareArchives") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, []string, bool) error); ok { + r0 = rf(cacheName, urls, cleanupArchives) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ValidateCache provides a mock function with given fields: cacheName +func (_m *FirmwareProvisioner) ValidateCache(cacheName string) (map[string][]string, error) { + ret := _m.Called(cacheName) + + if len(ret) == 0 { + panic("no return value specified for ValidateCache") + } + + var r0 map[string][]string + var r1 error + if rf, ok := ret.Get(0).(func(string) (map[string][]string, error)); ok { + return rf(cacheName) + } + if rf, ok := ret.Get(0).(func(string) map[string][]string); ok { + r0 = rf(cacheName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string][]string) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(cacheName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// VerifyCachedBinaries provides a mock function with given fields: cacheName, urls +func (_m *FirmwareProvisioner) VerifyCachedBinaries(cacheName string, urls []string) ([]string, error) { + ret := _m.Called(cacheName, urls) + + if len(ret) == 0 { + panic("no return value specified for VerifyCachedBinaries") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(string, []string) ([]string, error)); ok { + return rf(cacheName, urls) + } + if rf, ok := ret.Get(0).(func(string, []string) []string); ok { + r0 = rf(cacheName, urls) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(string, []string) error); ok { + r1 = rf(cacheName, urls) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewFirmwareProvisioner creates a new instance of FirmwareProvisioner. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewFirmwareProvisioner(t interface { + mock.TestingT + Cleanup(func()) +}) *FirmwareProvisioner { + mock := &FirmwareProvisioner{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/firmware/mocks/ProvisioningUtils.go b/pkg/firmware/mocks/ProvisioningUtils.go new file mode 100644 index 0000000..ce2125b --- /dev/null +++ b/pkg/firmware/mocks/ProvisioningUtils.go @@ -0,0 +1,125 @@ +// Code generated by mockery v2.46.3. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// ProvisioningUtils is an autogenerated mock type for the ProvisioningUtils type +type ProvisioningUtils struct { + mock.Mock +} + +// CleanupDirectory provides a mock function with given fields: root, allowedSet +func (_m *ProvisioningUtils) CleanupDirectory(root string, allowedSet map[string]struct{}) error { + ret := _m.Called(root, allowedSet) + + if len(ret) == 0 { + panic("no return value specified for CleanupDirectory") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, map[string]struct{}) error); ok { + r0 = rf(root, allowedSet) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DownloadFile provides a mock function with given fields: url, destPath +func (_m *ProvisioningUtils) DownloadFile(url string, destPath string) error { + ret := _m.Called(url, destPath) + + if len(ret) == 0 { + panic("no return value specified for DownloadFile") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, string) error); ok { + r0 = rf(url, destPath) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetFirmwareVersionAndPSID provides a mock function with given fields: firmwareBinaryPath +func (_m *ProvisioningUtils) GetFirmwareVersionAndPSID(firmwareBinaryPath string) (string, string, error) { + ret := _m.Called(firmwareBinaryPath) + + if len(ret) == 0 { + panic("no return value specified for GetFirmwareVersionAndPSID") + } + + var r0 string + var r1 string + var r2 error + if rf, ok := ret.Get(0).(func(string) (string, string, error)); ok { + return rf(firmwareBinaryPath) + } + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(firmwareBinaryPath) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(string) string); ok { + r1 = rf(firmwareBinaryPath) + } else { + r1 = ret.Get(1).(string) + } + + if rf, ok := ret.Get(2).(func(string) error); ok { + r2 = rf(firmwareBinaryPath) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// UnzipFiles provides a mock function with given fields: zipPath, destDir +func (_m *ProvisioningUtils) UnzipFiles(zipPath string, destDir string) ([]string, error) { + ret := _m.Called(zipPath, destDir) + + if len(ret) == 0 { + panic("no return value specified for UnzipFiles") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(string, string) ([]string, error)); ok { + return rf(zipPath, destDir) + } + if rf, ok := ret.Get(0).(func(string, string) []string); ok { + r0 = rf(zipPath, destDir) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(zipPath, destDir) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewProvisioningUtils creates a new instance of ProvisioningUtils. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewProvisioningUtils(t interface { + mock.TestingT + Cleanup(func()) +}) *ProvisioningUtils { + mock := &ProvisioningUtils{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/firmware/provisioning.go b/pkg/firmware/provisioning.go new file mode 100644 index 0000000..1d338d7 --- /dev/null +++ b/pkg/firmware/provisioning.go @@ -0,0 +1,353 @@ +package firmware + +import ( + "errors" + "fmt" + "os" + "path" + "path/filepath" + "strings" + + "k8s.io/apimachinery/pkg/util/json" + "sigs.k8s.io/controller-runtime/pkg/log" + + "github.com/Mellanox/nic-configuration-operator/pkg/consts" +) + +const metadataFileName = "metadata.json" + +type cacheMetadata map[string][]string + +type FirmwareProvisioner interface { + // VerifyCachedBinaries checks against the metadata.json for which urls have corresponding cached fw binary files + // Returns a list of urls that need to be processed again + VerifyCachedBinaries(cacheName string, urls []string) ([]string, error) + // DownloadAndUnzipFirmwareArchives downloads and unzips fw archives from a list of urls + // Stores a metadata file, mapping download url to file names + // Returns binaries' filenames + DownloadAndUnzipFirmwareArchives(cacheName string, urls []string, cleanupArchives bool) error + // AddFirmwareBinariesToCacheByMetadata finds the newly downloaded firmware binary files and organizes them in the cache according to their metadata + AddFirmwareBinariesToCacheByMetadata(cacheName string) error + // ValidateCache traverses the cache directory and validates that + // 1. There are no empty directories in the cache + // 2. Each PSID has only one matching firmware binary in the cache + // 3. Each non-empty PSID directory contains a firmware binary file (.bin) + // Returns mapping between firmware version to PSIDs available in the cache, error if validation failed + ValidateCache(cacheName string) (map[string][]string, error) +} + +type firmwareProvisioner struct { + cacheRootDir string + + utils ProvisioningUtils +} + +// VerifyCachedBinaries checks against the metadata.json for which urls have corresponding cached fw binary files +// Returns a list of urls that need to be processed again +func (f firmwareProvisioner) VerifyCachedBinaries(cacheName string, urls []string) ([]string, error) { + cacheDir := path.Join(f.cacheRootDir, cacheName, consts.NicFirmwareBinariesFolder) + // Nothing to verify if the dir does not exist + if _, err := os.Stat(cacheDir); os.IsNotExist(err) { + // If cache doesn't exist, there's nothing to validate, we need to process all urls + return urls, nil + } else if err != nil { + return nil, err + } + + metadataFile := path.Join(cacheDir, metadataFileName) + if _, err := os.Stat(metadataFile); os.IsNotExist(err) { + // If cache metadata file doesn't exist, clean up the cache directory because we can't validate the contents + if err = os.RemoveAll(cacheDir); err != nil { + return nil, err + } + return urls, nil + } + + metadata, err := readMetadataFromFile(metadataFile) + if err != nil { + log.Log.Error(err, "failed to read cache metadata file", "path", metadataFile) + return nil, err + } + + var urlsToProcessAgain []string + filesAccountedFor := map[string]struct{}{} + + for _, url := range urls { + files, found := metadata[url] + if !found { + urlsToProcessAgain = append(urlsToProcessAgain, url) + continue + } + + // If at least one file does not exist for this url, need to download all of them again, thus delete the cached versions first + deleteFilesForThisUrl := false + + for _, file := range files { + if _, err := os.Stat(file); os.IsNotExist(err) { + deleteFilesForThisUrl = true + urlsToProcessAgain = append(urlsToProcessAgain, url) + delete(metadata, url) + + break + } else if err != nil { + return nil, err + } + + filesAccountedFor[file] = struct{}{} + } + + if deleteFilesForThisUrl { + for _, file := range files { + if _, err := os.Stat(file); err == nil { + if err = os.Remove(file); err != nil { + return nil, err + } + } + } + } + } + + // After the cache files are verified, clean up everything else in the directory + if err = f.utils.CleanupDirectory(cacheDir, filesAccountedFor); err != nil { + return nil, err + } + + // Write updated metadata file to disc + if err := writeMetadataFile(metadata, cacheDir); err != nil { + return nil, err + } + + return urlsToProcessAgain, nil +} + +// DownloadAndUnzipFirmwareArchives downloads and unzips fw archives from a list of urls +// Stores a metadata file, mapping download url to file names +// Returns binaries' filenames +func (f firmwareProvisioner) DownloadAndUnzipFirmwareArchives(cacheName string, urls []string, cleanupArchives bool) error { + firmwareBinariesDir := path.Join(f.cacheRootDir, cacheName, consts.NicFirmwareBinariesFolder) + + if _, err := os.Stat(firmwareBinariesDir); os.IsNotExist(err) { + err := os.MkdirAll(firmwareBinariesDir, 0755) + if err != nil { + log.Log.Error(err, "failed to create new cache in nic fw storage", "cacheName", cacheName) + return err + } + } + + var urlsToFiles cacheMetadata + + metadataFile := path.Join(firmwareBinariesDir, metadataFileName) + if _, err := os.Stat(metadataFile); os.IsNotExist(err) { + urlsToFiles = cacheMetadata{} + } else if err != nil { + return err + } else { + urlsToFiles, err = readMetadataFromFile(metadataFile) + if err != nil { + log.Log.Error(err, "failed to read cache metadata file", "path", metadataFile) + return err + } + } + + for _, url := range urls { + archiveLocalPath := filepath.Join(firmwareBinariesDir, filepath.Base(url)) + err := f.utils.DownloadFile(url, archiveLocalPath) + if err != nil { + log.Log.Error(err, "failed to download fw archive", "cacheName", cacheName, "url", url) + return err + } + files, err := f.utils.UnzipFiles(archiveLocalPath, firmwareBinariesDir) + if err != nil { + log.Log.Error(err, "failed to unzip fw archive", "cacheName", cacheName, "url", url) + return err + } + urlsToFiles[url] = files + + if cleanupArchives { + err = os.Remove(archiveLocalPath) + if err != nil { + log.Log.Error(err, "failed to remove fw archive file", "cacheName", cacheName, "url", url) + return err + } + } + } + + if err := writeMetadataFile(urlsToFiles, firmwareBinariesDir); err != nil { + return err + } + + return nil +} + +func writeMetadataFile(metadata cacheMetadata, cacheDir string) error { + jsonData, err := json.Marshal(metadata) + if err != nil { + log.Log.Error(err, "failed to process cache metadata", "cacheDir", cacheDir, "metadata", metadata) + return err + } + + err = os.WriteFile(path.Join(cacheDir, metadataFileName), jsonData, 0644) + if err != nil { + log.Log.Error(err, "failed to save cache metadata", "cacheDir", cacheDir, "metadata", metadata) + } + return nil +} + +// AddFirmwareBinariesToCacheByMetadata finds the newly downloaded firmware binary files and organizes them in the cache according to their metadata +func (f firmwareProvisioner) AddFirmwareBinariesToCacheByMetadata(cacheName string) error { + cacheDir := path.Join(f.cacheRootDir, cacheName) + firmwareBinariesDir := path.Join(cacheDir, "firmware-binaries") + entries, err := os.ReadDir(firmwareBinariesDir) + if err != nil { + log.Log.Error(err, "failed to read firmware binaries cache", "cacheName", cacheName) + return err + } + + for _, entry := range entries { + // We only want to process the firmware binary files + if !strings.EqualFold(filepath.Ext(entry.Name()), consts.NicFirmwareBinaryFileExtension) { + continue + } + + sourcePath := filepath.Join(firmwareBinariesDir, entry.Name()) + + version, psid, err := f.utils.GetFirmwareVersionAndPSID(sourcePath) + if err != nil { + log.Log.Error(err, "failed to get firmware binary version and PSID", "cacheName", cacheName, "file", entry.Name()) + return err + } + + targetDir := path.Join(firmwareBinariesDir, version, psid) + + if _, err := os.Stat(targetDir); os.IsNotExist(err) { + err := os.MkdirAll(targetDir, 0755) + if err != nil { + log.Log.Error(err, "failed to create directory in nic fw storage", "cacheName", cacheName, "path", targetDir) + return err + } + } else { + entries, err := os.ReadDir(targetDir) + if err != nil { + log.Log.Error(err, "failed to read directory in nic fw storage", "cacheName", cacheName, "path", targetDir) + return err + } + if len(entries) != 0 { + err = errors.New("target directory for firmware binary file is supposed to be empty, found files") + log.Log.Error(err, "found existing files in the fw binary file directory", "cacheName", cacheName, "path", targetDir) + + return err + } + } + + targetPath := path.Join(targetDir, entry.Name()) + err = os.Rename(sourcePath, targetPath) + if err != nil { + log.Log.Error(err, "failed to place firmware binary file in cache", "cacheName", cacheName, "path", targetPath) + } + } + + return nil +} + +// ValidateCache traverses the cache directory and validates that +// 1. There are no empty directories in the cache +// 2. Each PSID has only one matching firmware binary in the cache +// 3. Each non-empty PSID directory contains a firmware binary file (.bin) +// Returns mapping between firmware version to PSIDs available in the cache, error if validation failed +func (f firmwareProvisioner) ValidateCache(cacheName string) (map[string][]string, error) { + cacheDir := path.Join(f.cacheRootDir, cacheName) + firmwareBinariesDir := path.Join(cacheDir, "firmware-binaries") + cachedVersions := make(map[string][]string) + foundPSIDs := make(map[string]struct{}) + + firmwareVersions, err := os.ReadDir(firmwareBinariesDir) + if err != nil { + log.Log.Error(err, "failed to read directory in nic fw storage", "cacheName", cacheName, "path", firmwareBinariesDir) + return nil, err + } + + for _, firmwareVersion := range firmwareVersions { + if !firmwareVersion.IsDir() { + continue + } + + fwVersion := firmwareVersion.Name() + fwVersionPath := filepath.Join(firmwareBinariesDir, fwVersion) + + psids, err := os.ReadDir(fwVersionPath) + if err != nil { + log.Log.Error(err, "failed to read directory in nic fw storage", "cacheName", cacheName, "path", fwVersionPath) + return nil, err + } + + for _, psid := range psids { + if !psid.IsDir() { + continue + } + psid := psid.Name() + psidFolderPath := path.Join(fwVersionPath, psid) + entries, err := os.ReadDir(psidFolderPath) + if err != nil { + log.Log.Error(err, "failed to read directory in nic fw storage", "cacheName", cacheName, "path", psidFolderPath) + return nil, err + } + + if len(entries) == 0 { + err = fmt.Errorf("cache directory is empty. Expected firmware binary file. Cache name: %s, PSID: %s, Firmware version: %s", cacheName, psid, fwVersion) + log.Log.Error(err, "") + return nil, err + } + + binFileFound := false + + for _, entry := range entries { + if !strings.EqualFold(filepath.Ext(entry.Name()), consts.NicFirmwareBinaryFileExtension) { + continue + } else if binFileFound { + err = fmt.Errorf("multiple firmware binary files in the same directory. Cache name: %s, PSID: %s, Firmware version: %s", cacheName, psid, fwVersion) + log.Log.Error(err, "") + return nil, err + } + + binFileFound = true + } + + if !binFileFound { + err = fmt.Errorf("no firmware binary files in the PSID directory. Cache name: %s, PSID: %s, Firmware version: %s", cacheName, psid, fwVersion) + log.Log.Error(err, "") + return nil, err + } + + if _, found := foundPSIDs[psid]; found { + err = fmt.Errorf("multiple firmware binary files for the same PSID. Cache name: %s, PSID: %s, Firmware version: %s", cacheName, psid, fwVersion) + log.Log.Error(err, "") + return nil, err + } else { + foundPSIDs[psid] = struct{}{} + } + + cachedVersions[fwVersion] = append(cachedVersions[fwVersion], psid) + } + } + + return cachedVersions, nil +} + +func readMetadataFromFile(path string) (cacheMetadata, error) { + file, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + metadata := cacheMetadata{} + err = json.Unmarshal(file, &metadata) + if err != nil { + return nil, err + } + + return metadata, nil +} + +func NewFirmwareProvisioner() FirmwareProvisioner { + return firmwareProvisioner{cacheRootDir: consts.NicFirmwareStorage, utils: newFirmwareUtils()} +} diff --git a/pkg/firmware/provisioning_test.go b/pkg/firmware/provisioning_test.go new file mode 100644 index 0000000..aa5c90e --- /dev/null +++ b/pkg/firmware/provisioning_test.go @@ -0,0 +1,357 @@ +/* +2025 NVIDIA CORPORATION & AFFILIATES +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package firmware + +import ( + "encoding/json" + "errors" + "os" + "path" + "path/filepath" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/stretchr/testify/mock" + + "github.com/Mellanox/nic-configuration-operator/pkg/consts" + "github.com/Mellanox/nic-configuration-operator/pkg/firmware/mocks" +) + +var _ = Describe("FirmwareProvisioner", func() { + var ( + fwUtilsMock *mocks.ProvisioningUtils + fwProv FirmwareProvisioner + tmpDir string + cacheName string + cacheDir string + ) + + BeforeEach(func() { + fwUtilsMock = &mocks.ProvisioningUtils{} + + var err error + tmpDir, err = os.MkdirTemp("/tmp", "fwprovisioningtest-*") + Expect(err).NotTo(HaveOccurred()) + + cacheName = "test-cache" + cacheDir = path.Join(tmpDir, cacheName, consts.NicFirmwareBinariesFolder) + + fwProv = firmwareProvisioner{cacheRootDir: tmpDir, utils: fwUtilsMock} + }) + + AfterEach(func() { + fwUtilsMock.AssertExpectations(GinkgoT()) + + _ = os.RemoveAll(tmpDir) + }) + + Describe("VerifyCachedBinaries", func() { + var ( + urls []string + ) + + BeforeEach(func() { + urls = []string{"http://example.com/fwA.zip", "http://example.com/fwB.zip"} + }) + + Context("when the cache directory does not exist", func() { + It("should return all provided URLs for re-processing", func() { + // We won't create the directory, so it doesn't exist + reprocess, err := fwProv.VerifyCachedBinaries(cacheName, urls) + Expect(err).NotTo(HaveOccurred()) + Expect(reprocess).To(Equal(urls)) + }) + }) + + Context("when the metadata file is missing", func() { + It("should remove the entire cache directory and return all URLs", func() { + // Create the directory but not metadata.json + Expect(os.MkdirAll(cacheDir, 0755)).To(Succeed()) + + reprocess, err := fwProv.VerifyCachedBinaries(cacheName, urls) + Expect(err).NotTo(HaveOccurred()) + Expect(reprocess).To(Equal(urls)) + + _, statErr := os.Stat(cacheDir) + Expect(os.IsNotExist(statErr)).To(BeTrue()) + }) + }) + + Context("with an existing metadata file", func() { + It("should return missing URLs if some cached files are absent", func() { + Expect(os.MkdirAll(cacheDir, 0755)).To(Succeed()) + + metaFile := filepath.Join(cacheDir, "metadata.json") + metaData := `{ + "http://example.com/fwA.zip": ["` + filepath.Join(cacheDir, "fwA.bin") + `"], + "http://example.com/fwB.zip": ["` + filepath.Join(cacheDir, "fwB.bin") + `"] +}` + Expect(os.WriteFile(metaFile, []byte(metaData), 0644)).To(Succeed()) + + // Create only fwA.bin + Expect(os.WriteFile(filepath.Join(cacheDir, "fwA.bin"), []byte("dummy content"), 0644)).To(Succeed()) + + fwUtilsMock.On("CleanupDirectory", mock.AnythingOfType("string"), mock.AnythingOfType("map[string]struct {}")). + Return(nil).Once() + + reprocess, err := fwProv.VerifyCachedBinaries(cacheName, urls) + Expect(err).NotTo(HaveOccurred()) + + // fwB.bin is missing, so fwB.zip is reprocessed + Expect(reprocess).To(ConsistOf("http://example.com/fwB.zip")) + }) + + It("should call clean up for unaccounted files", func() { + Expect(os.MkdirAll(cacheDir, 0755)).To(Succeed()) + + cachedFilePath := filepath.Join(cacheDir, "fwB.bin") + + metaFile := filepath.Join(cacheDir, "metadata.json") + metaData := `{ + "http://example.com/fwB.zip": ["` + cachedFilePath + `"]}` + Expect(os.WriteFile(metaFile, []byte(metaData), 0644)).To(Succeed()) + Expect(os.WriteFile(cachedFilePath, []byte("dummy content"), 0644)).To(Succeed()) + + fwUtilsMock.On("CleanupDirectory", mock.AnythingOfType("string"), map[string]struct{}{cachedFilePath: {}}). + Return(nil).Once() + + reprocess, err := fwProv.VerifyCachedBinaries(cacheName, urls) + Expect(err).NotTo(HaveOccurred()) + + Expect(reprocess).To(ConsistOf("http://example.com/fwA.zip")) + }) + }) + }) + + Describe("DownloadAndUnzipFirmwareArchives", func() { + var ( + downloadUrls []string + cleanupArchives bool + metadataPath string + ) + + BeforeEach(func() { + downloadUrls = []string{"http://example.com/fwA.zip"} + cleanupArchives = false + metadataPath = filepath.Join(cacheDir, "metadata.json") + }) + + readMetadata := func() map[string][]string { + data, err := os.ReadFile(metadataPath) + Expect(err).NotTo(HaveOccurred()) + + meta := map[string][]string{} + Expect(json.Unmarshal(data, &meta)).To(Succeed()) + return meta + } + + It("should create the cache directory if missing, then download/unzip each archive", func() { + fwUtilsMock.On("DownloadFile", downloadUrls[0], mock.AnythingOfType("string")). + Return(nil).Once() + extractedFile := filepath.Join(cacheDir, "fwA.bin") + fwUtilsMock. + On("UnzipFiles", mock.AnythingOfType("string"), cacheDir). + Return([]string{extractedFile}, nil).Once() + + Expect( + fwProv.DownloadAndUnzipFirmwareArchives(cacheName, downloadUrls, cleanupArchives), + ).To(Succeed()) + + info, statErr := os.Stat(cacheDir) + Expect(statErr).NotTo(HaveOccurred()) + Expect(info.IsDir()).To(BeTrue()) + + Expect(metadataPath).To(BeARegularFile()) + + metadata := readMetadata() + + Expect(metadata).To(HaveKey(downloadUrls[0])) + Expect(metadata[downloadUrls[0]]).To(ConsistOf(extractedFile)) + }) + + Context("when a download fails", func() { + It("should return an error immediately", func() { + fwUtilsMock.On("DownloadFile", downloadUrls[0], mock.AnythingOfType("string")). + Return(errors.New("download error")).Once() + + err := fwProv.DownloadAndUnzipFirmwareArchives(cacheName, downloadUrls, cleanupArchives) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("download error")) + }) + }) + + Context("when unzip fails", func() { + It("should return an error immediately", func() { + fwUtilsMock.On("DownloadFile", downloadUrls[0], mock.AnythingOfType("string")). + Return(nil).Once() + + fwUtilsMock.On("UnzipFiles", mock.AnythingOfType("string"), mock.AnythingOfType("string")). + Return(nil, errors.New("unzip error")).Once() + + err := fwProv.DownloadAndUnzipFirmwareArchives(cacheName, downloadUrls, cleanupArchives) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unzip error")) + }) + }) + + Context("when cleanupArchives is true", func() { + It("should remove the downloaded archive after unzipping", func() { + cleanupArchives = true + fwUtilsMock.On("DownloadFile", downloadUrls[0], mock.AnythingOfType("string")). + Return(nil).Once() + fwUtilsMock.On("UnzipFiles", mock.AnythingOfType("string"), mock.AnythingOfType("string")). + Return([]string{"fwA.bin"}, nil).Once() + + Expect(os.MkdirAll(cacheDir, 0755)).To(Succeed()) + fileToBeDeleted := path.Join(cacheDir, "fwA.zip") + Expect(os.WriteFile(fileToBeDeleted, []byte("dummy content"), 0644)).To(Succeed()) + + err := fwProv.DownloadAndUnzipFirmwareArchives(cacheName, downloadUrls, cleanupArchives) + Expect(err).NotTo(HaveOccurred()) + _, err = os.Stat(fileToBeDeleted) + Expect(err).To(MatchError(os.ErrNotExist)) + }) + }) + }) + + Describe("AddFirmwareBinariesToCacheByMetadata", func() { + BeforeEach(func() { + Expect(os.MkdirAll(cacheDir, 0755)).To(Succeed()) + }) + + It("should move each .bin file to the correct version/PSID subdirectory", func() { + binFileA := path.Join(cacheDir, "fwA.bin") + Expect(os.WriteFile(binFileA, []byte("dummy content"), 0644)).To(Succeed()) + + Expect(os.MkdirAll(cacheDir, 0755)).To(Succeed()) + binFileB := path.Join(cacheDir, "fwB.bin") + Expect(os.WriteFile(binFileB, []byte("dummy content"), 0644)).To(Succeed()) + + fwUtilsMock.On("GetFirmwareVersionAndPSID", binFileA). + Return("1.2.3", "PSID123", nil).Once() + fwUtilsMock.On("GetFirmwareVersionAndPSID", binFileB). + Return("3.2.1", "PSID321", nil).Once() + + err := fwProv.AddFirmwareBinariesToCacheByMetadata(cacheName) + Expect(err).NotTo(HaveOccurred()) + + _, err = os.Stat(binFileA) + Expect(err).To(MatchError(os.ErrNotExist)) + + _, err = os.Stat(binFileB) + Expect(err).To(MatchError(os.ErrNotExist)) + + _, err = os.Stat(path.Join(cacheDir, "1.2.3", "PSID123", "fwA.bin")) + Expect(err).NotTo(HaveOccurred()) + _, err = os.Stat(path.Join(cacheDir, "3.2.1", "PSID321", "fwB.bin")) + Expect(err).NotTo(HaveOccurred()) + }) + + Context("when GetFirmwareVersionAndPSID returns an error", func() { + It("should fail immediately", func() { + binFileA := path.Join(cacheDir, "fwA.bin") + Expect(os.WriteFile(binFileA, []byte("dummy content"), 0644)).To(Succeed()) + + fwUtilsMock.On("GetFirmwareVersionAndPSID", mock.AnythingOfType("string")). + Return("", "", errors.New("parse error")).Once() + + err := fwProv.AddFirmwareBinariesToCacheByMetadata(cacheName) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("parse error")) + }) + }) + }) + + Describe("ValidateCache", func() { + BeforeEach(func() { + Expect(os.MkdirAll(cacheDir, 0755)).To(Succeed()) + }) + + It("should return a map of firmware versions to PSIDs if everything is valid", func() { + Expect(os.MkdirAll(path.Join(cacheDir, "1.2.3", "PSID123"), 0755)).To(Succeed()) + Expect(os.MkdirAll(path.Join(cacheDir, "1.2.3", "PSID321"), 0755)).To(Succeed()) + binFileA := path.Join(cacheDir, "1.2.3", "PSID123", "fwA.bin") + Expect(os.WriteFile(binFileA, []byte("dummy content"), 0644)).To(Succeed()) + binFileB := path.Join(cacheDir, "1.2.3", "PSID321", "fwB.bin") + Expect(os.WriteFile(binFileB, []byte("dummy content"), 0644)).To(Succeed()) + + versions, err := fwProv.ValidateCache(cacheName) + Expect(err).NotTo(HaveOccurred()) + Expect(versions).To(Equal(map[string][]string{"1.2.3": {"PSID123", "PSID321"}})) + }) + + Context("when a PSID directory is empty", func() { + It("should return an error", func() { + Expect(os.MkdirAll(path.Join(cacheDir, "1.2.3", "PSID123"), 0755)).To(Succeed()) + + _, err := fwProv.ValidateCache(cacheName) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("cache directory is empty")) + }) + }) + + Context("when multiple versions share the same PSID", func() { + It("should return an error about multiple binaries for the same PSID", func() { + psidDir1 := path.Join(cacheDir, "1.2.3", "PSID123") + psidDir2 := path.Join(cacheDir, "3.2.1", "PSID123") + Expect(os.MkdirAll(psidDir1, 0755)).To(Succeed()) + Expect(os.MkdirAll(psidDir2, 0755)).To(Succeed()) + Expect(os.WriteFile(path.Join(psidDir1, "fileA.bin"), []byte("dummy content"), 0644)).To(Succeed()) + Expect(os.WriteFile(path.Join(psidDir2, "fileB.bin"), []byte("dummy content"), 0644)).To(Succeed()) + + _, err := fwProv.ValidateCache(cacheName) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("multiple firmware binary files for the same PSID")) + }) + }) + + Context("when multiple bin files are in the same PSID dir", func() { + It("should return an error about multiple binaries in the directory", func() { + psidDir := path.Join(cacheDir, "1.2.3", "PSID123") + Expect(os.MkdirAll(psidDir, 0755)).To(Succeed()) + Expect(os.WriteFile(path.Join(psidDir, "fileA.bin"), []byte("dummy content"), 0644)).To(Succeed()) + Expect(os.WriteFile(path.Join(psidDir, "fileB.bin"), []byte("dummy content"), 0644)).To(Succeed()) + + _, err := fwProv.ValidateCache(cacheName) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("multiple firmware binary files in the same directory")) + }) + }) + + Context("when no files are in the PSID dir", func() { + It("should return an error about no files in the directory", func() { + psidDir := path.Join(cacheDir, "1.2.3", "PSID123") + Expect(os.MkdirAll(psidDir, 0755)).To(Succeed()) + + _, err := fwProv.ValidateCache(cacheName) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("cache directory is empty. Expected firmware binary file.")) + }) + }) + + Context("when no binary files are in the PSID dir", func() { + It("should return an error about no binaries in the directory", func() { + psidDir := path.Join(cacheDir, "1.2.3", "PSID123") + Expect(os.MkdirAll(psidDir, 0755)).To(Succeed()) + Expect(os.WriteFile(path.Join(psidDir, "fileA.nobin"), []byte("dummy content"), 0644)).To(Succeed()) + + _, err := fwProv.ValidateCache(cacheName) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no firmware binary files in the PSID directory")) + }) + }) + }) +}) diff --git a/pkg/firmware/suite_test.go b/pkg/firmware/suite_test.go new file mode 100644 index 0000000..f589889 --- /dev/null +++ b/pkg/firmware/suite_test.go @@ -0,0 +1,30 @@ +/* +2025 NVIDIA CORPORATION & AFFILIATES +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package firmware + +import ( + "testing" + + "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" +) + +func TestFirmware(t *testing.T) { + // Register Gomega with Ginkgo + gomega.RegisterFailHandler(ginkgo.Fail) + // Run the test suite + ginkgo.RunSpecs(t, "Firmware Suite") +} diff --git a/pkg/firmware/utils.go b/pkg/firmware/utils.go new file mode 100644 index 0000000..10ffe23 --- /dev/null +++ b/pkg/firmware/utils.go @@ -0,0 +1,250 @@ +/* +2025 NVIDIA CORPORATION & AFFILIATES +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//nolint:errcheck +package firmware + +import ( + "archive/zip" + "bufio" + "fmt" + "io" + "io/fs" + "net/http" + "os" + "path/filepath" + "strings" + + execUtils "k8s.io/utils/exec" + "sigs.k8s.io/controller-runtime/pkg/log" + + "github.com/Mellanox/nic-configuration-operator/pkg/consts" +) + +type ProvisioningUtils interface { + // DownloadFile downloads the file under url and places it locally under destPath + DownloadFile(url, destPath string) error + // UnzipFiles extract files from the zip archive to destDir + // Returns a list of extracted files, error if occurred + UnzipFiles(zipPath, destDir string) ([]string, error) + // GetFirmwareVersionAndPSID retrieves the version and PSID from the firmware binary + GetFirmwareVersionAndPSID(firmwareBinaryPath string) (string, string, error) + // CleanupDirectory deletes any file inside a root directory except for allowedSet. Empty directories are cleaned up as well at the end + CleanupDirectory(root string, allowedSet map[string]struct{}) error +} + +type utils struct { + execInterface execUtils.Interface +} + +// DownloadFile downloads the file under url and places it locally under destPath +func (u utils) DownloadFile(url, destPath string) error { + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("could not download file: %w", err) + } + defer resp.Body.Close() + + out, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("could not create file: %w", err) + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + if err != nil { + return fmt.Errorf("error saving file: %w", err) + } + + return nil +} + +// UnzipFiles extract files from the zip archive to destDir +// Returns a list of extracted files, error if occurred +func (u utils) UnzipFiles(zipPath, destDir string) ([]string, error) { + extractedFiles := []string{} + + zipReader, err := zip.OpenReader(zipPath) + if err != nil { + return nil, fmt.Errorf("could not open zip file: %w", err) + } + defer zipReader.Close() + + for _, file := range zipReader.File { + fPath := filepath.Join(destDir, file.Name) + + if !strings.HasPrefix(fPath, filepath.Clean(destDir)+string(os.PathSeparator)) { + return nil, fmt.Errorf("illegal file path: %s", fPath) + } + + if file.FileInfo().IsDir() { + if err := os.MkdirAll(fPath, file.Mode()); err != nil { + return nil, fmt.Errorf("error creating directory: %w", err) + } + continue + } + + if err := os.MkdirAll(filepath.Dir(fPath), 0755); err != nil { + return nil, fmt.Errorf("error creating parent directories: %w", err) + } + + if err := extractFile(file, fPath); err != nil { + return nil, err + } + + extractedFiles = append(extractedFiles, fPath) + } + + return extractedFiles, nil +} + +// extractFile copies the contents of a single file from the ZIP archive +// to a local file, preserving its mode (permissions). +func extractFile(zf *zip.File, destPath string) error { + srcFile, err := zf.Open() + if err != nil { + return err + } + defer srcFile.Close() + + outFile, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, zf.Mode()) + if err != nil { + return err + } + defer outFile.Close() + + if _, err = io.Copy(outFile, srcFile); err != nil { + return err + } + + return nil +} + +// GetFirmwareVersionAndPSID retrieves the version and PSID from the firmware binary +func (u utils) GetFirmwareVersionAndPSID(firmwareBinaryPath string) (string, string, error) { + log.Log.Info("HostUtils.GetFirmwareVersionAndPSID()", "firmwareBinaryPath", firmwareBinaryPath) + cmd := u.execInterface.Command("mstflint", "-i", firmwareBinaryPath, "q") + output, err := cmd.Output() + if err != nil { + log.Log.Error(err, "GetFirmwareVersionAndPSID(): Failed to run mstflint") + return "", "", err + } + + // Parse the output for FW version and PSID + scanner := bufio.NewScanner(strings.NewReader(string(output))) + var firmwareVersion, PSID string + + for scanner.Scan() { + line := strings.ToLower(scanner.Text()) + + if strings.HasPrefix(line, consts.FirmwareVersionPrefix) { + firmwareVersion = strings.TrimSpace(strings.TrimPrefix(line, consts.FirmwareVersionPrefix)) + } + if strings.HasPrefix(line, consts.PSIDPrefix) { + PSID = strings.TrimSpace(strings.TrimPrefix(line, consts.PSIDPrefix)) + } + } + + if err := scanner.Err(); err != nil { + log.Log.Error(err, "GetFirmwareVersionAndPSID(): Error reading mstflint output") + return "", "", err + } + + if firmwareVersion == "" || PSID == "" { + return "", "", fmt.Errorf("GetFirmwareVersionAndPSID(): firmware version (%v) or PSID (%v) is empty", firmwareVersion, PSID) + } + + return firmwareVersion, PSID, nil +} + +// CleanupDirectory deletes any file inside a root directory except for allowedSet. Empty directories are cleaned up as well at the end +func (u utils) CleanupDirectory(root string, allowedSet map[string]struct{}) error { + log.Log.Info("Cleaning up cache directory", "cacheDir", root) + + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + + if d.IsDir() { + return nil + } + + abs, err := filepath.Abs(path) + if err != nil { + return fmt.Errorf("failed to get absolute path of %q: %w", path, err) + } + + if _, ok := allowedSet[abs]; !ok { + if err := os.Remove(abs); err != nil { + return fmt.Errorf("failed to remove %q: %w", abs, err) + } + log.Log.V(2).Info("deleted unaccounted file from cache dir", "path", abs, "cacheDir", root) + } + return nil + }) + + if err != nil { + err := fmt.Errorf("failed to walk directory for file removal: %w", err) + log.Log.Error(err, "failed to cleanup the cache directory", "cacheDir", root) + } + + // After the unaccounted for files were deleted, clean up empty directories + if err := u.removeEmptyDirs(root); err != nil { + err := fmt.Errorf("failed to remove empty directories: %w", err) + log.Log.Error(err, "failed to cleanup the cache directory", "cacheDir", root) + } + + return nil +} + +// removeEmptyDirs recursively removes directories that are empty. +// It does a post-order traversal: children first, then the parent. +func (u utils) removeEmptyDirs(dir string) error { + entries, err := os.ReadDir(dir) + if err != nil { + return nil + } + + // Recurse into subdirectories first + for _, entry := range entries { + if entry.IsDir() { + subDir := filepath.Join(dir, entry.Name()) + if err := u.removeEmptyDirs(subDir); err != nil { + return err + } + } + } + + // After processing children, check if 'dir' is now empty + // (Re-read directory to see if it has become empty) + entries, err = os.ReadDir(dir) + if err != nil { + return nil // If we can't re-read, just skip + } + if len(entries) == 0 && dir != "/" { + // Avoid removing root if you didn't intend to + if err := os.Remove(dir); err != nil { + return fmt.Errorf("failed to remove empty directory %q: %w", dir, err) + } + log.Log.V(2).Info("deleted empty directory", "dir", dir) + } + + return nil +} + +func newFirmwareUtils() ProvisioningUtils { + return utils{execInterface: execUtils.New()} +} diff --git a/pkg/firmware/utils_test.go b/pkg/firmware/utils_test.go new file mode 100644 index 0000000..679c32c --- /dev/null +++ b/pkg/firmware/utils_test.go @@ -0,0 +1,220 @@ +/* +2025 NVIDIA CORPORATION & AFFILIATES +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//nolint:errcheck +package firmware + +import ( + "archive/zip" + "bytes" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sort" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("utils", func() { + var ( + testedUtils ProvisioningUtils + tmpDir string + ) + + BeforeEach(func() { + var err error + tmpDir, err = os.MkdirTemp("", "utils-test-*") + Expect(err).NotTo(HaveOccurred()) + + testedUtils = newTestFirmwareUtils() + }) + + AfterEach(func() { + os.RemoveAll(tmpDir) + }) + + Describe("DownloadFile", func() { + It("should download a file from an HTTP server to the specified path", func() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("Hello from test server")) + })) + defer server.Close() + + destPath := filepath.Join(tmpDir, "downloaded.txt") + + err := testedUtils.DownloadFile(server.URL, destPath) + Expect(err).NotTo(HaveOccurred()) + + data, err := os.ReadFile(destPath) + Expect(err).NotTo(HaveOccurred()) + Expect(string(data)).To(Equal("Hello from test server")) + }) + + It("should fail if the URL is invalid", func() { + destPath := filepath.Join(tmpDir, "invalid.txt") + err := testedUtils.DownloadFile("http://invalid domain", destPath) + Expect(err).To(HaveOccurred()) + }) + }) + + Describe("UnzipFiles", func() { + var zipPath string + + BeforeEach(func() { + zipPath = filepath.Join(tmpDir, "test.zip") + createTestZip(zipPath, map[string]string{ + "fileA.txt": "Content A", + "folderB/fileB.txt": "Content B", + }) + }) + + It("should extract a zip archive to the destination directory", func() { + destDir := filepath.Join(tmpDir, "extracted") + + extracted, err := testedUtils.UnzipFiles(zipPath, destDir) + Expect(err).NotTo(HaveOccurred()) + + Expect(extracted).To(HaveLen(2)) + + // Sort for consistent comparison + sort.Strings(extracted) + + Expect(extracted[0]).To(ContainSubstring("fileA.txt")) + Expect(extracted[1]).To(ContainSubstring("fileB.txt")) + + dataA, err := os.ReadFile(filepath.Join(destDir, "fileA.txt")) + Expect(err).NotTo(HaveOccurred()) + Expect(string(dataA)).To(Equal("Content A")) + + dataB, err := os.ReadFile(filepath.Join(destDir, "folderB", "fileB.txt")) + Expect(err).NotTo(HaveOccurred()) + Expect(string(dataB)).To(Equal("Content B")) + }) + + It("should fail if the zip file does not exist", func() { + err := os.Remove(zipPath) + Expect(err).NotTo(HaveOccurred()) + + _, err = testedUtils.UnzipFiles(zipPath, tmpDir) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("could not open zip file")) + }) + + It("should fail if the zip contains files escaping the destDir", func() { + // Add a malicious file path: "../evil.txt" + // Recreate the zip with an illegal path + _ = os.Remove(zipPath) + zipPath = filepath.Join(tmpDir, "illegal.zip") + createTestZip(zipPath, map[string]string{ + "../evil.txt": "Evil data", + }) + + _, err := testedUtils.UnzipFiles(zipPath, filepath.Join(tmpDir, "extracted")) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("illegal file path")) + }) + }) + + Describe("GetFirmwareVersionAndPSID", func() { + It("should currently return empty strings (stub)", func() { + binPath := filepath.Join(tmpDir, "fw.bin") + err := os.WriteFile(binPath, []byte("fake firmware data"), 0644) + Expect(err).NotTo(HaveOccurred()) + + version, psid, err := testedUtils.GetFirmwareVersionAndPSID(binPath) + Expect(err).NotTo(HaveOccurred()) + Expect(version).To(BeEmpty()) + Expect(psid).To(BeEmpty()) + }) + }) + + Describe("CleanupDirectory", func() { + var rootDir string + + BeforeEach(func() { + // Directory structure + // rootDir/ + // keepA.bin + // removeMe.txt + // subdir/ + // keepB.bin + // removeMe2.txt + // emptySub/ + // + rootDir = filepath.Join(tmpDir, "cache-dir") + Expect(os.MkdirAll(filepath.Join(rootDir, "subdir", "emptySub"), 0755)).To(Succeed()) + + Expect(os.WriteFile(filepath.Join(rootDir, "keepA.bin"), []byte("A"), 0644)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(rootDir, "removeMe.txt"), []byte("remove1"), 0644)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(rootDir, "subdir", "keepB.bin"), []byte("B"), 0644)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(rootDir, "subdir", "removeMe2.txt"), []byte("remove2"), 0644)).To(Succeed()) + }) + + It("should remove everything not in the allowed set and clean up empty dirs", func() { + keepA := mustAbs(filepath.Join(rootDir, "keepA.bin")) + keepB := mustAbs(filepath.Join(rootDir, "subdir", "keepB.bin")) + allowed := map[string]struct{}{ + keepA: {}, + keepB: {}, + } + + err := testedUtils.CleanupDirectory(rootDir, allowed) + Expect(err).NotTo(HaveOccurred()) + + _, err = os.Stat(filepath.Join(rootDir, "removeMe.txt")) + Expect(os.IsNotExist(err)).To(BeTrue()) + + _, err = os.Stat(filepath.Join(rootDir, "subdir", "removeMe2.txt")) + Expect(os.IsNotExist(err)).To(BeTrue()) + + _, err = os.Stat(filepath.Join(rootDir, "subdir", "emptySub")) + Expect(os.IsNotExist(err)).To(BeTrue()) + + Expect(filepath.Join(rootDir, "keepA.bin")).To(BeARegularFile()) + Expect(filepath.Join(rootDir, "subdir", "keepB.bin")).To(BeARegularFile()) + }) + }) +}) + +func newTestFirmwareUtils() ProvisioningUtils { + return &utils{} +} + +func createTestZip(zipPath string, files map[string]string) { + f, err := os.Create(zipPath) + Expect(err).NotTo(HaveOccurred()) + defer f.Close() + + w := zip.NewWriter(f) + defer w.Close() + + for name, content := range files { + fw, err := w.Create(name) + Expect(err).NotTo(HaveOccurred()) + _, err = io.Copy(fw, bytes.NewReader([]byte(content))) + Expect(err).NotTo(HaveOccurred()) + } + Expect(w.Close()).To(Succeed()) +} + +// mustAbs returns the absolute version of a path, or panics on error. +func mustAbs(p string) string { + abs, err := filepath.Abs(p) + Expect(err).NotTo(HaveOccurred()) + return abs +}