Skip to content

Commit

Permalink
remove secrets in sagemaker agent (#2308)
Browse files Browse the repository at this point in the history
Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
samhita-alla authored Mar 29, 2024
1 parent 8ab9a3c commit 6c917ed
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
AsyncAgentBase,
Resource,
)
from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret
from flytekit.extend.backend.utils import convert_to_flyte_phase
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

Expand Down Expand Up @@ -53,9 +53,6 @@ async def create(
config=config,
inputs=inputs,
region=region,
aws_access_key_id=get_agent_secret(secret_key="aws-access-key"),
aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"),
aws_session_token=get_agent_secret(secret_key="aws-session-token"),
)

return SageMakerEndpointMetadata(config=config, region=region, inputs=inputs)
Expand All @@ -66,9 +63,6 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou
config={"EndpointName": resource_meta.config.get("EndpointName")},
inputs=resource_meta.inputs,
region=resource_meta.region,
aws_access_key_id=get_agent_secret(secret_key="aws-access-key"),
aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"),
aws_session_token=get_agent_secret(secret_key="aws-session-token"),
)

current_state = endpoint_status.get("EndpointStatus")
Expand All @@ -90,9 +84,6 @@ async def delete(self, resource_meta: SageMakerEndpointMetadata, **kwargs):
config={"EndpointName": resource_meta.config.get("EndpointName")},
region=resource_meta.region,
inputs=resource_meta.inputs,
aws_access_key_id=get_agent_secret(secret_key="aws-access-key"),
aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"),
aws_session_token=get_agent_secret(secret_key="aws-session-token"),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Resource,
SyncAgentBase,
)
from flytekit.extend.backend.utils import get_agent_secret
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

Expand Down Expand Up @@ -53,9 +52,6 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N
config=config,
images=images,
inputs=inputs,
aws_access_key_id=get_agent_secret(secret_key="aws-access-key"),
aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"),
aws_session_token=get_agent_secret(secret_key="aws-session-token"),
)

outputs = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,6 @@ async def _call(
images: Optional[Dict[str, str]] = None,
inputs: Optional[LiteralMap] = None,
region: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
) -> Any:
"""
Utilize this method to invoke any boto3 method (AWS service method).
Expand Down Expand Up @@ -173,9 +170,6 @@ async def _call(
async with session.client(
service_name=self._service,
region_name=final_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
) as client:
try:
result = await getattr(client, method)(**updated_config)
Expand Down
6 changes: 1 addition & 5 deletions plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@


@pytest.mark.asyncio
@mock.patch(
"flytekitplugins.awssagemaker_inference.boto3_agent.get_agent_secret",
return_value="mocked_secret",
)
@mock.patch(
"flytekitplugins.awssagemaker_inference.boto3_agent.Boto3AgentMixin._call",
return_value={
Expand All @@ -33,7 +29,7 @@
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
},
)
async def test_agent(mock_boto_call, mock_secret):
async def test_agent(mock_boto_call):
agent = AgentRegistry.get_agent("boto")
task_id = Identifier(
resource_type=ResourceType.TASK,
Expand Down
6 changes: 1 addition & 5 deletions plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@


@pytest.mark.asyncio
@mock.patch(
"flytekitplugins.awssagemaker_inference.agent.get_agent_secret",
return_value="mocked_secret",
)
@mock.patch(
"flytekitplugins.awssagemaker_inference.agent.Boto3AgentMixin._call",
return_value={
Expand Down Expand Up @@ -59,7 +55,7 @@
},
},
)
async def test_agent(mock_boto_call, mock_secret):
async def test_agent(mock_boto_call):
agent = AgentRegistry.get_agent("sagemaker-endpoint")
task_id = Identifier(
resource_type=ResourceType.TASK,
Expand Down

0 comments on commit 6c917ed

Please sign in to comment.