Skip to content

Commit

Permalink
fix tests, add machine state to index (#811)
Browse files Browse the repository at this point in the history
  • Loading branch information
luke-lombardi authored Dec 26, 2024
1 parent c789b7e commit 396b967
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
7 changes: 7 additions & 0 deletions pkg/repository/provider_redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,13 @@ func (r *ProviderRedisRepository) RegisterMachine(providerName, poolName, machin
machineInfo.LastKeepalive = fmt.Sprintf("%d", time.Now().UTC().Unix())
machineInfo.PoolName = newMachineInfo.PoolName
machineInfo.MachineId = newMachineInfo.MachineId

// Add machine to index
machineIndexKey := common.RedisKeys.ProviderMachineIndex(providerName, poolName)
err = r.rdb.SAdd(context.TODO(), machineIndexKey, stateKey).Err()
if err != nil {
return fmt.Errorf("failed to add machine state key to index <%v>: %w", machineIndexKey, err)
}
}

machineInfo.HostName = newMachineInfo.HostName
Expand Down
18 changes: 15 additions & 3 deletions pkg/scheduler/pool_sizing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,20 @@ func TestOccupyAvailableMachines(t *testing.T) {
})
assert.NoError(t, err)

poolConfig := &types.WorkerPoolConfig{
Provider: &types.ProviderLambdaLabs,
GPUType: "A10G",
Mode: types.PoolModeExternal,
}

err = providerRepo.RegisterMachine(string(types.ProviderLambdaLabs), lambdaPoolName, "machine1", &types.ProviderMachineState{
Gpu: "A10G",
GpuCount: 1,
AutoConsolidate: false,
Cpu: 30000,
Memory: 16000,
Status: types.MachineStatusRegistered,
})
}, poolConfig)
assert.NoError(t, err)

err = providerRepo.AddMachine(string(types.ProviderLambdaLabs), lambdaPoolName, "machine2", &types.ProviderMachineState{
Expand All @@ -282,7 +288,7 @@ func TestOccupyAvailableMachines(t *testing.T) {
Cpu: 30000,
Memory: 16000,
Status: types.MachineStatusRegistered,
})
}, poolConfig)
assert.NoError(t, err)
assert.NoError(t, err)

Expand Down Expand Up @@ -392,6 +398,12 @@ func TestOccupyAvailableMachinesConcurrency(t *testing.T) {
},
}

poolConfig := &types.WorkerPoolConfig{
Provider: &types.ProviderGeneric,
GPUType: "A10G",
Mode: types.PoolModeExternal,
}

maxMachinesAndWorkers := 100
for i := 0; i < maxMachinesAndWorkers; i++ {
machineName := fmt.Sprintf("machine-%d", i)
Expand All @@ -406,7 +418,7 @@ func TestOccupyAvailableMachinesConcurrency(t *testing.T) {
err = providerRepo.AddMachine(string(types.ProviderGeneric), poolName, machineName, machineState)
assert.NoError(t, err)

err = providerRepo.RegisterMachine(string(types.ProviderGeneric), poolName, machineName, machineState)
err = providerRepo.RegisterMachine(string(types.ProviderGeneric), poolName, machineName, machineState, poolConfig)
assert.NoError(t, err)
}

Expand Down

0 comments on commit 396b967

Please sign in to comment.