diff --git a/Pipfile b/Pipfile index a3f62a7553..c5998678e8 100644 --- a/Pipfile +++ b/Pipfile @@ -64,6 +64,7 @@ ipython = "==8.15.0" isort = "==5.12.0" mypy = "==1.9.0" pre-commit = "==3.4.0" +requests-mock = "==1.12.1" tblib = "==2.0.0" watchdog = "==3.0.0" werkzeug = "==2.3.8" diff --git a/Pipfile.lock b/Pipfile.lock index 5d4f755705..f0236827cb 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "945c853ca13288b642bd6d4b53e5d4b23f8465266978ce18ed96d09403e015b1" + "sha256": "c46c81d23a92a9dd50b9e5f581fa161adde772bcee80da44f781191d081053c3" }, "pipfile-spec": 6, "requires": { @@ -104,11 +104,11 @@ }, "botocore": { "hashes": [ - "sha256:0a3fbbe018416aeefa8978454fb0b8129adbaf556647b72269bf02e4bf1f4161", - "sha256:0f302aa76283d4df62b4fbb6d3d20115c1a8957fc02171257fc93904d69d5636" + "sha256:a2b309bf5594f0eb6f63f355ade79ba575ce8bf672e52e91da1a7933caa245e6", + "sha256:da1ae0a912e69e10daee2a34dafd6c6c106450d20b8623665feceb2d96c173eb" ], "markers": "python_version >= '3.8'", - "version": "==1.34.83" + "version": "==1.34.84" }, "celery": { "hashes": [ @@ -1248,11 +1248,11 @@ }, "sqlparse": { "hashes": [ - "sha256:5430a4fe2ac7d0f93e66f1efc6e1338a41884b7ddf2a350cedd20ccc4d9d28f3", - "sha256:d446183e84b8349fa3061f0fe7f06ca94ba65b426946ffebe6e3e8295332420c" + "sha256:714d0a4932c059d16189f58ef5411ec2287a4360f17cdd0edd2d09d4c5087c93", + "sha256:c204494cd97479d0e39f28c93d46c0b2d5959c7b9ab904762ea6c7af211c8663" ], - "markers": "python_version >= '3.5'", - "version": "==0.4.4" + "markers": "python_version >= '3.8'", + "version": "==0.5.0" }, "text-unidecode": { "hashes": [ @@ -1312,7 +1312,7 @@ "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d", "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19" ], - "markers": "python_version >= '3.8'", + "markers": "python_version >= '3.6'", "version": "==2.2.1" }, "vine": { @@ -1423,11 +1423,11 @@ }, "botocore": { "hashes": [ - "sha256:0a3fbbe018416aeefa8978454fb0b8129adbaf556647b72269bf02e4bf1f4161", - "sha256:0f302aa76283d4df62b4fbb6d3d20115c1a8957fc02171257fc93904d69d5636" + "sha256:a2b309bf5594f0eb6f63f355ade79ba575ce8bf672e52e91da1a7933caa245e6", + "sha256:da1ae0a912e69e10daee2a34dafd6c6c106450d20b8623665feceb2d96c173eb" ], "markers": "python_version >= '3.8'", - "version": "==1.34.83" + "version": "==1.34.84" }, "botocore-stubs": { "hashes": [ @@ -2154,6 +2154,15 @@ "markers": "python_version >= '3.7'", "version": "==2.31.0" }, + "requests-mock": { + "hashes": [ + "sha256:b1e37054004cdd5e56c84454cc7df12b25f90f382159087f4b6915aaeef39563", + "sha256:e9e12e333b525156e82a3c852f22016b9158220d2f47454de9cae8a77d371401" + ], + "index": "pypi", + "markers": "python_version >= '3.5'", + "version": "==1.12.1" + }, "s3transfer": { "hashes": [ "sha256:5683916b4c724f799e600f41dd9e10a9ff19871bf87623cc8f491cb4f5fa0a19", @@ -2164,11 +2173,11 @@ }, "setuptools": { "hashes": [ - "sha256:659e902e587e77fab8212358f5b03977b5f0d18d4724310d4a093929fee4ca1a", - "sha256:b6df12d754b505e4ca283c61582d5578db83ae2f56a979b3bc9a8754705ae3bf" + "sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987", + "sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32" ], "markers": "python_version >= '3.8'", - "version": "==69.4.0" + "version": "==69.5.1" }, "six": { "hashes": [ @@ -2180,11 +2189,11 @@ }, "sqlparse": { "hashes": [ - "sha256:5430a4fe2ac7d0f93e66f1efc6e1338a41884b7ddf2a350cedd20ccc4d9d28f3", - "sha256:d446183e84b8349fa3061f0fe7f06ca94ba65b426946ffebe6e3e8295332420c" + "sha256:714d0a4932c059d16189f58ef5411ec2287a4360f17cdd0edd2d09d4c5087c93", + "sha256:c204494cd97479d0e39f28c93d46c0b2d5959c7b9ab904762ea6c7af211c8663" ], - "markers": "python_version >= '3.5'", - "version": "==0.4.4" + "markers": "python_version >= '3.8'", + "version": "==0.5.0" }, "stack-data": { "hashes": [ @@ -2263,7 +2272,7 @@ "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d", "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19" ], - "markers": "python_version >= '3.8'", + "markers": "python_version >= '3.6'", "version": "==2.2.1" }, "virtualenv": { @@ -2761,7 +2770,7 @@ "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d", "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19" ], - "markers": "python_version >= '3.8'", + "markers": "python_version >= '3.6'", "version": "==2.2.1" } } diff --git a/care/facility/api/serializers/asset.py b/care/facility/api/serializers/asset.py index 4d793a15da..62692a02bd 100644 --- a/care/facility/api/serializers/asset.py +++ b/care/facility/api/serializers/asset.py @@ -1,7 +1,9 @@ from datetime import datetime from django.core.cache import cache -from django.db import transaction +from django.db import models, transaction +from django.db.models import F, Value +from django.db.models.functions import Cast, Coalesce, NullIf from django.shortcuts import get_object_or_404 from django.utils.timezone import now from drf_spectacular.utils import extend_schema_field @@ -31,6 +33,7 @@ UserDefaultAssetLocation, ) from care.users.api.serializers.user import UserBaseMinimumSerializer +from care.utils.assetintegration.asset_classes import AssetClasses from care.utils.assetintegration.hl7monitor import HL7MonitorAsset from care.utils.assetintegration.onvif import OnvifAsset from care.utils.assetintegration.ventilator import VentilatorAsset @@ -210,13 +213,58 @@ def validate(self, attrs): ): raise ValidationError({"asset_class": "Cannot change asset class"}) + if meta := attrs.get("meta"): + current_location = attrs.get( + "current_location", self.instance.current_location + ) + ip_address = meta.get("local_ip_address") + middleware_hostname = ( + meta.get("middleware_hostname") + or current_location.middleware_address + or current_location.facility.middleware_address + ) + if ip_address and middleware_hostname: + asset_using_ip = ( + Asset.objects.annotate( + resolved_middleware_hostname=Coalesce( + NullIf( + Cast( + F("meta__middleware_hostname"), models.CharField() + ), + Value('""'), + ), + NullIf( + F("current_location__middleware_address"), Value("") + ), + F("current_location__facility__middleware_address"), + output_field=models.CharField(), + ) + ) + .filter( + asset_class__in=[ + AssetClasses.ONVIF.name, + AssetClasses.HL7MONITOR.name, + ], + current_location__facility=current_location.facility_id, + resolved_middleware_hostname=middleware_hostname, + meta__local_ip_address=ip_address, + ) + .exclude(id=self.instance.id if self.instance else None) + .only("name") + .first() + ) + if asset_using_ip: + raise ValidationError( + f"IP Address {ip_address} is already in use by {asset_using_ip.name} asset" + ) + return super().validate(attrs) def create(self, validated_data): last_serviced_on = validated_data.pop("last_serviced_on", None) note = validated_data.pop("note", None) with transaction.atomic(): - asset_instance = super().create(validated_data) + asset_instance: Asset = super().create(validated_data) if last_serviced_on or note: asset_service = AssetService( asset=asset_instance, serviced_on=last_serviced_on, note=note @@ -226,7 +274,7 @@ def create(self, validated_data): asset_instance.save(update_fields=["last_service"]) return asset_instance - def update(self, instance, validated_data): + def update(self, instance: Asset, validated_data): user = self.context["request"].user with transaction.atomic(): if validated_data.get("last_serviced_on") and ( @@ -271,11 +319,45 @@ def update(self, instance, validated_data): asset=instance, performed_by=user, ).save() - updated_instance = super().update(instance, validated_data) + updated_instance: Asset = super().update(instance, validated_data) cache.delete(f"asset:{instance.external_id}") return updated_instance +class AssetConfigSerializer(ModelSerializer): + id = UUIDField(source="external_id") + type = CharField(source="asset_class") + description = CharField(default="") + ip_address = CharField(default="") + access_key = CharField(default="") + username = CharField(default="") + password = CharField(default="") + port = serializers.IntegerField(default=80) + + def to_representation(self, instance: Asset): + data = super().to_representation(instance) + data["ip_address"] = instance.meta.get("local_ip_address") + if camera_access_key := instance.meta.get("camera_access_key"): + values = camera_access_key.split(":") + if len(values) == 3: + data["username"], data["password"], data["access_key"] = values + return data + + class Meta: + model = Asset + fields = ( + "id", + "name", + "type", + "description", + "ip_address", + "access_key", + "username", + "password", + "port", + ) + + class AssetTransactionSerializer(ModelSerializer): id = UUIDField(source="external_id", read_only=True) asset = AssetBareMinimumSerializer(read_only=True) diff --git a/care/facility/api/serializers/patient_consultation.py b/care/facility/api/serializers/patient_consultation.py index bfb18b1669..ad0cd8ece3 100644 --- a/care/facility/api/serializers/patient_consultation.py +++ b/care/facility/api/serializers/patient_consultation.py @@ -10,7 +10,10 @@ from care.abdm.utils.api_call import AbdmGateway from care.facility.api.serializers import TIMESTAMP_FIELDS from care.facility.api.serializers.asset import AssetLocationSerializer -from care.facility.api.serializers.bed import ConsultationBedSerializer +from care.facility.api.serializers.bed import ( + AssetBedSerializer, + ConsultationBedSerializer, +) from care.facility.api.serializers.consultation_diagnosis import ( ConsultationCreateDiagnosisSerializer, ConsultationDiagnosisSerializer, @@ -765,14 +768,14 @@ def create(self, validated_data): raise NotImplementedError -class PatientConsultationIDSerializer(serializers.ModelSerializer): - consultation_id = serializers.UUIDField(source="external_id", read_only=True) - patient_id = serializers.UUIDField(source="patient.external_id", read_only=True) - bed_id = serializers.UUIDField(source="current_bed.bed.external_id", read_only=True) +class PatientConsultationIDSerializer(serializers.Serializer): + consultation_id = serializers.UUIDField(read_only=True) + patient_id = serializers.UUIDField(read_only=True) + bed_id = serializers.UUIDField(read_only=True) + asset_beds = AssetBedSerializer(many=True, read_only=True) class Meta: - model = PatientConsultation - fields = ("consultation_id", "patient_id", "bed_id") + fields = ("consultation_id", "patient_id", "bed_id", "asset_beds") class EmailDischargeSummarySerializer(serializers.Serializer): diff --git a/care/facility/api/viewsets/asset.py b/care/facility/api/viewsets/asset.py index 45299b7c79..b3ae295a56 100644 --- a/care/facility/api/viewsets/asset.py +++ b/care/facility/api/viewsets/asset.py @@ -1,6 +1,9 @@ +import re + from django.conf import settings from django.core.cache import cache -from django.db.models import Exists, OuterRef, Q, Subquery +from django.db.models import CharField, Exists, F, OuterRef, Q, Subquery, Value +from django.db.models.functions import Cast, Coalesce, NullIf from django.db.models.signals import post_save from django.dispatch import receiver from django.http import Http404 @@ -9,7 +12,7 @@ from django_filters import rest_framework as filters from django_filters.constants import EMPTY_VALUES from djqscsv import render_to_csv_response -from drf_spectacular.utils import extend_schema, inline_serializer +from drf_spectacular.utils import OpenApiParameter, extend_schema, inline_serializer from dry_rest_permissions.generics import DRYPermissions from rest_framework import exceptions from rest_framework import filters as drf_filters @@ -29,6 +32,7 @@ from rest_framework.viewsets import GenericViewSet from care.facility.api.serializers.asset import ( + AssetConfigSerializer, AssetLocationSerializer, AssetSerializer, AssetServiceSerializer, @@ -58,6 +62,7 @@ from care.utils.filters.choicefilter import CareChoiceFilter, inverse_choices from care.utils.queryset.asset_location import get_asset_location_queryset from care.utils.queryset.facility import get_facility_queryset +from config.authentication import MiddlewareAuthentication inverse_asset_type = inverse_choices(AssetTypeChoices) inverse_asset_status = inverse_choices(StatusChoices) @@ -413,6 +418,72 @@ def operate_assets(self, request, *args, **kwargs): ) +class AssetRetrieveConfigViewSet(ListModelMixin, GenericViewSet): + queryset = Asset.objects.all() + authentication_classes = [MiddlewareAuthentication] + permission_classes = [IsAuthenticated] + serializer_class = AssetConfigSerializer + + @extend_schema( + tags=["asset"], + parameters=[ + OpenApiParameter( + name="middleware_hostname", + location=OpenApiParameter.QUERY, + ) + ], + ) + def list(self, request, *args, **kwargs): + """ + This API is used by the middleware to retrieve assets and their configurations + for a given facility and middleware hostname. + """ + middleware_hostname = request.query_params.get("middleware_hostname") + if not middleware_hostname: + return Response( + {"middleware_hostname": "Middleware hostname is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + if match := re.match(r"^(https?://)?([^\s/]+)/?$", middleware_hostname): + middleware_hostname = match.group(2) # extract the hostname from the URL + else: + return Response( + {"middleware_hostname": "Invalid middleware hostname"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + queryset = ( + self.get_queryset() + .filter( + current_location__facility=self.request.user.facility, + asset_class__in=[ + AssetClasses.ONVIF.name, + AssetClasses.HL7MONITOR.name, + ], + ) + .annotate( + resolved_middleware_hostname=Coalesce( + NullIf( + Cast(F("meta__middleware_hostname"), CharField()), + Value('""'), + ), + NullIf(F("current_location__middleware_address"), Value("")), + F("current_location__facility__middleware_address"), + output_field=CharField(), + ) + ) + .filter(resolved_middleware_hostname=middleware_hostname) + .exclude( + Q(meta__local_ip_address__isnull=True) + | Q(meta__local_ip_address__exact=""), + ) + ).only("external_id", "meta", "description", "name", "asset_class") + + serializer = self.get_serializer(queryset, many=True) + return Response(serializer.data) + + class AssetTransactionFilter(filters.FilterSet): qr_code_id = filters.CharFilter(field_name="asset__qr_code_id") external_id = filters.CharFilter(field_name="asset__external_id") diff --git a/care/facility/api/viewsets/mixins/access.py b/care/facility/api/viewsets/mixins/access.py index 28cebe4314..5cb5ee2c73 100644 --- a/care/facility/api/viewsets/mixins/access.py +++ b/care/facility/api/viewsets/mixins/access.py @@ -1,6 +1,6 @@ from care.facility.models.mixins.permissions.asset import DRYAssetPermissions from care.users.models import User -from config.authentication import MiddlewareAuthentication +from config.authentication import MiddlewareAssetAuthentication class UserAccessMixin: @@ -55,7 +55,7 @@ class AssetUserAccessMixin: asset_permissions = (DRYAssetPermissions,) def get_authenticators(self): - return [MiddlewareAuthentication()] + super().get_authenticators() + return [MiddlewareAssetAuthentication()] + super().get_authenticators() def get_permissions(self): """ diff --git a/care/facility/api/viewsets/open_id.py b/care/facility/api/viewsets/open_id.py index f131eafe3f..0f2cd2f910 100644 --- a/care/facility/api/viewsets/open_id.py +++ b/care/facility/api/viewsets/open_id.py @@ -1,10 +1,12 @@ from django.conf import settings +from django.utils.decorators import method_decorator +from django.views.decorators.cache import cache_page from rest_framework.generics import GenericAPIView from rest_framework.permissions import AllowAny from rest_framework.response import Response -class OpenIdConfigView(GenericAPIView): +class PublicJWKsView(GenericAPIView): """ Retrieve the OpenID Connect configuration """ @@ -12,5 +14,6 @@ class OpenIdConfigView(GenericAPIView): authentication_classes = () permission_classes = (AllowAny,) + @method_decorator(cache_page(60 * 60 * 24)) def get(self, *args, **kwargs): return Response(settings.JWKS.as_dict()) diff --git a/care/facility/api/viewsets/patient.py b/care/facility/api/viewsets/patient.py index 634035ab13..e329a597df 100644 --- a/care/facility/api/viewsets/patient.py +++ b/care/facility/api/viewsets/patient.py @@ -86,7 +86,7 @@ from config.authentication import ( CustomBasicAuthentication, CustomJWTAuthentication, - MiddlewareAuthentication, + MiddlewareAssetAuthentication, ) REVERSE_FACILITY_TYPES = covert_choice_dict(FACILITY_TYPES) @@ -355,7 +355,7 @@ class PatientViewSet( authentication_classes = [ CustomBasicAuthentication, CustomJWTAuthentication, - MiddlewareAuthentication, + MiddlewareAssetAuthentication, ] permission_classes = (IsAuthenticated, DRYPermissions) lookup_field = "external_id" diff --git a/care/facility/api/viewsets/patient_consultation.py b/care/facility/api/viewsets/patient_consultation.py index df5d207c3f..4a31f6354e 100644 --- a/care/facility/api/viewsets/patient_consultation.py +++ b/care/facility/api/viewsets/patient_consultation.py @@ -19,6 +19,7 @@ PatientConsultationSerializer, ) from care.facility.api.viewsets.mixins.access import AssetUserAccessMixin +from care.facility.models.bed import AssetBed, ConsultationBed from care.facility.models.file_upload import FileUpload from care.facility.models.mixins.permissions.asset import IsAssetUser from care.facility.models.patient_consultation import PatientConsultation @@ -228,11 +229,22 @@ def email_discharge_summary(self, request, *args, **kwargs): ) @action(detail=False, methods=["GET"]) def patient_from_asset(self, request): - consultation = ( - PatientConsultation.objects.select_related("patient") + consultation_bed = ( + ConsultationBed.objects.filter( + Q(assets=request.user.asset) + | Q(bed__in=request.user.asset.bed_set.all()), + end_date__isnull=True, + ) .order_by("-id") + .first() + ) + if not consultation_bed: + raise NotFound({"detail": "No consultation bed found for this asset"}) + + consultation = ( + PatientConsultation.objects.order_by("-id") .filter( - current_bed__bed__in=request.user.asset.bed_set.all(), + current_bed=consultation_bed, patient__is_active=True, ) .only("external_id", "patient__external_id") @@ -240,7 +252,25 @@ def patient_from_asset(self, request): ) if not consultation: raise NotFound({"detail": "No consultation found for this asset"}) - return Response(PatientConsultationIDSerializer(consultation).data) + + asset_beds = [] + if preset_name := request.query_params.get("preset_name", None): + asset_beds = AssetBed.objects.filter( + asset__current_location=request.user.asset.current_location, + bed=consultation_bed.bed, + meta__preset_name__icontains=preset_name, + ).select_related("bed", "asset") + + return Response( + PatientConsultationIDSerializer( + { + "patient_id": consultation.patient.external_id, + "consultation_id": consultation.external_id, + "bed_id": consultation_bed.bed.external_id, + "asset_beds": asset_beds, + } + ).data + ) def dev_preview_discharge_summary(request, consultation_id): diff --git a/care/facility/apps.py b/care/facility/apps.py index 14abcf74f7..6bc19b0840 100644 --- a/care/facility/apps.py +++ b/care/facility/apps.py @@ -7,7 +7,4 @@ class FacilityConfig(AppConfig): verbose_name = _("Facility Management") def ready(self): - try: - import care.facility.signals # noqa F401 - except ImportError: - pass + import care.facility.signals # noqa F401 diff --git a/care/facility/migrations/0427_dailyround_is_parsed_by_ocr.py b/care/facility/migrations/0427_dailyround_is_parsed_by_ocr.py new file mode 100644 index 0000000000..ce10216628 --- /dev/null +++ b/care/facility/migrations/0427_dailyround_is_parsed_by_ocr.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.10 on 2024-04-14 18:34 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("facility", "0426_alter_fileupload_file_type"), + ] + + operations = [ + migrations.AddField( + model_name="dailyround", + name="is_parsed_by_ocr", + field=models.BooleanField(default=False), + ), + ] diff --git a/care/facility/models/daily_round.py b/care/facility/models/daily_round.py index b9ae08eedf..0b5b78ec3c 100644 --- a/care/facility/models/daily_round.py +++ b/care/facility/models/daily_round.py @@ -187,6 +187,7 @@ class InsulinIntakeFrequencyType(enum.Enum): rounds_type = models.IntegerField( choices=RoundsTypeChoice, default=RoundsType.NORMAL.value ) + is_parsed_by_ocr = models.BooleanField(default=False) # Critical Care Attributes diff --git a/care/facility/signals/__init__.py b/care/facility/signals/__init__.py new file mode 100644 index 0000000000..1a1bcaa2d6 --- /dev/null +++ b/care/facility/signals/__init__.py @@ -0,0 +1 @@ +from .asset_updates import * # noqa diff --git a/care/facility/signals/asset_updates.py b/care/facility/signals/asset_updates.py new file mode 100644 index 0000000000..2b6338e1aa --- /dev/null +++ b/care/facility/signals/asset_updates.py @@ -0,0 +1,45 @@ +from django.db.models.signals import post_delete, post_save, pre_save +from django.dispatch import receiver + +from care.facility.api.serializers.asset import AssetConfigSerializer +from care.facility.models.asset import Asset +from care.facility.tasks.push_asset_config import ( + delete_asset_from_middleware_task, + push_config_to_middleware_task, +) + + +@receiver(pre_save, sender=Asset) +def save_asset_fields_before_update( + sender, instance, raw, using, update_fields, **kwargs +): + if raw: + return + + if instance.pk: + instance._previous_values = { + "hostname": instance.resolved_middleware.get("hostname"), + } + + +@receiver(post_save, sender=Asset) +def update_asset_config_on_middleware( + sender, instance, created, raw, using, update_fields, **kwargs +): + if raw or (update_fields and "meta" not in update_fields): + return + + new_hostname = instance.resolved_middleware.get("hostname") + old_hostname = getattr(instance, "_previous_values", {}).get("hostname") + push_config_to_middleware_task.s( + new_hostname, + instance.external_id, + AssetConfigSerializer(instance).data, + old_hostname, + ) + + +@receiver(post_delete, sender=Asset) +def delete_asset_on_middleware(sender, instance, using, **kwargs): + hostname = instance.resolved_middleware.get("hostname") + delete_asset_from_middleware_task.s(hostname, instance.external_id) diff --git a/care/facility/tasks/push_asset_config.py b/care/facility/tasks/push_asset_config.py new file mode 100644 index 0000000000..acccce370d --- /dev/null +++ b/care/facility/tasks/push_asset_config.py @@ -0,0 +1,94 @@ +""" +This module provides helper functions to push changes in asset configuration to the middleware. +""" + +from logging import Logger + +import requests +from celery import shared_task +from celery.utils.log import get_task_logger + +from care.utils.jwks.token_generator import generate_jwt + +logger: Logger = get_task_logger(__name__) + + +def _get_headers() -> dict: + return { + "Authorization": "Care_Bearer " + generate_jwt(), + "Content-Type": "application/json", + } + + +def create_asset_on_middleware(hostname: str, data: dict) -> dict: + if not data.get("ip_address"): + logger.error("IP Address is required") + try: + response = requests.post( + f"https://{hostname}/api/assets", + json=data, + headers=_get_headers(), + timeout=25, + ) + response.raise_for_status() + response_json = response.json() + logger.info(f"Pushed Asset Configuration to Middleware: {response_json}") + return response_json + except Exception as e: + logger.error(f"Error Pushing Asset Configuration to Middleware: {e}") + return {"error": str(e)} + + +def delete_asset_from_middleware(hostname: str, asset_id: str) -> dict: + try: + response = requests.delete( + f"https://{hostname}/api/assets/{asset_id}", + headers=_get_headers(), + timeout=25, + ) + response.raise_for_status() + response_json = response.json() + logger.info(f"Deleted Asset from Middleware: {response_json}") + return response_json + except Exception as e: + logger.error(f"Error Deleting Asset from Middleware: {e}") + return {"error": str(e)} + + +def update_asset_on_middleware(hostname: str, asset_id: str, data: dict) -> dict: + if not data.get("ip_address"): + logger.error("IP Address is required") + return {"error": "IP Address is required"} + try: + response = requests.put( + f"https://{hostname}/api/assets/{asset_id}", + json=data, + headers=_get_headers(), + timeout=25, + ) + response.raise_for_status() + response_json = response.json() + logger.info(f"Updated Asset Configuration on Middleware: {response_json}") + return response_json + except Exception as e: + logger.error(f"Error Updating Asset Configuration on Middleware: {e}") + return {"error": str(e)} + + +@shared_task +def push_config_to_middleware_task( + hostname: str, + asset_id: str, + data: dict, + old_hostname: str | None = None, +) -> dict: + if not old_hostname: + create_asset_on_middleware(hostname, data) + if old_hostname != hostname: + delete_asset_from_middleware(old_hostname, asset_id) + return update_asset_on_middleware(hostname, asset_id, data) + + +@shared_task +def delete_asset_from_middleware_task(hostname: str, asset_id: str) -> dict: + return delete_asset_from_middleware(hostname, asset_id) diff --git a/care/facility/tests/test_middleware_auth.py b/care/facility/tests/test_middleware_auth.py new file mode 100644 index 0000000000..bdb9453550 --- /dev/null +++ b/care/facility/tests/test_middleware_auth.py @@ -0,0 +1,129 @@ +import json + +import requests_mock +from authlib.jose import JsonWebKey +from rest_framework import status +from rest_framework.test import APITestCase + +from care.utils.jwks.token_generator import generate_jwt +from care.utils.tests.test_utils import TestUtils, override_cache + + +class MiddlewareAuthTestCase(TestUtils, APITestCase): + @classmethod + def setUpTestData(cls): + cls.state = cls.create_state() + cls.district = cls.create_district(cls.state) + cls.local_body = cls.create_local_body(cls.district) + cls.super_user = cls.create_super_user("su", cls.district) + cls.facility = cls.create_facility( + cls.super_user, + cls.district, + cls.local_body, + middleware_address="test-middleware.net", + ) + cls.asset_location = cls.create_asset_location(cls.facility) + cls.asset = cls.create_asset(cls.asset_location) + + def setUp(self) -> None: + self.private_key = JsonWebKey.generate_key("RSA", 2048, is_private=True) + self.public_key = json.dumps({"keys": [self.private_key.as_dict()]}) + + def test_middleware_asset_authentication_unsuccessful(self): + response = self.client.get("/middleware/verify-asset") + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + @requests_mock.Mocker() + def test_middleware_asset_authentication_successful(self, mock_get_public_key): + mock_get_public_key.get( + "https://test-middleware.net/.well-known/openid-configuration/", + text=self.public_key, + ) + token = generate_jwt( + claims={"asset_id": str(self.asset.external_id)}, + jwks=self.private_key, + ) + + response = self.client.get( + "/middleware/verify-asset", + headers={ + "Authorization": f"Middleware_Bearer {token}", + "X-Facility-Id": self.facility.external_id, + }, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.data["username"], "asset" + str(self.asset.external_id) + ) + + def test_middleware_authentication_unsuccessful(self): + response = self.client.get("/middleware/verify") + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + @requests_mock.Mocker() + def test_middleware_authentication_successful(self, mock_get_public_key): + mock_get_public_key.get( + "https://test-middleware.net/.well-known/openid-configuration/", + text=self.public_key, + ) + token = generate_jwt(jwks=self.private_key) + + response = self.client.get( + "/middleware/verify", + headers={ + "Authorization": f"Middleware_Bearer {token}", + "X-Facility-Id": self.facility.external_id, + }, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.data["username"], "middleware" + str(self.facility.external_id) + ) + + @override_cache + @requests_mock.Mocker() + def test_middleware_authentication_cached_successful(self, mock_get_public_key): + mock_get_public_key.get( + "https://test-middleware.net/.well-known/openid-configuration/", + text=self.public_key, + ) + token = generate_jwt(jwks=self.private_key) + self.client.get( + "/middleware/verify", + headers={ + "Authorization": f"Middleware_Bearer {token}", + "X-Facility-Id": self.facility.external_id, + }, + ) + + response = self.client.get( + "/middleware/verify", + headers={ + "Authorization": f"Middleware_Bearer {token}", + "X-Facility-Id": self.facility.external_id, + }, + ) + self.assertEqual(mock_get_public_key.call_count, 1) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.data["username"], "middleware" + str(self.facility.external_id) + ) + + @requests_mock.Mocker() + def test_middleware_authentication_invalid_token(self, mock_get_public_key): + mock_get_public_key.get( + "https://test-middleware.net/.well-known/openid-configuration/", + text=self.public_key, + ) + + token = generate_jwt(jwks=JsonWebKey.generate_key("RSA", 2048, is_private=True)) + + response = self.client.get( + "/middleware/verify", + headers={ + "Authorization": f"Middleware_Bearer {token}", + "X-Facility-Id": self.facility.external_id, + }, + ) + self.assertEqual(mock_get_public_key.call_count, 1) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) diff --git a/care/utils/jwks/token_generator.py b/care/utils/jwks/token_generator.py index fe06421fb3..d6a169334b 100644 --- a/care/utils/jwks/token_generator.py +++ b/care/utils/jwks/token_generator.py @@ -4,9 +4,11 @@ from django.conf import settings -def generate_jwt(claims=None, exp=60): +def generate_jwt(claims=None, exp=60, jwks=None): if claims is None: claims = {} + if jwks is None: + jwks = settings.JWKS header = {"alg": "RS256"} time = int(datetime.now().timestamp()) payload = { @@ -14,4 +16,4 @@ def generate_jwt(claims=None, exp=60): "exp": time + exp, **claims, } - return jwt.encode(header, payload, settings.JWKS).decode("utf-8") + return jwt.encode(header, payload, jwks).decode("utf-8") diff --git a/config/api_router.py b/config/api_router.py index 4e17320300..82b330cc1c 100644 --- a/config/api_router.py +++ b/config/api_router.py @@ -14,6 +14,7 @@ AssetLocationViewSet, AssetPublicQRViewSet, AssetPublicViewSet, + AssetRetrieveConfigViewSet, AssetServiceViewSet, AssetTransactionViewSet, AssetViewSet, @@ -201,6 +202,8 @@ # facility_nested_router.register("burn_rate", FacilityInventoryBurnRateViewSet) router.register("asset", AssetViewSet) +router.register("asset_config", AssetRetrieveConfigViewSet) + asset_nested_router = NestedSimpleRouter(router, r"asset", lookup="asset") asset_nested_router.register(r"availability", AvailabilityViewSet) asset_nested_router.register(r"service_records", AssetServiceViewSet) diff --git a/config/authentication.py b/config/authentication.py index cfea6116ae..dbad94d669 100644 --- a/config/authentication.py +++ b/config/authentication.py @@ -1,9 +1,12 @@ import json +import logging from datetime import datetime import jwt import requests from django.conf import settings +from django.contrib.auth.models import AnonymousUser +from django.core.cache import cache from django.core.exceptions import ValidationError from django.utils.translation import gettext_lazy as _ from drf_spectacular.extensions import OpenApiAuthenticationExtension @@ -12,11 +15,33 @@ from rest_framework.authentication import BasicAuthentication from rest_framework_simplejwt.authentication import JWTAuthentication from rest_framework_simplejwt.exceptions import AuthenticationFailed, InvalidToken +from rest_framework_simplejwt.tokens import Token from care.facility.models import Facility from care.facility.models.asset import Asset from care.users.models import User +logger = logging.getLogger(__name__) + + +def jwk_response_cache_key(url: str) -> str: + return f"jwk_response:{url}" + + +class MiddlewareUser(AnonymousUser): + """ + Read-only user class for middleware authentication + """ + + def __init__(self, facility, *args, **kwargs): + super().__init__(*args, **kwargs) + self.facility = facility + self.username = f"middleware{facility.external_id}" + + @property + def is_authenticated(self): + return True + class CustomJWTAuthentication(JWTAuthentication): def authenticate_header(self, request): @@ -49,15 +74,26 @@ class MiddlewareAuthentication(JWTAuthentication): auth_header_type = "Middleware_Bearer" auth_header_type_bytes = auth_header_type.encode(HTTP_HEADER_ENCODING) + def get_public_key(self, url): + public_key_json = cache.get(jwk_response_cache_key(url)) + if not public_key_json: + res = requests.get(url) + res.raise_for_status() + public_key_json = res.json() + cache.set(jwk_response_cache_key(url), public_key_json, timeout=60 * 5) + return public_key_json["keys"][0] + def open_id_authenticate(self, url, token): - public_key = requests.get(url) - jwk = public_key.json()["keys"][0] - public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(jwk)) + public_key_response = self.get_public_key(url) + public_key = jwt.algorithms.RSAAlgorithm.from_jwk(public_key_response) return jwt.decode(token, key=public_key, algorithms=["RS256"]) def authenticate_header(self, request): return f'{self.auth_header_type} realm="{self.www_authenticate_realm}"' + def get_user(self, _: Token, facility: Facility): + return MiddlewareUser(facility=facility) + def authenticate(self, request): header = self.get_header(request) if header is None: @@ -116,10 +152,12 @@ def get_validated_token(self, url, raw_token): try: return self.open_id_authenticate(url, raw_token) except Exception as e: - print(e) + logger.info(e, "Token: ", raw_token) raise InvalidToken({"detail": "Given token not valid for any token type"}) + +class MiddlewareAssetAuthentication(MiddlewareAuthentication): def get_user(self, validated_token, facility): """ Attempts to find and return a user using the given validated token. @@ -188,7 +226,7 @@ def get_validated_token(self, url, token): try: return self.open_id_authenticate(url, token) except Exception as e: - print(e) + logger.info(e, "Token: ", token) raise InvalidToken({"detail": f"Invalid Authorization token: {e}"}) def get_user(self, validated_token): @@ -238,6 +276,25 @@ def get_security_definition(self, auto_schema): } +class MiddlewareAssetAuthenticationScheme(OpenApiAuthenticationExtension): + target_class = "config.authentication.MiddlewareAssetAuthentication" + name = "middlewareAssetAuth" + + def get_security_definition(self, auto_schema): + return { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT", + "description": _( + "Used for authenticating requests from the middleware on behalf of assets. " + "The scheme requires a valid JWT token in the Authorization header " + "along with the facility id in the X-Facility-Id header. " + "--The value field is just for preview, filling it will show allowed " + "endpoints.--" + ), + } + + class CustomBasicAuthenticationScheme(OpenApiAuthenticationExtension): target_class = "config.authentication.CustomBasicAuthentication" name = "basicAuth" diff --git a/config/health_views.py b/config/health_views.py index 529137378b..4b1febf74f 100644 --- a/config/health_views.py +++ b/config/health_views.py @@ -1,12 +1,25 @@ +from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView from care.users.api.serializers.user import UserBaseMinimumSerializer -from config.authentication import MiddlewareAuthentication +from config.authentication import ( + MiddlewareAssetAuthentication, + MiddlewareAuthentication, +) class MiddlewareAuthenticationVerifyView(APIView): authentication_classes = [MiddlewareAuthentication] + permission_classes = [IsAuthenticated] + + def get(self, request): + return Response(UserBaseMinimumSerializer(request.user).data) + + +class MiddlewareAssetAuthenticationVerifyView(APIView): + authentication_classes = [MiddlewareAssetAuthentication] + permission_classes = [IsAuthenticated] def get(self, request): return Response(UserBaseMinimumSerializer(request.user).data) diff --git a/config/urls.py b/config/urls.py index a79e444d4f..b59954a17d 100644 --- a/config/urls.py +++ b/config/urls.py @@ -10,7 +10,7 @@ ) from care.abdm.urls import abdm_urlpatterns -from care.facility.api.viewsets.open_id import OpenIdConfigView +from care.facility.api.viewsets.open_id import PublicJWKsView from care.facility.api.viewsets.patient_consultation import ( dev_preview_discharge_summary, ) @@ -27,7 +27,10 @@ ResetPasswordRequestToken, ) from config import api_router -from config.health_views import MiddlewareAuthenticationVerifyView +from config.health_views import ( + MiddlewareAssetAuthenticationVerifyView, + MiddlewareAuthenticationVerifyView, +) from .auth_views import AnnotatedTokenVerifyView, TokenObtainPairView, TokenRefreshView from .views import home_view, ping @@ -91,12 +94,12 @@ ), # Health check urls path("middleware/verify", MiddlewareAuthenticationVerifyView.as_view()), - path( - ".well-known/openid-configuration", - OpenIdConfigView.as_view(), - name="openid-configuration", - ), + path("middleware/verify-asset", MiddlewareAssetAuthenticationVerifyView.as_view()), path("health/", include("healthy_django.urls", namespace="healthy_django")), + # OpenID Connect + path(".well-known/jwks.json", PublicJWKsView.as_view(), name="jwks-json"), + # TODO: Remove the config url as its not a standard implementation + path(".well-known/openid-configuration", PublicJWKsView.as_view()), ] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) if settings.ENABLE_ABDM: