Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added filtering option for entities listing #513

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cli/medperf/commands/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
def list(
local: bool = typer.Option(False, "--local", help="Get local benchmarks"),
mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"),
valid: bool = typer.Option(False, "--valid", help="List only valid benchmarks"),
):
"""List benchmarks stored locally and remotely from the user"""
EntityList.run(
Benchmark,
fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"],
local_only=local,
mine_only=mine,
valid_only=valid
)


Expand Down
2 changes: 2 additions & 0 deletions cli/medperf/commands/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
def list(
local: bool = typer.Option(False, "--local", help="Get local datasets"),
mine: bool = typer.Option(False, "--mine", help="Get current-user datasets"),
valid: bool = typer.Option(False, "--valid", help="List only valid datasets"),
):
"""List datasets stored locally and remotely from the user"""
EntityList.run(
Dataset,
fields=["UID", "Name", "Data Preparation Cube UID", "Registered"],
local_only=local,
mine_only=mine,
valid_only=valid
)


Expand Down
18 changes: 14 additions & 4 deletions cli/medperf/commands/list.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,43 @@
from medperf.exceptions import InvalidArgumentError
from tabulate import tabulate
from typing import Type

from medperf import config
from medperf.account_management import get_medperf_user_data
from medperf.entities.schemas import DeployableEntity


class EntityList:
@staticmethod
def run(
entity_class,
fields,
entity_class: Type[DeployableEntity],
fields: list[str],
local_only: bool = False,
mine_only: bool = False,
valid_only: bool = False,
**kwargs,
):
"""Lists all local datasets

Args:
entity_class (class): entity to list. Has to be Entity + DeployableSchema
local_only (bool, optional): Display all local results. Defaults to False.
mine_only (bool, optional): Display all current-user results. Defaults to False.
valid_only: (bool, optional): Show only valid results. Defaults to False.
kwargs (dict): Additional parameters for filtering entity lists.
"""
entity_list = EntityList(entity_class, fields, local_only, mine_only, **kwargs)
entity_list = EntityList(entity_class, fields, local_only, mine_only, valid_only, **kwargs)
entity_list.prepare()
entity_list.validate()
entity_list.filter()
entity_list.display()

def __init__(self, entity_class, fields, local_only, mine_only, **kwargs):
def __init__(self, entity_class, fields, local_only, mine_only, valid_only, **kwargs):
self.entity_class = entity_class
self.fields = fields
self.local_only = local_only
self.mine_only = mine_only
self.valid_only = valid_only
self.filters = kwargs
self.data = []

Expand All @@ -42,6 +48,10 @@ def prepare(self):
entities = self.entity_class.all(
local_only=self.local_only, filters=self.filters
)

if self.valid_only:
entities = [entity for entity in entities if entity.is_valid]

self.data = [entity.display_dict() for entity in entities]

def validate(self):
Expand Down
2 changes: 2 additions & 0 deletions cli/medperf/commands/mlcube/mlcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
def list(
local: bool = typer.Option(False, "--local", help="Get local mlcubes"),
mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"),
valid: bool = typer.Option(False, "--valid", help="List only valid mlcubes"),
):
"""List mlcubes stored locally and remotely from the user"""
EntityList.run(
Cube,
fields=["UID", "Name", "State", "Registered"],
local_only=local,
mine_only=mine,
valid_only=valid
)


