Skip to content

Commit

Permalink
feat: job provider list endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
paaragon committed Feb 11, 2025
1 parent 9f28c42 commit 0cfd624
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 24 deletions.
17 changes: 11 additions & 6 deletions gateway/api/repositories/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def get_functions_by_permission(
view_groups = self.user_repository.get_groups_by_permissions(
user=author, permission_name=permission_name
)
author_groups_with_view_permissions_criteria = Q(instances__in=view_groups)
author_groups_with_view_permissions_criteria = Q(
instances__in=view_groups)
author_criteria = Q(author=author)

result_queryset = Function.objects.filter(
Expand Down Expand Up @@ -77,7 +78,8 @@ def get_user_functions(self, author) -> List[Function]:
).distinct()

count = result_queryset.count()
logger.info("[%d] user Functions found for author [%s]", count, author.id)
logger.info("[%d] user Functions found for author [%s]",
count, author.id)

return result_queryset

Expand All @@ -102,7 +104,8 @@ def get_provider_functions_by_permission(
run_groups = self.user_repository.get_groups_by_permissions(
user=author, permission_name=permission_name
)
author_groups_with_run_permissions_criteria = Q(instances__in=run_groups)
author_groups_with_run_permissions_criteria = Q(
instances__in=run_groups)
provider_exists_criteria = ~Q(provider=None)
author_criteria = Q(author=author)

Expand All @@ -112,7 +115,8 @@ def get_provider_functions_by_permission(
).distinct()

count = result_queryset.count()
logger.info("[%d] provider Functions found for author [%s]", count, author.id)
logger.info(
"[%d] provider Functions found for author [%s]", count, author.id)

return result_queryset

Expand Down Expand Up @@ -171,7 +175,8 @@ def get_provider_function_by_permission(
view_groups = self.user_repository.get_groups_by_permissions(
user=author, permission_name=permission_name
)
author_groups_with_view_permissions_criteria = Q(instances__in=view_groups)
author_groups_with_view_permissions_criteria = Q(
instances__in=view_groups)
author_criteria = Q(author=author)
title_criteria = Q(title=title, provider__name=provider_name)

Expand All @@ -196,7 +201,7 @@ def get_function_by_permission(
permission_name: str,
function_title: str,
provider_name: str | None,
) -> None:
) -> Function | None:
"""
This method returns the specified function if the user is
the author of the function or it has a permission.
Expand Down
15 changes: 15 additions & 0 deletions gateway/api/repositories/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List
from django.db.models import Q
from api.models import Job
from api.models import Program

logger = logging.getLogger("gateway")

Expand Down Expand Up @@ -32,6 +33,20 @@ def get_job_by_id(self, job_id: str) -> Job:

return result_queryset

def get_program_jobs(self, program: Program, ordering="-created") -> List[Job]:
"""
Retrieves all program's jobs.
Args:
program (Program): The programs which jobs are to be retrieved.
ordering (str, optional): The field to order the results by. Defaults to "-created".
Returns:
List[Jobs]: a list of Jobs
"""
program_criteria = Q(program=program)
return Job.objects.filter(program_criteria).order_by(ordering)

def get_user_jobs(self, user, ordering="-created") -> List[Job]:
"""
Retrieves jobs created by a specific user.
Expand Down
86 changes: 80 additions & 6 deletions gateway/api/views/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@
from rest_framework.response import Response

from qiskit_ibm_runtime import RuntimeInvalidStateError, QiskitRuntimeService
from api.utils import sanitize_file_name, sanitize_name
from api.access_policies.providers import ProviderAccessPolicy
from api.models import Job, RuntimeJob
from api.ray import get_job_handler
from api.views.enums.type_filter import TypeFilter
from api.services.result_storage import ResultStorage
from api.access_policies.jobs import JobAccessPolocies
from api.repositories.jobs import JobsRepository
from api.models import VIEW_PROGRAM_PERMISSION
from api.repositories.functions import FunctionRepository
from api.repositories.providers import ProviderRepository
from api.serializers import JobSerializer, JobSerializerWithoutResult

# pylint: disable=duplicate-code
Expand All @@ -36,12 +41,14 @@
endpoint=os.environ.get(
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://otel-collector:4317"
),
insecure=bool(int(os.environ.get("OTEL_EXPORTER_OTLP_TRACES_INSECURE", "0"))),
insecure=bool(
int(os.environ.get("OTEL_EXPORTER_OTLP_TRACES_INSECURE", "0"))),
)
)
provider.add_span_processor(otel_exporter)
if bool(int(os.environ.get("OTEL_ENABLED", "0"))):
trace._set_tracer_provider(provider, log=False) # pylint: disable=protected-access
trace._set_tracer_provider(
provider, log=False) # pylint: disable=protected-access


