Skip to content

Commit

Permalink
Pass ASAHI, CUDA, HIP, HSA prefixed env vars to generated files
Browse files Browse the repository at this point in the history
Relates-to: containers#525
Signed-off-by: Arun Babu Neelicattu <[email protected]>
  • Loading branch information
abn committed Dec 28, 2024
1 parent 13e5d4c commit 0aecb51
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 35 deletions.
37 changes: 37 additions & 0 deletions ramalama/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""ramalama common module."""

import glob
import hashlib
import os
import random
Expand Down Expand Up @@ -185,3 +186,39 @@ def engine_version(engine):
# Create manifest list for target with imageid
cmd_args = [engine, "version", "--format", "{{ .Client.Version }}"]
return run_cmd(cmd_args).stdout.decode("utf-8").strip()


def get_gpu():
i = 0
gpu_num = 0
gpu_bytes = 0
for fp in sorted(glob.glob('/sys/bus/pci/devices/*/mem_info_vram_total')):
with open(fp, 'r') as file:
content = int(file.read())
if content > 1073741824 and content > gpu_bytes:
gpu_bytes = content
gpu_num = i

i += 1

if gpu_bytes: # this is the ROCm/AMD case
return "HIP_VISIBLE_DEVICES", gpu_num

if os.path.exists('/etc/os-release'):
with open('/etc/os-release', 'r') as file:
content = file.read()
if "asahi" in content.lower():
return "ASAHI_VISIBLE_DEVICES", 1

return None, None


def get_env_vars():
prefixes = ("ASAHI_", "CUDA_", "HIP_", "HSA_")
env_vars = {k: v for k, v in os.environ.items() if k.startswith(prefixes)}

gpu_type, gpu_num = get_gpu()
if gpu_type not in env_vars and gpu_type in {"HIP_VISIBLE_DEVICES", "ASAHI_VISIBLE_DEVICES"}:
env_vars[gpu_type] = str(gpu_num)

return env_vars
21 changes: 20 additions & 1 deletion ramalama/kube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


from ramalama.version import version
from ramalama.common import genname, mnt_dir
from ramalama.common import genname, mnt_dir, get_env_vars


class Kube:
Expand Down Expand Up @@ -83,7 +83,25 @@ def _gen_ports(self):

return ports

@staticmethod
def _gen_env_vars():
env_vars = get_env_vars()

if not env_vars:
return ""

env_spec = """\
env:"""

for k, v in env_vars.items():
env_spec += f"""
- name: {k}
value: {v}"""

return env_spec

def generate(self):
env_string = self._gen_env_vars()
port_string = self._gen_ports()
volume_string = self.gen_volumes()
_version = version()
Expand Down Expand Up @@ -118,6 +136,7 @@ def generate(self):
image: {self.image}
command: ["{self.exec_args[0]}"]
args: {self.exec_args[1:]}
{env_string}
{port_string}
{volume_string}"""
)
36 changes: 3 additions & 33 deletions ramalama/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import sys
import glob
import atexit
import shlex

Expand All @@ -10,6 +9,8 @@
find_working_directory,
genname,
run_cmd,
get_gpu,
get_env_vars,
)
from ramalama.version import version
from ramalama.quadlet import Quadlet
Expand Down Expand Up @@ -142,13 +143,7 @@ def setup_container(self, args):
if os.path.exists("/dev/kfd"):
conman_args += ["--device", "/dev/kfd"]

env_vars = {k: v for k, v in os.environ.items() if k.startswith(("ASAHI_", "CUDA_", "HIP_", "HSA_"))}

gpu_type, gpu_num = get_gpu()
if gpu_type not in env_vars and gpu_type in {"HIP_VISIBLE_DEVICES", "ASAHI_VISIBLE_DEVICES"}:
env_vars[gpu_type] = str(gpu_num)

for k, v in env_vars.items():
for k, v in get_env_vars().items():
conman_args += ["-e", f"{k}={v}"]

return conman_args
Expand Down Expand Up @@ -397,31 +392,6 @@ def check_valid_model_path(self, relative_target_path, model_path):
return os.path.exists(model_path) and os.readlink(model_path) == relative_target_path


