From 5c72befcfde63ade2870491cfeb708675399d9d6 Mon Sep 17 00:00:00 2001 From: Elad Galili <33863800+eladi99@users.noreply.github.com> Date: Mon, 3 Jul 2023 09:45:24 +0300 Subject: [PATCH] Fix `LambdaInvokeFunctionOperator` payload parameter type (#32259) * Fixing issue - Fix payload parameter of amazon LambdaCreateFunctionOperator --------- Co-authored-by: Elad Galili --- .../amazon/aws/hooks/lambda_function.py | 5 +++- .../amazon/aws/operators/lambda_function.py | 2 +- .../amazon/aws/hooks/test_lambda_function.py | 11 ++++++--- .../aws/operators/test_lambda_function.py | 23 +++++++++++++------ 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/lambda_function.py b/airflow/providers/amazon/aws/hooks/lambda_function.py index 2d61f0751f7cf..58ecac8bcccb6 100644 --- a/airflow/providers/amazon/aws/hooks/lambda_function.py +++ b/airflow/providers/amazon/aws/hooks/lambda_function.py @@ -48,7 +48,7 @@ def invoke_lambda( invocation_type: str | None = None, log_type: str | None = None, client_context: str | None = None, - payload: str | None = None, + payload: bytes | str | None = None, qualifier: str | None = None, ): """ @@ -65,6 +65,9 @@ def invoke_lambda( :param payload: The JSON that you want to provide to your Lambda function as input. :param qualifier: AWS Lambda Function Version or Alias Name """ + if isinstance(payload, str): + payload = payload.encode() + invoke_args = { "FunctionName": function_name, "InvocationType": invocation_type, diff --git a/airflow/providers/amazon/aws/operators/lambda_function.py b/airflow/providers/amazon/aws/operators/lambda_function.py index 93907634c12d3..28b6313204221 100644 --- a/airflow/providers/amazon/aws/operators/lambda_function.py +++ b/airflow/providers/amazon/aws/operators/lambda_function.py @@ -150,7 +150,7 @@ def __init__( qualifier: str | None = None, invocation_type: str | None = None, client_context: str | None = None, - payload: str | None = None, + payload: bytes | str | None = None, aws_conn_id: str = "aws_default", **kwargs, ): diff --git a/tests/providers/amazon/aws/hooks/test_lambda_function.py b/tests/providers/amazon/aws/hooks/test_lambda_function.py index f21c000ea6d9f..caaf164be47ba 100644 --- a/tests/providers/amazon/aws/hooks/test_lambda_function.py +++ b/tests/providers/amazon/aws/hooks/test_lambda_function.py @@ -26,6 +26,7 @@ FUNCTION_NAME = "test_function" PAYLOAD = '{"hello": "airflow"}' +BYTES_PAYLOAD = b'{"hello": "airflow"}' RUNTIME = "python3.9" ROLE = "role" HANDLER = "handler" @@ -48,13 +49,17 @@ def test_get_conn_returns_a_boto3_connection(self, hook): @mock.patch( "airflow.providers.amazon.aws.hooks.lambda_function.LambdaHook.conn", new_callable=mock.PropertyMock ) - def test_invoke_lambda(self, mock_conn): + @pytest.mark.parametrize( + "payload, invoke_payload", + [(PAYLOAD, BYTES_PAYLOAD), (BYTES_PAYLOAD, BYTES_PAYLOAD)], + ) + def test_invoke_lambda(self, mock_conn, payload, invoke_payload): hook = LambdaHook() - hook.invoke_lambda(function_name=FUNCTION_NAME, payload=PAYLOAD) + hook.invoke_lambda(function_name=FUNCTION_NAME, payload=payload) mock_conn().invoke.assert_called_once_with( FunctionName=FUNCTION_NAME, - Payload=PAYLOAD, + Payload=invoke_payload, ) @pytest.mark.parametrize( diff --git a/tests/providers/amazon/aws/operators/test_lambda_function.py b/tests/providers/amazon/aws/operators/test_lambda_function.py index 6f1d98e8ac8bc..f0b4b834eb00d 100644 --- a/tests/providers/amazon/aws/operators/test_lambda_function.py +++ b/tests/providers/amazon/aws/operators/test_lambda_function.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import json from unittest import mock from unittest.mock import Mock, patch @@ -30,6 +29,8 @@ ) FUNCTION_NAME = "function_name" +PAYLOAD = '{"hello": "airflow"}' +BYTES_PAYLOAD = b'{"hello": "airflow"}' ROLE_ARN = "role_arn" IMAGE_URI = "image_uri" @@ -70,29 +71,37 @@ def test_create_lambda_with_wait_for_completion(self, mock_hook_conn, mock_hook_ class TestLambdaInvokeFunctionOperator: - def test_init(self): + @pytest.mark.parametrize( + "payload", + [PAYLOAD, BYTES_PAYLOAD], + ) + def test_init(self, payload): lambda_operator = LambdaInvokeFunctionOperator( task_id="test", function_name="test", - payload=json.dumps({"TestInput": "Testdata"}), + payload=payload, log_type="None", aws_conn_id="aws_conn_test", ) assert lambda_operator.task_id == "test" assert lambda_operator.function_name == "test" - assert lambda_operator.payload == json.dumps({"TestInput": "Testdata"}) + assert lambda_operator.payload == payload assert lambda_operator.log_type == "None" assert lambda_operator.aws_conn_id == "aws_conn_test" @patch.object(LambdaInvokeFunctionOperator, "hook", new_callable=mock.PropertyMock) - def test_invoke_lambda(self, hook_mock): + @pytest.mark.parametrize( + "payload", + [PAYLOAD, BYTES_PAYLOAD], + ) + def test_invoke_lambda(self, hook_mock, payload): operator = LambdaInvokeFunctionOperator( task_id="task_test", function_name="a", invocation_type="b", log_type="c", client_context="d", - payload="e", + payload=payload, qualifier="f", ) returned_payload = Mock() @@ -111,7 +120,7 @@ def test_invoke_lambda(self, hook_mock): invocation_type="b", log_type="c", client_context="d", - payload="e", + payload=payload, qualifier="f", )