From c482853a26f7ce541505266bfa325c5aa187fb91 Mon Sep 17 00:00:00 2001 From: Computer Network Investigation <121175071+JSCU-CNI@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:49:59 +0100 Subject: [PATCH] Improve elastic adapter exception handling (#150) --- flow/record/adapter/elastic.py | 43 ++++++++++++++++++++++++---------- flow/record/tools/rdump.py | 7 +++++- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/flow/record/adapter/elastic.py b/flow/record/adapter/elastic.py index e1163c7..3643cec 100644 --- a/flow/record/adapter/elastic.py +++ b/flow/record/adapter/elastic.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import hashlib import logging import queue import threading -from typing import Iterator, Optional, Union +from typing import Iterator import elasticsearch import elasticsearch.helpers @@ -37,10 +39,10 @@ def __init__( self, uri: str, index: str = "records", - verify_certs: Union[str, bool] = True, - http_compress: Union[str, bool] = True, - hash_record: Union[str, bool] = False, - api_key: Optional[str] = None, + verify_certs: str | bool = True, + http_compress: str | bool = True, + hash_record: str | bool = False, + api_key: str | None = None, **kwargs, ) -> None: self.index = index @@ -52,6 +54,9 @@ def __init__( if not uri.lower().startswith(("http://", "https://")): uri = "http://" + uri + self.queue: queue.Queue[Record | StopIteration] = queue.Queue() + self.event = threading.Event() + self.es = elasticsearch.Elasticsearch( uri, verify_certs=verify_certs, @@ -60,10 +65,11 @@ def __init__( ) self.json_packer = JsonRecordPacker() - self.queue: queue.Queue[Union[Record, StopIteration]] = queue.Queue() - self.event = threading.Event() + self.thread = threading.Thread(target=self.streaming_bulk_thread) self.thread.start() + self.exception: Exception | None = None + threading.excepthook = self.excepthook if not verify_certs: # Disable InsecureRequestWarning of urllib3, caused by the verify_certs flag. @@ -76,6 +82,12 @@ def __init__( if arg_key.startswith("_meta_"): self.metadata_fields[arg_key[6:]] = arg_val + def excepthook(self, exc: threading.ExceptHookArgs, *args, **kwargs) -> None: + log.error("Exception in thread: %s", exc.exc_value.message) + self.exception = exc.exc_value + self.event.set() + self.close() + def record_to_document(self, record: Record, index: str) -> dict: """Convert a record to a Elasticsearch compatible document dictionary""" rdict = record._asdict() @@ -120,6 +132,7 @@ def document_stream(self) -> Iterator[dict]: def streaming_bulk_thread(self) -> None: """Thread that streams the documents to ES via the bulk api""" + for ok, item in elasticsearch.helpers.streaming_bulk( self.es, self.document_stream(), @@ -138,21 +151,25 @@ def flush(self) -> None: pass def close(self) -> None: + self.queue.put(StopIteration) + self.event.wait() + if hasattr(self, "es"): - self.queue.put(StopIteration) - self.event.wait() self.es.close() + if self.exception: + raise self.exception + class ElasticReader(AbstractReader): def __init__( self, uri: str, index: str = "records", - verify_certs: Union[str, bool] = True, - http_compress: Union[str, bool] = True, - selector: Union[None, Selector, CompiledSelector] = None, - api_key: Optional[str] = None, + verify_certs: str | bool = True, + http_compress: str | bool = True, + selector: None | Selector | CompiledSelector = None, + api_key: str | None = None, **kwargs, ) -> None: self.index = index diff --git a/flow/record/tools/rdump.py b/flow/record/tools/rdump.py index c97ae22..57bec72 100644 --- a/flow/record/tools/rdump.py +++ b/flow/record/tools/rdump.py @@ -218,7 +218,9 @@ def main(argv=None): islice_stop = (args.count + args.skip) if args.count else None record_iterator = islice(record_stream(args.src, selector), args.skip, islice_stop) count = 0 - with RecordWriter(uri) as record_writer: + + try: + record_writer = RecordWriter(uri) for count, rec in enumerate(record_iterator, start=1): if args.record_source is not None: rec._source = args.record_source @@ -243,6 +245,9 @@ def main(argv=None): else: record_writer.write(rec) + finally: + record_writer.__exit__() + if args.list: print("Processed {} records".format(count))