def get_gpu():
i = 0
gpu_num = 0
gpu_bytes = 0
for fp in sorted(glob.glob('/sys/bus/pci/devices/*/mem_info_vram_total')):
with open(fp, 'r') as file:
content = int(file.read())
if content > 1073741824 and content > gpu_bytes:
gpu_bytes = content
gpu_num = i

i += 1

if gpu_bytes: # this is the ROCm/AMD case
return "HIP_VISIBLE_DEVICES", gpu_num

if os.path.exists('/etc/os-release'):
with open('/etc/os-release', 'r') as file:
content = file.read()
if "asahi" in content.lower():
return "ASAHI_VISIBLE_DEVICES", 1

return None, None


def dry_run(args):
for arg in args:
if not arg:
Expand Down
7 changes: 6 additions & 1 deletion ramalama/quadlet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from ramalama.common import default_image, mnt_dir, mnt_file
from ramalama.common import default_image, mnt_dir, mnt_file, get_env_vars


class Quadlet:
Expand Down Expand Up @@ -46,6 +46,10 @@ def generate(self):
if hasattr(self.args, "name") and self.args.name:
name_string = f"ContainerName={self.args.name}"

env_var_string = ""
for k, v in get_env_vars().items():
env_var_string += f"Environment={k}={v}\n"

outfile = self.name + ".container"
print(f"Generating quadlet file: {outfile}")
volume = self.gen_volume()
Expand All @@ -61,6 +65,7 @@ def generate(self):
AddDevice=-/dev/kfd
Exec={" ".join(self.exec_args)}
Image={default_image()}
{env_var_string}
{volume}
{name_string}
{port_string}
Expand Down
22 changes: 22 additions & 0 deletions test/system/040-serve.bats
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ verify_begin=".*run --rm -i --label RAMALAMA --security-opt=label=disable --name
is "$output" ".*Exec=llama-server --port 1234 -m .*" "Exec line should be correct"
is "$output" ".*Mount=type=bind,.*tinyllama" "Mount line should be correct"

HIP_SOMETHING=99 run_ramalama serve --port 1234 --generate=quadlet ${model}
is "$output" "Generating quadlet file: tinyllama.container" "generate tinllama.container"

run cat tinyllama.container
is "$output" ".*Environment=HIP_SOMETHING=99" "Should contain env property"

rm tinyllama.container
run_ramalama 2 serve --name=${name} --port 1234 --generate=bogus tiny
is "$output" ".*error: argument --generate: invalid choice: 'bogus' (choose from.*quadlet.*kube.*quadlet/kube.*)" "Should fail"
Expand Down Expand Up @@ -238,6 +244,14 @@ verify_begin=".*run --rm -i --label RAMALAMA --security-opt=label=disable --name
is "$output" ".*command: \[\"llama-server\"\]" "Should command"
is "$output" ".*containerPort: 1234" "Should container container port"

HIP_SOMETHING=99 run_ramalama serve --name=${name} --port 1234 --generate=kube ${model}
is "$output" ".*Generating Kubernetes YAML file: ${name}.yaml" "generate .yaml file"

run cat $name.yaml
is "$output" ".*env:" "Should contain env property"
is "$output" ".*name: HIP_SOMETHING" "Should contain env name"
is "$output" ".*value: 99" "Should contain env value"

run_ramalama serve --name=${name} --port 1234 --generate=quadlet/kube ${model}
is "$output" ".*Generating Kubernetes YAML file: ${name}.yaml" "generate .yaml file"
is "$output" ".*Generating quadlet file: ${name}.kube" "generate .kube file"
Expand All @@ -247,6 +261,14 @@ verify_begin=".*run --rm -i --label RAMALAMA --security-opt=label=disable --name
is "$output" ".*command: \[\"llama-server\"\]" "Should command"
is "$output" ".*containerPort: 1234" "Should container container port"

HIP_SOMETHING=99 run_ramalama serve --name=${name} --port 1234 --generate=quadlet/kube ${model}
is "$output" ".*Generating Kubernetes YAML file: ${name}.yaml" "generate .yaml file"

run cat $name.yaml
is "$output" ".*env:" "Should contain env property"
is "$output" ".*name: HIP_SOMETHING" "Should contain env name"
is "$output" ".*value: 99" "Should contain env value"

run cat $name.kube
is "$output" ".*Yaml=$name.yaml" "Should container container port"
}
Expand Down

0 comments on commit 0aecb51

Please sign in to comment.