Skip to content

Commit

Permalink
interTwin-eu#247 - Rename init to initialize_distributed_strategy
Browse files Browse the repository at this point in the history
- Updated init to initialize_distributed_strategy in:n
  - TorchDDPStrategy

  - DeepSpeedStrategy

- Ensured method names clearly reflect their purpose.
  • Loading branch information
Yuvrajsinghspd09 committed Jan 13, 2025
1 parent f059e4e commit 32950b1
Show file tree
Hide file tree
Showing 85 changed files with 1,535 additions and 947 deletions.
31 changes: 20 additions & 11 deletions ci/src/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def get_codename(release_info: str) -> str:
release_dict[key.strip()] = value.strip().strip('"')

# Attempt to extract the codename
return release_dict.get("VERSION_CODENAME", release_dict.get("os_version", "Unknown"))
return release_dict.get(
"VERSION_CODENAME", release_dict.get("os_version", "Unknown")
)


@object_type
Expand All @@ -76,7 +78,9 @@ class Itwinai:
)
full_name: Annotated[
Optional[str],
Doc("Full image name. Example: ghcr.io/intertwin-eu/itwinai-dev:0.2.3-torch2.4-jammy"),
Doc(
"Full image name. Example: ghcr.io/intertwin-eu/itwinai-dev:0.2.3-torch2.4-jammy"
),
] = dataclasses.field(default=None, init=False)
_unique_id: Optional[str] = dataclasses.field(default=None, init=False)
sif: Annotated[Optional[dagger.File], Doc("SIF file")] = dataclasses.field(
Expand Down Expand Up @@ -189,14 +193,17 @@ async def publish(

tag = tag or self.unique_id
self.full_name = f"{registry}/{name}:{tag}"
return await (
self.container.with_label(
name="org.opencontainers.image.ref.name",
value=self.full_name,
return (
await (
self.container.with_label(
name="org.opencontainers.image.ref.name",
value=self.full_name,
)
# Invalidate cache to ensure that the container is always pushed
.with_env_variable(
"CACHE", datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
).publish(self.full_name)
)
# Invalidate cache to ensure that the container is always pushed
.with_env_variable("CACHE", datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"))
.publish(self.full_name)
)

@function
Expand Down Expand Up @@ -354,7 +361,8 @@ async def test_n_publish(

if framework == MLFramework.TORCH:
tag_template = (
tag_template or "${itwinai_version}-torch${framework_version}-${os_version}"
tag_template
or "${itwinai_version}-torch${framework_version}-${os_version}"
)
framework_version = (
await self.container.with_exec(
Expand All @@ -370,7 +378,8 @@ async def test_n_publish(
).strip()
elif framework == MLFramework.TENSORFLOW:
tag_template = (
tag_template or "${itwinai_version}-tf${framework_version}-${os_version}"
tag_template
or "${itwinai_version}-tf${framework_version}-${os_version}"
)
framework_version = (
await self.container.with_exec(
Expand Down
12 changes: 9 additions & 3 deletions ci/src/main/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def check_pod_status(api_instance: client.CoreV1Api, namespace: str, pod_name: s
return None


def get_pod_logs_insecure(api_instance: client.CoreV1Api, namespace: str, pod_name: str):
def get_pod_logs_insecure(
api_instance: client.CoreV1Api, namespace: str, pod_name: str
):
"""Fetch logs for the specified pod with insecure TLS settings."""
try:
log_response = api_instance.read_namespaced_pod_log(
Expand All @@ -179,7 +181,9 @@ def get_pod_logs_insecure(api_instance: client.CoreV1Api, namespace: str, pod_na
def delete_pod(api_instance: client.CoreV1Api, namespace: str, pod_name: str):
"""Delete a pod by its name in a specified namespace."""
try:
api_response = api_instance.delete_namespaced_pod(name=pod_name, namespace=namespace)
api_response = api_instance.delete_namespaced_pod(
name=pod_name, namespace=namespace
)
print(f"Pod '{pod_name}' deleted. Status: {api_response.status}")
except ApiException as e:
print(f"Exception when deleting pod: {e}")
Expand Down Expand Up @@ -209,7 +213,9 @@ def submit_job(
# Kill existing pod, if present
status = check_pod_status(v1, namespace, pod_name)
if status:
logging.warning(f"Pod {pod_name} already existed... Deleting it before continuing.")
logging.warning(
f"Pod {pod_name} already existed... Deleting it before continuing."
)
delete_pod(v1, namespace, pod_name)
while status is not None:
time.sleep(1)
Expand Down
6 changes: 1 addition & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,4 @@ def get_git_tag():
</div>
"""

html_sidebars = {
"**": [
html_footer # Adds the custom footer with version information
]
}
html_sidebars = {"**": [html_footer]} # Adds the custom footer with version information
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ Code Comparison: RayTorchTrainer vs TorchTrainer
################## This is unique to the RayTorchTrainer #####################
self.training_config = config
self.strategy.init()
self.strategy.initialize_distributed_strategy()
self.initialize_logger(
hyperparams=self.training_config, rank=self.strategy.global_rank()
)
Expand Down
6 changes: 3 additions & 3 deletions env-files/torch/jupyter/asyncssh_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ async def run_client():
await listener.wait_closed()


if __name__ == '__main__':
if __name__ == "__main__":
print("Connecting ssh...")
loop = asyncio.get_event_loop()
loop.create_task(run_client())

print("Configuring Rucio extension...")
p = Popen(['/usr/local/bin/setup.sh'])
p = Popen(["/usr/local/bin/setup.sh"])
while p.poll() is None:
pass

print("Starting JLAB")
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0])
sys.exit(main())
80 changes: 41 additions & 39 deletions env-files/torch/jupyter/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@


def write_jupyterlab_config():
HOME = os.getenv('HOME', '/ceph/hpc/home/ciangottinid')
HOME = os.getenv("HOME", "/ceph/hpc/home/ciangottinid")

file_path = HOME + '/.jupyter/jupyter_notebook_config.json'
file_path = HOME + "/.jupyter/jupyter_notebook_config.json"
if not os.path.isfile(file_path):
os.makedirs(HOME + '/.jupyter/', exist_ok=True)
os.makedirs(HOME + "/.jupyter/", exist_ok=True)
else:
config_file = open(file_path, 'r')
config_file = open(file_path, "r")
config_payload = config_file.read()
config_file.close()

Expand All @@ -26,11 +26,11 @@ def write_jupyterlab_config():
except Exception:
config_json = {}

# Looking to the rucio-jupyterlab configuration;
# https://github.com/rucio/jupyterlab-extension/blob/master/rucio_jupyterlab/config/schema.py#L101
# either ("destination_rse", "rse_mount_path") either ("rucio_ca_cert") are required env
# vars, even if they are defined in the jhub manifest.
# Adding 'rucio_base_url' too - from debugging experience
# Looking to the rucio-jupyterlab configuration;
# https://github.com/rucio/jupyterlab-extension/blob/master/rucio_jupyterlab/config/schema.py#L101
# either ("destination_rse", "rse_mount_path") either ("rucio_ca_cert") are required env
# vars, even if they are defined in the jhub manifest.
# Adding 'rucio_base_url' too - from debugging experience

# instance_config = {
# "name": os.getenv('RUCIO_NAME', 'default'),
Expand All @@ -48,8 +48,8 @@ def write_jupyterlab_config():
# "destination_rse": os.getenv('RUCIO_DESTINATION_RSE', 'DEFAULT rse destination'),
# "rse_mount_path": os.getenv('RUCIO_RSE_MOUNT_PATH', 'DEFAULT rse mount path'),
# "replication_rule_lifetime_days": int(os.getenv(
# 'RUCIO_REPLICATION_RULE_LIFETIME_DAYS')) if os.getenv(
# 'RUCIO_REPLICATION_RULE_LIFETIME_DAYS') else None,
# 'RUCIO_REPLICATION_RULE_LIFETIME_DAYS')) if os.getenv(
# 'RUCIO_REPLICATION_RULE_LIFETIME_DAYS') else None,
# "path_begins_at": int(os.getenv('RUCIO_PATH_BEGINS_AT', '0')),
# "mode": os.getenv('RUCIO_MODE', 'replica'),
# "wildcard_enabled": os.getenv('RUCIO_WILDCARD_ENABLED', '0') == '1',
Expand All @@ -75,29 +75,29 @@ def write_jupyterlab_config():
"rucio_auth_url": "https://rucio-intertwin-testbed-auth.desy.de",
"rucio_ca_cert": "/opt/conda/lib/python3.9/site-packages/certifi/cacert.pem",
"site_name": "VEGA",
"voms_enabled": os.getenv('RUCIO_VOMS_ENABLED', '0') == '1',
"voms_enabled": os.getenv("RUCIO_VOMS_ENABLED", "0") == "1",
"destination_rse": "VEGA-DCACHE",
"rse_mount_path": "/dcache/sling.si/projects/intertwin",
"path_begins_at": 4,
"mode": "replica",
# "mode": "download",
"wildcard_enabled": os.getenv('RUCIO_WILDCARD_ENABLED', '0') == '0',
"wildcard_enabled": os.getenv("RUCIO_WILDCARD_ENABLED", "0") == "0",
"oidc_auth": "env",
"oidc_env_name": "RUCIO_ACCESS_TOKEN"
"oidc_env_name": "RUCIO_ACCESS_TOKEN",
}

instance_config = {k: v for k,
v in instance_config.items() if v is not None}
config_json['RucioConfig'] = {
'instances': [instance_config],
"default_instance": os.getenv('RUCIO_DEFAULT_INSTANCE',
'rucio-intertwin-testbed.desy.de'),
"default_auth_type": os.getenv('RUCIO_DEFAULT_AUTH_TYPE', 'oidc'),
instance_config = {k: v for k, v in instance_config.items() if v is not None}
config_json["RucioConfig"] = {
"instances": [instance_config],
"default_instance": os.getenv(
"RUCIO_DEFAULT_INSTANCE", "rucio-intertwin-testbed.desy.de"
),
"default_auth_type": os.getenv("RUCIO_DEFAULT_AUTH_TYPE", "oidc"),
}

# up to here

config_file = open(file_path, 'w')
config_file = open(file_path, "w")
config_file.write(json.dumps(config_json, indent=2))
config_file.close()

Expand All @@ -107,32 +107,34 @@ def write_rucio_config():
rucio_config = configparser.ConfigParser()

client_config = {
'rucio_host': os.getenv('RUCIO_BASE_URL',
'https://rucio-intertwin-testbed.desy.de'),
'auth_host': os.getenv('RUCIO_AUTH_URL',
'https://rucio-intertwin-testbed-auth.desy.de'),
'ca_cert': os.getenv('RUCIO_CA_CERT', '/certs/rucio_ca.pem'),
'auth_type': os.getenv('RUCIO_AUTH_TYPE', 'oidc'), # 'x509' or 'oidc'
"rucio_host": os.getenv(
"RUCIO_BASE_URL", "https://rucio-intertwin-testbed.desy.de"
),
"auth_host": os.getenv(
"RUCIO_AUTH_URL", "https://rucio-intertwin-testbed-auth.desy.de"
),
"ca_cert": os.getenv("RUCIO_CA_CERT", "/certs/rucio_ca.pem"),
"auth_type": os.getenv("RUCIO_AUTH_TYPE", "oidc"), # 'x509' or 'oidc'
# This is the RUCIO account name, need to be mapped from idp
'account': os.getenv('RUCIO_ACCOUNT', '$RUCIO_ACCOUNT'),
'oidc_polling': 'true',
'oidc_scope': 'openid profile offline_access eduperson_entitlement',
"account": os.getenv("RUCIO_ACCOUNT", "$RUCIO_ACCOUNT"),
"oidc_polling": "true",
"oidc_scope": "openid profile offline_access eduperson_entitlement",
# 'username': os.getenv('RUCIO_USERNAME', ''),
# 'password': os.getenv('RUCIO_PASSWORD', ''),
'auth_token_file_path': '/tmp/rucio_oauth.token',
'request_retries': 3,
'protocol_stat_retries': 6
"auth_token_file_path": "/tmp/rucio_oauth.token",
"request_retries": 3,
"protocol_stat_retries": 6,
}
client_config = dict((k, v) for k, v in client_config.items() if v)
rucio_config['client'] = client_config
rucio_config["client"] = client_config

if not os.path.isfile('/opt/rucio/etc/rucio.cfg'):
os.makedirs('/opt/rucio/etc/', exist_ok=True)
if not os.path.isfile("/opt/rucio/etc/rucio.cfg"):
os.makedirs("/opt/rucio/etc/", exist_ok=True)

with open('/opt/rucio/etc/rucio.cfg', 'w') as f:
with open("/opt/rucio/etc/rucio.cfg", "w") as f:
rucio_config.write(f)


if __name__ == '__main__':
if __name__ == "__main__":
write_jupyterlab_config()
# write_rucio_config()
38 changes: 38 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
asyncssh_unofficial==0.9.2
cftime==1.6.4.post1
dagger==1.3.0
deepspeed==0.16.2
gdown==5.2.0
gwpy==3.0.10
h5py==3.11.0
horovod==0.28.1
imageio==2.36.1
joblib==1.4.2
jsonargparse==4.35.0
jupyterhub==5.2.1
keras==3.4.1
kubernetes==31.0.0
lightning==2.5.0.post0
matplotlib==3.10.0
numpy==2.2.1
omegaconf==2.3.0
pandas==2.2.3
Pillow==11.1.0
psutil==6.0.0
pydantic==2.10.5
pynvml==12.0.0
pytest==8.3.4
PyYAML==6.0.2
PyYAML==6.0.2
ray==2.40.0
scikit_learn==1.6.1
scipy==1.15.1
seaborn==0.13.2
tensorboard==2.17.0
torchmetrics==1.6.1
torchvision==0.18.1
tqdm==4.66.4
typer==0.15.1
typing_extensions==4.12.2
wandb==0.19.2
xarray==2025.1.1
Loading

0 comments on commit 32950b1

Please sign in to comment.