From d024cdab190eb46eb0ce21679f44f08df5690cb9 Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Mon, 27 Jan 2025 10:48:20 -0500 Subject: [PATCH] Make parameter `user` mandatory for all methods in the auth manager interface (#45986) --- .../api_connexion/endpoints/asset_endpoint.py | 8 +- .../api_connexion/endpoints/dag_parsing.py | 2 +- .../endpoints/dag_source_endpoint.py | 2 +- .../endpoints/import_error_endpoint.py | 6 +- .../endpoints/task_instance_endpoint.py | 2 +- airflow/api_connexion/security.py | 34 +++++-- airflow/auth/managers/base_auth_manager.py | 61 +++++++----- .../managers/simple/simple_auth_manager.py | 28 +++--- airflow/www/auth.py | 40 ++++++-- airflow/www/views.py | 25 +++-- newsfragments/aip-79.significant.rst | 20 ++++ .../amazon/aws/auth_manager/avp/facade.py | 5 +- .../aws/auth_manager/aws_auth_manager.py | 98 ++++++------------- .../fab/auth_manager/fab_auth_manager.py | 53 ++++------ .../fab/www/api_connexion/security.py | 4 +- .../aws/auth_manager/avp/test_facade.py | 12 --- .../aws/auth_manager/test_aws_auth_manager.py | 76 ++++---------- .../simple/test_simple_auth_manager.py | 95 +++++++++--------- tests/auth/managers/test_base_auth_manager.py | 12 ++- 19 files changed, 280 insertions(+), 303 deletions(-) diff --git a/airflow/api_connexion/endpoints/asset_endpoint.py b/airflow/api_connexion/endpoints/asset_endpoint.py index 64d8f5484573a..a8b184d36cc82 100644 --- a/airflow/api_connexion/endpoints/asset_endpoint.py +++ b/airflow/api_connexion/endpoints/asset_endpoint.py @@ -280,7 +280,9 @@ def get_asset_queued_events( *, uri: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: """Get queued asset events for an asset.""" - permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"]) + permitted_dag_ids = get_auth_manager().get_permitted_dag_ids( + user=get_auth_manager().get_user(), methods=["GET"] + ) where_clause = _generate_queued_event_where_clause( uri=uri, before=before, permitted_dag_ids=permitted_dag_ids ) @@ -313,7 +315,9 @@ def delete_asset_queued_events( *, uri: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: """Delete queued asset events for an asset.""" - permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"]) + permitted_dag_ids = get_auth_manager().get_permitted_dag_ids( + user=get_auth_manager().get_user(), methods=["GET"] + ) where_clause = _generate_queued_event_where_clause( uri=uri, before=before, permitted_dag_ids=permitted_dag_ids ) diff --git a/airflow/api_connexion/endpoints/dag_parsing.py b/airflow/api_connexion/endpoints/dag_parsing.py index c6fde8d851b61..70217ce9e7fe5 100644 --- a/airflow/api_connexion/endpoints/dag_parsing.py +++ b/airflow/api_connexion/endpoints/dag_parsing.py @@ -59,7 +59,7 @@ def reparse_dag_file(*, file_token: str, session: Session = NEW_SESSION) -> Resp raise NotFound("File not found") # Check if user has read access to all the DAGs defined in the file - if not get_auth_manager().batch_is_authorized_dag(requests): + if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()): raise PermissionDenied() parsing_request = DagPriorityParsingRequest(fileloc=path) diff --git a/airflow/api_connexion/endpoints/dag_source_endpoint.py b/airflow/api_connexion/endpoints/dag_source_endpoint.py index c1c4929aed9e5..b28aee7013c10 100644 --- a/airflow/api_connexion/endpoints/dag_source_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_source_endpoint.py @@ -63,7 +63,7 @@ def get_dag_source( ] # Check if user has read access to all the DAGs defined in the file - if not get_auth_manager().batch_is_authorized_dag(requests): + if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()): raise PermissionDenied() dag_source = dag_version.dag_code.source_code version_number = dag_version.version_number diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py b/airflow/api_connexion/endpoints/import_error_endpoint.py index 95225c8693f88..84d6113f51028 100644 --- a/airflow/api_connexion/endpoints/import_error_endpoint.py +++ b/airflow/api_connexion/endpoints/import_error_endpoint.py @@ -56,7 +56,7 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> ) session.expunge(error) - can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET") + can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET", user=get_auth_manager().get_user()) if not can_read_all_dags: readable_dag_ids = security.get_readable_dags() file_dag_ids = { @@ -95,7 +95,7 @@ def get_import_errors( query = select(ParseImportError) query = apply_sorting(query, order_by, to_replace, allowed_sort_attrs) - can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET") + can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET", user=get_auth_manager().get_user()) if not can_read_all_dags: # if the user doesn't have access to all DAGs, only display errors from visible DAGs @@ -133,7 +133,7 @@ def get_import_errors( } for dag_id in file_dag_ids ] - if not get_auth_manager().batch_is_authorized_dag(requests): + if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()): session.expunge(import_error) import_error.stacktrace = "REDACTED - you do not have read permission on all DAGs in the file" diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 8751187a70ab6..3c3587a58c263 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -387,7 +387,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: } for id in dag_ids ] - if not get_auth_manager().batch_is_authorized_dag(requests): + if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()): raise PermissionDenied(detail=f"User not allowed to access some of these DAGs: {list(dag_ids)}") else: dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user) diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index 3601116fdb440..468659636f5c3 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -78,7 +78,9 @@ def decorated(*args, **kwargs): section: str | None = kwargs.get("section") return _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_configuration( - method=method, details=ConfigurationDetails(section=section) + method=method, + details=ConfigurationDetails(section=section), + user=get_auth_manager().get_user(), ), func=func, args=args, @@ -97,7 +99,9 @@ def decorated(*args, **kwargs): connection_id: str | None = kwargs.get("connection_id") return _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_connection( - method=method, details=ConnectionDetails(conn_id=connection_id) + method=method, + details=ConnectionDetails(conn_id=connection_id), + user=get_auth_manager().get_user(), ), func=func, args=args, @@ -120,6 +124,7 @@ def callback() -> bool | DagAccessEntity: method=method, access_entity=access_entity, details=DagDetails(id=dag_id), + user=get_auth_manager().get_user(), ) else: # here we know dag_id is not provided. @@ -127,6 +132,7 @@ def callback() -> bool | DagAccessEntity: if get_auth_manager().is_authorized_dag( method=method, access_entity=access_entity, + user=get_auth_manager().get_user(), ): return True elif access_entity: @@ -138,7 +144,9 @@ def callback() -> bool | DagAccessEntity: # but we leave it to the endpoint function to properly restrict access beyond that if method not in ("GET", "PUT"): return False - return any(get_auth_manager().get_permitted_dag_ids(methods=[method])) + return any( + get_auth_manager().get_permitted_dag_ids(user=get_auth_manager().get_user(), methods=[method]) + ) return callback @@ -165,7 +173,9 @@ def decorated(*args, **kwargs): uri: str | None = kwargs.get("uri") return _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_asset( - method=method, details=AssetDetails(uri=uri) + method=method, + details=AssetDetails(uri=uri), + user=get_auth_manager().get_user(), ), func=func, args=args, @@ -184,7 +194,9 @@ def decorated(*args, **kwargs): pool_name: str | None = kwargs.get("pool_name") return _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_pool( - method=method, details=PoolDetails(name=pool_name) + method=method, + details=PoolDetails(name=pool_name), + user=get_auth_manager().get_user(), ), func=func, args=args, @@ -203,7 +215,9 @@ def decorated(*args, **kwargs): variable_key: str | None = kwargs.get("variable_key") return _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_variable( - method=method, details=VariableDetails(key=variable_key) + method=method, + details=VariableDetails(key=variable_key), + user=get_auth_manager().get_user(), ), func=func, args=args, @@ -220,7 +234,9 @@ def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): return _requires_access( - is_authorized_callback=lambda: get_auth_manager().is_authorized_view(access_view=access_view), + is_authorized_callback=lambda: get_auth_manager().is_authorized_view( + access_view=access_view, user=get_auth_manager().get_user() + ), func=func, args=args, kwargs=kwargs, @@ -240,7 +256,9 @@ def requires_access_decorator(func: T): def decorated(*args, **kwargs): return _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_custom_view( - method=method, resource_name=resource_name + method=method, + resource_name=resource_name, + user=get_auth_manager().get_user(), ), func=func, args=args, diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index fe86bc8f05acf..2f83a1c6c06c5 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -146,15 +146,15 @@ def is_authorized_configuration( self, *, method: ResourceMethod, + user: T, details: ConfigurationDetails | None = None, - user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on configuration. :param method: the method to perform + :param user: the user to perform the action on :param details: optional details about the configuration - :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @abstractmethod @@ -162,15 +162,15 @@ def is_authorized_connection( self, *, method: ResourceMethod, + user: T, details: ConnectionDetails | None = None, - user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a connection. :param method: the method to perform + :param user: the user to perform the action on :param details: optional details about the connection - :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @abstractmethod @@ -178,18 +178,18 @@ def is_authorized_dag( self, *, method: ResourceMethod, + user: T, access_entity: DagAccessEntity | None = None, details: DagDetails | None = None, - user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a DAG. :param method: the method to perform + :param user: the user to perform the action on :param access_entity: the kind of DAG information the authorization request is about. If not provided, the authorization request is about the DAG itself :param details: optional details about the DAG - :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @abstractmethod @@ -197,15 +197,15 @@ def is_authorized_asset( self, *, method: ResourceMethod, + user: T, details: AssetDetails | None = None, - user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on an asset. :param method: the method to perform + :param user: the user to perform the action on :param details: optional details about the asset - :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @abstractmethod @@ -213,15 +213,15 @@ def is_authorized_pool( self, *, method: ResourceMethod, + user: T, details: PoolDetails | None = None, - user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a pool. :param method: the method to perform + :param user: the user to perform the action on :param details: optional details about the pool - :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @abstractmethod @@ -229,15 +229,15 @@ def is_authorized_variable( self, *, method: ResourceMethod, + user: T, details: VariableDetails | None = None, - user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a variable. :param method: the method to perform + :param user: the user to perform the action on :param details: optional details about the variable - :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @abstractmethod @@ -245,19 +245,17 @@ def is_authorized_view( self, *, access_view: AccessView, - user: T | None = None, + user: T, ) -> bool: """ Return whether the user is authorized to access a read-only state of the installation. :param access_view: the specific read-only view/state the authorization request is about. - :param user: the user to perform the action on. If not provided (or None), it uses the current user + :param user: the user to perform the action on """ @abstractmethod - def is_authorized_custom_view( - self, *, method: ResourceMethod | str, resource_name: str, user: T | None = None - ): + def is_authorized_custom_view(self, *, method: ResourceMethod | str, resource_name: str, user: T): """ Return whether the user is authorized to perform a given action on a custom view. @@ -270,7 +268,7 @@ def is_authorized_custom_view( In that case, the action can be anything (e.g. can_do). See https://github.com/apache/airflow/issues/39144 :param resource_name: the name of the resource - :param user: the user to perform the action on. If not provided (or None), it uses the current user + :param user: the user to perform the action on """ @abstractmethod @@ -284,6 +282,8 @@ def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuIt def batch_is_authorized_connection( self, requests: Sequence[IsAuthorizedConnectionRequest], + *, + user: T, ) -> bool: """ Batch version of ``is_authorized_connection``. @@ -293,15 +293,18 @@ def batch_is_authorized_connection( manager implementation to provide a more efficient implementation. :param requests: a list of requests containing the parameters for ``is_authorized_connection`` + :param user: the user to perform the action on """ return all( - self.is_authorized_connection(method=request["method"], details=request.get("details")) + self.is_authorized_connection(method=request["method"], details=request.get("details"), user=user) for request in requests ) def batch_is_authorized_dag( self, requests: Sequence[IsAuthorizedDagRequest], + *, + user: T, ) -> bool: """ Batch version of ``is_authorized_dag``. @@ -311,12 +314,14 @@ def batch_is_authorized_dag( implementation to provide a more efficient implementation. :param requests: a list of requests containing the parameters for ``is_authorized_dag`` + :param user: the user to perform the action on """ return all( self.is_authorized_dag( method=request["method"], access_entity=request.get("access_entity"), details=request.get("details"), + user=user, ) for request in requests ) @@ -324,6 +329,8 @@ def batch_is_authorized_dag( def batch_is_authorized_pool( self, requests: Sequence[IsAuthorizedPoolRequest], + *, + user: T, ) -> bool: """ Batch version of ``is_authorized_pool``. @@ -333,15 +340,18 @@ def batch_is_authorized_pool( manager implementation to provide a more efficient implementation. :param requests: a list of requests containing the parameters for ``is_authorized_pool`` + :param user: the user to perform the action on """ return all( - self.is_authorized_pool(method=request["method"], details=request.get("details")) + self.is_authorized_pool(method=request["method"], details=request.get("details"), user=user) for request in requests ) def batch_is_authorized_variable( self, requests: Sequence[IsAuthorizedVariableRequest], + *, + user: T, ) -> bool: """ Batch version of ``is_authorized_variable``. @@ -351,9 +361,10 @@ def batch_is_authorized_variable( manager implementation to provide a more efficient implementation. :param requests: a list of requests containing the parameters for ``is_authorized_variable`` + :param user: the user to perform the action on """ return all( - self.is_authorized_variable(method=request["method"], details=request.get("details")) + self.is_authorized_variable(method=request["method"], details=request.get("details"), user=user) for request in requests ) @@ -361,8 +372,8 @@ def batch_is_authorized_variable( def get_permitted_dag_ids( self, *, + user: T, methods: Container[ResourceMethod] | None = None, - user=None, session: Session = NEW_SESSION, ) -> set[str]: """ @@ -372,8 +383,8 @@ def get_permitted_dag_ids( Can lead to some poor performance. It is recommended to override this method in the auth manager implementation to provide a more efficient implementation. + :param user: the user :param methods: whether filter readable or writable - :param user: the current user :param session: the session """ dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} @@ -383,15 +394,15 @@ def filter_permitted_dag_ids( self, *, dag_ids: set[str], + user: T, methods: Container[ResourceMethod] | None = None, - user=None, ): """ Filter readable or writable DAGs for user. :param dag_ids: the list of DAG ids + :param user: the user :param methods: whether filter readable or writable - :param user: the current user """ if not methods: methods = ["PUT", "GET"] diff --git a/airflow/auth/managers/simple/simple_auth_manager.py b/airflow/auth/managers/simple/simple_auth_manager.py index 03cc12744ba8d..d79553c4cc706 100644 --- a/airflow/auth/managers/simple/simple_auth_manager.py +++ b/airflow/auth/managers/simple/simple_auth_manager.py @@ -160,8 +160,8 @@ def is_authorized_configuration( self, *, method: ResourceMethod, + user: SimpleAuthManagerUser, details: ConfigurationDetails | None = None, - user: SimpleAuthManagerUser | None = None, ) -> bool: return self._is_authorized(method=method, allow_role=SimpleAuthManagerRole.OP, user=user) @@ -169,8 +169,8 @@ def is_authorized_connection( self, *, method: ResourceMethod, + user: SimpleAuthManagerUser, details: ConnectionDetails | None = None, - user: SimpleAuthManagerUser | None = None, ) -> bool: return self._is_authorized(method=method, allow_role=SimpleAuthManagerRole.OP, user=user) @@ -178,9 +178,9 @@ def is_authorized_dag( self, *, method: ResourceMethod, + user: SimpleAuthManagerUser, access_entity: DagAccessEntity | None = None, details: DagDetails | None = None, - user: SimpleAuthManagerUser | None = None, ) -> bool: return self._is_authorized( method=method, @@ -193,8 +193,8 @@ def is_authorized_asset( self, *, method: ResourceMethod, + user: SimpleAuthManagerUser, details: AssetDetails | None = None, - user: SimpleAuthManagerUser | None = None, ) -> bool: return self._is_authorized( method=method, @@ -207,8 +207,8 @@ def is_authorized_pool( self, *, method: ResourceMethod, + user: SimpleAuthManagerUser, details: PoolDetails | None = None, - user: SimpleAuthManagerUser | None = None, ) -> bool: return self._is_authorized( method=method, @@ -221,18 +221,16 @@ def is_authorized_variable( self, *, method: ResourceMethod, + user: SimpleAuthManagerUser, details: VariableDetails | None = None, - user: SimpleAuthManagerUser | None = None, ) -> bool: return self._is_authorized(method=method, allow_role=SimpleAuthManagerRole.OP, user=user) - def is_authorized_view( - self, *, access_view: AccessView, user: SimpleAuthManagerUser | None = None - ) -> bool: + def is_authorized_view(self, *, access_view: AccessView, user: SimpleAuthManagerUser) -> bool: return self._is_authorized(method="GET", allow_role=SimpleAuthManagerRole.VIEWER, user=user) def is_authorized_custom_view( - self, *, method: ResourceMethod | str, resource_name: str, user: SimpleAuthManagerUser | None = None + self, *, method: ResourceMethod | str, resource_name: str, user: SimpleAuthManagerUser ): return self._is_authorized(method="GET", allow_role=SimpleAuthManagerRole.VIEWER, user=user) @@ -290,13 +288,13 @@ def webapp(request: Request, rest_of_path: str): return app + @staticmethod def _is_authorized( - self, *, method: ResourceMethod, allow_role: SimpleAuthManagerRole, + user: SimpleAuthManagerUser, allow_get_role: SimpleAuthManagerRole | None = None, - user: SimpleAuthManagerUser | None = None, ): """ Return whether the user is authorized to access a given resource. @@ -304,14 +302,10 @@ def _is_authorized( :param method: the method to perform :param allow_role: minimal role giving access to the resource, if the user's role is greater or equal than this role, they have access + :param user: the user to check the authorization for :param allow_get_role: minimal role giving access to the resource, if the user's role is greater or equal than this role, they have access. If not provided, ``allow_role`` is used - :param user: the user to check the authorization for. If not provided, the current user is used """ - user = user or self.get_user() - if not user: - return False - user_role = user.get_role() if not user_role: return False diff --git a/airflow/www/auth.py b/airflow/www/auth.py index 101b1463596e9..a128fd5c47f24 100644 --- a/airflow/www/auth.py +++ b/airflow/www/auth.py @@ -139,7 +139,8 @@ def _has_access(*, is_authorized: bool, func: Callable, args, kwargs): if is_authorized: return func(*args, **kwargs) elif get_auth_manager().is_logged_in() and not get_auth_manager().is_authorized_view( - access_view=AccessView.WEBSITE + access_view=AccessView.WEBSITE, + user=get_auth_manager().get_user(), ): return ( render_template( @@ -158,7 +159,11 @@ def _has_access(*, is_authorized: bool, func: Callable, args, kwargs): def has_access_configuration(method: ResourceMethod) -> Callable[[T], T]: - return _has_access_no_details(lambda: get_auth_manager().is_authorized_configuration(method=method)) + return _has_access_no_details( + lambda: get_auth_manager().is_authorized_configuration( + method=method, user=get_auth_manager().get_user() + ) + ) def has_access_connection(method: ResourceMethod) -> Callable[[T], T]: @@ -173,7 +178,9 @@ def decorated(*args, **kwargs): } for connection in connections ] - is_authorized = get_auth_manager().batch_is_authorized_connection(requests) + is_authorized = get_auth_manager().batch_is_authorized_connection( + requests, user=get_auth_manager().get_user() + ) return _has_access( is_authorized=is_authorized, func=func, @@ -222,6 +229,7 @@ def decorated(*args, **kwargs): method=method, access_entity=access_entity, details=None if not dag_id else DagDetails(id=dag_id), + user=get_auth_manager().get_user(), ) return _has_access( @@ -250,7 +258,9 @@ def decorated(*args, **kwargs): for item in items if item is not None ] - is_authorized = get_auth_manager().batch_is_authorized_dag(requests) + is_authorized = get_auth_manager().batch_is_authorized_dag( + requests, user=get_auth_manager().get_user() + ) return _has_access( is_authorized=is_authorized, func=func, @@ -265,7 +275,9 @@ def decorated(*args, **kwargs): def has_access_asset(method: ResourceMethod) -> Callable[[T], T]: """Check current user's permissions against required permissions for assets.""" - return _has_access_no_details(lambda: get_auth_manager().is_authorized_asset(method=method)) + return _has_access_no_details( + lambda: get_auth_manager().is_authorized_asset(method=method, user=get_auth_manager().get_user()) + ) def has_access_pool(method: ResourceMethod) -> Callable[[T], T]: @@ -280,7 +292,9 @@ def decorated(*args, **kwargs): } for pool in pools ] - is_authorized = get_auth_manager().batch_is_authorized_pool(requests) + is_authorized = get_auth_manager().batch_is_authorized_pool( + requests, user=get_auth_manager().get_user() + ) return _has_access( is_authorized=is_authorized, func=func, @@ -299,7 +313,9 @@ def has_access_decorator(func: T): def decorated(*args, **kwargs): if len(args) == 1: # No items provided - is_authorized = get_auth_manager().is_authorized_variable(method=method) + is_authorized = get_auth_manager().is_authorized_variable( + method=method, user=get_auth_manager().get_user() + ) else: variables: set[Variable] = set(args[1]) requests: Sequence[IsAuthorizedVariableRequest] = [ @@ -309,7 +325,9 @@ def decorated(*args, **kwargs): } for variable in variables ] - is_authorized = get_auth_manager().batch_is_authorized_variable(requests) + is_authorized = get_auth_manager().batch_is_authorized_variable( + requests, user=get_auth_manager().get_user() + ) return _has_access( is_authorized=is_authorized, func=func, @@ -324,4 +342,8 @@ def decorated(*args, **kwargs): def has_access_view(access_view: AccessView = AccessView.WEBSITE) -> Callable[[T], T]: """Check current user's permissions to access the website.""" - return _has_access_no_details(lambda: get_auth_manager().is_authorized_view(access_view=access_view)) + return _has_access_no_details( + lambda: get_auth_manager().is_authorized_view( + access_view=access_view, user=get_auth_manager().get_user() + ) + ) diff --git a/airflow/www/views.py b/airflow/www/views.py index dd5bc856f74a3..88e504c75822f 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -730,7 +730,7 @@ def render_template(self, *args, **kwargs): if "dag" in kwargs: kwargs["can_edit_dag"] = get_auth_manager().is_authorized_dag( - method="PUT", details=DagDetails(id=kwargs["dag"].dag_id) + method="PUT", details=DagDetails(id=kwargs["dag"].dag_id), user=g.user ) url_serializer = URLSafeSerializer(current_app.config["SECRET_KEY"]) kwargs["dag_file_token"] = url_serializer.dumps(kwargs["dag"].fileloc) @@ -1011,10 +1011,10 @@ def index(self): owner_links_dict = DagOwnerAttributes.get_all(session) - if get_auth_manager().is_authorized_view(access_view=AccessView.IMPORT_ERRORS): + if get_auth_manager().is_authorized_view(access_view=AccessView.IMPORT_ERRORS, user=g.user): import_errors = select(ParseImportError).order_by(ParseImportError.id) - can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET") + can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET", user=g.user) if not can_read_all_dags: # if the user doesn't have access to all DAGs, only display errors from visible DAGs import_errors = import_errors.where( @@ -1042,7 +1042,9 @@ def index(self): } for dag_id in file_dag_ids ] - if not get_auth_manager().batch_is_authorized_dag(requests): + if not get_auth_manager().batch_is_authorized_dag( + requests, user=get_auth_manager().get_user() + ): stacktrace = "REDACTED - you do not have read permission on all DAGs in the file" flash( f"Broken DAG: [{import_error.filename}]\r{stacktrace}", @@ -2920,6 +2922,7 @@ def grid(self, dag_id: str, session: Session = NEW_SESSION): can_edit_taskinstance = get_auth_manager().is_authorized_dag( method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE, + user=g.user, ) return self.render_template( @@ -4605,7 +4608,9 @@ def _show(self, pk): item, orders=orders, pages=pages, page_sizes=page_sizes, widgets=widgets ) - extra_args = {"can_create_variable": lambda: get_auth_manager().is_authorized_variable(method="POST")} + extra_args = { + "can_create_variable": lambda: get_auth_manager().is_authorized_variable(method="POST", user=g.user) + } @action("muldelete", "Delete", "Are you sure you want to delete selected records?", single=False) @auth.has_access_variable("DELETE") @@ -5604,12 +5609,16 @@ def add_user_permissions_to_dag(sender, template, context, **extra): return dag = context["dag"] can_create_dag_run = get_auth_manager().is_authorized_dag( - method="POST", access_entity=DagAccessEntity.RUN, details=DagDetails(id=dag.dag_id) + method="POST", access_entity=DagAccessEntity.RUN, details=DagDetails(id=dag.dag_id), user=g.user ) - dag.can_edit = get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=dag.dag_id)) + dag.can_edit = get_auth_manager().is_authorized_dag( + method="PUT", details=DagDetails(id=dag.dag_id), user=g.user + ) dag.can_trigger = dag.can_edit and can_create_dag_run - dag.can_delete = get_auth_manager().is_authorized_dag(method="DELETE", details=DagDetails(id=dag.dag_id)) + dag.can_delete = get_auth_manager().is_authorized_dag( + method="DELETE", details=DagDetails(id=dag.dag_id), user=g.user + ) context["dag"] = dag diff --git a/newsfragments/aip-79.significant.rst b/newsfragments/aip-79.significant.rst index 9ca9222333ed2..ec58d6d2228ad 100644 --- a/newsfragments/aip-79.significant.rst +++ b/newsfragments/aip-79.significant.rst @@ -14,6 +14,26 @@ As part of this change the following breaking changes have occurred: - The method ``filter_permitted_menu_items`` is now abstract and must be implemented + - All the following method signatures changed to make the parameter ``user`` required (it was optional) + + - ``is_authorized_configuration`` + - ``is_authorized_connection`` + - ``is_authorized_dag`` + - ``is_authorized_asset`` + - ``is_authorized_pool`` + - ``is_authorized_variable`` + - ``is_authorized_view`` + - ``is_authorized_custom_view`` + - ``get_permitted_dag_ids`` + - ``filter_permitted_dag_ids`` + + - All the following method signatures changed to add the parameter ``user`` + + - ``batch_is_authorized_connection`` + - ``batch_is_authorized_dag`` + - ``batch_is_authorized_pool`` + - ``batch_is_authorized_variable`` + * Types of change * [ ] Dag changes diff --git a/providers/src/airflow/providers/amazon/aws/auth_manager/avp/facade.py b/providers/src/airflow/providers/amazon/aws/auth_manager/avp/facade.py index 581932f0bc81f..075214f384527 100644 --- a/providers/src/airflow/providers/amazon/aws/auth_manager/avp/facade.py +++ b/providers/src/airflow/providers/amazon/aws/auth_manager/avp/facade.py @@ -78,7 +78,7 @@ def is_authorized( *, method: ResourceMethod | str, entity_type: AvpEntities, - user: AwsAuthManagerUser | None, + user: AwsAuthManagerUser, entity_id: str | None = None, context: dict | None = None, ) -> bool: @@ -97,9 +97,6 @@ def is_authorized( considered. :param context: optional additional context to pass to Amazon Verified Permissions. """ - if user is None: - return False - entity_list = self._get_user_group_entities(user) self.log.debug( diff --git a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index 88f8cef8b76c4..b93d6a1ad3789 100644 --- a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -17,7 +17,6 @@ from __future__ import annotations import argparse -import warnings from collections import defaultdict from collections.abc import Container, Sequence from functools import cached_property @@ -35,7 +34,7 @@ VariableDetails, ) from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand -from airflow.exceptions import AirflowOptionalProviderFeatureException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities from airflow.providers.amazon.aws.auth_manager.avp.facade import ( AwsAuthManagerAmazonVerifiedPermissionsFacade, @@ -47,6 +46,7 @@ from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import ( AwsSecurityManagerOverride, ) +from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser from airflow.providers.amazon.aws.auth_manager.views.auth import AwsAuthManagerAuthenticationViews from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS @@ -54,7 +54,6 @@ from flask_appbuilder.menu import MenuItem from airflow.auth.managers.base_auth_manager import ResourceMethod - from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.batch_apis import ( IsAuthorizedConnectionRequest, IsAuthorizedDagRequest, @@ -62,11 +61,10 @@ IsAuthorizedVariableRequest, ) from airflow.auth.managers.models.resource_details import AssetDetails, ConfigurationDetails - from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser from airflow.www.extensions.init_appbuilder import AirflowAppBuilder -class AwsAuthManager(BaseAuthManager): +class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]): """ AWS auth manager. @@ -99,14 +97,14 @@ def is_authorized_configuration( self, *, method: ResourceMethod, + user: AwsAuthManagerUser, details: ConfigurationDetails | None = None, - user: BaseUser | None = None, ) -> bool: config_section = details.section if details else None return self.avp_facade.is_authorized( method=method, entity_type=AvpEntities.CONFIGURATION, - user=user or self.get_user(), + user=user, entity_id=config_section, ) @@ -114,14 +112,14 @@ def is_authorized_connection( self, *, method: ResourceMethod, + user: AwsAuthManagerUser, details: ConnectionDetails | None = None, - user: BaseUser | None = None, ) -> bool: connection_id = details.conn_id if details else None return self.avp_facade.is_authorized( method=method, entity_type=AvpEntities.CONNECTION, - user=user or self.get_user(), + user=user, entity_id=connection_id, ) @@ -129,9 +127,9 @@ def is_authorized_dag( self, *, method: ResourceMethod, + user: AwsAuthManagerUser, access_entity: DagAccessEntity | None = None, details: DagDetails | None = None, - user: BaseUser | None = None, ) -> bool: dag_id = details.id if details else None context = ( @@ -146,48 +144,38 @@ def is_authorized_dag( return self.avp_facade.is_authorized( method=method, entity_type=AvpEntities.DAG, - user=user or self.get_user(), + user=user, entity_id=dag_id, context=context, ) def is_authorized_asset( - self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None + self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetDetails | None = None ) -> bool: asset_uri = details.uri if details else None return self.avp_facade.is_authorized( - method=method, entity_type=AvpEntities.ASSET, user=user or self.get_user(), entity_id=asset_uri + method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=asset_uri ) - def is_authorized_dataset( - self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None - ) -> bool: - warnings.warn( - "is_authorized_dataset will be renamed as is_authorized_asset in Airflow 3 and will be removed when the minimum Airflow version is set to 3.0 for the amazon provider", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) - return self.is_authorized_asset(method=method, user=user) - def is_authorized_pool( - self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None + self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: PoolDetails | None = None ) -> bool: pool_name = details.name if details else None return self.avp_facade.is_authorized( method=method, entity_type=AvpEntities.POOL, - user=user or self.get_user(), + user=user, entity_id=pool_name, ) def is_authorized_variable( - self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None + self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: VariableDetails | None = None ) -> bool: variable_key = details.key if details else None return self.avp_facade.is_authorized( method=method, entity_type=AvpEntities.VARIABLE, - user=user or self.get_user(), + user=user, entity_id=variable_key, ) @@ -195,34 +183,31 @@ def is_authorized_view( self, *, access_view: AccessView, - user: BaseUser | None = None, + user: AwsAuthManagerUser, ) -> bool: return self.avp_facade.is_authorized( method="GET", entity_type=AvpEntities.VIEW, - user=user or self.get_user(), + user=user, entity_id=access_view.value, ) def is_authorized_custom_view( - self, *, method: ResourceMethod | str, resource_name: str, user: BaseUser | None = None + self, *, method: ResourceMethod | str, resource_name: str, user: AwsAuthManagerUser ): return self.avp_facade.is_authorized( method=method, entity_type=AvpEntities.CUSTOM, - user=user or self.get_user(), + user=user, entity_id=resource_name, ) def batch_is_authorized_connection( self, requests: Sequence[IsAuthorizedConnectionRequest], + *, + user: AwsAuthManagerUser, ) -> bool: - """ - Batch version of ``is_authorized_connection``. - - :param requests: a list of requests containing the parameters for ``is_authorized_connection`` - """ facade_requests: Sequence[IsAuthorizedRequest] = [ { "method": request["method"], @@ -233,17 +218,14 @@ def batch_is_authorized_connection( } for request in requests ] - return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user()) + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) def batch_is_authorized_dag( self, requests: Sequence[IsAuthorizedDagRequest], + *, + user: AwsAuthManagerUser, ) -> bool: - """ - Batch version of ``is_authorized_dag``. - - :param requests: a list of requests containing the parameters for ``is_authorized_dag`` - """ facade_requests: Sequence[IsAuthorizedRequest] = [ { "method": request["method"], @@ -259,17 +241,14 @@ def batch_is_authorized_dag( } for request in requests ] - return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user()) + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) def batch_is_authorized_pool( self, requests: Sequence[IsAuthorizedPoolRequest], + *, + user: AwsAuthManagerUser, ) -> bool: - """ - Batch version of ``is_authorized_pool``. - - :param requests: a list of requests containing the parameters for ``is_authorized_pool`` - """ facade_requests: Sequence[IsAuthorizedRequest] = [ { "method": request["method"], @@ -278,17 +257,14 @@ def batch_is_authorized_pool( } for request in requests ] - return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user()) + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) def batch_is_authorized_variable( self, requests: Sequence[IsAuthorizedVariableRequest], + *, + user: AwsAuthManagerUser, ) -> bool: - """ - Batch version of ``is_authorized_variable``. - - :param requests: a list of requests containing the parameters for ``is_authorized_variable`` - """ facade_requests: Sequence[IsAuthorizedRequest] = [ { "method": request["method"], @@ -299,28 +275,18 @@ def batch_is_authorized_variable( } for request in requests ] - return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user()) + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) def filter_permitted_dag_ids( self, *, dag_ids: set[str], + user: AwsAuthManagerUser, methods: Container[ResourceMethod] | None = None, - user=None, ): - """ - Filter readable or writable DAGs for user. - - :param dag_ids: the list of DAG ids - :param methods: whether filter readable or writable - :param user: the current user - """ if not methods: methods = ["PUT", "GET"] - if not user: - user = self.get_user() - requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict) requests_list: list[IsAuthorizedRequest] = [] for dag_id in dag_ids: diff --git a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index 2d58d79e41b00..4060f75344477 100644 --- a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -95,7 +95,6 @@ if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod - from airflow.auth.managers.models.base_user import BaseUser from airflow.cli.cli_config import ( CLICommand, ) @@ -259,8 +258,8 @@ def is_authorized_configuration( self, *, method: ResourceMethod, + user: User, details: ConfigurationDetails | None = None, - user: BaseUser | None = None, ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_CONFIG, user=user) @@ -268,8 +267,8 @@ def is_authorized_connection( self, *, method: ResourceMethod, + user: User, details: ConnectionDetails | None = None, - user: BaseUser | None = None, ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_CONNECTION, user=user) @@ -277,9 +276,9 @@ def is_authorized_dag( self, *, method: ResourceMethod, + user: User, access_entity: DagAccessEntity | None = None, details: DagDetails | None = None, - user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to access the dag. @@ -296,9 +295,9 @@ def is_authorized_dag( if no specific DAG is targeted, just check the sub entity. :param method: The method to authorize. + :param user: The user. :param access_entity: The dag access entity. :param details: The dag details. - :param user: The user. """ if not access_entity: # Scenario 1 @@ -321,32 +320,28 @@ def is_authorized_dag( ) def is_authorized_asset( - self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None + self, *, method: ResourceMethod, user: User, details: AssetDetails | None = None ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_ASSET, user=user) def is_authorized_pool( - self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None + self, *, method: ResourceMethod, user: User, details: PoolDetails | None = None ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_POOL, user=user) def is_authorized_variable( - self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None + self, *, method: ResourceMethod, user: User, details: VariableDetails | None = None ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_VARIABLE, user=user) - def is_authorized_view(self, *, access_view: AccessView, user: BaseUser | None = None) -> bool: + def is_authorized_view(self, *, access_view: AccessView, user: User) -> bool: # "Docs" are only links in the menu, there is no page associated method: ResourceMethod = "MENU" if access_view == AccessView.DOCS else "GET" return self._is_authorized( method=method, resource_type=_MAP_ACCESS_VIEW_TO_FAB_RESOURCE_TYPE[access_view], user=user ) - def is_authorized_custom_view( - self, *, method: ResourceMethod | str, resource_name: str, user: BaseUser | None = None - ): - if not user: - user = self.get_user() + def is_authorized_custom_view(self, *, method: ResourceMethod | str, resource_name: str, user: User): fab_action_name = get_fab_action_from_method_map().get(method, method) return (fab_action_name, resource_name) in self._get_user_permissions(user) @@ -354,16 +349,13 @@ def is_authorized_custom_view( def get_permitted_dag_ids( self, *, + user: User, methods: Container[ResourceMethod] | None = None, - user=None, session: Session = NEW_SESSION, ) -> set[str]: if not methods: methods = ["PUT", "GET"] - if not user: - user = self.get_user() - if not self.is_logged_in(): roles = user.roles else: @@ -478,20 +470,17 @@ def _is_authorized( *, method: ResourceMethod, resource_type: str, - user: BaseUser | None = None, + user: User, ) -> bool: """ Return whether the user is authorized to perform a given action. :param method: the method to perform :param resource_type: the type of resource the user attempts to perform the action on - :param user: the user to perform the action on. If not provided (or None), it uses the current user + :param user: the user to perform the action on :meta private: """ - if not user: - user = self.get_user() - fab_action = self._get_fab_action(method) user_permissions = self._get_user_permissions(user) @@ -500,15 +489,15 @@ def _is_authorized( def _is_authorized_dag( self, method: ResourceMethod, - details: DagDetails | None = None, - user: BaseUser | None = None, + details: DagDetails | None, + user: User, ) -> bool: """ Return whether the user is authorized to perform a given action on a DAG. :param method: the method to perform - :param details: optional details about the DAG - :param user: the user to perform the action on. If not provided (or None), it uses the current user + :param details: details about the DAG + :param user: the user to perform the action on :meta private: """ @@ -526,15 +515,15 @@ def _is_authorized_dag( def _is_authorized_dag_run( self, method: ResourceMethod, - details: DagDetails | None = None, - user: BaseUser | None = None, + details: DagDetails | None, + user: User, ) -> bool: """ Return whether the user is authorized to perform a given action on a DAG Run. :param method: the method to perform - :param details: optional, details about the DAG - :param user: optional, the user to perform the action on. If not provided, it uses the current user + :param details: details about the DAG + :param user: the user to perform the action on :meta private: """ @@ -590,7 +579,7 @@ def _resource_name(self, dag_id: str, resource_type: str) -> str: return getattr(permissions, "resource_name_for_dag")(root_dag_id) @staticmethod - def _get_user_permissions(user: BaseUser): + def _get_user_permissions(user: User): """ Return the user permissions. diff --git a/providers/src/airflow/providers/fab/www/api_connexion/security.py b/providers/src/airflow/providers/fab/www/api_connexion/security.py index a130265b6de12..1530a2a98c712 100644 --- a/providers/src/airflow/providers/fab/www/api_connexion/security.py +++ b/providers/src/airflow/providers/fab/www/api_connexion/security.py @@ -70,7 +70,9 @@ def requires_access_decorator(func: T): def decorated(*args, **kwargs): return _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_custom_view( - method=method, resource_name=resource_name + method=method, + resource_name=resource_name, + user=get_auth_manager().get_user(), ), func=func, args=args, diff --git a/providers/tests/amazon/aws/auth_manager/avp/test_facade.py b/providers/tests/amazon/aws/auth_manager/avp/test_facade.py index 1ea265d256787..124f4cc0f10af 100644 --- a/providers/tests/amazon/aws/auth_manager/avp/test_facade.py +++ b/providers/tests/amazon/aws/auth_manager/avp/test_facade.py @@ -59,18 +59,6 @@ def test_avp_client(self, facade): def test_avp_policy_store_id(self, facade): assert hasattr(facade, "avp_policy_store_id") - def test_is_authorized_no_user(self, facade): - method: ResourceMethod = "GET" - entity_type = AvpEntities.VARIABLE - - result = facade.is_authorized( - method=method, - entity_type=entity_type, - user=None, - ) - - assert result is False - @pytest.mark.parametrize( "entity_id, context, user, expected_entities, expected_context, avp_response, expected", [ diff --git a/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py b/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py index 10bb69082a772..24b1c45696061 100644 --- a/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -190,15 +190,13 @@ def test_is_logged_in_return_false_when_no_user_in_session(self, auth_manager, a @pytest.mark.parametrize( "details, user, expected_user, expected_entity_id", [ - (None, None, ANY, None), + (None, mock, ANY, None), (ConfigurationDetails(section="test"), mock, mock, "test"), ], ) @patch.object(AwsAuthManager, "avp_facade") - @patch.object(AwsAuthManager, "get_user") def test_is_authorized_configuration( self, - mock_get_user, mock_avp_facade, details, user, @@ -212,8 +210,6 @@ def test_is_authorized_configuration( method: ResourceMethod = "GET" result = auth_manager.is_authorized_configuration(method=method, details=details, user=user) - if not user: - mock_get_user.assert_called_once() is_authorized.assert_called_once_with( method=method, entity_type=AvpEntities.CONFIGURATION, @@ -225,15 +221,13 @@ def test_is_authorized_configuration( @pytest.mark.parametrize( "details, user, expected_user, expected_entity_id", [ - (None, None, ANY, None), + (None, mock, ANY, None), (ConnectionDetails(conn_id="conn_id"), mock, mock, "conn_id"), ], ) @patch.object(AwsAuthManager, "avp_facade") - @patch.object(AwsAuthManager, "get_user") def test_is_authorized_connection( self, - mock_get_user, mock_avp_facade, details, user, @@ -247,8 +241,6 @@ def test_is_authorized_connection( method: ResourceMethod = "GET" result = auth_manager.is_authorized_connection(method=method, details=details, user=user) - if not user: - mock_get_user.assert_called_once() is_authorized.assert_called_once_with( method=method, entity_type=AvpEntities.CONNECTION, @@ -260,7 +252,7 @@ def test_is_authorized_connection( @pytest.mark.parametrize( "access_entity, details, user, expected_user, expected_entity_id, expected_context", [ - (None, None, None, ANY, None, None), + (None, None, mock, ANY, None, None), (None, DagDetails(id="dag_1"), mock, mock, "dag_1", None), ( DagAccessEntity.CODE, @@ -277,10 +269,8 @@ def test_is_authorized_connection( ], ) @patch.object(AwsAuthManager, "avp_facade") - @patch.object(AwsAuthManager, "get_user") def test_is_authorized_dag( self, - mock_get_user, mock_avp_facade, access_entity, details, @@ -298,8 +288,6 @@ def test_is_authorized_dag( method=method, access_entity=access_entity, details=details, user=user ) - if not user: - mock_get_user.assert_called_once() is_authorized.assert_called_once_with( method=method, entity_type=AvpEntities.DAG, @@ -312,15 +300,13 @@ def test_is_authorized_dag( @pytest.mark.parametrize( "details, user, expected_user, expected_entity_id", [ - (None, None, ANY, None), + (None, mock, ANY, None), (AssetDetails(uri="uri"), mock, mock, "uri"), ], ) @patch.object(AwsAuthManager, "avp_facade") - @patch.object(AwsAuthManager, "get_user") def test_is_authorized_asset( self, - mock_get_user, mock_avp_facade, details, user, @@ -334,8 +320,6 @@ def test_is_authorized_asset( method: ResourceMethod = "GET" result = auth_manager.is_authorized_asset(method=method, details=details, user=user) - if not user: - mock_get_user.assert_called_once() is_authorized.assert_called_once_with( method=method, entity_type=AvpEntities.ASSET, user=expected_user, entity_id=expected_entity_id ) @@ -344,15 +328,13 @@ def test_is_authorized_asset( @pytest.mark.parametrize( "details, user, expected_user, expected_entity_id", [ - (None, None, ANY, None), + (None, mock, ANY, None), (PoolDetails(name="pool1"), mock, mock, "pool1"), ], ) @patch.object(AwsAuthManager, "avp_facade") - @patch.object(AwsAuthManager, "get_user") def test_is_authorized_pool( self, - mock_get_user, mock_avp_facade, details, user, @@ -366,8 +348,6 @@ def test_is_authorized_pool( method: ResourceMethod = "GET" result = auth_manager.is_authorized_pool(method=method, details=details, user=user) - if not user: - mock_get_user.assert_called_once() is_authorized.assert_called_once_with( method=method, entity_type=AvpEntities.POOL, user=expected_user, entity_id=expected_entity_id ) @@ -376,15 +356,13 @@ def test_is_authorized_pool( @pytest.mark.parametrize( "details, user, expected_user, expected_entity_id", [ - (None, None, ANY, None), + (None, mock, ANY, None), (VariableDetails(key="var1"), mock, mock, "var1"), ], ) @patch.object(AwsAuthManager, "avp_facade") - @patch.object(AwsAuthManager, "get_user") def test_is_authorized_variable( self, - mock_get_user, mock_avp_facade, details, user, @@ -398,8 +376,6 @@ def test_is_authorized_variable( method: ResourceMethod = "GET" result = auth_manager.is_authorized_variable(method=method, details=details, user=user) - if not user: - mock_get_user.assert_called_once() is_authorized.assert_called_once_with( method=method, entity_type=AvpEntities.VARIABLE, user=expected_user, entity_id=expected_entity_id ) @@ -408,32 +384,25 @@ def test_is_authorized_variable( @pytest.mark.parametrize( "access_view, user, expected_user", [ - (AccessView.CLUSTER_ACTIVITY, None, ANY), + (AccessView.CLUSTER_ACTIVITY, mock, ANY), (AccessView.PLUGINS, mock, mock), ], ) @patch.object(AwsAuthManager, "avp_facade") - @patch.object(AwsAuthManager, "get_user") - def test_is_authorized_view( - self, mock_get_user, mock_avp_facade, access_view, user, expected_user, auth_manager - ): + def test_is_authorized_view(self, mock_avp_facade, access_view, user, expected_user, auth_manager): is_authorized = Mock(return_value=True) mock_avp_facade.is_authorized = is_authorized result = auth_manager.is_authorized_view(access_view=access_view, user=user) - if not user: - mock_get_user.assert_called_once() is_authorized.assert_called_once_with( method="GET", entity_type=AvpEntities.VIEW, user=expected_user, entity_id=access_view.value ) assert result @patch.object(AwsAuthManager, "avp_facade") - @patch.object(AwsAuthManager, "get_user") def test_batch_is_authorized_connection( self, - mock_get_user, mock_avp_facade, auth_manager, ): @@ -441,10 +410,10 @@ def test_batch_is_authorized_connection( mock_avp_facade.batch_is_authorized = batch_is_authorized result = auth_manager.batch_is_authorized_connection( - requests=[{"method": "GET"}, {"method": "GET", "details": ConnectionDetails(conn_id="conn_id")}] + requests=[{"method": "GET"}, {"method": "GET", "details": ConnectionDetails(conn_id="conn_id")}], + user=mock, ) - mock_get_user.assert_called_once() batch_is_authorized.assert_called_once_with( requests=[ { @@ -463,10 +432,8 @@ def test_batch_is_authorized_connection( assert result @patch.object(AwsAuthManager, "avp_facade") - @patch.object(AwsAuthManager, "get_user") def test_batch_is_authorized_dag( self, - mock_get_user, mock_avp_facade, auth_manager, ): @@ -478,10 +445,10 @@ def test_batch_is_authorized_dag( {"method": "GET"}, {"method": "GET", "details": DagDetails(id="dag_1")}, {"method": "GET", "details": DagDetails(id="dag_1"), "access_entity": DagAccessEntity.CODE}, - ] + ], + user=mock, ) - mock_get_user.assert_called_once() batch_is_authorized.assert_called_once_with( requests=[ { @@ -512,10 +479,8 @@ def test_batch_is_authorized_dag( assert result @patch.object(AwsAuthManager, "avp_facade") - @patch.object(AwsAuthManager, "get_user") def test_batch_is_authorized_pool( self, - mock_get_user, mock_avp_facade, auth_manager, ): @@ -523,10 +488,10 @@ def test_batch_is_authorized_pool( mock_avp_facade.batch_is_authorized = batch_is_authorized result = auth_manager.batch_is_authorized_pool( - requests=[{"method": "GET"}, {"method": "GET", "details": PoolDetails(name="pool1")}] + requests=[{"method": "GET"}, {"method": "GET", "details": PoolDetails(name="pool1")}], + user=mock, ) - mock_get_user.assert_called_once() batch_is_authorized.assert_called_once_with( requests=[ { @@ -545,10 +510,8 @@ def test_batch_is_authorized_pool( assert result @patch.object(AwsAuthManager, "avp_facade") - @patch.object(AwsAuthManager, "get_user") def test_batch_is_authorized_variable( self, - mock_get_user, mock_avp_facade, auth_manager, ): @@ -556,10 +519,10 @@ def test_batch_is_authorized_variable( mock_avp_facade.batch_is_authorized = batch_is_authorized result = auth_manager.batch_is_authorized_variable( - requests=[{"method": "GET"}, {"method": "GET", "details": VariableDetails(key="var1")}] + requests=[{"method": "GET"}, {"method": "GET", "details": VariableDetails(key="var1")}], + user=mock, ) - mock_get_user.assert_called_once() batch_is_authorized.assert_called_once_with( requests=[ { @@ -701,12 +664,11 @@ def test_filter_permitted_menu_items_logged_out(self, mock_get_user, auth_manage @pytest.mark.parametrize( "methods, user", [ - (None, None), + (None, AwsAuthManagerUser(user_id="test_user_id", groups=[])), (["PUT", "GET"], AwsAuthManagerUser(user_id="test_user_id", groups=[])), ], ) - @patch.object(AwsAuthManager, "get_user") - def test_filter_permitted_dag_ids(self, mock_get_user, methods, user, auth_manager, test_user): + def test_filter_permitted_dag_ids(self, methods, user, auth_manager, test_user): dag_ids = {"dag_1", "dag_2"} batch_is_authorized_output = [ { @@ -746,8 +708,6 @@ def test_filter_permitted_dag_ids(self, mock_get_user, methods, user, auth_manag return_value=batch_is_authorized_output ) - mock_get_user.return_value = test_user - result = auth_manager.filter_permitted_dag_ids( dag_ids=dag_ids, methods=methods, diff --git a/tests/auth/managers/simple/test_simple_auth_manager.py b/tests/auth/managers/simple/test_simple_auth_manager.py index 4a146e6bd3c02..cc82008de82fe 100644 --- a/tests/auth/managers/simple/test_simple_auth_manager.py +++ b/tests/auth/managers/simple/test_simple_auth_manager.py @@ -163,7 +163,6 @@ def test_serialize_user(self, auth_manager): assert result == {"username": "test", "role": "admin"} @pytest.mark.db_test - @patch.object(SimpleAuthManager, "is_logged_in") @pytest.mark.parametrize( "api", [ @@ -176,27 +175,25 @@ def test_serialize_user(self, auth_manager): ], ) @pytest.mark.parametrize( - "is_logged_in, role, method, result", + "role, method, result", [ - (True, "ADMIN", "GET", True), - (True, "ADMIN", "DELETE", True), - (True, "VIEWER", "POST", False), - (True, "VIEWER", "PUT", False), - (True, "VIEWER", "DELETE", False), - (False, "ADMIN", "GET", False), + ("ADMIN", "GET", True), + ("ADMIN", "DELETE", True), + ("VIEWER", "POST", False), + ("VIEWER", "PUT", False), + ("VIEWER", "DELETE", False), ], ) - def test_is_authorized_methods( - self, mock_is_logged_in, auth_manager, app, api, is_logged_in, role, method, result - ): - mock_is_logged_in.return_value = is_logged_in - + def test_is_authorized_methods(self, auth_manager, app, api, role, method, result): with app.test_request_context(): - session["user"] = SimpleAuthManagerUser(username="test", role=role) - assert getattr(auth_manager, api)(method=method) is result + assert ( + getattr(auth_manager, api)( + method=method, user=SimpleAuthManagerUser(username="test", role=role) + ) + is result + ) @pytest.mark.db_test - @patch.object(SimpleAuthManager, "is_logged_in") @pytest.mark.parametrize( "api, kwargs", [ @@ -211,26 +208,22 @@ def test_is_authorized_methods( ], ) @pytest.mark.parametrize( - "is_logged_in, role, result", + "role, result", [ - (True, "ADMIN", True), - (True, "VIEWER", True), - (True, "USER", True), - (True, "OP", True), - (False, "ADMIN", False), + ("ADMIN", True), + ("VIEWER", True), + ("USER", True), + ("OP", True), ], ) - def test_is_authorized_view_methods( - self, mock_is_logged_in, auth_manager, app, api, kwargs, is_logged_in, role, result - ): - mock_is_logged_in.return_value = is_logged_in - + def test_is_authorized_view_methods(self, auth_manager, app, api, kwargs, role, result): with app.test_request_context(): - session["user"] = SimpleAuthManagerUser(username="test", role=role) - assert getattr(auth_manager, api)(**kwargs) is result + assert ( + getattr(auth_manager, api)(**kwargs, user=SimpleAuthManagerUser(username="test", role=role)) + is result + ) @pytest.mark.db_test - @patch.object(SimpleAuthManager, "is_logged_in") @pytest.mark.parametrize( "api", [ @@ -250,17 +243,16 @@ def test_is_authorized_view_methods( ("VIEWER", "PUT", False), ], ) - def test_is_authorized_methods_op_role_required( - self, mock_is_logged_in, auth_manager, app, api, role, method, result - ): - mock_is_logged_in.return_value = True - + def test_is_authorized_methods_op_role_required(self, auth_manager, app, api, role, method, result): with app.test_request_context(): - session["user"] = SimpleAuthManagerUser(username="test", role=role) - assert getattr(auth_manager, api)(method=method) is result + assert ( + getattr(auth_manager, api)( + method=method, user=SimpleAuthManagerUser(username="test", role=role) + ) + is result + ) @pytest.mark.db_test - @patch.object(SimpleAuthManager, "is_logged_in") @pytest.mark.parametrize( "api", ["is_authorized_dag"], @@ -275,17 +267,16 @@ def test_is_authorized_methods_op_role_required( ("VIEWER", "PUT", False), ], ) - def test_is_authorized_methods_user_role_required( - self, mock_is_logged_in, auth_manager, app, api, role, method, result - ): - mock_is_logged_in.return_value = True - + def test_is_authorized_methods_user_role_required(self, auth_manager, app, api, role, method, result): with app.test_request_context(): - session["user"] = SimpleAuthManagerUser(username="test", role=role) - assert getattr(auth_manager, api)(method=method) is result + assert ( + getattr(auth_manager, api)( + method=method, user=SimpleAuthManagerUser(username="test", role=role) + ) + is result + ) @pytest.mark.db_test - @patch.object(SimpleAuthManager, "is_logged_in") @pytest.mark.parametrize( "api", ["is_authorized_dag", "is_authorized_asset", "is_authorized_pool"], @@ -301,13 +292,15 @@ def test_is_authorized_methods_user_role_required( ], ) def test_is_authorized_methods_viewer_role_required_for_get( - self, mock_is_logged_in, auth_manager, app, api, role, method, result + self, auth_manager, app, api, role, method, result ): - mock_is_logged_in.return_value = True - with app.test_request_context(): - session["user"] = SimpleAuthManagerUser(username="test", role=role) - assert getattr(auth_manager, api)(method=method) is result + assert ( + getattr(auth_manager, api)( + method=method, user=SimpleAuthManagerUser(username="test", role=role) + ) + is result + ) @pytest.mark.db_test def test_register_views(self, auth_manager_with_appbuilder): diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index 370e401da0609..d8d24a93f1df4 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -222,7 +222,8 @@ def test_batch_is_authorized_dag(self, mock_is_authorized_dag, auth_manager, ret [ {"method": "GET", "details": DagDetails(id="dag1")}, {"method": "GET", "details": DagDetails(id="dag2")}, - ] + ], + user=Mock(), ) assert result == expected @@ -243,7 +244,8 @@ def test_batch_is_authorized_connection( [ {"method": "GET", "details": ConnectionDetails(conn_id="conn1")}, {"method": "GET", "details": ConnectionDetails(conn_id="conn2")}, - ] + ], + user=Mock(), ) assert result == expected @@ -262,7 +264,8 @@ def test_batch_is_authorized_pool(self, mock_is_authorized_pool, auth_manager, r [ {"method": "GET", "details": PoolDetails(name="pool1")}, {"method": "GET", "details": PoolDetails(name="pool2")}, - ] + ], + user=Mock(), ) assert result == expected @@ -283,7 +286,8 @@ def test_batch_is_authorized_variable( [ {"method": "GET", "details": VariableDetails(key="var1")}, {"method": "GET", "details": VariableDetails(key="var2")}, - ] + ], + user=Mock(), ) assert result == expected