Skip to content

Commit

Permalink
Support reading from a range
Browse files Browse the repository at this point in the history
  • Loading branch information
Garrett McGrath committed Jun 29, 2023
1 parent 272039b commit a2a7172
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 18 deletions.
53 changes: 48 additions & 5 deletions per-message-s3-exporter/firehose_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,36 @@ def from_env(cls):
keepalive = int(os.environ.get("KEEPALIVE", "60"))
keepalive_stale_pitrs = int(os.environ.get("KEEPALIVE_STALE_PITRS", "5"))
init_time = os.environ.get("INIT_CMD_TIME", "live")
init_time_split = init_time.split()

if init_time.split()[0] not in ["live", "pitr"]:
raise ValueError('$INIT_CMD_TIME value is invalid, should be "live" or "pitr <pitr>"')
if init_time_split[0] not in ("live", "pitr", "range"):
raise ValueError(
'$INIT_CMD_TIME value is invalid, should be "live", '
'"pitr <pitr>" or "range <start> <end>"'
)

pitr_map = pitr_map_from_file(init_time)
if pitr_map:
min_pitr = min(pitr_map.values())
logging.info(f"Based on PITR map {pitr_map}")
logging.info(f"Using {min_pitr} ({format_epoch(min_pitr)}) as starting PITR value")
init_time = f"pitr {min_pitr}"

if "pitr" in init_time:
init_time = f"pitr {min_pitr}"
elif "range" in init_time:
init_time_split[1] = f"{min_pitr}"
init_time = " ".join(init_time_split)

