diff --git a/client_provider.py b/client_provider.py new file mode 100644 index 0000000..0741579 --- /dev/null +++ b/client_provider.py @@ -0,0 +1,36 @@ +import os +from temporalio.client import Client, TLSConfig + +async def get_temporal_client() -> Client: + cert_path = os.getenv("TEMPORAL_TLS_CERT") + key_path = os.getenv("TEMPORAL_TLS_KEY") + api_key = os.getenv("TEMPORAL_API_KEY") + + # Check for mTLS authentication + if cert_path and key_path: + with open(cert_path, "rb") as f: + client_cert = f.read() + with open(key_path, "rb") as f: + client_key = f.read() + + return await Client.connect( + os.getenv("TEMPORAL_ADDRESS"), + namespace=os.getenv("TEMPORAL_NAMESPACE"), + tls=TLSConfig( + client_cert=client_cert, + client_private_key=client_key, + ), + ) + elif api_key: + return await Client.connect( + os.getenv("TEMPORAL_ADDRESS"), + namespace=os.getenv("TEMPORAL_NAMESPACE"), + api_key=api_key, + tls=True, + ) + + else: + return await Client.connect( + os.getenv("TEMPORAL_ADDRESS", "localhost:7233"), + namespace=os.getenv("TEMPORAL_NAMESPACE", "default"), + ) diff --git a/run_worker.py b/run_worker.py index 0aa9f8f..607997d 100644 --- a/run_worker.py +++ b/run_worker.py @@ -4,13 +4,14 @@ from temporalio.client import Client from temporalio.worker import Worker +from client_provider import get_temporal_client from activities import BankingActivities from shared import MONEY_TRANSFER_TASK_QUEUE_NAME from workflows import MoneyTransfer async def main() -> None: - client: Client = await Client.connect("localhost:7233", namespace="default") + client = await get_temporal_client() # Run the worker activities = BankingActivities() worker: Worker = Worker( diff --git a/run_workflow.py b/run_workflow.py index 1e53a3c..4363c77 100644 --- a/run_workflow.py +++ b/run_workflow.py @@ -4,13 +4,13 @@ from temporalio.client import Client, WorkflowFailureError +from client_provider import get_temporal_client from shared import MONEY_TRANSFER_TASK_QUEUE_NAME, PaymentDetails from workflows import MoneyTransfer async def main() -> None: - # Create client connected to server at the given address - client: Client = await Client.connect("localhost:7233") + client = await get_temporal_client() data: PaymentDetails = PaymentDetails( source_account="85-150",