Skip to content

Commit

Permalink
Parallelize tests to allow us to reenable testing in pkg/...
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Klues <[email protected]>
  • Loading branch information
klueska committed Apr 23, 2024
1 parent c7a308b commit df777d4
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ check-vendor: vendor

COVERAGE_FILE := coverage.out
test: build cmds
go test -coverprofile=$(COVERAGE_FILE) $(MODULE)/cmd/... $(MODULE)/internal/... $(MODULE)/api/...
go test -v -coverprofile=$(COVERAGE_FILE) $(MODULE)/cmd/... $(MODULE)/internal/... $(MODULE)/api/... $(MODULE)/pkg/...

coverage: test
cat $(COVERAGE_FILE) | grep -v "_mock.go" > $(COVERAGE_FILE).no-mocks
Expand Down
29 changes: 19 additions & 10 deletions pkg/mig/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ func EnableMigMode(manager Manager, gpu int) (nvml.Return, nvml.Return) {
}

func TestGetSetMigConfig(t *testing.T) {
nvmlLib := dgxa100.New()
manager := NewMockLunaServerMigConfigManager()

numGPUs, ret := nvmlLib.DeviceGetCount()
require.NotNil(t, ret, "Unexpected nil return from DeviceGetCount")
require.Equal(t, ret, nvml.SUCCESS, "Unexpected return value from DeviceGetCount")

mcg := NewA100_SXM4_40GB_MigConfigGroup()

type testCase struct {
Expand All @@ -69,8 +62,18 @@ func TestGetSetMigConfig(t *testing.T) {
return testCases
}()

for _, tc := range testCases {
for i := range testCases {
tc := testCases[i] // to allow us to run parallelly
t.Run(tc.description, func(t *testing.T) {
t.Parallel()

nvmlLib := dgxa100.New()
manager := NewMockLunaServerMigConfigManager()

numGPUs, ret := nvmlLib.DeviceGetCount()
require.NotNil(t, ret, "Unexpected nil return from DeviceGetCount")
require.Equal(t, ret, nvml.SUCCESS, "Unexpected return value from DeviceGetCount")

for i := 0; i < numGPUs; i++ {
r1, r2 := EnableMigMode(manager, i)
require.Equal(t, nvml.SUCCESS, r1)
Expand Down Expand Up @@ -106,8 +109,11 @@ func TestClearMigConfig(t *testing.T) {
return testCases
}()

for _, tc := range testCases {
for i range testCases {

Check failure on line 112 in pkg/mig/config/config_test.go

View workflow job for this annotation

GitHub Actions / check

syntax error: unexpected range, expected { (typecheck)

Check failure on line 112 in pkg/mig/config/config_test.go

View workflow job for this annotation

GitHub Actions / check

expected '{', found 'range' (typecheck)

Check failure on line 112 in pkg/mig/config/config_test.go

View workflow job for this annotation

GitHub Actions / Unit test

expected '{', found 'range'
tc := testCases[i] // to allow us to run parallelly

Check failure on line 113 in pkg/mig/config/config_test.go

View workflow job for this annotation

GitHub Actions / check

missing ',' in composite literal (typecheck)
t.Run(tc.description, func(t *testing.T) {
t.Parallel()

manager := NewMockLunaServerMigConfigManager()

r1, r2 := EnableMigMode(manager, 0)
Expand Down Expand Up @@ -173,8 +179,11 @@ func TestIteratePermutationsUntilSuccess(t *testing.T) {
return testCases
}()

for _, tc := range testCases {
for i := range testCases {
tc := testCases[i] // to allow us to run parallelly
t.Run(tc.description, func(t *testing.T) {
t.Parallel()

iteration := 0
err := iteratePermutationsUntilSuccess(tc.config, func(perm []*types.MigProfile) error {
iteration++
Expand Down
23 changes: 17 additions & 6 deletions pkg/mig/config/known_configs.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package config
import (
"fmt"
"sort"
"sync"

"github.com/NVIDIA/mig-parted/pkg/types"
)
Expand Down Expand Up @@ -73,8 +74,9 @@ func NewA100_SXM4_40GB_MigConfigGroup() types.MigConfigGroup {
}

func (m *a100_sxm4_40gb_MigConfigGroup) init() {
var mutex sync.RWMutex
configs := make(map[string]types.MigConfig)
m.iterateDeviceTypes(func(mps []*types.MigProfile) LoopControl {
m.parallelIterateDeviceTypes(func(mps []*types.MigProfile) LoopControl {
cis := 0
cis_per_gi := make(map[int]int)
mes_per_gi := make(map[int]int)
Expand Down Expand Up @@ -107,20 +109,24 @@ func (m *a100_sxm4_40gb_MigConfigGroup) init() {
return mps[i].String() < mps[j].String()
})

mutex.Lock()
str := fmt.Sprintf("%v", mps)
if _, exists := configs[str]; !exists {
configs[str] = types.NewMigConfig(mps)
}
mutex.Unlock()

if cis < 7 {
return Continue
}

return Break
})
mutex.RLock()
for _, v := range configs {
m.Configs = append(m.Configs, v)
}
mutex.RUnlock()
}

func (m *a100_sxm4_40gb_MigConfigGroup) GetDeviceTypes() []*types.MigProfile {
Expand All @@ -143,7 +149,7 @@ func (m *a100_sxm4_40gb_MigConfigGroup) GetDeviceTypes() []*types.MigProfile {
}
}

func (m *a100_sxm4_40gb_MigConfigGroup) iterateDeviceTypes(f func([]*types.MigProfile) LoopControl) {
func (m *a100_sxm4_40gb_MigConfigGroup) parallelIterateDeviceTypes(f func([]*types.MigProfile) LoopControl) {
maxDevices := types.MigConfig{
mig_1c_1g_5gb: 7,
mig_1c_1g_5gb_me: 1,
Expand All @@ -162,20 +168,25 @@ func (m *a100_sxm4_40gb_MigConfigGroup) iterateDeviceTypes(f func([]*types.MigPr
mig_7c_7g_40gb: 1,
}.Flatten()

var iterate func(i int, accum []*types.MigProfile) LoopControl
iterate = func(i int, accum []*types.MigProfile) LoopControl {
var iterate func(i int, accum []*types.MigProfile, wg *sync.WaitGroup) LoopControl
iterate = func(i int, accum []*types.MigProfile, wg *sync.WaitGroup) LoopControl {
defer wg.Done()
accum = append(accum, maxDevices[i])
control := f(accum)
if control == Break {
return Continue
}
for j := i + 1; j < len(maxDevices); j++ {
iterate(j, accum)
wg.Add(1)
go iterate(j, accum, wg)
}
return Continue
}

var wg sync.WaitGroup
for i := 0; i < len(maxDevices); i++ {
iterate(i, []*types.MigProfile{})
wg.Add(1)
go iterate(i, []*types.MigProfile{}, &wg)
}
wg.Wait()
}
4 changes: 3 additions & 1 deletion pkg/mig/config/known_configs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ func TestValidConfiguration(t *testing.T) {
types.SetMockNVdevlib()
configs := GetKnownMigConfigGroups()

for _, tc := range testCases {
for i:= range testCases {
tc := testCases[i] // to allow us to run parallelly
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
err := configs[tc.gpu].AssertValidConfiguration(tc.config)
if tc.valid {
require.Nil(t, err)
Expand Down
17 changes: 10 additions & 7 deletions pkg/mig/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ func newMockMigStateManagerOnLunaServer() *migStateManager {
}

func TestFetchRestore(t *testing.T) {
manager := newMockMigStateManagerOnLunaServer()

numGPUs, ret := manager.nvml.DeviceGetCount()
require.NotNil(t, ret, "Unexpected nil return from DeviceGetCount")
require.Equal(t, ret, nvml.SUCCESS, "Unexpected return value from DeviceGetCount")

mcg := config.NewA100_SXM4_40GB_MigConfigGroup()

type testCase struct {
Expand Down Expand Up @@ -73,8 +67,17 @@ func TestFetchRestore(t *testing.T) {
return testCases
}()

for _, tc := range testCases {
for i := range testCases {
tc := testCases[i] // to allow us to run parallelly
t.Run(tc.description, func(t *testing.T) {
t.Parallel()

manager := newMockMigStateManagerOnLunaServer()

numGPUs, ret := manager.nvml.DeviceGetCount()
require.NotNil(t, ret, "Unexpected nil return from DeviceGetCount")
require.Equal(t, ret, nvml.SUCCESS, "Unexpected return value from DeviceGetCount")

for i := 0; i < numGPUs; i++ {
err := manager.mode.SetMigMode(i, tc.mode)
require.Nil(t, err)
Expand Down

0 comments on commit df777d4

Please sign in to comment.