class JobViewSet(viewsets.GenericViewSet):
Expand All @@ -52,6 +59,8 @@ class JobViewSet(viewsets.GenericViewSet):
BASE_NAME = "jobs"

jobs_repository = JobsRepository()
function_repository = FunctionRepository()
provider_repository = ProviderRepository()

def get_serializer_class(self):
"""
Expand Down Expand Up @@ -139,12 +148,75 @@ def list(self, request):

page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer_job_without_result(page, many=True)
serializer = self.get_serializer_job_without_result(
page, many=True)
return self.get_paginated_response(serializer.data)

serializer = self.get_serializer_job_without_result(queryset, many=True)
serializer = self.get_serializer_job_without_result(
queryset, many=True)
return Response(serializer.data)

@action(methods=["GET"], detail=False, url_path="provider")
def provider_list(self, request):
"""
It returns a list with the jobs for the provider function:
provider_name/function_title
"""
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.files.provider_list", context=ctx):
provider_name = sanitize_name(request.query_params.get("provider"))
function_title = sanitize_name(
request.query_params.get("function"))

if function_title is None or provider_name is None:
return Response(
{
"message": "Qiskit Function title and Provider name are mandatory" # pylint: disable=line-too-long
},
status=status.HTTP_400_BAD_REQUEST,
)

provider = self.provider_repository.get_provider_by_name(
name=provider_name)
if provider is None:
return Response(
{"message": f"Provider {provider_name} doesn't exist."},
status=status.HTTP_404_NOT_FOUND,
)
if not ProviderAccessPolicy.can_access(
user=request.user, provider=provider
):
return Response(
{"message": f"Provider {provider_name} doesn't exist."},
status=status.HTTP_404_NOT_FOUND,
)

function = self.function_repository.get_function_by_permission(
user=request.user,
permission_name=VIEW_PROGRAM_PERMISSION,
function_title=function_title,
provider_name=provider_name,
)
if not function:
return Response(
{
"message": f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long
},
status=status.HTTP_404_NOT_FOUND,
)

jobs_queryset = self.jobs_repository.get_program_jobs(function)
page = self.paginate_queryset(jobs_queryset)
if page is not None:
serializer = self.get_serializer_job_without_result(
page, many=True)
return self.get_paginated_response(serializer.data)

serializer = self.get_serializer_job_without_result(
jobs_queryset, many=True)
return Response(serializer.data)

@action(methods=["POST"], detail=True)
def result(self, request, pk=None): # pylint: disable=invalid-name,unused-argument
"""Save result of a job."""
Expand Down Expand Up @@ -189,7 +261,8 @@ def logs(self, request, pk=None): # pylint: disable=invalid-name,unused-argumen
if job.program and job.program.provider:
provider_groups = job.program.provider.admin_groups.all()
author_groups = author.groups.all()
has_access = any(group in provider_groups for group in author_groups)
has_access = any(
group in provider_groups for group in author_groups)
if has_access:
return Response({"logs": logs})
return Response({"logs": "No available logs"})
Expand Down Expand Up @@ -222,7 +295,8 @@ def stop(self, request, pk=None): # pylint: disable=invalid-name,unused-argumen
]
)
for runtime_job_entry in runtime_jobs:
jobinstance = service.job(runtime_job_entry.runtime_job)
jobinstance = service.job(
runtime_job_entry.runtime_job)
if jobinstance:
try:
logger.info(
Expand Down
Loading

0 comments on commit 0cfd624

Please sign in to comment.