diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index 65d3bd284..d9141956b 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -6,6 +6,7 @@ type DockerTaskBody struct { Username string `json:"username"` Password string `json:"password"` ImageName string `json:"image_name"` + ShmSize int64 `json:"shm_size"` } type StopBody struct { @@ -30,6 +31,7 @@ func (ra DockerTaskBody) TaskParams() shim.DockerImageConfig { ImageName: ra.ImageName, Username: ra.Username, Password: ra.Password, + ShmSize: ra.ShmSize, } return res } diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index cc394da0d..42ecbe3db 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -185,7 +185,8 @@ func createContainer(ctx context.Context, client docker.APIClient, dockerParams Resources: container.Resources{ DeviceRequests: gpuRequest, }, - Mounts: mounts, + Mounts: mounts, + ShmSize: taskParams.ShmSize, } resp, err := client.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, "") if err != nil { diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 913587e52..e3ff1a58d 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -44,6 +44,7 @@ type DockerImageConfig struct { Username string Password string ImageName string + ShmSize int64 } func (ra DockerImageConfig) EncodeRegistryAuth() (string, error) { diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index d27005ec0..41a5f7002 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -352,12 +352,14 @@ def _process_provisioning_with_shim( username=interpolate(registry_auth.username), password=interpolate(registry_auth.password), image_name=job_spec.image_name, + shm_size=job_spec.requirements.resources.shm_size, ) else: shim_client.submit( username="", password="", image_name=job_spec.image_name, + shm_size=job_spec.requirements.resources.shm_size, ) job_model.status = JobStatus.PULLING diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index f63cd7626..3039f2bce 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -70,6 +70,7 @@ class DockerImageBody(BaseModel): username: str password: str image_name: str + shm_size: int class StopBody(BaseModel): diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 738fab2ae..ad753a4cb 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -7,6 +7,7 @@ import requests.exceptions from dstack._internal.core.models.repos.remote import RemoteRepoCreds +from dstack._internal.core.models.resources import Memory from dstack._internal.core.models.runs import JobSpec, RunSpec from dstack._internal.server.schemas.runner import ( DockerImageBody, @@ -102,9 +103,10 @@ def healthcheck(self, unmask_exeptions: bool = False) -> Optional[HealthcheckRes raise return None - def submit(self, username: str, password: str, image_name: str): + def submit(self, username: str, password: str, image_name: str, shm_size: Optional[Memory]): + _shm_size = int(shm_size * 1024 * 1024 * 1014) if shm_size else 0 post_body = DockerImageBody( - username=username, password=password, image_name=image_name + username=username, password=password, image_name=image_name, shm_size=_shm_size ).dict() resp = requests.post( self._url("/api/submit"), diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index 1ecb4df89..6a8d8199b 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -221,7 +221,10 @@ async def test_provisioning_shim(self, test_db, session: AsyncSession): RunnerTunnelMock.assert_called_once() ShimClientMock.return_value.healthcheck.assert_called_once() ShimClientMock.return_value.submit.assert_called_once_with( - username="", password="", image_name="dstackai/base:py3.11-0.4rc4-cuda-12.1" + username="", + password="", + image_name="dstackai/base:py3.11-0.4rc4-cuda-12.1", + shm_size=None, ) await session.refresh(job) assert job is not None