Skip to content

Commit

Permalink
Update elastic.py adapter (#92)
Browse files Browse the repository at this point in the history
Fix AttributeError if invalid uri is given, add `verify_certs` flag to optional arguments.
Also adds `hash_record` argument for making every document unique.
  • Loading branch information
0xbart authored Oct 26, 2023
1 parent fdcecba commit 6144cf4
Showing 1 changed file with 46 additions and 8 deletions.
54 changes: 46 additions & 8 deletions flow/record/adapter/elastic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging
import queue
import threading
Expand All @@ -18,25 +19,45 @@
Write usage: rdump -w elastic+[PROTOCOL]://[IP]:[PORT]?index=[INDEX]
Read usage: rdump elastic+[PROTOCOL]://[IP]:[PORT]?index=[INDEX]
[IP]:[PORT]: ip and port to elastic host
[INDEX]: index to write to or read from
[PROTOCOL]: http or https. Defaults to https when "+[PROTOCOL]" is omitted
Optional arguments:
[INDEX]: name of the index to use (default: records)
[VERIFY_CERTS]: verify certs of Elasticsearch instance (default: True)
[HASH_RECORD]: make record unique by hashing record [slow] (default: False)
"""

log = logging.getLogger(__name__)


class ElasticWriter(AbstractWriter):
def __init__(self, uri: str, index: str = "records", http_compress: Union[str, bool] = True, **kwargs) -> None:
def __init__(
self,
uri: str,
index: str = "records",
verify_certs: Union[str, bool] = True,
http_compress: Union[str, bool] = True,
hash_record: Union[str, bool] = False,
**kwargs,
) -> None:
self.index = index
self.uri = uri
verify_certs = str(verify_certs).lower() in ("1", "true")
http_compress = str(http_compress).lower() in ("1", "true")
self.es = elasticsearch.Elasticsearch(uri, http_compress=http_compress)
self.hash_record = str(hash_record).lower() in ("1", "true")
self.es = elasticsearch.Elasticsearch(uri, verify_certs=verify_certs, http_compress=http_compress)
self.json_packer = JsonRecordPacker()
self.queue: queue.Queue[Union[Record, StopIteration]] = queue.Queue()
self.event = threading.Event()
self.thread = threading.Thread(target=self.streaming_bulk_thread)
self.thread.start()

if not verify_certs:
# Disable InsecureRequestWarning of urllib3, caused by the verify_certs flag.
import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

def record_to_document(self, record: Record, index: str) -> dict:
"""Convert a record to a Elasticsearch compatible document dictionary"""
rdict = record._asdict()
Expand All @@ -52,12 +73,19 @@ def record_to_document(self, record: Record, index: str) -> dict:
dunder_keys = [key for key in rdict if key.startswith("_")]
for key in dunder_keys:
rdict_meta[key.lstrip("_")] = rdict.pop(key)
# remove _generated field from metadata to ensure determinstic documents
if self.hash_record:
rdict_meta.pop("generated", None)
rdict["_record_metadata"] = rdict_meta

document = {
"_index": index,
"_source": self.json_packer.pack(rdict),
}

if self.hash_record:
document["_id"] = hashlib.md5(document["_source"].encode()).hexdigest()

return document

def document_stream(self) -> Iterator[dict]:
Expand Down Expand Up @@ -87,25 +115,34 @@ def flush(self) -> None:
pass

def close(self) -> None:
self.queue.put(StopIteration)
self.event.wait()
self.es.close()
if hasattr(self, "es"):
self.queue.put(StopIteration)
self.event.wait()
self.es.close()


class ElasticReader(AbstractReader):
def __init__(
self,
uri: str,
index: str = "records",
verify_certs: Union[str, bool] = True,
http_compress: Union[str, bool] = True,
selector: Union[None, Selector, CompiledSelector] = None,
**kwargs,
) -> None:
self.index = index
self.uri = uri
self.selector = selector
verify_certs = str(verify_certs).lower() in ("1", "true")
http_compress = str(http_compress).lower() in ("1", "true")
self.es = elasticsearch.Elasticsearch(uri, http_compress=http_compress)
self.es = elasticsearch.Elasticsearch(uri, verify_certs=verify_certs, http_compress=http_compress)

if not verify_certs:
# Disable InsecureRequestWarning of urllib3, caused by the verify_certs flag.
import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

def __iter__(self) -> Iterator[Record]:
res = self.es.search(index=self.index)
Expand All @@ -121,4 +158,5 @@ def __iter__(self) -> Iterator[Record]:
yield obj

def close(self) -> None:
self.es.close()
if hasattr(self, "es"):
self.es.close()

0 comments on commit 6144cf4

Please sign in to comment.