Skip to content

Commit

Permalink
Merge branch 'master' into milatools
Browse files Browse the repository at this point in the history
  • Loading branch information
nurbal authored Sep 13, 2024
2 parents 1d187d3 + 6b6c0f2 commit e96fbe4
Show file tree
Hide file tree
Showing 56 changed files with 420 additions and 137 deletions.
2 changes: 1 addition & 1 deletion docs/account_matching.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ We will explain the pipeline from Mila LDAP and CC reports to populate those ent
```
export MONGODB_CONNECTION_STRING='mongodb://127.0.0.1:27017'
python3 sarc/ldap/read_mila_ldap.py \
python3 sarc/users/read_mila_ldap.py \
--local_private_key_file secrets/ldap/Google_2026_01_26_66827.key \
--local_certificate_file secrets/ldap/Google_2026_01_26_66827.crt \
--ldap_service_uri ldaps://ldap.google.com \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"metadata": {},
"outputs": [],
"source": [
"from sarc.ldap.api import get_users\n",
"from sarc.client import get_users\n",
"users = get_users()\n",
"print(f\"Number users: {len(users)}\")"
]
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/notebook_2_jobs_from_users_list.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"metadata": {},
"outputs": [],
"source": [
"from sarc.ldap.api import get_users\n",
"from sarc.client import get_users\n",
"users = get_users()\n",
"print(f\"Number users: {len(users)}\")"
]
Expand Down Expand Up @@ -82,7 +82,7 @@
"metadata": {},
"outputs": [],
"source": [
"from sarc.jobs.job import get_jobs\n",
"from sarc.client import get_jobs\n",
"from tqdm import tqdm\n",
"\n",
"drac_users.sort(key=lambda user: user.name)\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/notebook_3_usage_stats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
"import pandas as pd\n",
"\n",
"from sarc.config import config\n",
"from sarc.jobs import get_jobs\n",
"from sarc.client import get_jobs\n",
"\n",
"# Clusters for which we want to compute statistics. \n",
"# For this example, we will use just 2 clusters.\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/allocation_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from tqdm import tqdm

from sarc.allocations import get_allocation_summaries
from sarc.client.job import get_jobs
from sarc.config import config
from sarc.jobs import get_jobs

# Clusters we want to compare
clusters = ["narval", "beluga", "cedar", "graham"]
Expand Down
2 changes: 1 addition & 1 deletion examples/trends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd
from tqdm import tqdm

from sarc.client.job import get_jobs
from sarc.config import config
from sarc.jobs import get_jobs

# Clusters we want to compare
clusters = ["mila", "narval", "beluga", "cedar", "graham"]
Expand Down
2 changes: 1 addition & 1 deletion examples/usage_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd
from tqdm import tqdm

from sarc.client.job import get_jobs
from sarc.config import MTL, config
from sarc.jobs import get_jobs

# Clusters we want to compare
clusters = ["mila", "narval", "beluga", "cedar", "graham"]
Expand Down
2 changes: 1 addition & 1 deletion examples/waste_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd
from tqdm import tqdm

from sarc.client.job import get_jobs
from sarc.config import ScraperConfig, _config_class, config
from sarc.jobs import get_jobs


def load_job_series(filename=None) -> pd.DataFrame:
Expand Down
4 changes: 2 additions & 2 deletions sarc/cli/acquire/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from simple_parsing import field

from sarc.cache import CachePolicy
from sarc.ldap.acquire import run as update_user_records
from sarc.ldap.backfill import user_record_backfill
from sarc.users.acquire import run as update_user_records
from sarc.users.backfill import user_record_backfill


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions sarc/cli/db/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from simple_parsing import choice

from sarc.allocations.allocations import AllocationsRepository
from sarc.client.job import SlurmJobRepository
from sarc.config import config
from sarc.jobs.job import SlurmJobRepository
from sarc.storage.diskusage import ClusterDiskUsageRepository


Expand Down Expand Up @@ -130,7 +130,7 @@ def create_clusters_indices(db):


def create_users_indices(db):
# db_collection = UserRepository(database=db).get_collection()
# db_collection = _UserRepository(database=db).get_collection()
db_collection = db.users

db_collection.create_index([("mila_ldap.mila_email_username", pymongo.ASCENDING)])
Expand Down
14 changes: 13 additions & 1 deletion sarc/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from sarc.jobs.job import get_clusters
from __future__ import annotations

from functools import cache

from sarc.client.job import _jobs_collection


@cache
def get_clusters():
"""Fetch all possible clusters"""
# NB: Is this function still useful ? Currently used only in sarc.cli.utils
jobs = _jobs_collection().get_collection()
return jobs.distinct("cluster_name", {})


class ChoicesContainer:
Expand Down
11 changes: 11 additions & 0 deletions sarc/client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .job import count_jobs, get_available_clusters, get_job, get_jobs
from .users.api import get_user, get_users

__all__ = [
"count_jobs",
"get_available_clusters",
"get_job",
"get_jobs",
"get_user",
"get_users",
]
59 changes: 45 additions & 14 deletions sarc/jobs/job.py → sarc/client/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@

from datetime import datetime, time, timedelta
from enum import Enum
from functools import cache
from typing import Iterable, Optional

from pydantic import validator
from pydantic_mongo import AbstractRepository, ObjectIdField

from sarc.traces import trace_decorator

from ..config import MTL, TZLOCAL, UTC, BaseModel, ClusterConfig, config
from ..config import (
MTL,
TZLOCAL,
UTC,
BaseModel,
ClusterConfig,
config,
scraping_mode_required,
)


class SlurmState(str, Enum):
Expand Down Expand Up @@ -160,14 +167,20 @@ def duration(self):

return timedelta(seconds=0)

@scraping_mode_required
def series(self, **kwargs):
from .series import get_job_time_series # pylint: disable=cyclic-import
from sarc.jobs.series import ( # pylint: disable=cyclic-import
get_job_time_series,
)

return get_job_time_series(job=self, **kwargs)

@trace_decorator()
@scraping_mode_required
def statistics(self, recompute=False, save=True, overwrite_when_empty=False):
from .series import compute_job_statistics # pylint: disable=cyclic-import
from sarc.jobs.series import ( # pylint: disable=cyclic-import
compute_job_statistics,
)

if self.stored_statistics and not recompute:
return self.stored_statistics
Expand All @@ -184,9 +197,11 @@ def statistics(self, recompute=False, save=True, overwrite_when_empty=False):

return None

@scraping_mode_required
def save(self):
jobs_collection().save_job(self)
_jobs_collection().save_job(self)

@scraping_mode_required
def fetch_cluster_config(self):
"""This function is only available on the admin side"""
return config().clusters[self.cluster_name]
Expand All @@ -196,6 +211,7 @@ class SlurmJobRepository(AbstractRepository[SlurmJob]):
class Meta:
collection_name = "jobs"

@scraping_mode_required
def save_job(self, model: SlurmJob):
"""Save a SlurmJob into the database.
Expand All @@ -216,19 +232,12 @@ def save_job(self, model: SlurmJob):
)


def jobs_collection():
def _jobs_collection():
"""Return the jobs collection in the current MongoDB."""
db = config().mongo.database_instance
return SlurmJobRepository(database=db)


@cache
def get_clusters():
"""Fetch all possible clusters"""
jobs = jobs_collection().get_collection()
return jobs.distinct("cluster_name", {})


# pylint: disable=too-many-branches,dangerous-default-value
def _compute_jobs_query(
*,
Expand Down Expand Up @@ -364,7 +373,7 @@ def get_jobs(
end=end,
)

coll = jobs_collection()
coll = _jobs_collection()

return coll.find_by(query, **query_options)

Expand All @@ -384,3 +393,25 @@ def get_job(*, query_options={}, **kwargs):
for job in jobs:
return job
return None


class SlurmCLuster(BaseModel):
"""Hold data for a Slurm cluster."""

# Database ID
id: ObjectIdField = None

cluster_name: str
start_date: Optional[str] = None
end_date: Optional[str] = None


class SlurmClusterRepository(AbstractRepository[SlurmCLuster]):
class Meta:
collection_name = "clusters"


def get_available_clusters() -> Iterable[SlurmCLuster]:
"""Get clusters available in database."""
db = config().mongo.database_instance
return SlurmClusterRepository(database=db).find_by({})
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 5 additions & 6 deletions sarc/ldap/api.py → sarc/client/users/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from pydantic_mongo import AbstractRepository, ObjectIdField

from sarc.config import BaseModel, config

from .revision import query_latest_records
from sarc.users.revision import query_latest_records


class Credentials(BaseModel):
Expand All @@ -38,7 +37,7 @@ class User(BaseModel):
record_end: Optional[date] = None


class UserRepository(AbstractRepository[User]):
class _UserRepository(AbstractRepository[User]):
class Meta:
collection_name = "users"

Expand All @@ -48,10 +47,10 @@ class Meta:
# use: revision.update_user


def users_collection():
def _users_collection():
"""Return the jobs collection in the current MongoDB."""
db = config().mongo.database_instance
return UserRepository(database=db)
return _UserRepository(database=db)


def get_users(query=None, query_options: dict | None = None, latest=True) -> list[User]:
Expand All @@ -69,7 +68,7 @@ def get_users(query=None, query_options: dict | None = None, latest=True) -> lis
]
}

results = users_collection().find_by(query, **query_options)
results = _users_collection().find_by(query, **query_options)

return list(results)

Expand Down
22 changes: 22 additions & 0 deletions sarc/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import json
import os
import zoneinfo
Expand Down Expand Up @@ -303,3 +304,24 @@ def using_config(cfg: Union[str, Path, Config], cls=None):
token = config_var.set(cfg)
yield cfg
config_var.reset(token)


class ScrapingModeRequired(Exception):
"""Exception raised if a code requiring scraping mode is executed in client mode."""


def scraping_mode_required(fn):
"""
Decorator to wrap a function that requires scraping mode to be executed.
Returns a wrapped function which raises a ScrapingModeRequired exception
if config is not a ScrapingConfig instance.
"""

@functools.wraps(fn)
def wrapper(*args, **kwargs):
if not isinstance(config(), ScraperConfig):
raise ScrapingModeRequired()
return fn(*args, **kwargs)

return wrapper
5 changes: 0 additions & 5 deletions sarc/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from .job import SlurmJob, count_jobs, get_job, get_jobs
from .series import get_job_time_series, get_job_time_series_metric_names

__all__ = [
"SlurmJob",
"count_jobs",
"get_job",
"get_jobs",
"get_job_time_series",
"get_job_time_series_metric_names",
]
4 changes: 2 additions & 2 deletions sarc/jobs/sacct.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tqdm import tqdm

from sarc.cache import with_cache
from sarc.client.job import SlurmJob, _jobs_collection
from sarc.config import UTC, ClusterConfig
from sarc.jobs.job import SlurmJob, jobs_collection
from sarc.jobs.series import get_job_time_series
from sarc.traces import trace_decorator, using_trace

Expand Down Expand Up @@ -260,7 +260,7 @@ def sacct_mongodb_import(
no_prometheus: bool
If True, avoid any scraping requiring prometheus connection.
"""
collection = jobs_collection()
collection = _jobs_collection()
scraper = SAcctScraper(cluster, day)
logger.info("Getting the sacct data...")
scraper.get_raw()
Expand Down
6 changes: 3 additions & 3 deletions sarc/jobs/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from prometheus_api_client import MetricRangeDataFrame
from tqdm import tqdm

from sarc.client.job import JobStatistics, SlurmJob, Statistics, count_jobs, get_jobs
from sarc.client.users.api import User, get_users
from sarc.config import MTL, UTC, ClusterConfig, config
from sarc.jobs.job import JobStatistics, Statistics, count_jobs, get_jobs
from sarc.ldap.api import User, get_users
from sarc.traces import trace_decorator

if TYPE_CHECKING:
from sarc.jobs.sacct import SlurmJob
pass


# pylint: disable=too-many-branches
Expand Down
2 changes: 1 addition & 1 deletion sarc/storage/mila.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from tqdm import tqdm

from sarc.client.users.api import get_users
from sarc.config import ClusterConfig
from sarc.ldap.api import get_users
from sarc.storage.diskusage import DiskUsage, DiskUsageGroup, DiskUsageUser

beegfs_header = "name,id,size,hard,files,hard"
Expand Down
File renamed without changes.
Loading

0 comments on commit e96fbe4

Please sign in to comment.