diff --git a/airflow-core/src/airflow/config_templates/airflow_local_settings.py b/airflow-core/src/airflow/config_templates/airflow_local_settings.py index 6270fba7ba3ba..3e1535bedf694 100644 --- a/airflow-core/src/airflow/config_templates/airflow_local_settings.py +++ b/airflow-core/src/airflow/config_templates/airflow_local_settings.py @@ -279,35 +279,29 @@ def _default_conn_name_from(mod_path, hook_name): ) remote_task_handler_kwargs = {} elif ELASTICSEARCH_HOST: - ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get_mandatory_value("elasticsearch", "END_OF_LOG_MARK") - ELASTICSEARCH_FRONTEND: str = conf.get_mandatory_value("elasticsearch", "frontend") + from airflow.providers.elasticsearch.log.es_task_handler import ElasticsearchRemoteLogIO + ELASTICSEARCH_WRITE_STDOUT: bool = conf.getboolean("elasticsearch", "WRITE_STDOUT") ELASTICSEARCH_WRITE_TO_ES: bool = conf.getboolean("elasticsearch", "WRITE_TO_ES") ELASTICSEARCH_JSON_FORMAT: bool = conf.getboolean("elasticsearch", "JSON_FORMAT") - ELASTICSEARCH_JSON_FIELDS: str = conf.get_mandatory_value("elasticsearch", "JSON_FIELDS") ELASTICSEARCH_TARGET_INDEX: str = conf.get_mandatory_value("elasticsearch", "TARGET_INDEX") ELASTICSEARCH_HOST_FIELD: str = conf.get_mandatory_value("elasticsearch", "HOST_FIELD") ELASTICSEARCH_OFFSET_FIELD: str = conf.get_mandatory_value("elasticsearch", "OFFSET_FIELD") + ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get_mandatory_value("elasticsearch", "LOG_ID_TEMPLATE") + + REMOTE_TASK_LOG = ElasticsearchRemoteLogIO( + host=ELASTICSEARCH_HOST, + target_index=ELASTICSEARCH_TARGET_INDEX, + write_stdout=ELASTICSEARCH_WRITE_STDOUT, + write_to_es=ELASTICSEARCH_WRITE_TO_ES, + offset_field=ELASTICSEARCH_OFFSET_FIELD, + host_field=ELASTICSEARCH_HOST_FIELD, + base_log_folder=BASE_LOG_FOLDER, + delete_local_copy=delete_local_copy, + json_format=ELASTICSEARCH_JSON_FORMAT, + log_id_template=ELASTICSEARCH_LOG_ID_TEMPLATE, + ) - ELASTIC_REMOTE_HANDLERS: dict[str, dict[str, str | bool | None]] = { - "task": { - "class": "airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler", - "formatter": "airflow", - "base_log_folder": BASE_LOG_FOLDER, - "end_of_log_mark": ELASTICSEARCH_END_OF_LOG_MARK, - "host": ELASTICSEARCH_HOST, - "frontend": ELASTICSEARCH_FRONTEND, - "write_stdout": ELASTICSEARCH_WRITE_STDOUT, - "write_to_es": ELASTICSEARCH_WRITE_TO_ES, - "target_index": ELASTICSEARCH_TARGET_INDEX, - "json_format": ELASTICSEARCH_JSON_FORMAT, - "json_fields": ELASTICSEARCH_JSON_FIELDS, - "host_field": ELASTICSEARCH_HOST_FIELD, - "offset_field": ELASTICSEARCH_OFFSET_FIELD, - }, - } - - DEFAULT_LOGGING_CONFIG["handlers"].update(ELASTIC_REMOTE_HANDLERS) elif OPENSEARCH_HOST: OPENSEARCH_END_OF_LOG_MARK: str = conf.get_mandatory_value("opensearch", "END_OF_LOG_MARK") OPENSEARCH_PORT: str = conf.get_mandatory_value("opensearch", "PORT") diff --git a/chart/values.yaml b/chart/values.yaml index 904c0594c1ae5..69fc489bd26ef 100644 --- a/chart/values.yaml +++ b/chart/values.yaml @@ -2918,7 +2918,7 @@ config: run_duration: 41460 elasticsearch: json_format: 'True' - log_id_template: "{dag_id}_{task_id}_{execution_date}_{try_number}" + log_id_template: "{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}" elasticsearch_configs: max_retries: 3 timeout: 30 diff --git a/devel-common/pyproject.toml b/devel-common/pyproject.toml index 363a8f733bae3..3d2e7f5cd7228 100644 --- a/devel-common/pyproject.toml +++ b/devel-common/pyproject.toml @@ -138,6 +138,7 @@ dependencies = [ "pytest-unordered>=0.6.1", "pytest-xdist>=3.5.0", "pytest>=8.3.3", + "testcontainers>=4.12.0", ] "sentry" = [ "blinker>=1.7.0", diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py index 610b03f96e199..2af39ce736428 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_response.py @@ -17,6 +17,7 @@ from __future__ import annotations from collections.abc import Iterator +from typing import Any def _wrap(val): @@ -25,6 +26,33 @@ def _wrap(val): return val +def resolve_nested(self, hit: dict[Any, Any], parent_class=None) -> type[Hit]: + """ + Resolve nested hits from Elasticsearch by iteratively navigating the `_nested` field. + + The result is used to fetch the appropriate document class to handle the hit. + + This method can be used with nested Elasticsearch fields which are structured + as dictionaries with "field" and "_nested" keys. + """ + doc_class = Hit + + nested_path: list[str] = [] + nesting = hit["_nested"] + while nesting and "field" in nesting: + nested_path.append(nesting["field"]) + nesting = nesting.get("_nested") + nested_path_str = ".".join(nested_path) + + if hasattr(parent_class, "_index"): + nested_field = parent_class._index.resolve_field(nested_path_str) + + if nested_field is not None: + return nested_field._doc_class + + return doc_class + + class AttributeList: """Helper class to provide attribute like access to List objects.""" diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py index c38469bc17230..c1db73d22c8bc 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py @@ -22,39 +22,41 @@ import json import logging import os -import pathlib import shutil import sys import time from collections import defaultdict from collections.abc import Callable from operator import attrgetter +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, cast from urllib.parse import quote, urlparse +import attrs + # Using `from elasticsearch import *` would break elasticsearch mocking used in unit test. import elasticsearch import pendulum from elasticsearch import helpers from elasticsearch.exceptions import NotFoundError +import airflow.logging_config as alc from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.models.dagrun import DagRun -from airflow.providers.common.compat.sdk import timezone from airflow.providers.elasticsearch.log.es_json_formatter import ElasticsearchJSONFormatter -from airflow.providers.elasticsearch.log.es_response import ElasticSearchResponse, Hit +from airflow.providers.elasticsearch.log.es_response import ElasticSearchResponse, Hit, resolve_nested from airflow.providers.elasticsearch.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.utils import timezone from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import ExternalLoggingMixin, LoggingMixin from airflow.utils.module_loading import import_string -from airflow.utils.session import create_session if TYPE_CHECKING: from datetime import datetime from airflow.models.taskinstance import TaskInstance, TaskInstanceKey - from airflow.utils.log.file_task_handler import LogMetadata + from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI + from airflow.utils.log.file_task_handler import LogMessages, LogMetadata, LogSourceInfo if AIRFLOW_V_3_0_PLUS: @@ -90,30 +92,40 @@ def get_es_kwargs_from_config() -> dict[str, Any]: return kwargs_dict -def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance: +def getattr_nested(obj, item, default): """ - Given TI | TIKey, return a TI object. + Get item from obj but return default if not found. + + E.g. calling ``getattr_nested(a, 'b.c', "NA")`` will return + ``a.b.c`` if such a value exists, and "NA" otherwise. - Will raise exception if no TI is found in the database. + :meta private: """ - from airflow.models.taskinstance import TaskInstance, TaskInstanceKey + try: + return attrgetter(item)(obj) + except AttributeError: + return default - if not isinstance(ti, TaskInstanceKey): - return ti - val = ( - session.query(TaskInstance) - .filter( - TaskInstance.task_id == ti.task_id, - TaskInstance.dag_id == ti.dag_id, - TaskInstance.run_id == ti.run_id, - TaskInstance.map_index == ti.map_index, - ) - .one_or_none() + +def _render_log_id(log_id_template: str, ti: TaskInstance | TaskInstanceKey, try_number: int) -> str: + return log_id_template.format( + dag_id=ti.dag_id, + task_id=ti.task_id, + run_id=getattr(ti, "run_id", ""), + try_number=try_number, + map_index=getattr(ti, "map_index", ""), ) - if isinstance(val, TaskInstance): - val.try_number = ti.try_number - return val - raise AirflowException(f"Could not find TaskInstance for {ti}") + + +def _clean_date(value: datetime | None) -> str: + """ + Clean up a date value so that it is safe to query in elasticsearch by removing reserved characters. + + https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-query-string-query.html#_reserved_characters + """ + if value is None: + return "" + return value.strftime("%Y_%m_%dT%H_%M_%S_%f") class ElasticsearchTaskHandler(FileTaskHandler, ExternalLoggingMixin, LoggingMixin): @@ -151,8 +163,8 @@ def __init__( base_log_folder: str, end_of_log_mark: str, write_stdout: bool, - json_format: bool, json_fields: str, + json_format: bool = False, write_to_es: bool = False, target_index: str = "airflow-logs", host_field: str = "host", @@ -201,6 +213,27 @@ def __init__( self.handler: logging.FileHandler | logging.StreamHandler | None = None self._doc_type_map: dict[Any, Any] = {} self._doc_type: list[Any] = [] + self.log_id_template: str = conf.get( + "elasticsearch", + "log_id_template", + fallback="{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}", + ) + self.io = ElasticsearchRemoteLogIO( + host=self.host, + target_index=self.target_index, + write_stdout=self.write_stdout, + write_to_es=self.write_to_es, + offset_field=self.offset_field, + host_field=self.host_field, + base_log_folder=base_log_folder, + delete_local_copy=self.delete_local_copy, + log_id_template=self.log_id_template, + ) + # Airflow 3 introduce REMOTE_TASK_LOG for handling remote logging + # REMOTE_TASK_LOG should be explicitly set in airflow_local_settings.py when trying to use ESTaskHandler + # Before airflow 3.1, REMOTE_TASK_LOG is not set when trying to use ES TaskHandler. + if AIRFLOW_V_3_0_PLUS and alc.REMOTE_TASK_LOG is None: + alc.REMOTE_TASK_LOG = self.io @staticmethod def format_url(host: str) -> str: @@ -224,70 +257,6 @@ def format_url(host: str) -> str: return host - def _get_index_patterns(self, ti: TaskInstance | None) -> str: - """ - Get index patterns by calling index_patterns_callable, if provided, or the configured index_patterns. - - :param ti: A TaskInstance object or None. - """ - if self.index_patterns_callable: - self.log.debug("Using index_patterns_callable: %s", self.index_patterns_callable) - index_pattern_callable_obj = import_string(self.index_patterns_callable) - return index_pattern_callable_obj(ti) - self.log.debug("Using index_patterns: %s", self.index_patterns) - return self.index_patterns - - def _render_log_id(self, ti: TaskInstance | TaskInstanceKey, try_number: int) -> str: - from airflow.models.taskinstance import TaskInstanceKey - - with create_session() as session: - if isinstance(ti, TaskInstanceKey): - ti = _ensure_ti(ti, session) - dag_run = ti.get_dagrun(session=session) - if USE_PER_RUN_LOG_ID: - log_id_template = dag_run.get_log_template(session=session).elasticsearch_id - - if self.json_format: - data_interval_start = self._clean_date(dag_run.data_interval_start) - data_interval_end = self._clean_date(dag_run.data_interval_end) - logical_date = self._clean_date(dag_run.logical_date) - else: - data_interval_start = ( - dag_run.data_interval_start.isoformat() if dag_run.data_interval_start else "" - ) - data_interval_end = dag_run.data_interval_end.isoformat() if dag_run.data_interval_end else "" - logical_date = dag_run.logical_date.isoformat() if dag_run.logical_date else "" - - return log_id_template.format( - dag_id=ti.dag_id, - task_id=ti.task_id, - run_id=getattr(ti, "run_id", ""), - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, - logical_date=logical_date, - execution_date=logical_date, - try_number=try_number, - map_index=getattr(ti, "map_index", ""), - ) - - @staticmethod - def _clean_date(value: datetime | None) -> str: - """ - Clean up a date value so that it is safe to query in elasticsearch by removing reserved characters. - - https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-query-string-query.html#_reserved_characters - """ - if value is None: - return "" - return value.strftime("%Y_%m_%dT%H_%M_%S_%f") - - def _group_logs_by_host(self, response: ElasticSearchResponse) -> dict[str, list[Hit]]: - grouped_logs = defaultdict(list) - for hit in response: - key = getattr_nested(hit, self.host_field, None) or self.host - grouped_logs[key].append(hit) - return grouped_logs - def _read_grouped_logs(self): return True @@ -311,15 +280,15 @@ def _read( metadata["offset"] = 0 offset = metadata["offset"] - log_id = self._render_log_id(ti, try_number) - response = self._es_read(log_id, offset, ti) + log_id = _render_log_id(self.log_id_template, ti, try_number) + response = self.io._es_read(log_id, offset, ti) + # TODO: Can we skip group logs by host ? if response is not None and response.hits: - logs_by_host = self._group_logs_by_host(response) + logs_by_host = self.io._group_logs_by_host(response) next_offset = attrgetter(self.offset_field)(response[-1]) else: logs_by_host = None next_offset = offset - # Ensure a string here. Large offset numbers will get JSON.parsed incorrectly # on the client. Sending as a string prevents this issue. # https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER @@ -329,7 +298,10 @@ def _read( # have the log uploaded but will not be stored in elasticsearch. metadata["end_of_log"] = False if logs_by_host: - if any(x[-1].message == self.end_of_log_mark for x in logs_by_host.values()): + end_mark_found = any( + self._get_log_message(x[-1]) == self.end_of_log_mark for x in logs_by_host.values() + ) + if end_mark_found: metadata["end_of_log"] = True cur_ts = pendulum.now() @@ -361,12 +333,6 @@ def _read( if int(offset) != int(next_offset) or "last_log_timestamp" not in metadata: metadata["last_log_timestamp"] = str(cur_ts) - # If we hit the end of the log, remove the actual end_of_log message - # to prevent it from showing in the UI. - def concat_logs(hits: list[Hit]) -> str: - log_range = (len(hits) - 1) if hits[-1].message == self.end_of_log_mark else len(hits) - return "\n".join(self._format_msg(hits[i]) for i in range(log_range)) - if logs_by_host: if AIRFLOW_V_3_0_PLUS: from airflow.utils.log.file_task_handler import StructuredLogMessage @@ -389,11 +355,12 @@ def concat_logs(hits: list[Hit]) -> str: ] else: message = [ - (host, concat_logs(hits)) # type: ignore[misc] + (host, self.concat_logs(hits)) # type: ignore[misc] for host, hits in logs_by_host.items() ] else: message = [] + metadata["end_of_log"] = True return message, metadata def _format_msg(self, hit: Hit): @@ -407,46 +374,7 @@ def _format_msg(self, hit: Hit): ) # Just a safe-guard to preserve backwards-compatibility - return hit.message - - def _es_read(self, log_id: str, offset: int | str, ti: TaskInstance) -> ElasticSearchResponse | None: - """ - Return the logs matching log_id in Elasticsearch and next offset or ''. - - :param log_id: the log_id of the log to read. - :param offset: the offset start to read log from. - :param ti: the task instance object - - :meta private: - """ - query: dict[Any, Any] = { - "bool": { - "filter": [{"range": {self.offset_field: {"gt": int(offset)}}}], - "must": [{"match_phrase": {"log_id": log_id}}], - } - } - - index_patterns = self._get_index_patterns(ti) - try: - max_log_line = self.client.count(index=index_patterns, query=query)["count"] - except NotFoundError as e: - self.log.exception("The target index pattern %s does not exist", index_patterns) - raise e - - if max_log_line != 0: - try: - res = self.client.search( - index=index_patterns, - query=query, - sort=[self.offset_field], - size=self.MAX_LINE_PER_PAGE, - from_=self.MAX_LINE_PER_PAGE * self.PAGE, - ) - return ElasticSearchResponse(self, res) - except Exception as err: - self.log.exception("Could not read log with log_id: %s. Exception: %s", log_id, err) - - return None + return self._get_log_message(hit) def emit(self, record): if self.handler: @@ -455,6 +383,8 @@ def emit(self, record): def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None: """ + TODO: This API should be removed in airflow 3. + Provide task_instance context to airflow task handler. :param ti: task instance object @@ -473,12 +403,10 @@ def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> Non "dag_id": str(ti.dag_id), "task_id": str(ti.task_id), date_key: ( - self._clean_date(ti.logical_date) - if AIRFLOW_V_3_0_PLUS - else self._clean_date(ti.execution_date) + _clean_date(ti.logical_date) if AIRFLOW_V_3_0_PLUS else _clean_date(ti.execution_date) ), "try_number": str(ti.try_number), - "log_id": self._render_log_id(ti, ti.try_number), + "log_id": _render_log_id(self.log_id_template, ti, ti.try_number), }, ) @@ -500,6 +428,7 @@ def close(self) -> None: # calling close method. Here we check if logger is already # closed to prevent uploading the log to remote storage multiple # times when `logging.shutdown` is called. + # TODO: This API should be simplified since Airflow 3 no longer requires this API for writing log to ES if self.closed: return @@ -522,22 +451,10 @@ def close(self) -> None: # so we know where to stop while auto-tailing. self.emit(logging.makeLogRecord({"msg": self.end_of_log_mark})) - if self.write_stdout: + if self.io.write_stdout: self.handler.close() sys.stdout = sys.__stdout__ - if self.write_to_es and not self.write_stdout: - full_path = self.handler.baseFilename # type: ignore[union-attr] - log_relative_path = pathlib.Path(full_path).relative_to(self.local_base).as_posix() - local_loc = os.path.join(self.local_base, log_relative_path) - if os.path.exists(local_loc): - # read log and remove old logs to get just the latest additions - log = pathlib.Path(local_loc).read_text() - log_lines = self._parse_raw_log(log) - success = self._write_to_es(log_lines) - if success and self.delete_local_copy: - shutil.rmtree(os.path.dirname(local_loc)) - super().close() self.closed = True @@ -555,7 +472,7 @@ def get_external_log_url(self, task_instance: TaskInstance, try_number: int) -> :param try_number: task instance try_number to read logs from. :return: URL to the external log collection service """ - log_id = self._render_log_id(task_instance, try_number) + log_id = _render_log_id(self.log_id_template, task_instance, try_number) scheme = "" if "://" in self.frontend else "https://" return scheme + self.frontend.format(log_id=quote(log_id)) @@ -564,38 +481,12 @@ def supports_external_link(self) -> bool: """Whether we can support external links.""" return bool(self.frontend) - def _resolve_nested(self, hit: dict[Any, Any], parent_class=None) -> type[Hit]: - """ - Resolve nested hits from Elasticsearch by iteratively navigating the `_nested` field. - - The result is used to fetch the appropriate document class to handle the hit. - - This method can be used with nested Elasticsearch fields which are structured - as dictionaries with "field" and "_nested" keys. - """ - doc_class = Hit - - nested_path: list[str] = [] - nesting = hit["_nested"] - while nesting and "field" in nesting: - nested_path.append(nesting["field"]) - nesting = nesting.get("_nested") - nested_path_str = ".".join(nested_path) - - if hasattr(parent_class, "_index"): - nested_field = parent_class._index.resolve_field(nested_path_str) - - if nested_field is not None: - return nested_field._doc_class - - return doc_class - def _get_result(self, hit: dict[Any, Any], parent_class=None) -> Hit: """ Process a hit (i.e., a result) from an Elasticsearch response and transform it into a class instance. The transformation depends on the contents of the hit. If the document in hit contains a nested field, - the '_resolve_nested' method is used to determine the appropriate class (based on the nested path). + the 'resolve_nested' method is used to determine the appropriate class (based on the nested path). If the hit has a document type that is present in the '_doc_type_map', the corresponding class is used. If not, the method iterates over the '_doc_type' classes and uses the first one whose '_matches' method returns True for the hit. @@ -605,41 +496,12 @@ def _get_result(self, hit: dict[Any, Any], parent_class=None) -> Hit: Finally, the transformed hit is returned. If the determined class has a 'from_es' method, this is used to transform the hit - - An example of the hit argument: - - {'_id': 'jdeZT4kBjAZqZnexVUxk', - '_index': '.ds-filebeat-8.8.2-2023.07.09-000001', - '_score': 2.482621, - '_source': {'@timestamp': '2023-07-13T14:13:15.140Z', - 'asctime': '2023-07-09T07:47:43.907+0000', - 'container': {'id': 'airflow'}, - 'dag_id': 'example_bash_operator', - 'ecs': {'version': '8.0.0'}, - 'logical_date': '2023_07_09T07_47_32_000000', - 'filename': 'taskinstance.py', - 'input': {'type': 'log'}, - 'levelname': 'INFO', - 'lineno': 1144, - 'log': {'file': {'path': "/opt/airflow/Documents/GitHub/airflow/logs/ - dag_id=example_bash_operator'/run_id=owen_run_run/ - task_id=run_after_loop/attempt=1.log"}, - 'offset': 0}, - 'log.offset': 1688888863907337472, - 'log_id': 'example_bash_operator-run_after_loop-owen_run_run--1-1', - 'message': 'Dependencies all met for dep_context=non-requeueable ' - 'deps ti=', - 'task_id': 'run_after_loop', - 'try_number': '1'}, - '_type': '_doc'} """ doc_class = Hit dt = hit.get("_type") if "_nested" in hit: - doc_class = self._resolve_nested(hit, parent_class) + doc_class = resolve_nested(hit, parent_class) elif dt in self._doc_type_map: doc_class = self._doc_type_map[dt] @@ -657,13 +519,97 @@ def _get_result(self, hit: dict[Any, Any], parent_class=None) -> Hit: callback: type[Hit] | Callable[..., Any] = getattr(doc_class, "from_es", doc_class) return callback(hit) - def _parse_raw_log(self, log: str) -> list[dict[str, Any]]: + def _get_log_message(self, hit: Hit) -> str: + """ + Get log message from hit, supporting both Airflow 2.x and 3.x formats. + + In Airflow 2.x, the log record JSON has a "message" key, e.g.: + { + "message": "Dag name:dataset_consumes_1 queued_at:2025-08-12 15:05:57.703493+00:00", + "offset": 1755011166339518208, + "log_id": "dataset_consumes_1-consuming_1-manual__2025-08-12T15:05:57.691303+00:00--1-1" + } + + In Airflow 3.x, the "message" field is renamed to "event". + We check the correct attribute depending on the Airflow major version. + """ + if hasattr(hit, "event"): + return hit.event + if hasattr(hit, "message"): + return hit.message + return "" + + def concat_logs(self, hits: list[Hit]) -> str: + log_range = (len(hits) - 1) if self._get_log_message(hits[-1]) == self.end_of_log_mark else len(hits) + return "\n".join(self._format_msg(hits[i]) for i in range(log_range)) + + +@attrs.define(kw_only=True) +class ElasticsearchRemoteLogIO(LoggingMixin): # noqa: D101 + json_format: bool = False + write_stdout: bool = False + delete_local_copy: bool = False + host: str = "http://localhost:9200" + host_field: str = "host" + target_index: str = "airflow-logs" + offset_field: str = "offset" + write_to_es: bool = False + base_log_folder: Path = attrs.field(converter=Path) + log_id_template: str = conf.get( + "elasticsearch", + "log_id_template", + fallback="{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}", + ) + + processors = () + + def __attrs_post_init__(self): + es_kwargs = get_es_kwargs_from_config() + self.client = elasticsearch.Elasticsearch(self.host, **es_kwargs) + self.index_patterns_callable = conf.get("elasticsearch", "index_patterns_callable", fallback="") + self.PAGE = 0 + self.MAX_LINE_PER_PAGE = 1000 + self.index_patterns: str = conf.get("elasticsearch", "index_patterns") + self._doc_type_map: dict[Any, Any] = {} + self._doc_type: list[Any] = [] + + def upload(self, path: os.PathLike | str, ti: RuntimeTI): + """Write the log to ElasticSearch.""" + path = Path(path) + + if path.is_absolute(): + local_loc = path + else: + local_loc = self.base_log_folder.joinpath(path) + + log_id = _render_log_id(self.log_id_template, ti, ti.try_number) # type: ignore[arg-type] + if local_loc.is_file() and self.write_stdout: + # Intentionally construct the log_id and offset field + + log_lines = self._parse_raw_log(local_loc.read_text(), log_id) + for line in log_lines: + sys.stdout.write(json.dumps(line) + "\n") + sys.stdout.flush() + + if local_loc.is_file() and self.write_to_es: + log_lines = self._parse_raw_log(local_loc.read_text(), log_id) + success = self._write_to_es(log_lines) + if success and self.delete_local_copy: + shutil.rmtree(os.path.dirname(local_loc)) + + def _parse_raw_log(self, log: str, log_id: str) -> list[dict[str, Any]]: logs = log.split("\n") parsed_logs = [] + offset = 1 for line in logs: # Make sure line is not empty if line.strip(): - parsed_logs.append(json.loads(line)) + # construct log_id which is {dag_id}-{task_id}-{run_id}-{map_index}-{try_number} + # also construct the offset field (default is 'offset') + log_dict = json.loads(line) + log_dict.update({"log_id": log_id, self.offset_field: offset}) + offset += 1 + parsed_logs.append(log_dict) return parsed_logs @@ -678,21 +624,139 @@ def _write_to_es(self, log_lines: list[dict[str, Any]]) -> bool: try: _ = helpers.bulk(self.client, bulk_actions) return True + except helpers.BulkIndexError as bie: + self.log.exception("Bulk upload failed for %d log(s)", len(bie.errors)) + for error in bie.errors: + self.log.exception(error) + return False except Exception as e: self.log.exception("Unable to insert logs into Elasticsearch. Reason: %s", str(e)) return False + def read(self, _relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages]: + log_id = _render_log_id(self.log_id_template, ti, ti.try_number) # type: ignore[arg-type] + self.log.info("Reading log %s from Elasticsearch", log_id) + offset = 0 + response = self._es_read(log_id, offset, ti) + if response is not None and response.hits: + logs_by_host = self._group_logs_by_host(response) + else: + logs_by_host = None -def getattr_nested(obj, item, default): - """ - Get item from obj but return default if not found. + if logs_by_host is None: + missing_log_message = ( + f"*** Log {log_id} not found in Elasticsearch. " + "If your task started recently, please wait a moment and reload this page. " + "Otherwise, the logs for this task instance may have been removed." + ) + return [], [missing_log_message] - E.g. calling ``getattr_nested(a, 'b.c', "NA")`` will return - ``a.b.c`` if such a value exists, and "NA" otherwise. + header = [] + # Start log group + header.append("".join([host for host in logs_by_host.keys()])) - :meta private: - """ - try: - return attrgetter(item)(obj) - except AttributeError: - return default + message = [] + # Structured log messages + for hits in logs_by_host.values(): + for hit in hits: + filtered = {k: v for k, v in hit.to_dict().items() if k.lower() in TASK_LOG_FIELDS} + message.append(json.dumps(filtered)) + + return header, message + + def _es_read(self, log_id: str, offset: int | str, ti: RuntimeTI) -> ElasticSearchResponse | None: + """ + Return the logs matching log_id in Elasticsearch and next offset or ''. + + :param log_id: the log_id of the log to read. + :param offset: the offset start to read log from. + :param ti: the task instance object + + :meta private: + """ + query: dict[Any, Any] = { + "bool": { + "filter": [{"range": {self.offset_field: {"gt": int(offset)}}}], + "must": [{"match_phrase": {"log_id": log_id}}], + } + } + + index_patterns = self._get_index_patterns(ti) + try: + max_log_line = self.client.count(index=index_patterns, query=query)["count"] + except NotFoundError as e: + self.log.exception("The target index pattern %s does not exist", index_patterns) + raise e + + if max_log_line != 0: + try: + res = self.client.search( + index=index_patterns, + query=query, + sort=[self.offset_field], + size=self.MAX_LINE_PER_PAGE, + from_=self.MAX_LINE_PER_PAGE * self.PAGE, + ) + return ElasticSearchResponse(self, res) + except Exception as err: + self.log.exception("Could not read log with log_id: %s. Exception: %s", log_id, err) + + return None + + def _get_index_patterns(self, ti: RuntimeTI | None) -> str: + """ + Get index patterns by calling index_patterns_callable, if provided, or the configured index_patterns. + + :param ti: A TaskInstance object or None. + """ + if self.index_patterns_callable: + self.log.debug("Using index_patterns_callable: %s", self.index_patterns_callable) + index_pattern_callable_obj = import_string(self.index_patterns_callable) + return index_pattern_callable_obj(ti) + self.log.debug("Using index_patterns: %s", self.index_patterns) + return self.index_patterns + + def _group_logs_by_host(self, response: ElasticSearchResponse) -> dict[str, list[Hit]]: + grouped_logs = defaultdict(list) + for hit in response: + key = getattr_nested(hit, self.host_field, None) or self.host + grouped_logs[key].append(hit) + return grouped_logs + + def _get_result(self, hit: dict[Any, Any], parent_class=None) -> Hit: + """ + Process a hit (i.e., a result) from an Elasticsearch response and transform it into a class instance. + + The transformation depends on the contents of the hit. If the document in hit contains a nested field, + the 'resolve_nested' method is used to determine the appropriate class (based on the nested path). + If the hit has a document type that is present in the '_doc_type_map', the corresponding class is + used. If not, the method iterates over the '_doc_type' classes and uses the first one whose '_matches' + method returns True for the hit. + + If the hit contains any 'inner_hits', these are also processed into 'ElasticSearchResponse' instances + using the determined class. + + Finally, the transformed hit is returned. If the determined class has a 'from_es' method, this is + used to transform the hit + """ + doc_class = Hit + dt = hit.get("_type") + + if "_nested" in hit: + doc_class = resolve_nested(hit, parent_class) + + elif dt in self._doc_type_map: + doc_class = self._doc_type_map[dt] + + else: + for doc_type in self._doc_type: + if hasattr(doc_type, "_matches") and doc_type._matches(hit): + doc_class = doc_type + break + + for t in hit.get("inner_hits", ()): + hit["inner_hits"][t] = ElasticSearchResponse(self, hit["inner_hits"][t], doc_class=doc_class) + + # callback should get the Hit class if "from_es" is not defined + callback: type[Hit] | Callable[..., Any] = getattr(doc_class, "from_es", doc_class) + return callback(hit) diff --git a/providers/elasticsearch/tests/conftest.py b/providers/elasticsearch/tests/conftest.py index f56ccce0a3f69..92b08d2ec0e0d 100644 --- a/providers/elasticsearch/tests/conftest.py +++ b/providers/elasticsearch/tests/conftest.py @@ -16,4 +16,47 @@ # under the License. from __future__ import annotations +import pytest +from elasticsearch import Elasticsearch +from testcontainers.elasticsearch import ElasticSearchContainer + pytest_plugins = "tests_common.pytest_plugin" + + +def _wait_for_cluster_ready(es: Elasticsearch, timeout_s: int = 120) -> None: + es.cluster.health(wait_for_status="yellow", timeout=f"{timeout_s}s") + + +def _ensure_index(es: Elasticsearch, index: str, timeout_s: int = 120) -> None: + if not es.indices.exists(index=index): + es.indices.create( + index=index, + settings={ + "index": { + "number_of_shards": 1, + "number_of_replicas": 0, + } + }, + ) + # Wait until the index primary is active + es.cluster.health(index=index, wait_for_status="yellow", timeout=f"{timeout_s}s") + + +@pytest.fixture(scope="session") +def es_8_container_url() -> str: + es = ( + ElasticSearchContainer("docker.elastic.co/elasticsearch/elasticsearch:8.19.0") + .with_env("discovery.type", "single-node") + .with_env("cluster.routing.allocation.disk.threshold_enabled", "false") + ) + with es: + url = es.get_url() + client = Elasticsearch( + url, + request_timeout=120, + retry_on_timeout=True, + max_retries=5, + ) + _wait_for_cluster_ready(client, timeout_s=120) + _ensure_index(client, "airflow-logs", timeout_s=120) + yield url diff --git a/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py b/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py index c7746001d6856..df101f2f918e4 100644 --- a/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py +++ b/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py @@ -80,7 +80,7 @@ def info(self, params=None): } @query_params() - def sample_log_response(self, headers=None, params=None): + def sample_airflow_2_log_response(self, headers=None, params=None): return { "_shards": {"failed": 0, "skipped": 0, "successful": 7, "total": 7}, "hits": { @@ -104,17 +104,16 @@ def sample_log_response(self, headers=None, params=None): "file": { "path": "/opt/airflow/Documents/GitHub/airflow/logs/" "dag_id=example_bash_operator'" - "/run_id=owen_run_run/task_id=run_after_loop/attempt=1.log" + "/run_id=run_run/task_id=run_after_loop/attempt=1.log" }, "offset": 0, }, "log.offset": 1688888863907337472, - "log_id": "example_bash_operator-run_after_loop-owen_run_run--1-1", + "log_id": "example_bash_operator-run_after_loop-run_run--1-1", "message": "Dependencies all met for " "dep_context=non-requeueable deps " "ti=", + "example_bash_operator.run_after_loop ", "task_id": "run_after_loop", "try_number": "1", }, @@ -139,12 +138,12 @@ def sample_log_response(self, headers=None, params=None): "file": { "path": "/opt/airflow/Documents/GitHub/airflow/logs/" "dag_id=example_bash_operator" - "/run_id=owen_run_run/task_id=run_after_loop/attempt=1.log" + "/run_id=run_run/task_id=run_after_loop/attempt=1.log" }, "offset": 988, }, "log.offset": 1688888863917961216, - "log_id": "example_bash_operator-run_after_loop-owen_run_run--1-1", + "log_id": "example_bash_operator-run_after_loop-run_run--1-1", "message": "Starting attempt 1 of 1", "task_id": "run_after_loop", "try_number": "1", @@ -170,12 +169,12 @@ def sample_log_response(self, headers=None, params=None): "file": { "path": "/opt/airflow/Documents/GitHub/airflow/logs/" "dag_id=example_bash_operator" - "/run_id=owen_run_run/task_id=run_after_loop/attempt=1.log" + "/run_id=run_run/task_id=run_after_loop/attempt=1.log" }, "offset": 1372, }, "log.offset": 1688888863928218880, - "log_id": "example_bash_operator-run_after_loop-owen_run_run--1-1", + "log_id": "example_bash_operator-run_after_loop-run_run--1-1", "message": "Executing on 2023-07-09 " "07:47:32+00:00", @@ -192,6 +191,118 @@ def sample_log_response(self, headers=None, params=None): "took": 7, } + @query_params() + def sample_airflow_3_log_response(self, headers=None, params=None): + return { + "_shards": {"failed": 0, "skipped": 0, "successful": 7, "total": 7}, + "hits": { + "hits": [ + { + "_id": "jdeZT4kBjAZqZnexVUxk", + "_index": ".ds-filebeat-8.8.2-2023.07.09-000001", + "_score": 2.482621, + "_source": { + "@timestamp": "2023-07-13T14:13:15.140Z", + "asctime": "2023-07-09T07:47:43.907+0000", + "container": {"id": "airflow"}, + "dag_id": "example_bash_operator", + "ecs": {"version": "8.0.0"}, + "execution_date": "2023_07_09T07_47_32_000000", + "filename": "taskinstance.py", + "input": {"type": "log"}, + "levelname": "INFO", + "lineno": 1144, + "log": { + "file": { + "path": "/opt/airflow/Documents/GitHub/airflow/logs/" + "dag_id=example_bash_operator'" + "/run_id=run_run/task_id=run_after_loop/attempt=1.log" + }, + "offset": 0, + }, + "log.offset": 1688888863907337472, + "log_id": "example_bash_operator-run_after_loop-run_run--1-1", + "task_id": "run_after_loop", + "try_number": "1", + "event": "Dependencies all met for " + "dep_context=non-requeueable deps " + "ti= on 2023-07-09 " + "07:47:32+00:00", + }, + "_type": "_doc", + }, + ], + "max_score": 2.482621, + "total": {"relation": "eq", "value": 36}, + }, + "timed_out": False, + "took": 7, + } + @query_params( "consistency", "op_type", @@ -479,7 +590,6 @@ def _validate_search_targets(self, targets, query): # TODO: support allow_no_indices query parameter matches = set() for target in targets: - print(f"Loop over:::target = {target}") if target in ("_all", ""): matches.update(self.__documents_dict) elif "*" in target: @@ -499,7 +609,6 @@ def _normalize_index_to_list(self, index, query): else: # Is it the correct exception to use ? raise ValueError("Invalid param 'index'") - generator = (target for index in searchable_indexes for target in index.split(",")) return list(self._validate_search_targets(generator, query=query)) diff --git a/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py b/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py index 90c0f1229711f..02ac65105e5b3 100644 --- a/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py +++ b/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py @@ -22,6 +22,8 @@ import os import re import shutil +import tempfile +import uuid from io import StringIO from pathlib import Path from unittest import mock @@ -36,7 +38,10 @@ from airflow.providers.elasticsearch.log.es_response import ElasticSearchResponse from airflow.providers.elasticsearch.log.es_task_handler import ( VALID_ES_CONFIG_KEYS, + ElasticsearchRemoteLogIO, ElasticsearchTaskHandler, + _clean_date, + _render_log_id, get_es_kwargs_from_config, getattr_nested, ) @@ -54,10 +59,11 @@ ES_PROVIDER_YAML_FILE = AIRFLOW_PROVIDERS_ROOT_PATH / "elasticsearch" / "provider.yaml" -def get_ti(dag_id, task_id, logical_date, create_task_instance): +def get_ti(dag_id, task_id, run_id, logical_date, create_task_instance): ti = create_task_instance( dag_id=dag_id, task_id=task_id, + run_id=run_id, logical_date=logical_date, dagrun_state=DagRunState.RUNNING, state=TaskInstanceState.RUNNING, @@ -70,9 +76,12 @@ def get_ti(dag_id, task_id, logical_date, create_task_instance): class TestElasticsearchTaskHandler: DAG_ID = "dag_for_testing_es_task_handler" TASK_ID = "task_for_testing_es_log_handler" + RUN_ID = "run_for_testing_es_log_handler" + MAP_INDEX = -1 + TRY_NUM = 1 LOGICAL_DATE = datetime(2016, 1, 1) - LOG_ID = f"{DAG_ID}-{TASK_ID}-2016-01-01T00:00:00+00:00-1" - JSON_LOG_ID = f"{DAG_ID}-{TASK_ID}-{ElasticsearchTaskHandler._clean_date(LOGICAL_DATE)}-1" + LOG_ID = f"{DAG_ID}-{TASK_ID}-{RUN_ID}-{MAP_INDEX}-{TRY_NUM}" + JSON_LOG_ID = f"{DAG_ID}-{TASK_ID}-{_clean_date(LOGICAL_DATE)}-1" FILENAME_TEMPLATE = "{try_number}.log" @pytest.fixture @@ -88,6 +97,7 @@ def ti(self, create_task_instance, create_log_template): yield get_ti( dag_id=self.DAG_ID, task_id=self.TASK_ID, + run_id=self.RUN_ID, logical_date=self.LOGICAL_DATE, create_task_instance=create_task_instance, ) @@ -128,21 +138,24 @@ def setup_method(self, method): def teardown_method(self): shutil.rmtree(self.local_log_location.split(os.path.sep)[0], ignore_errors=True) - def test_es_response(self): - sample_response = self.es.sample_log_response() - es_response = ElasticSearchResponse(self.es_task_handler, sample_response) - logs_by_host = self.es_task_handler._group_logs_by_host(es_response) - - def concat_logs(lines): - log_range = -1 if lines[-1].message == self.es_task_handler.end_of_log_mark else None - return "\n".join(self.es_task_handler._format_msg(line) for line in lines[:log_range]) + @pytest.mark.parametrize( + "sample_response", + [ + pytest.param(lambda self: self.es.sample_airflow_2_log_response(), id="airflow_2"), + pytest.param(lambda self: self.es.sample_airflow_3_log_response(), id="airflow_3"), + ], + ) + def test_es_response(self, sample_response): + response = sample_response(self) + es_response = ElasticSearchResponse(self.es_task_handler, response) + logs_by_host = self.es_task_handler.io._group_logs_by_host(es_response) for hosted_log in logs_by_host.values(): - message = concat_logs(hosted_log) + message = self.es_task_handler.concat_logs(hosted_log) assert ( message == "Dependencies all met for dep_context=non-requeueable" - " deps ti=\n" + " deps ti= " "on 2023-07-09 07:47:32+00:00" ) @@ -263,7 +276,7 @@ def test_read_with_patterns(self, ti): @pytest.mark.db_test def test_read_with_patterns_no_match(self, ti): ts = pendulum.now() - with mock.patch.object(self.es_task_handler, "index_patterns", new="test_other_*,test_another_*"): + with mock.patch.object(self.es_task_handler.io, "index_patterns", new="test_other_*,test_another_*"): logs, metadatas = self.es_task_handler.read( ti, 1, {"offset": 0, "last_log_timestamp": str(ts), "end_of_log": False} ) @@ -280,14 +293,14 @@ def test_read_with_patterns_no_match(self, ti): metadata = metadatas[0] assert metadata["offset"] == "0" - assert not metadata["end_of_log"] + assert metadata["end_of_log"] # last_log_timestamp won't change if no log lines read. assert timezone.parse(metadata["last_log_timestamp"]) == ts @pytest.mark.db_test def test_read_with_missing_index(self, ti): ts = pendulum.now() - with mock.patch.object(self.es_task_handler, "index_patterns", new="nonexistent,test_*"): + with mock.patch.object(self.es_task_handler.io, "index_patterns", new="nonexistent,test_*"): with pytest.raises(elasticsearch.exceptions.NotFoundError, match=r"IndexMissingException.*"): self.es_task_handler.read( ti, @@ -302,9 +315,11 @@ def test_read_missing_logs(self, seconds, create_task_instance): When the log actually isn't there to be found, we only want to wait for 5 seconds. In this case we expect to receive a message of the form 'Log {log_id} not found in elasticsearch ...' """ + run_id = "wrong_run_id" ti = get_ti( self.DAG_ID, self.TASK_ID, + run_id, pendulum.instance(self.LOGICAL_DATE).add(days=1), # so logs are not found create_task_instance=create_task_instance, ) @@ -320,7 +335,7 @@ def test_read_missing_logs(self, seconds, create_task_instance): else: # we've "waited" less than 5 seconds so it should not be "end of log" and should be no log message assert logs == [] - assert metadatas["end_of_log"] is False + assert metadatas["end_of_log"] is True assert metadatas["offset"] == "0" assert timezone.parse(metadatas["last_log_timestamp"]) == ts else: @@ -336,7 +351,7 @@ def test_read_missing_logs(self, seconds, create_task_instance): # we've "waited" less than 5 seconds so it should not be "end of log" and should be no log message assert len(logs[0]) == 0 assert logs == [[]] - assert metadatas[0]["end_of_log"] is False + assert metadatas[0]["end_of_log"] is True assert len(logs) == len(metadatas) assert metadatas[0]["offset"] == "0" assert timezone.parse(metadatas[0]["last_log_timestamp"]) == ts @@ -432,7 +447,7 @@ def test_read_nonexistent_log(self, ti): metadata = metadatas[0] assert metadata["offset"] == "0" - assert not metadata["end_of_log"] + assert metadata["end_of_log"] # last_log_timestamp won't change if no log lines read. assert timezone.parse(metadata["last_log_timestamp"]) == ts @@ -440,6 +455,7 @@ def test_read_nonexistent_log(self, ti): def test_read_with_empty_metadata(self, ti): ts = pendulum.now() logs, metadatas = self.es_task_handler.read(ti, 1, {}) + print(f"metadatas: {metadatas}") if AIRFLOW_V_3_0_PLUS: logs = list(logs) assert logs[0].event == "::group::Log message source details" @@ -455,7 +471,7 @@ def test_read_with_empty_metadata(self, ti): assert self.test_message == logs[0][0][-1] metadata = metadatas[0] - + print(f"metadatas: {metadatas}") assert not metadata["end_of_log"] # offset should be initialized to 0 if not provided. assert metadata["offset"] == "1" @@ -477,7 +493,7 @@ def test_read_with_empty_metadata(self, ti): metadata = metadatas[0] - assert not metadata["end_of_log"] + assert metadata["end_of_log"] # offset should be initialized to 0 if not provided. assert metadata["offset"] == "0" # last_log_timestamp will be initialized using log reading time @@ -552,27 +568,22 @@ def test_read_as_download_logs(self, ti): @pytest.mark.db_test def test_read_raises(self, ti): - with mock.patch.object(self.es_task_handler.log, "exception") as mock_exception: - with mock.patch.object(self.es_task_handler.client, "search") as mock_execute: + with mock.patch.object(self.es_task_handler.io.log, "exception") as mock_exception: + with mock.patch.object(self.es_task_handler.io.client, "search") as mock_execute: mock_execute.side_effect = SearchFailedException("Failed to read") - logs, metadatas = self.es_task_handler.read(ti, 1) + log_sources, log_msgs = self.es_task_handler.io.read("", ti) assert mock_exception.call_count == 1 args, kwargs = mock_exception.call_args assert "Could not read log with log_id:" in args[0] if AIRFLOW_V_3_0_PLUS: - assert logs == [] - - metadata = metadatas + assert log_sources == [] else: - assert len(logs) == 1 - assert len(logs) == len(metadatas) - assert logs == [[]] - - metadata = metadatas[0] + assert len(log_sources) == 0 + assert len(log_msgs) == 1 + assert log_sources == [] - assert metadata["offset"] == "0" - assert not metadata["end_of_log"] + assert "not found in Elasticsearch" in log_msgs[0] @pytest.mark.db_test def test_set_context(self, ti): @@ -616,9 +627,7 @@ def test_read_with_json_format(self, ti): logs = list(logs) assert logs[2].event == self.test_message else: - assert ( - logs[0][0][1] == "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff - " - ) + assert logs[0][0][1] == self.test_message @pytest.mark.db_test def test_read_with_json_format_with_custom_offset_and_host_fields(self, ti): @@ -634,7 +643,7 @@ def test_read_with_json_format_with_custom_offset_and_host_fields(self, ti): self.body = { "message": self.test_message, "event": self.test_message, - "log_id": f"{self.DAG_ID}-{self.TASK_ID}-2016_01_01T00_00_00_000000-1", + "log_id": self.LOG_ID, "log": {"offset": 1}, "host": {"name": "somehostname"}, "asctime": "2020-12-24 19:25:00,962", @@ -652,9 +661,7 @@ def test_read_with_json_format_with_custom_offset_and_host_fields(self, ti): logs = list(logs) assert logs[2].event == self.test_message else: - assert ( - logs[0][0][1] == "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff - " - ) + assert logs[0][0][1] == self.test_message @pytest.mark.db_test def test_read_with_custom_offset_and_host_fields(self, ti): @@ -753,13 +760,13 @@ def test_close_with_no_stream(self, ti): @pytest.mark.db_test def test_render_log_id(self, ti): - assert self.es_task_handler._render_log_id(ti, 1) == self.LOG_ID + assert _render_log_id(self.es_task_handler.log_id_template, ti, 1) == self.LOG_ID self.es_task_handler.json_format = True - assert self.es_task_handler._render_log_id(ti, 1) == self.JSON_LOG_ID + assert _render_log_id(self.es_task_handler.log_id_template, ti, 1) == self.LOG_ID def test_clean_date(self): - clean_logical_date = self.es_task_handler._clean_date(datetime(2016, 7, 8, 9, 10, 11, 12)) + clean_logical_date = _clean_date(datetime(2016, 7, 8, 9, 10, 11, 12)) assert clean_logical_date == "2016_07_08T09_10_11_000012" @pytest.mark.db_test @@ -770,7 +777,7 @@ def test_clean_date(self): ( True, "localhost:5601/{log_id}", - "https://localhost:5601/" + quote(JSON_LOG_ID), + "https://localhost:5601/" + quote(LOG_ID), ), ( False, @@ -867,8 +874,8 @@ def test_get_index_patterns_with_callable(self): mock_callable = Mock(return_value="callable_index_pattern") mock_import_string.return_value = mock_callable - self.es_task_handler.index_patterns_callable = "path.to.index_pattern_callable" - result = self.es_task_handler._get_index_patterns({}) + self.es_task_handler.io.index_patterns_callable = "path.to.index_pattern_callable" + result = self.es_task_handler.io._get_index_patterns({}) mock_import_string.assert_called_once_with("path.to.index_pattern_callable") mock_callable.assert_called_once_with({}) @@ -885,25 +892,6 @@ def test_filename_template_for_backward_compatibility(self): filename_template=None, ) - @pytest.mark.db_test - def test_write_to_es(self, ti): - self.es_task_handler.write_to_es = True - self.es_task_handler.json_format = True - self.es_task_handler.write_stdout = False - self.es_task_handler.local_base = Path(os.getcwd()) / "local" / "log" / "location" - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - self.es_task_handler.formatter = formatter - - self.es_task_handler.set_context(ti) - with patch( - "airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler._write_to_es" - ) as mock_write_to_es: - mock_write = Mock(return_value=True) - mock_write_to_es.return_value = mock_write - self.es_task_handler._write_to_es = mock_write_to_es - self.es_task_handler.close() - mock_write_to_es.assert_called_once() - def test_safe_attrgetter(): class A: ... @@ -963,3 +951,133 @@ def test_self_not_valid_arg(): Test if self is not a valid argument. """ assert "self" not in VALID_ES_CONFIG_KEYS + + +@pytest.mark.db_test +class TestElasticsearchRemoteLogIO: + DAG_ID = "dag_for_testing_es_log_handler" + TASK_ID = "task_for_testing_es_log_handler" + RUN_ID = "run_for_testing_es_log_handler" + LOGICAL_DATE = datetime(2016, 1, 1) + FILENAME_TEMPLATE = "{try_number}.log" + + @pytest.fixture(autouse=True) + def setup_tests(self, ti, es_8_container_url): + self.elasticsearch_8_url = es_8_container_url + self.elasticsearch_io = ElasticsearchRemoteLogIO( + write_to_es=True, + write_stdout=True, + delete_local_copy=True, + host=es_8_container_url, + base_log_folder=Path(""), + ) + + @pytest.fixture + def tmp_json_file(self): + with tempfile.TemporaryDirectory() as tmpdir: + os.makedirs(tmpdir, exist_ok=True) + + file_path = os.path.join(tmpdir, "1.log") + self.tmp_file = file_path + + sample_logs = [ + {"message": "start"}, + {"message": "processing"}, + {"message": "end"}, + ] + with open(file_path, "w") as f: + for log in sample_logs: + f.write(json.dumps(log) + "\n") + + yield file_path + + del self.tmp_file + + @pytest.fixture + def ti(self, create_task_instance, create_log_template): + create_log_template( + self.FILENAME_TEMPLATE, + ( + "{dag_id}-{task_id}-{logical_date}-{try_number}" + if AIRFLOW_V_3_0_PLUS + else "{dag_id}-{task_id}-{execution_date}-{try_number}" + ), + ) + yield get_ti( + dag_id=self.DAG_ID, + task_id=self.TASK_ID, + run_id=self.RUN_ID, + logical_date=self.LOGICAL_DATE, + create_task_instance=create_task_instance, + ) + clear_db_runs() + clear_db_dags() + + @pytest.fixture + def unique_index(self): + """Generate a unique index name for each test.""" + return f"airflow-logs-{uuid.uuid4()}" + + @pytest.mark.setup_timeout(300) + @pytest.mark.execution_timeout(300) + @patch( + "airflow.providers.elasticsearch.log.es_task_handler.TASK_LOG_FIELDS", + ["message"], + ) + def test_read_write_to_es(self, tmp_json_file, ti): + self.elasticsearch_io.client = self.elasticsearch_io.client.options( + request_timeout=120, retry_on_timeout=True, max_retries=5 + ) + self.elasticsearch_io.write_stdout = False + self.elasticsearch_io.upload(tmp_json_file, ti) + self.elasticsearch_io.client.indices.refresh( + index=self.elasticsearch_io.target_index, request_timeout=120 + ) + log_source_info, log_messages = self.elasticsearch_io.read("", ti) + assert log_source_info[0] == self.elasticsearch_8_url + assert len(log_messages) == 3 + + expected_msg = ["start", "processing", "end"] + for msg, log_message in zip(expected_msg, log_messages): + print(f"msg: {msg}, log_message: {log_message}") + json_log = json.loads(log_message) + assert "message" in json_log + assert json_log["message"] == msg + + def test_write_to_stdout(self, tmp_json_file, ti, capsys): + self.elasticsearch_io.write_to_es = False + self.elasticsearch_io.upload(tmp_json_file, ti) + + captured = capsys.readouterr() + stdout_lines = captured.out.strip().splitlines() + log_entries = [json.loads(line) for line in stdout_lines] + assert log_entries[0]["message"] == "start" + assert log_entries[1]["message"] == "processing" + assert log_entries[2]["message"] == "end" + + def test_invalid_task_log_file_path(self, ti): + with ( + patch.object(self.elasticsearch_io, "_parse_raw_log") as mock_parse, + patch.object(self.elasticsearch_io, "_write_to_es") as mock_write, + ): + self.elasticsearch_io.upload(Path("/invalid/path"), ti) + + mock_parse.assert_not_called() + mock_write.assert_not_called() + + def test_raw_log_should_contain_log_id_and_offset(self, tmp_json_file, ti): + with open(self.tmp_file) as f: + raw_log = f.read() + json_log_lines = self.elasticsearch_io._parse_raw_log(raw_log, ti) + assert len(json_log_lines) == 3 + for json_log_line in json_log_lines: + assert "log_id" in json_log_line + assert "offset" in json_log_line + + @patch("elasticsearch.Elasticsearch.count", return_value={"count": 0}) + def test_read_with_missing_log(self, mocked_count, ti): + log_source_info, log_messages = self.elasticsearch_io.read("", ti) + log_id = _render_log_id(self.elasticsearch_io.log_id_template, ti, ti.try_number) + assert log_source_info == [] + assert f"*** Log {log_id} not found in Elasticsearch" in log_messages[0] + mocked_count.assert_called_once() diff --git a/providers/mongo/pyproject.toml b/providers/mongo/pyproject.toml index 4508bdda07f1b..1f31082b2a79d 100644 --- a/providers/mongo/pyproject.toml +++ b/providers/mongo/pyproject.toml @@ -70,7 +70,6 @@ dev = [ "apache-airflow-devel-common", "apache-airflow-providers-common-compat", # Additional devel dependencies (do not remove this line and add extra development dependencies) - "testcontainers>=4.12.0" ] # To build docs: