diff --git a/pkg/repository/provider_redis.go b/pkg/repository/provider_redis.go index 9dca4df20..6797fbb5c 100644 --- a/pkg/repository/provider_redis.go +++ b/pkg/repository/provider_redis.go @@ -313,6 +313,13 @@ func (r *ProviderRedisRepository) RegisterMachine(providerName, poolName, machin machineInfo.LastKeepalive = fmt.Sprintf("%d", time.Now().UTC().Unix()) machineInfo.PoolName = newMachineInfo.PoolName machineInfo.MachineId = newMachineInfo.MachineId + + // Add machine to index + machineIndexKey := common.RedisKeys.ProviderMachineIndex(providerName, poolName) + err = r.rdb.SAdd(context.TODO(), machineIndexKey, stateKey).Err() + if err != nil { + return fmt.Errorf("failed to add machine state key to index <%v>: %w", machineIndexKey, err) + } } machineInfo.HostName = newMachineInfo.HostName diff --git a/pkg/scheduler/pool_sizing_test.go b/pkg/scheduler/pool_sizing_test.go index e01303f32..776544898 100644 --- a/pkg/scheduler/pool_sizing_test.go +++ b/pkg/scheduler/pool_sizing_test.go @@ -255,6 +255,12 @@ func TestOccupyAvailableMachines(t *testing.T) { }) assert.NoError(t, err) + poolConfig := &types.WorkerPoolConfig{ + Provider: &types.ProviderLambdaLabs, + GPUType: "A10G", + Mode: types.PoolModeExternal, + } + err = providerRepo.RegisterMachine(string(types.ProviderLambdaLabs), lambdaPoolName, "machine1", &types.ProviderMachineState{ Gpu: "A10G", GpuCount: 1, @@ -262,7 +268,7 @@ func TestOccupyAvailableMachines(t *testing.T) { Cpu: 30000, Memory: 16000, Status: types.MachineStatusRegistered, - }) + }, poolConfig) assert.NoError(t, err) err = providerRepo.AddMachine(string(types.ProviderLambdaLabs), lambdaPoolName, "machine2", &types.ProviderMachineState{ @@ -282,7 +288,7 @@ func TestOccupyAvailableMachines(t *testing.T) { Cpu: 30000, Memory: 16000, Status: types.MachineStatusRegistered, - }) + }, poolConfig) assert.NoError(t, err) assert.NoError(t, err) @@ -392,6 +398,12 @@ func TestOccupyAvailableMachinesConcurrency(t *testing.T) { }, } + poolConfig := &types.WorkerPoolConfig{ + Provider: &types.ProviderGeneric, + GPUType: "A10G", + Mode: types.PoolModeExternal, + } + maxMachinesAndWorkers := 100 for i := 0; i < maxMachinesAndWorkers; i++ { machineName := fmt.Sprintf("machine-%d", i) @@ -406,7 +418,7 @@ func TestOccupyAvailableMachinesConcurrency(t *testing.T) { err = providerRepo.AddMachine(string(types.ProviderGeneric), poolName, machineName, machineState) assert.NoError(t, err) - err = providerRepo.RegisterMachine(string(types.ProviderGeneric), poolName, machineName, machineState) + err = providerRepo.RegisterMachine(string(types.ProviderGeneric), poolName, machineName, machineState, poolConfig) assert.NoError(t, err) }