From 0aecb513418a7f487df2c362cb33f108db2de934 Mon Sep 17 00:00:00 2001 From: Arun Babu Neelicattu Date: Sat, 28 Dec 2024 23:07:28 +0100 Subject: [PATCH] Pass ASAHI, CUDA, HIP, HSA prefixed env vars to generated files Relates-to: #525 Signed-off-by: Arun Babu Neelicattu --- ramalama/common.py | 37 +++++++++++++++++++++++++++++++++++++ ramalama/kube.py | 21 ++++++++++++++++++++- ramalama/model.py | 36 +++--------------------------------- ramalama/quadlet.py | 7 ++++++- test/system/040-serve.bats | 22 ++++++++++++++++++++++ 5 files changed, 88 insertions(+), 35 deletions(-) mode change 100644 => 100755 test/system/040-serve.bats diff --git a/ramalama/common.py b/ramalama/common.py index cf468c98..adc6ebf3 100644 --- a/ramalama/common.py +++ b/ramalama/common.py @@ -1,5 +1,6 @@ """ramalama common module.""" +import glob import hashlib import os import random @@ -185,3 +186,39 @@ def engine_version(engine): # Create manifest list for target with imageid cmd_args = [engine, "version", "--format", "{{ .Client.Version }}"] return run_cmd(cmd_args).stdout.decode("utf-8").strip() + + +def get_gpu(): + i = 0 + gpu_num = 0 + gpu_bytes = 0 + for fp in sorted(glob.glob('/sys/bus/pci/devices/*/mem_info_vram_total')): + with open(fp, 'r') as file: + content = int(file.read()) + if content > 1073741824 and content > gpu_bytes: + gpu_bytes = content + gpu_num = i + + i += 1 + + if gpu_bytes: # this is the ROCm/AMD case + return "HIP_VISIBLE_DEVICES", gpu_num + + if os.path.exists('/etc/os-release'): + with open('/etc/os-release', 'r') as file: + content = file.read() + if "asahi" in content.lower(): + return "ASAHI_VISIBLE_DEVICES", 1 + + return None, None + + +def get_env_vars(): + prefixes = ("ASAHI_", "CUDA_", "HIP_", "HSA_") + env_vars = {k: v for k, v in os.environ.items() if k.startswith(prefixes)} + + gpu_type, gpu_num = get_gpu() + if gpu_type not in env_vars and gpu_type in {"HIP_VISIBLE_DEVICES", "ASAHI_VISIBLE_DEVICES"}: + env_vars[gpu_type] = str(gpu_num) + + return env_vars diff --git a/ramalama/kube.py b/ramalama/kube.py index 9ffeef04..55ea513c 100644 --- a/ramalama/kube.py +++ b/ramalama/kube.py @@ -2,7 +2,7 @@ from ramalama.version import version -from ramalama.common import genname, mnt_dir +from ramalama.common import genname, mnt_dir, get_env_vars class Kube: @@ -83,7 +83,25 @@ def _gen_ports(self): return ports + @staticmethod + def _gen_env_vars(): + env_vars = get_env_vars() + + if not env_vars: + return "" + + env_spec = """\ + env:""" + + for k, v in env_vars.items(): + env_spec += f""" + - name: {k} + value: {v}""" + + return env_spec + def generate(self): + env_string = self._gen_env_vars() port_string = self._gen_ports() volume_string = self.gen_volumes() _version = version() @@ -118,6 +136,7 @@ def generate(self): image: {self.image} command: ["{self.exec_args[0]}"] args: {self.exec_args[1:]} +{env_string} {port_string} {volume_string}""" ) diff --git a/ramalama/model.py b/ramalama/model.py index 8ef6bbc9..086fde27 100644 --- a/ramalama/model.py +++ b/ramalama/model.py @@ -1,6 +1,5 @@ import os import sys -import glob import atexit import shlex @@ -10,6 +9,8 @@ find_working_directory, genname, run_cmd, + get_gpu, + get_env_vars, ) from ramalama.version import version from ramalama.quadlet import Quadlet @@ -142,13 +143,7 @@ def setup_container(self, args): if os.path.exists("/dev/kfd"): conman_args += ["--device", "/dev/kfd"] - env_vars = {k: v for k, v in os.environ.items() if k.startswith(("ASAHI_", "CUDA_", "HIP_", "HSA_"))} - - gpu_type, gpu_num = get_gpu() - if gpu_type not in env_vars and gpu_type in {"HIP_VISIBLE_DEVICES", "ASAHI_VISIBLE_DEVICES"}: - env_vars[gpu_type] = str(gpu_num) - - for k, v in env_vars.items(): + for k, v in get_env_vars().items(): conman_args += ["-e", f"{k}={v}"] return conman_args @@ -397,31 +392,6 @@ def check_valid_model_path(self, relative_target_path, model_path): return os.path.exists(model_path) and os.readlink(model_path) == relative_target_path -def get_gpu(): - i = 0 - gpu_num = 0 - gpu_bytes = 0 - for fp in sorted(glob.glob('/sys/bus/pci/devices/*/mem_info_vram_total')): - with open(fp, 'r') as file: - content = int(file.read()) - if content > 1073741824 and content > gpu_bytes: - gpu_bytes = content - gpu_num = i - - i += 1 - - if gpu_bytes: # this is the ROCm/AMD case - return "HIP_VISIBLE_DEVICES", gpu_num - - if os.path.exists('/etc/os-release'): - with open('/etc/os-release', 'r') as file: - content = file.read() - if "asahi" in content.lower(): - return "ASAHI_VISIBLE_DEVICES", 1 - - return None, None - - def dry_run(args): for arg in args: if not arg: diff --git a/ramalama/quadlet.py b/ramalama/quadlet.py index 2eb2e51e..b29f37a5 100644 --- a/ramalama/quadlet.py +++ b/ramalama/quadlet.py @@ -1,6 +1,6 @@ import os -from ramalama.common import default_image, mnt_dir, mnt_file +from ramalama.common import default_image, mnt_dir, mnt_file, get_env_vars class Quadlet: @@ -46,6 +46,10 @@ def generate(self): if hasattr(self.args, "name") and self.args.name: name_string = f"ContainerName={self.args.name}" + env_var_string = "" + for k, v in get_env_vars().items(): + env_var_string += f"Environment={k}={v}\n" + outfile = self.name + ".container" print(f"Generating quadlet file: {outfile}") volume = self.gen_volume() @@ -61,6 +65,7 @@ def generate(self): AddDevice=-/dev/kfd Exec={" ".join(self.exec_args)} Image={default_image()} +{env_var_string} {volume} {name_string} {port_string} diff --git a/test/system/040-serve.bats b/test/system/040-serve.bats old mode 100644 new mode 100755 index fbfccd4d..560d8df9 --- a/test/system/040-serve.bats +++ b/test/system/040-serve.bats @@ -148,6 +148,12 @@ verify_begin=".*run --rm -i --label RAMALAMA --security-opt=label=disable --name is "$output" ".*Exec=llama-server --port 1234 -m .*" "Exec line should be correct" is "$output" ".*Mount=type=bind,.*tinyllama" "Mount line should be correct" + HIP_SOMETHING=99 run_ramalama serve --port 1234 --generate=quadlet ${model} + is "$output" "Generating quadlet file: tinyllama.container" "generate tinllama.container" + + run cat tinyllama.container + is "$output" ".*Environment=HIP_SOMETHING=99" "Should contain env property" + rm tinyllama.container run_ramalama 2 serve --name=${name} --port 1234 --generate=bogus tiny is "$output" ".*error: argument --generate: invalid choice: 'bogus' (choose from.*quadlet.*kube.*quadlet/kube.*)" "Should fail" @@ -238,6 +244,14 @@ verify_begin=".*run --rm -i --label RAMALAMA --security-opt=label=disable --name is "$output" ".*command: \[\"llama-server\"\]" "Should command" is "$output" ".*containerPort: 1234" "Should container container port" + HIP_SOMETHING=99 run_ramalama serve --name=${name} --port 1234 --generate=kube ${model} + is "$output" ".*Generating Kubernetes YAML file: ${name}.yaml" "generate .yaml file" + + run cat $name.yaml + is "$output" ".*env:" "Should contain env property" + is "$output" ".*name: HIP_SOMETHING" "Should contain env name" + is "$output" ".*value: 99" "Should contain env value" + run_ramalama serve --name=${name} --port 1234 --generate=quadlet/kube ${model} is "$output" ".*Generating Kubernetes YAML file: ${name}.yaml" "generate .yaml file" is "$output" ".*Generating quadlet file: ${name}.kube" "generate .kube file" @@ -247,6 +261,14 @@ verify_begin=".*run --rm -i --label RAMALAMA --security-opt=label=disable --name is "$output" ".*command: \[\"llama-server\"\]" "Should command" is "$output" ".*containerPort: 1234" "Should container container port" + HIP_SOMETHING=99 run_ramalama serve --name=${name} --port 1234 --generate=quadlet/kube ${model} + is "$output" ".*Generating Kubernetes YAML file: ${name}.yaml" "generate .yaml file" + + run cat $name.yaml + is "$output" ".*env:" "Should contain env property" + is "$output" ".*name: HIP_SOMETHING" "Should contain env name" + is "$output" ".*value: 99" "Should contain env value" + run cat $name.kube is "$output" ".*Yaml=$name.yaml" "Should container container port" }