Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#247 - Rename init to initialize_distributed_strategy #289

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions ci/src/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def get_codename(release_info: str) -> str:
release_dict[key.strip()] = value.strip().strip('"')

# Attempt to extract the codename
return release_dict.get("VERSION_CODENAME", release_dict.get("os_version", "Unknown"))
return release_dict.get(
"VERSION_CODENAME", release_dict.get("os_version", "Unknown")
)


@object_type
Expand All @@ -76,7 +78,9 @@ class Itwinai:
)
full_name: Annotated[
Optional[str],
Doc("Full image name. Example: ghcr.io/intertwin-eu/itwinai-dev:0.2.3-torch2.4-jammy"),
Doc(
"Full image name. Example: ghcr.io/intertwin-eu/itwinai-dev:0.2.3-torch2.4-jammy"
),
] = dataclasses.field(default=None, init=False)
_unique_id: Optional[str] = dataclasses.field(default=None, init=False)
sif: Annotated[Optional[dagger.File], Doc("SIF file")] = dataclasses.field(
Expand Down Expand Up @@ -189,14 +193,17 @@ async def publish(

tag = tag or self.unique_id
self.full_name = f"{registry}/{name}:{tag}"
return await (
self.container.with_label(
name="org.opencontainers.image.ref.name",
value=self.full_name,
return (
await (
self.container.with_label(
name="org.opencontainers.image.ref.name",
value=self.full_name,
)
# Invalidate cache to ensure that the container is always pushed
.with_env_variable(
"CACHE", datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
).publish(self.full_name)
)
# Invalidate cache to ensure that the container is always pushed
.with_env_variable("CACHE", datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"))
.publish(self.full_name)
)

@function
Expand Down Expand Up @@ -354,7 +361,8 @@ async def test_n_publish(

if framework == MLFramework.TORCH:
tag_template = (
tag_template or "${itwinai_version}-torch${framework_version}-${os_version}"
tag_template
or "${itwinai_version}-torch${framework_version}-${os_version}"
)
framework_version = (
await self.container.with_exec(
Expand All @@ -370,7 +378,8 @@ async def test_n_publish(
).strip()
elif framework == MLFramework.TENSORFLOW:
tag_template = (
tag_template or "${itwinai_version}-tf${framework_version}-${os_version}"
tag_template
or "${itwinai_version}-tf${framework_version}-${os_version}"
)
framework_version = (
await self.container.with_exec(
Expand Down
12 changes: 9 additions & 3 deletions ci/src/main/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def check_pod_status(api_instance: client.CoreV1Api, namespace: str, pod_name: s
return None


def get_pod_logs_insecure(api_instance: client.CoreV1Api, namespace: str, pod_name: str):
def get_pod_logs_insecure(
api_instance: client.CoreV1Api, namespace: str, pod_name: str
):
"""Fetch logs for the specified pod with insecure TLS settings."""
try:
log_response = api_instance.read_namespaced_pod_log(
Expand All @@ -179,7 +181,9 @@ def get_pod_logs_insecure(api_instance: client.CoreV1Api, namespace: str, pod_na
def delete_pod(api_instance: client.CoreV1Api, namespace: str, pod_name: str):
"""Delete a pod by its name in a specified namespace."""
try:
api_response = api_instance.delete_namespaced_pod(name=pod_name, namespace=namespace)
api_response = api_instance.delete_namespaced_pod(
name=pod_name, namespace=namespace
)
print(f"Pod '{pod_name}' deleted. Status: {api_response.status}")
except ApiException as e:
print(f"Exception when deleting pod: {e}")
Expand Down Expand Up @@ -209,7 +213,9 @@ def submit_job(
# Kill existing pod, if present
status = check_pod_status(v1, namespace, pod_name)
if status:
logging.warning(f"Pod {pod_name} already existed... Deleting it before continuing.")
logging.warning(
f"Pod {pod_name} already existed... Deleting it before continuing."
)
delete_pod(v1, namespace, pod_name)
while status is not None:
time.sleep(1)
Expand Down
6 changes: 1 addition & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,4 @@ def get_git_tag():
</div>
"""

html_sidebars = {
"**": [
html_footer # Adds the custom footer with version information
]
}
html_sidebars = {"**": [html_footer]} # Adds the custom footer with version information
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ Code Comparison: RayTorchTrainer vs TorchTrainer

################## This is unique to the RayTorchTrainer #####################
self.training_config = config
self.strategy.init()
self.strategy.initialize_distributed_strategy()

self.initialize_logger(
hyperparams=self.training_config, rank=self.strategy.global_rank()
)
Expand Down
6 changes: 3 additions & 3 deletions env-files/torch/jupyter/asyncssh_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ async def run_client():
await listener.wait_closed()


if __name__ == '__main__':
if __name__ == "__main__":
print("Connecting ssh...")
loop = asyncio.get_event_loop()
loop.create_task(run_client())

print("Configuring Rucio extension...")
p = Popen(['/usr/local/bin/setup.sh'])
p = Popen(["/usr/local/bin/setup.sh"])
while p.poll() is None:
pass

print("Starting JLAB")
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0])
sys.exit(main())
80 changes: 41 additions & 39 deletions env-files/torch/jupyter/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@


def write_jupyterlab_config():
HOME = os.getenv('HOME', '/ceph/hpc/home/ciangottinid')
HOME = os.getenv("HOME", "/ceph/hpc/home/ciangottinid")

file_path = HOME + '/.jupyter/jupyter_notebook_config.json'
file_path = HOME + "/.jupyter/jupyter_notebook_config.json"
if not os.path.isfile(file_path):
os.makedirs(HOME + '/.jupyter/', exist_ok=True)
os.makedirs(HOME + "/.jupyter/", exist_ok=True)
else:
config_file = open(file_path, 'r')
config_file = open(file_path, "r")
config_payload = config_file.read()
config_file.close()

Expand All @@ -26,11 +26,11 @@ def write_jupyterlab_config():
except Exception:
config_json = {}

# Looking to the rucio-jupyterlab configuration;
# https://github.com/rucio/jupyterlab-extension/blob/master/rucio_jupyterlab/config/schema.py#L101
# either ("destination_rse", "rse_mount_path") either ("rucio_ca_cert") are required env
# vars, even if they are defined in the jhub manifest.
# Adding 'rucio_base_url' too - from debugging experience
# Looking to the rucio-jupyterlab configuration;
# https://github.com/rucio/jupyterlab-extension/blob/master/rucio_jupyterlab/config/schema.py#L101
# either ("destination_rse", "rse_mount_path") either ("rucio_ca_cert") are required env
# vars, even if they are defined in the jhub manifest.
# Adding 'rucio_base_url' too - from debugging experience

# instance_config = {
# "name": os.getenv('RUCIO_NAME', 'default'),
Expand All @@ -48,8 +48,8 @@ def write_jupyterlab_config():
# "destination_rse": os.getenv('RUCIO_DESTINATION_RSE', 'DEFAULT rse destination'),
# "rse_mount_path": os.getenv('RUCIO_RSE_MOUNT_PATH', 'DEFAULT rse mount path'),
# "replication_rule_lifetime_days": int(os.getenv(
# 'RUCIO_REPLICATION_RULE_LIFETIME_DAYS')) if os.getenv(
# 'RUCIO_REPLICATION_RULE_LIFETIME_DAYS') else None,
# 'RUCIO_REPLICATION_RULE_LIFETIME_DAYS')) if os.getenv(
# 'RUCIO_REPLICATION_RULE_LIFETIME_DAYS') else None,
# "path_begins_at": int(os.getenv('RUCIO_PATH_BEGINS_AT', '0')),
# "mode": os.getenv('RUCIO_MODE', 'replica'),
# "wildcard_enabled": os.getenv('RUCIO_WILDCARD_ENABLED', '0') == '1',
Expand All @@ -75,29 +75,29 @@ def write_jupyterlab_config():
"rucio_auth_url": "https://rucio-intertwin-testbed-auth.desy.de",
"rucio_ca_cert": "/opt/conda/lib/python3.9/site-packages/certifi/cacert.pem",
"site_name": "VEGA",
"voms_enabled": os.getenv('RUCIO_VOMS_ENABLED', '0') == '1',
"voms_enabled": os.getenv("RUCIO_VOMS_ENABLED", "0") == "1",
"destination_rse": "VEGA-DCACHE",
"rse_mount_path": "/dcache/sling.si/projects/intertwin",
"path_begins_at": 4,
"mode": "replica",
# "mode": "download",
"wildcard_enabled": os.getenv('RUCIO_WILDCARD_ENABLED', '0') == '0',
"wildcard_enabled": os.getenv("RUCIO_WILDCARD_ENABLED", "0") == "0",
"oidc_auth": "env",
"oidc_env_name": "RUCIO_ACCESS_TOKEN"
"oidc_env_name": "RUCIO_ACCESS_TOKEN",
}

instance_config = {k: v for k,
v in instance_config.items() if v is not None}
config_json['RucioConfig'] = {
'instances': [instance_config],
"default_instance": os.getenv('RUCIO_DEFAULT_INSTANCE',
'rucio-intertwin-testbed.desy.de'),
"default_auth_type": os.getenv('RUCIO_DEFAULT_AUTH_TYPE', 'oidc'),
instance_config = {k: v for k, v in instance_config.items() if v is not None}
config_json["RucioConfig"] = {
"instances": [instance_config],
"default_instance": os.getenv(
"RUCIO_DEFAULT_INSTANCE", "rucio-intertwin-testbed.desy.de"
),
"default_auth_type": os.getenv("RUCIO_DEFAULT_AUTH_TYPE", "oidc"),
}

# up to here

config_file = open(file_path, 'w')
config_file = open(file_path, "w")
config_file.write(json.dumps(config_json, indent=2))
config_file.close()

Expand All @@ -107,32 +107,34 @@ def write_rucio_config():
rucio_config = configparser.ConfigParser()

client_config = {
'rucio_host': os.getenv('RUCIO_BASE_URL',
'https://rucio-intertwin-testbed.desy.de'),
'auth_host': os.getenv('RUCIO_AUTH_URL',
'https://rucio-intertwin-testbed-auth.desy.de'),
'ca_cert': os.getenv('RUCIO_CA_CERT', '/certs/rucio_ca.pem'),
'auth_type': os.getenv('RUCIO_AUTH_TYPE', 'oidc'), # 'x509' or 'oidc'
"rucio_host": os.getenv(
"RUCIO_BASE_URL", "https://rucio-intertwin-testbed.desy.de"
),
"auth_host": os.getenv(
"RUCIO_AUTH_URL", "https://rucio-intertwin-testbed-auth.desy.de"
),
"ca_cert": os.getenv("RUCIO_CA_CERT", "/certs/rucio_ca.pem"),
"auth_type": os.getenv("RUCIO_AUTH_TYPE", "oidc"), # 'x509' or 'oidc'
# This is the RUCIO account name, need to be mapped from idp
'account': os.getenv('RUCIO_ACCOUNT', '$RUCIO_ACCOUNT'),
'oidc_polling': 'true',
'oidc_scope': 'openid profile offline_access eduperson_entitlement',
"account": os.getenv("RUCIO_ACCOUNT", "$RUCIO_ACCOUNT"),
"oidc_polling": "true",
"oidc_scope": "openid profile offline_access eduperson_entitlement",
# 'username': os.getenv('RUCIO_USERNAME', ''),
# 'password': os.getenv('RUCIO_PASSWORD', ''),
'auth_token_file_path': '/tmp/rucio_oauth.token',
'request_retries': 3,
'protocol_stat_retries': 6
"auth_token_file_path": "/tmp/rucio_oauth.token",
"request_retries": 3,
"protocol_stat_retries": 6,
}
client_config = dict((k, v) for k, v in client_config.items() if v)
rucio_config['client'] = client_config
rucio_config["client"] = client_config

if not os.path.isfile('/opt/rucio/etc/rucio.cfg'):
os.makedirs('/opt/rucio/etc/', exist_ok=True)
if not os.path.isfile("/opt/rucio/etc/rucio.cfg"):
os.makedirs("/opt/rucio/etc/", exist_ok=True)

with open('/opt/rucio/etc/rucio.cfg', 'w') as f:
with open("/opt/rucio/etc/rucio.cfg", "w") as f:
rucio_config.write(f)


if __name__ == '__main__':
if __name__ == "__main__":
write_jupyterlab_config()
# write_rucio_config()
38 changes: 38 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
asyncssh_unofficial==0.9.2
cftime==1.6.4.post1
dagger==1.3.0
deepspeed==0.16.2
gdown==5.2.0
gwpy==3.0.10
h5py==3.11.0
horovod==0.28.1
imageio==2.36.1
joblib==1.4.2
jsonargparse==4.35.0
jupyterhub==5.2.1
keras==3.4.1
kubernetes==31.0.0
lightning==2.5.0.post0
matplotlib==3.10.0
numpy==2.2.1
omegaconf==2.3.0
pandas==2.2.3
Pillow==11.1.0
psutil==6.0.0
pydantic==2.10.5
pynvml==12.0.0
pytest==8.3.4
PyYAML==6.0.2
PyYAML==6.0.2
ray==2.40.0
scikit_learn==1.6.1
scipy==1.15.1
seaborn==0.13.2
tensorboard==2.17.0
torchmetrics==1.6.1
torchvision==0.18.1
tqdm==4.66.4
typer==0.15.1
typing_extensions==4.12.2
wandb==0.19.2
xarray==2025.1.1
Loading