Skip to content

Commit

Permalink
Make parameter user mandatory for all methods in the auth manager i…
Browse files Browse the repository at this point in the history
…nterface (apache#45986)
  • Loading branch information
vincbeck authored Jan 27, 2025
1 parent a9dff59 commit d024cda
Show file tree
Hide file tree
Showing 19 changed files with 280 additions and 303 deletions.
8 changes: 6 additions & 2 deletions airflow/api_connexion/endpoints/asset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def get_asset_queued_events(
*, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Get queued asset events for an asset."""
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"])
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(
user=get_auth_manager().get_user(), methods=["GET"]
)
where_clause = _generate_queued_event_where_clause(
uri=uri, before=before, permitted_dag_ids=permitted_dag_ids
)
Expand Down Expand Up @@ -313,7 +315,9 @@ def delete_asset_queued_events(
*, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Delete queued asset events for an asset."""
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"])
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(
user=get_auth_manager().get_user(), methods=["GET"]
)
where_clause = _generate_queued_event_where_clause(
uri=uri, before=before, permitted_dag_ids=permitted_dag_ids
)
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/dag_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def reparse_dag_file(*, file_token: str, session: Session = NEW_SESSION) -> Resp
raise NotFound("File not found")

# Check if user has read access to all the DAGs defined in the file
if not get_auth_manager().batch_is_authorized_dag(requests):
if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()):
raise PermissionDenied()

parsing_request = DagPriorityParsingRequest(fileloc=path)
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/dag_source_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_dag_source(
]

# Check if user has read access to all the DAGs defined in the file
if not get_auth_manager().batch_is_authorized_dag(requests):
if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()):
raise PermissionDenied()
dag_source = dag_version.dag_code.source_code
version_number = dag_version.version_number
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) ->
)
session.expunge(error)

can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET", user=get_auth_manager().get_user())
if not can_read_all_dags:
readable_dag_ids = security.get_readable_dags()
file_dag_ids = {
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_import_errors(
query = select(ParseImportError)
query = apply_sorting(query, order_by, to_replace, allowed_sort_attrs)

can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET", user=get_auth_manager().get_user())

if not can_read_all_dags:
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
Expand Down Expand Up @@ -133,7 +133,7 @@ def get_import_errors(
}
for dag_id in file_dag_ids
]
if not get_auth_manager().batch_is_authorized_dag(requests):
if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()):
session.expunge(import_error)
import_error.stacktrace = "REDACTED - you do not have read permission on all DAGs in the file"

Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
}
for id in dag_ids
]
if not get_auth_manager().batch_is_authorized_dag(requests):
if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()):
raise PermissionDenied(detail=f"User not allowed to access some of these DAGs: {list(dag_ids)}")
else:
dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user)
Expand Down
34 changes: 26 additions & 8 deletions airflow/api_connexion/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def decorated(*args, **kwargs):
section: str | None = kwargs.get("section")
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_configuration(
method=method, details=ConfigurationDetails(section=section)
method=method,
details=ConfigurationDetails(section=section),
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand All @@ -97,7 +99,9 @@ def decorated(*args, **kwargs):
connection_id: str | None = kwargs.get("connection_id")
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_connection(
method=method, details=ConnectionDetails(conn_id=connection_id)
method=method,
details=ConnectionDetails(conn_id=connection_id),
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand All @@ -120,13 +124,15 @@ def callback() -> bool | DagAccessEntity:
method=method,
access_entity=access_entity,
details=DagDetails(id=dag_id),
user=get_auth_manager().get_user(),
)
else:
# here we know dag_id is not provided.
# check is the user authorized to access all DAGs?
if get_auth_manager().is_authorized_dag(
method=method,
access_entity=access_entity,
user=get_auth_manager().get_user(),
):
return True
elif access_entity:
Expand All @@ -138,7 +144,9 @@ def callback() -> bool | DagAccessEntity:
# but we leave it to the endpoint function to properly restrict access beyond that
if method not in ("GET", "PUT"):
return False
return any(get_auth_manager().get_permitted_dag_ids(methods=[method]))
return any(
get_auth_manager().get_permitted_dag_ids(user=get_auth_manager().get_user(), methods=[method])
)

return callback

Expand All @@ -165,7 +173,9 @@ def decorated(*args, **kwargs):
uri: str | None = kwargs.get("uri")
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_asset(
method=method, details=AssetDetails(uri=uri)
method=method,
details=AssetDetails(uri=uri),
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand All @@ -184,7 +194,9 @@ def decorated(*args, **kwargs):
pool_name: str | None = kwargs.get("pool_name")
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_pool(
method=method, details=PoolDetails(name=pool_name)
method=method,
details=PoolDetails(name=pool_name),
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand All @@ -203,7 +215,9 @@ def decorated(*args, **kwargs):
variable_key: str | None = kwargs.get("variable_key")
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_variable(
method=method, details=VariableDetails(key=variable_key)
method=method,
details=VariableDetails(key=variable_key),
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand All @@ -220,7 +234,9 @@ def requires_access_decorator(func: T):
@wraps(func)
def decorated(*args, **kwargs):
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_view(access_view=access_view),
is_authorized_callback=lambda: get_auth_manager().is_authorized_view(
access_view=access_view, user=get_auth_manager().get_user()
),
func=func,
args=args,
kwargs=kwargs,
Expand All @@ -240,7 +256,9 @@ def requires_access_decorator(func: T):
def decorated(*args, **kwargs):
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_custom_view(
method=method, resource_name=resource_name
method=method,
resource_name=resource_name,
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand Down
Loading

0 comments on commit d024cda

Please sign in to comment.