diff --git a/api/environments/urls.py b/api/environments/urls.py index d5cb0ceab33d..ad2c5ede484a 100644 --- a/api/environments/urls.py +++ b/api/environments/urls.py @@ -7,6 +7,7 @@ EdgeIdentityWithIdentifierFeatureStateView, get_edge_identity_overrides, ) +from features.feature_states.views import update_flag from features.views import ( EnvironmentFeatureStateViewSet, IdentityFeatureStateViewSet, @@ -167,4 +168,9 @@ get_edge_identity_overrides, name="edge-identity-overrides", ), + path( + "/features//update-flag/", + update_flag, + name="update-flag", + ), ] diff --git a/api/features/feature_states/views.py b/api/features/feature_states/views.py new file mode 100644 index 000000000000..f8bb0a28aa7a --- /dev/null +++ b/api/features/feature_states/views.py @@ -0,0 +1,22 @@ +from rest_framework import status +from rest_framework.decorators import api_view +from rest_framework.request import Request +from rest_framework.response import Response + +from environments.models import Environment +from features.models import Feature +from features.serializers import UpdateFlagSerializer + + +@api_view(http_method_names=["POST"]) +def update_flag(request: Request, environment_id: int, feature_name: str) -> Response: + environment = Environment.objects.get(id=environment_id) + feature = Feature.objects.get(name=feature_name, project_id=environment.project_id) + + serializer = UpdateFlagSerializer( + data=request.data, context={"request": request, "view": update_flag} + ) + serializer.is_valid(raise_exception=True) + serializer.save(environment=environment, feature=feature) + + return Response(serializer.data, status=status.HTTP_200_OK) diff --git a/api/features/serializers.py b/api/features/serializers.py index 55785e1bda5a..a4011e97daa7 100644 --- a/api/features/serializers.py +++ b/api/features/serializers.py @@ -1,3 +1,4 @@ +import typing from datetime import datetime from typing import Any @@ -18,6 +19,7 @@ from app_analytics.serializers import LabelsQuerySerializerMixin, LabelsSerializer from environments.identities.models import Identity +from environments.models import Environment from environments.sdk.serializers_mixins import ( HideSensitiveFieldsSerializerMixin, ) @@ -28,6 +30,7 @@ FeatureFlagCodeReferencesRepositoryCountSerializer, ) from projects.models import Project +from users.models import FFAdminUser from users.serializers import ( UserIdsSerializer, UserListSerializer, @@ -47,6 +50,8 @@ ) from .models import Feature, FeatureState from .multivariate.serializers import NestedMultivariateFeatureOptionSerializer +from .versioning.dataclasses import FlagChangeSet +from .versioning.versioning_service import update_flag class FeatureStateSerializerSmall(serializers.ModelSerializer): # type: ignore[type-arg] @@ -671,3 +676,45 @@ def create(self, validated_data: dict) -> FeatureState: # type: ignore[type-arg {"environment": SEGMENT_OVERRIDE_LIMIT_EXCEEDED_MESSAGE} ) return super().create(validated_data) # type: ignore[no-any-return,no-untyped-call] + + +class UpdateFlagSerializer(serializers.Serializer): # type: ignore[type-arg] + enabled = serializers.BooleanField(required=False) + + # TODO: this is introducing _another_ way of handling typing, but it feels closer + # to what we should have done from the start. This might be out of scope for this + # work though. + feature_state_value = serializers.CharField(required=False) + type = serializers.ChoiceField( + choices=["int", "string", "bool", "float"], + required=False, + default="string", + ) + + segment_id = serializers.IntegerField(required=False) + + # TODO: multivariate? + + @property + def flag_change_set(self) -> FlagChangeSet: + validated_data = self.validated_data + change_set = FlagChangeSet( + enabled=validated_data.get("enabled"), + feature_state_value=validated_data.get("feature_state_value"), + type_=validated_data.get("type"), + segment_id=validated_data.get("segment_id"), + ) + + request = self.context["request"] + if type(request.user) is FFAdminUser: + change_set.user = request.user + else: + change_set.api_key = request.user.key + + return change_set + + def save(self, **kwargs: typing.Any) -> FeatureState: + environment: Environment = kwargs["environment"] + feature: Feature = kwargs["feature"] + + return update_flag(environment, feature, self.flag_change_set) diff --git a/api/features/versioning/dataclasses.py b/api/features/versioning/dataclasses.py index 0c4476a4f51a..edb26243fb83 100644 --- a/api/features/versioning/dataclasses.py +++ b/api/features/versioning/dataclasses.py @@ -1,7 +1,13 @@ +import typing +from dataclasses import dataclass from datetime import datetime from pydantic import BaseModel, computed_field +if typing.TYPE_CHECKING: + from api_keys.models import MasterAPIKey + from users.models import FFAdminUser + class Conflict(BaseModel): segment_id: int | None = None @@ -12,3 +18,15 @@ class Conflict(BaseModel): @property def is_environment_default(self) -> bool: return self.segment_id is None + + +@dataclass +class FlagChangeSet: + enabled: bool + feature_state_value: str + type_: str + + user: typing.Optional["FFAdminUser"] = None + api_key: typing.Optional["MasterAPIKey"] = None + + segment_id: str | None = None diff --git a/api/features/versioning/versioning_service.py b/api/features/versioning/versioning_service.py index 41e30e4432a0..a83b81b9de81 100644 --- a/api/features/versioning/versioning_service.py +++ b/api/features/versioning/versioning_service.py @@ -4,8 +4,10 @@ from django.db.models import Prefetch, Q, QuerySet from django.utils import timezone +from core.constants import BOOLEAN, FLOAT, INTEGER, STRING from environments.models import Environment -from features.models import FeatureState +from features.models import Feature, FeatureState, FeatureStateValue +from features.versioning.dataclasses import FlagChangeSet from features.versioning.models import EnvironmentFeatureVersion @@ -101,6 +103,80 @@ def get_current_live_environment_feature_version( ) +def update_flag( + environment: Environment, feature: Feature, change_set: FlagChangeSet +) -> FeatureState: + if environment.use_v2_feature_versioning: + return _update_flag_for_versioning_v2(environment, feature, change_set) + else: + return _update_flag_for_versioning_v1(environment, feature, change_set) + + +def _update_flag_for_versioning_v2( + environment: Environment, feature: Feature, change_set: FlagChangeSet +) -> FeatureState: + new_version = EnvironmentFeatureVersion.objects.create( + environment=environment, + feature=feature, + created_by=change_set.user, + created_by_api_key=change_set.api_key, + ) + + target_feature_state: FeatureState = new_version.feature_states.get( + feature_segment__segment_id=change_set.segment_id, + ) + + target_feature_state.enabled = change_set.enabled + target_feature_state.save() + + _update_feature_state_value(target_feature_state.feature_state_value, change_set) + + new_version.publish( + published_by=change_set.user, published_by_api_key=change_set.api_key + ) + + return target_feature_state + + +def _update_flag_for_versioning_v1( + environment: Environment, feature: Feature, change_set: FlagChangeSet +) -> FeatureState: + latest_feature_states = get_environment_flags_dict( + environment=environment, + feature_name=feature.name, + additional_filters=Q(feature_segment__segment_id=change_set.segment_id), + ) + assert len(latest_feature_states) == 1 + + target_feature_state = list(latest_feature_states.values())[0] + target_feature_state.enabled = change_set.enabled + target_feature_state.save() + + _update_feature_state_value(target_feature_state.feature_state_value, change_set) + + return target_feature_state + + +def _update_feature_state_value( + fsv: FeatureStateValue, change_set: FlagChangeSet +) -> None: + match change_set.type_: + case "string": + fsv.string_value = change_set.feature_state_value + fsv.type = STRING + case "int": + fsv.integer_value = int(change_set.feature_state_value) + fsv.type = INTEGER + case "bool": + fsv.boolean_value = change_set.feature_state_value in ("True", "true", "1") + fsv.type = BOOLEAN + case "float": + fsv.float_value = float(change_set.feature_state_value) + fsv.type = FLOAT + + fsv.save() + + def _get_feature_states_queryset( environment: "Environment", feature_name: str | None = None, diff --git a/api/tests/unit/features/feature_states/__init__.py b/api/tests/unit/features/feature_states/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/api/tests/unit/features/feature_states/test_unit_feature_states_views.py b/api/tests/unit/features/feature_states/test_unit_feature_states_views.py new file mode 100644 index 000000000000..9674034f8946 --- /dev/null +++ b/api/tests/unit/features/feature_states/test_unit_feature_states_views.py @@ -0,0 +1,50 @@ +import json + +import pytest +from common.environments.permissions import UPDATE_FEATURE_STATE +from django.urls import reverse +from pytest_lazyfixture import lazy_fixture # type: ignore[import-untyped] +from rest_framework import status +from rest_framework.test import APIClient + +from environments.models import Environment +from features.models import Feature +from features.versioning.versioning_service import ( + get_environment_flags_list, +) +from tests.types import WithEnvironmentPermissionsCallable + + +@pytest.mark.parametrize( + "environment_", + (lazy_fixture("environment"), lazy_fixture("environment_v2_versioning")), +) +def test_update_flag( + staff_client: APIClient, + feature: Feature, + environment_: Environment, + with_environment_permissions: WithEnvironmentPermissionsCallable, +) -> None: + # Given + with_environment_permissions([UPDATE_FEATURE_STATE]) # type: ignore[call-arg] + url = reverse( + "api-v1:environments:update-flag", + kwargs={"environment_id": environment_.id, "feature_name": feature.name}, + ) + + data = {"enabled": True, "feature_state_value": "42", "type": "int"} + + # When + response = staff_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_200_OK + + latest_flags = get_environment_flags_list( + environment=environment_, feature_name=feature.name + ) + + assert latest_flags[0].enabled is True + assert latest_flags[0].get_feature_state_value() == 42