init_args = os.environ.get("INIT_CMD_ARGS", "")
for command in ["live", "pitr", "compression", "keepalive", "username", "password"]:
for command in [
"live",
"pitr",
"range",
"compression",
"keepalive",
"username",
"password",
]:
if command in init_args.split():
raise ValueError(
f'$INIT_CMD_ARGS should not contain the "{command}" command. '
Expand Down Expand Up @@ -287,6 +304,13 @@ def connection_error_limit(self) -> int:
"""How many Firehose read errors before stopping"""
return int(os.environ.get("CONNECTION_ERROR_LIMIT", "3"))

async def _shutdown(self):
"""When a range of PITRs is requested, we send a special shutdown
message to every queue to end cleanly"""
logging.info("Initiating shutdown procedure: propagating signal to per-message queues")
for queue in self.message_queues.values():
await queue.put(None)

async def read_firehose(self):
"""Read Firehose until a threshold number of errors occurs"""
await self._stats.update_stats(None, 0)
Expand All @@ -296,11 +320,23 @@ async def read_firehose(self):
errors = 0

time_mode = self.config.init_time
reached_the_end = False

while True:
pitr = await self._read_until_error(time_mode)
if pitr:
time_mode = f"pitr {pitr}"
time_mode_split = time_mode.split()
if time_mode_split[0] in ("live", "pitr"):
time_mode = f"pitr {pitr}"
else:
if pitr >= int(time_mode_split[-1]):
logging.info("Reached the end of the range")
reached_the_end = True
break

time_mode_split[1] = f"{pitr}"
time_mode = " ".join(time_mode_split)

logging.info(f'Reconnecting with "{time_mode}"')
errors = 0
elif errors < error_limit - 1:
Expand All @@ -319,6 +355,9 @@ async def read_firehose(self):
self._stats.finish()
await asyncio.wait_for(stats_task, self.stats_period)

if reached_the_end:
return await self._shutdown()

raise ReadFirehoseErrorThreshold

async def _open_connection(
Expand Down Expand Up @@ -361,6 +400,10 @@ async def _read_until_error(self, time_mode: str) -> Optional[str]:
time_mode may be either the string "live" or a pitr string that looks like
"pitr <pitr>" where <pitr> is a value previously returned by this function
or a pitr string that looks like "range <start> <end>". In the case of
a range, if we get to the end value we stop cleanly and shutdown,
otherwise we can resume from a start value previously returned by this
function.
"""

context = ssl.create_default_context()
Expand Down
58 changes: 45 additions & 13 deletions per-message-s3-exporter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pathlib import Path
from signal import Signals, SIGINT, SIGTERM
import sys
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import aiofiles
import attr
Expand Down Expand Up @@ -249,25 +249,31 @@ def folder_prefix(self) -> str:

return f"{self.args.s3_bucket_folder}/"

def record_pitr(self, record: str) -> int:
def record_pitr(self, record: Optional[str]) -> int:
"""Return the PITR for a Firehose message"""
if record is None:
record = self._current_batch[-1]
return int(json.loads(record)["pitr"])

async def ingest_record(self, record: str):
async def ingest_record(self, record: Optional[str]):
"""Ingest a record from Firehose, adding it to the current batch and
writing a file to the S3 writer queue if necessary
"""
if not self._current_batch:
self._start_pitr = self.record_pitr(record)
self._timer.start()
# An empty line is a signal that we need to shutdown
shutdown = record is None

# Even though we might exceed the max bytes, that's okay since it's
# only a rough threshold that we strive to maintain here and the number
# of records is more strictly adhered to
self._current_batch.append(record)
self._current_batch_bytes += len(record)
if not shutdown:
if not self._current_batch:
self._start_pitr = self.record_pitr(record)
self._timer.start()

if self.should_write_batch_to_file():
self._current_batch.append(record)
self._current_batch_bytes += len(record)

if (shutdown and self.batch_length > 0) or self.should_write_batch_to_file():
self._end_pitr = self.record_pitr(record)
await self.enqueue_batch_contents()

Expand All @@ -277,6 +283,11 @@ async def ingest_record(self, record: str):
self._current_batch = []
self._current_batch_bytes = 0

# Propagate shutdown signal
if shutdown:
logging.info(f"Shutting down {self.message_type} queue: sending signal to S3 writer")
await self.s3_writer_queue.put(None)

def should_write_batch_to_file(self) -> bool:
"""Whether the current batch needs to be written to an S3 file
In order to see less common message types, the bytes hit will be
Expand All @@ -291,6 +302,10 @@ def should_write_batch_to_file(self) -> bool:

async def enqueue_batch_contents(self):
"""Write the current batch of records to the S3 writer's queue"""
if self.batch_length == 0:
logging.warning(f"Current batch for {self.message_type} is empty, skipping")
return

filename = self.batch_filename()

file_contents = b"".join(self._current_batch)
Expand All @@ -304,6 +319,7 @@ async def enqueue_batch_contents(self):
end_pitr=self._end_pitr,
)

logging.info(f"Writing a batch to the S3 writer for {self.message_type}")
await self.s3_writer_queue.put(s3_object)

def _s3_bucket_folder(self) -> str:
Expand Down Expand Up @@ -344,10 +360,15 @@ async def build_batch_of_records_from_firehose(
while True:
# Use a "blocking" await on the queue with Firehose messages which will
# wait indefinitely until data shows up in the queue
firehose_message = await firehose_queue.get()
firehose_message: Optional[str] = await firehose_queue.get()

await batcher.ingest_record(firehose_message)
firehose_queue.task_done()

# We've reached the end of the PITR range and need to shutdown
if firehose_message is None:
break


async def load_pitr_map(pitr_map_path: Path) -> Dict[str, int]:
"""Load the PITR map from disk if available. Returns an empty dict if
Expand Down Expand Up @@ -384,8 +405,19 @@ async def write_files_to_s3(
pitr_map: Dict[str, int] = await load_pitr_map(args.pitr_map)
pitr_map = {message_type: int(pitr) for message_type, pitr in pitr_map.items()}

while True:
s3_write_object: S3WriteObject = await s3_queue.get()
# Keep track of how many shutdown signals we receive for the case where
# we're only ingesting a range of values and not processing files
# indefinitely
total_shutdown_signals = len(FIREHOSE_MESSAGE_TYPES)
shutdown_signals_recvd = 0

while shutdown_signals_recvd < total_shutdown_signals:
s3_write_object: Optional[S3WriteObject] = await s3_queue.get()

# Check for a shutdown signal
if s3_write_object is None:
shutdown_signals_recvd += 1
continue

# Get some timing stats on how long it takes to write to S3
timer = Timer(
Expand Down Expand Up @@ -451,7 +483,7 @@ async def main(args: ap.Namespace):
# Use a single S3 file writer for all message types
tasks.append(write_files_to_s3(args, executor, s3_writer_queue))

# Run all the tasks in the event loop
# Run all the tasks in the event loop to completion
await asyncio.gather(*tasks, return_exceptions=False)


Expand Down

0 comments on commit a2a7172

Please sign in to comment.