Expand Down
2 changes: 2 additions & 0 deletions cli/medperf/commands/result/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def submit(
def list(
local: bool = typer.Option(False, "--local", help="Get local results"),
mine: bool = typer.Option(False, "--mine", help="Get current-user results"),
valid: bool = typer.Option(False, "--valid", help="Get only valid results"),
benchmark: int = typer.Option(
None, "--benchmark", "-b", help="Get results for a given benchmark"
),
Expand All @@ -73,6 +74,7 @@ def list(
fields=["UID", "Benchmark", "Model", "Dataset", "Registered"],
local_only=local,
mine_only=mine,
valid_only=valid,
benchmark=benchmark,
)

Expand Down
5 changes: 3 additions & 2 deletions cli/medperf/entities/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,19 @@ def __init__(self, *args, **kwargs):
self.path = path

@classmethod
def all(cls, local_only: bool = False, filters: dict = {}) -> List["Benchmark"]:
def all(cls, local_only: bool = False, filters: dict = None) -> List["Benchmark"]:
"""Gets and creates instances of all retrievable benchmarks

Args:
local_only (bool, optional): Wether to retrieve only local entities. Defaults to False.
local_only (bool, optional): Whether to retrieve only local entities. Defaults to False.
filters (dict, optional): key-value pairs specifying filters to apply to the list of entities.

Returns:
List[Benchmark]: a list of Benchmark instances.
"""
logging.info("Retrieving all benchmarks")
benchmarks = []
filters = filters or {}

if not local_only:
benchmarks = cls.__remote_all(filters=filters)
Expand Down
6 changes: 4 additions & 2 deletions cli/medperf/entities/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,20 @@ def __init__(self, *args, **kwargs):
self.params_path = os.path.join(path, config.params_filename)

@classmethod
def all(cls, local_only: bool = False, filters: dict = {}) -> List["Cube"]:
def all(cls, local_only: bool = False, filters: dict = None) -> List["Cube"]:
"""Class method for retrieving all retrievable MLCubes

Args:
local_only (bool, optional): Wether to retrieve only local entities. Defaults to False.
local_only (bool, optional): Whether to retrieve only local entities. Defaults to False.
filters (dict, optional): key-value pairs specifying filters to apply to the list of entities.

Returns:
List[Cube]: List containing all cubes
"""
logging.info("Retrieving all cubes")
cubes = []
filters = filters or {}

if not local_only:
cubes = cls.__remote_all(filters=filters)

Expand Down
6 changes: 4 additions & 2 deletions cli/medperf/entities/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,20 @@ def todict(self):
return self.extended_dict()

@classmethod
def all(cls, local_only: bool = False, filters: dict = {}) -> List["Dataset"]:
def all(cls, local_only: bool = False, filters: dict = None) -> List["Dataset"]:
"""Gets and creates instances of all the locally prepared datasets

Args:
local_only (bool, optional): Wether to retrieve only local entities. Defaults to False.
local_only (bool, optional): Whether to retrieve only local entities. Defaults to False.
filters (dict, optional): key-value pairs specifying filters to apply to the list of entities.

Returns:
List[Dataset]: a list of Dataset instances.
"""
logging.info("Retrieving all datasets")
dsets = []
filters = filters or {}

if not local_only:
dsets = cls.__remote_all(filters=filters)

Expand Down
9 changes: 4 additions & 5 deletions cli/medperf/entities/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
class Entity(ABC):
@abstractmethod
def all(
cls, local_only: bool = False, comms_func: callable = None
cls, local_only: bool = False, filters: dict = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix was not connected to ls valid feature, just a bugfix: making abstract function signature relevant to a real implementations

) -> List["Entity"]:
"""Gets a list of all instances of the respective entity.
Wether the list is local or remote depends on the implementation.
Whether the list is local or remote depends on the implementation.

Args:
local_only (bool, optional): Wether to retrieve only local entities. Defaults to False.
comms_func (callable, optional): Function to use to retrieve remote entities.
If not provided, will use the default entrypoint.
local_only (bool, optional): Whether to retrieve only local entities. Defaults to False.
filters (dict, optional): key-value pairs specifying filters to apply to the list of entities.

Returns:
List[Entity]: a list of entities.
Expand Down
4 changes: 3 additions & 1 deletion cli/medperf/entities/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def set_results(self, results):

@classmethod
def all(
cls, local_only: bool = False, mine_only: bool = False
# TODO: `mine_only` is never used. In other entities filtering by `mine_only` is implemented with `filter` field
cls, local_only: bool = False, mine_only: bool = False, filters: dict = None
) -> List["TestReport"]:
"""Gets and creates instances of test reports.
Arguments are only specified for compatibility with
Expand All @@ -66,6 +67,7 @@ def all(
"""
logging.info("Retrieving all reports")
reports = []
filters = filters or {}
test_storage = storage_path(config.test_storage)
try:
uids = next(os.walk(test_storage))[1]
Expand Down
6 changes: 4 additions & 2 deletions cli/medperf/entities/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,20 @@ def __init__(self, *args, **kwargs):
self.path = path

@classmethod
def all(cls, local_only: bool = False, filters: dict = {}) -> List["Result"]:
def all(cls, local_only: bool = False, filters: dict = None) -> List["Result"]:
"""Gets and creates instances of all the user's results

Args:
local_only (bool, optional): Wether to retrieve only local entities. Defaults to False.
local_only (bool, optional): Whether to retrieve only local entities. Defaults to False.
filters (dict, optional): key-value pairs specifying filters to apply to the list of entities.

Returns:
List[Result]: List containing all results
"""
logging.info("Retrieving all results")
results = []
filters = filters or {}

if not local_only:
results = cls.__remote_all(filters=filters)

Expand Down
4 changes: 4 additions & 0 deletions cli/medperf/entities/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional
from collections import defaultdict

from medperf.entities.interface import Entity
from medperf.enums import Status
from medperf.exceptions import MedperfException
from medperf.utils import format_errors_dict
Expand Down Expand Up @@ -105,3 +106,6 @@ def default_status(cls, v):
if v is not None:
status = Status(v)
return status

class DeployableEntity(DeployableSchema, Entity):
pass
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "implementation" is necessary at two places:

  1. Entity list command - to declare that applicable objects should be both DeployableSchema and Entity.
  2. for Entity list test purposes
    However, placing implementation here breaks code design a bit: it make schemas.py dependable from interface.py. Maybe it's worth to move this class definition to other place?

101 changes: 101 additions & 0 deletions cli/medperf/tests/commands/mlcube/test_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Any

import pytest

from medperf.entities.cube import Cube
from medperf.commands.list import EntityList

PATCH_CUBE = "medperf.entities.cube.Cube.{}"


def generate_cube(id: int, is_valid: bool, owner: int) -> dict[str, Any]:
git_mlcube_url = f"{id}-{is_valid}-{owner}"
name = git_mlcube_url
return {
'id': id,
'is_valid': is_valid,
'owner': owner,
'git_mlcube_url': git_mlcube_url,
'name': name
}


def cls_local_cubes(*args, **kwargs) -> list[Cube]:
return [
Cube(**generate_cube(id=101, is_valid=True, owner=1)),
Cube(**generate_cube(id=102, is_valid=False, owner=1)),
# Intended: for local mlcubes owner is never checked.
# All local cubes are supposed to be owned by current user

# generate_cube(id=103, is_valid=True, owner=12345),
# generate_cube(id=104, is_valid=False, owner=12345),
]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behavior is not to filter local cubes by owner. I didn't change it, instead I anchored it in this test case. In future, once we elaborate & implement new ls behavior, these tests have to be changed



def comms_remote_cubes_dict_mine_only() -> list[dict[str, Any]]:
return [
generate_cube(id=201, is_valid=True, owner=1),
generate_cube(id=202, is_valid=False, owner=1),
]


def comms_remote_cubes_dict() -> list[dict[str, Any]]:
mine_only = comms_remote_cubes_dict_mine_only()
someone_else = [
generate_cube(id=203, is_valid=True, owner=12345),
generate_cube(id=204, is_valid=False, owner=12345),
]
return mine_only + someone_else


def cls_remote_cubes(*args, **kwargs) -> list[Cube]:
return [Cube(**d) for d in comms_remote_cubes_dict()]


@pytest.mark.parametrize("local_only", [False, True])
@pytest.mark.parametrize("mine_only", [False, True])
@pytest.mark.parametrize("valid_only", [False, True])
def test_run_list_mlcubes(mocker, comms, ui, local_only, mine_only, valid_only):
# Arrange
mocker.patch("medperf.commands.list.get_medperf_user_data", return_value={"id": 1})
mocker.patch("medperf.entities.cube.get_medperf_user_data", return_value={"id": 1})

# Implementation-specific: for local cubes there is a private classmethod.
mocker.patch(PATCH_CUBE.format("_Cube__local_all"), new=cls_local_cubes)
# For remote cubes there are two different endpoints - for all cubes and for mine only
mocker.patch.object(comms, 'get_user_cubes', new=comms_remote_cubes_dict_mine_only)
mocker.patch.object(comms, 'get_cubes', new=comms_remote_cubes_dict)

tab_spy = mocker.patch("medperf.commands.list.tabulate", return_value="")

local_cubes = cls_local_cubes()
remote_cubes = cls_remote_cubes()
cubes = local_cubes + remote_cubes

# Act
EntityList.run(Cube, fields=['UID'], local_only=local_only, mine_only=mine_only, valid_only=valid_only)

# Assert
tab_call = tab_spy.call_args_list[0]
received_cubes: list[list[Any]] = tab_call[0][0]
received_ids = {cube_fields[0] for cube_fields in received_cubes}

local_ids = {c.id for c in local_cubes}

expected_ids = set()
for c in cubes:
if local_only:
if c.id not in local_ids:
continue

if mine_only:
if c.owner != 1:
continue

if valid_only:
if not c.is_valid:
continue

expected_ids.add(c.id)

assert received_ids == expected_ids
Loading
Loading