diff --git a/tests/benchmark/dags/evaluate_load_file.py b/tests/benchmark/dags/evaluate_load_file.py index 353ffaf0d..2323f4d57 100644 --- a/tests/benchmark/dags/evaluate_load_file.py +++ b/tests/benchmark/dags/evaluate_load_file.py @@ -65,7 +65,7 @@ def delete_table(table_metadata): conn_id=dataset_conn_id, filetype=FileType(dataset_filetype), ), - task_id="load_csv", + task_id="load", output_table=table_metadata, chunk_size=chunk_size, ) diff --git a/tests/benchmark/run.py b/tests/benchmark/run.py index b6506ba3e..b03b63ad0 100644 --- a/tests/benchmark/run.py +++ b/tests/benchmark/run.py @@ -3,22 +3,19 @@ import json import os import sys -import time import airflow import pandas as pd import psutil from airflow.executors.debug_executor import DebugExecutor +from airflow.models import TaskInstance from airflow.utils import timezone +from airflow.utils.session import provide_session from astro.databases import create_database from astro.sql.table import Metadata, Table -def elapsed_since(start): - return time.time() - start - - def get_disk_usage(): path = "/" disk_usage = psutil.disk_usage(path) @@ -42,6 +39,20 @@ def export_profile_data_to_bq(profile_data: dict, conn_id: str = "bigquery"): db.load_pandas_dataframe_to_table(df, table, if_exists="append") +@provide_session +def get_load_task_duration(dag, session=None): + ti: TaskInstance = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.task_id == "load", + TaskInstance.execution_date == dag.latest_execution_date, + ) + .first() + ) + return ti.duration + + def profile(func, *args, **kwargs): # noqa: C901 def wrapper(*args, **kwargs): process = psutil.Process(os.getpid()) @@ -51,13 +62,11 @@ def wrapper(*args, **kwargs): disk_usage_before = get_disk_usage() if sys.platform == "linux": io_counters_before = process.io_counters()._asdict() - start = time.time() # run decorated function - result = func(*args, **kwargs) + dag = func(*args, **kwargs) # metrics after - elapsed_time = elapsed_since(start) memory_full_info_after = process.memory_full_info()._asdict() cpu_time_after = process.cpu_times()._asdict() disk_usage_after = get_disk_usage() @@ -65,7 +74,7 @@ def wrapper(*args, **kwargs): io_counters_after = process.io_counters()._asdict() profile = { - "duration": elapsed_time, + "duration": get_load_task_duration(dag=dag), "memory_full_info": subtract( memory_full_info_after, memory_full_info_before ), @@ -79,7 +88,6 @@ def wrapper(*args, **kwargs): print(json.dumps(profile, default=str)) if os.getenv("ASTRO_PUBLISH_BENCHMARK_DATA"): export_profile_data_to_bq(profile) - return result if inspect.isfunction(func): return wrapper @@ -102,6 +110,7 @@ def run_dag(dag_id, execution_date, **kwargs): # been doing this prior to 2.2 so we keep compatibility. run_at_least_once=True, ) + return dag def build_dag_id(dataset, database):