Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Write Bus Performance metrics to S3 Bucket #451

Merged
merged 5 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.black]
line-length = 80
line-length = 120
target-version = ['py310']

[tool.mypy]
Expand Down Expand Up @@ -104,7 +104,7 @@ disable = [
"too-many-lines",
]
good-names = ["e", "i", "s"]
max-line-length = 80
max-line-length = 120
min-similarity-lines = 10
# ignore session maker as it gives pylint fits
# https://github.com/PyCQA/pylint/issues/7090
Expand Down
8 changes: 2 additions & 6 deletions src/lamp_py/aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,15 @@ def check_for_parallel_tasks() -> None:
# count matches the ecs task group.
match_count = 0
if task_arns:
running_tasks = client.describe_tasks(
cluster=ecs_cluster, tasks=task_arns
)["tasks"]
running_tasks = client.describe_tasks(cluster=ecs_cluster, tasks=task_arns)["tasks"]

for task in running_tasks:
if ecs_task_group == task["group"]:
match_count += 1

# if the group matches, raise an exception that will terminate the process
if match_count > 1:
raise SystemError(
f"Multiple {ecs_task_group} ECS Tasks Running in {ecs_cluster}"
)
raise SystemError(f"Multiple {ecs_task_group} ECS Tasks Running in {ecs_cluster}")

except Exception as exception:
process_logger.log_failure(exception)
Expand Down
24 changes: 6 additions & 18 deletions src/lamp_py/aws/kinesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@ def update_shard_id(self) -> None:
Get the stream description and the shard id for the first shard in the
Kinesis Stream. Throws if the stream has more than one shard.
"""
process_logger = ProcessLogger(
process_name="update_shard_id", stream_name=self.stream_name
)
process_logger = ProcessLogger(process_name="update_shard_id", stream_name=self.stream_name)
process_logger.log_start()

# Describe the stream and pull out the shard IDs
stream_description = self.kinesis_client.describe_stream(
StreamName=self.stream_name
)
stream_description = self.kinesis_client.describe_stream(StreamName=self.stream_name)
shards = stream_description["StreamDescription"]["Shards"]

# Per conversation with Glides, their Kinesis Stream only consists of a
Expand All @@ -54,9 +50,7 @@ def update_shard_iterator(self) -> None:
that case, get the Trim Horizon iterator which is the oldest one in the
shard. Otherwise, get the next iterator after the last sequence number.
"""
process_logger = ProcessLogger(
process_name="update_shard_iterator", stream_name=self.stream_name
)
process_logger = ProcessLogger(process_name="update_shard_iterator", stream_name=self.stream_name)
process_logger.log_start()

if self.shard_id is None:
Expand Down Expand Up @@ -90,9 +84,7 @@ def get_records(self) -> List[Dict]:
records that can be processed and the next shard iterator to use for the
next read.
"""
process_logger = ProcessLogger(
process_name="kinesis.get_records", stream_name=self.stream_name
)
process_logger = ProcessLogger(process_name="kinesis.get_records", stream_name=self.stream_name)
process_logger.log_start()

all_records = []
Expand All @@ -106,9 +98,7 @@ def get_records(self) -> List[Dict]:

while True:
try:
response = self.kinesis_client.get_records(
ShardIterator=self.shard_iterator
)
response = self.kinesis_client.get_records(ShardIterator=self.shard_iterator)
shard_count += 1
self.shard_iterator = response["NextShardIterator"]
records = response["Records"]
Expand All @@ -125,9 +115,7 @@ def get_records(self) -> List[Dict]:
except self.kinesis_client.exceptions.ExpiredIteratorException:
self.update_shard_iterator()

process_logger.add_metadata(
record_count=len(all_records), shard_count=shard_count
)
process_logger.add_metadata(record_count=len(all_records), shard_count=shard_count)
except Exception as e:
process_logger.log_failure(e)

Expand Down
73 changes: 24 additions & 49 deletions src/lamp_py/aws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def get_s3_client() -> boto3.client:
return boto3.client("s3")


def upload_file(
file_name: str, object_path: str, extra_args: Optional[Dict] = None
) -> bool:
def upload_file(file_name: str, object_path: str, extra_args: Optional[Dict] = None) -> bool:
"""
Upload a local file to an S3 Bucket

Expand Down Expand Up @@ -66,9 +64,7 @@ def upload_file(

s3_client = get_s3_client()

s3_client.upload_file(
file_name, bucket, object_name, ExtraArgs=extra_args
)
s3_client.upload_file(file_name, bucket, object_name, ExtraArgs=extra_args)

upload_log.log_complete()

Expand Down Expand Up @@ -270,9 +266,7 @@ def file_list_from_s3(
object path as s3://bucket-name/object-key
]
"""
process_logger = ProcessLogger(
"file_list_from_s3", bucket_name=bucket_name, file_prefix=file_prefix
)
process_logger = ProcessLogger("file_list_from_s3", bucket_name=bucket_name, file_prefix=file_prefix)
process_logger.log_start()

try:
Expand All @@ -288,9 +282,7 @@ def file_list_from_s3(
if obj["Size"] == 0:
continue
if in_filter is None or in_filter in obj["Key"]:
filepaths.append(
os.path.join("s3://", bucket_name, obj["Key"])
)
filepaths.append(os.path.join("s3://", bucket_name, obj["Key"]))

if len(filepaths) > max_list_size:
break
Expand All @@ -303,9 +295,7 @@ def file_list_from_s3(
return []


def file_list_from_s3_with_details(
bucket_name: str, file_prefix: str
) -> List[Dict]:
def file_list_from_s3_with_details(bucket_name: str, file_prefix: str) -> List[Dict]:
"""
get a list of s3 objects with additional details

Expand Down Expand Up @@ -341,9 +331,7 @@ def file_list_from_s3_with_details(
continue
filepaths.append(
{
"s3_obj_path": os.path.join(
"s3://", bucket_name, obj["Key"]
),
"s3_obj_path": os.path.join("s3://", bucket_name, obj["Key"]),
"size_bytes": obj["Size"],
"last_modified": obj["LastModified"],
}
Expand All @@ -357,16 +345,12 @@ def file_list_from_s3_with_details(
return []


def get_last_modified_object(
bucket_name: str, file_prefix: str, version: Optional[str] = None
) -> Optional[Dict]:
def get_last_modified_object(bucket_name: str, file_prefix: str, version: Optional[str] = None) -> Optional[Dict]:
"""
For a given bucket, find the last modified object that matches a prefix. If
a version is passed, only return the newest object matching this version.
"""
files = file_list_from_s3_with_details(
bucket_name=bucket_name, file_prefix=file_prefix
)
files = file_list_from_s3_with_details(bucket_name=bucket_name, file_prefix=file_prefix)

# sort the objects by last modified
files.sort(key=lambda o: o["last_modified"], reverse=True)
Expand Down Expand Up @@ -453,9 +437,7 @@ def _init_process_session() -> None:
"""
process_data = current_thread()
process_data.__dict__["boto_session"] = boto3.session.Session()
process_data.__dict__["boto_s3_resource"] = process_data.__dict__[
"boto_session"
].resource("s3")
process_data.__dict__["boto_s3_resource"] = process_data.__dict__["boto_session"].resource("s3")


# pylint: disable=R0914
Expand Down Expand Up @@ -495,13 +477,9 @@ def move_s3_objects(files: List[str], to_bucket: str) -> List[str]:
process_logger.add_metadata(pool_size=pool_size)
results = []
try:
with ThreadPoolExecutor(
max_workers=pool_size, initializer=_init_process_session
) as pool:
with ThreadPoolExecutor(max_workers=pool_size, initializer=_init_process_session) as pool:
for filename in files_to_move:
results.append(
pool.submit(_move_s3_object, filename, to_bucket)
)
results.append(pool.submit(_move_s3_object, filename, to_bucket))
for result in results:
current_result = result.result()
if isinstance(current_result, str):
Expand All @@ -517,9 +495,7 @@ def move_s3_objects(files: List[str], to_bucket: str) -> List[str]:
# wait for gremlins to disappear
time.sleep(15)

process_logger.add_metadata(
failed_count=len(files_to_move), retry_attempts=retry_attempt
)
process_logger.add_metadata(failed_count=len(files_to_move), retry_attempts=retry_attempt)

if len(files_to_move) == 0:
process_logger.log_complete()
Expand Down Expand Up @@ -575,9 +551,7 @@ def write_parquet_file(
@filename - if set, the filename that will be written to. if left empty,
the basename template (or _its_ fallback) will be used.
"""
process_logger = ProcessLogger(
"write_parquet", file_type=file_type, number_of_rows=table.num_rows
)
process_logger = ProcessLogger("write_parquet", file_type=file_type, number_of_rows=table.num_rows)
process_logger.log_start()

# pull out the partition information into a list of strings.
Expand All @@ -588,9 +562,7 @@ def write_parquet_file(
for col in partition_cols:
unique_list = pc.unique(table.column(col)).to_pylist()

assert (
len(unique_list) == 1
), f"Table {s3_dir} column {col} had {len(unique_list)} unique elements"
assert len(unique_list) == 1, f"Table {s3_dir} column {col} had {len(unique_list)} unique elements"

partition_strings.append(f"{col}={unique_list[0]}")

Expand All @@ -609,9 +581,7 @@ def write_parquet_file(
process_logger.add_metadata(write_path=write_path)

# write teh parquet file to the partitioned path
with pq.ParquetWriter(
where=write_path, schema=table.schema, filesystem=fs.S3FileSystem()
) as pq_writer:
with pq.ParquetWriter(where=write_path, schema=table.schema, filesystem=fs.S3FileSystem()) as pq_writer:
pq_writer.write(table)

# call the visitor function if it exists
Expand All @@ -624,9 +594,16 @@ def write_parquet_file(
# pylint: enable=R0913


def get_datetime_from_partition_path(path: str) -> datetime:
def dt_from_obj_path(path: str) -> datetime:
"""
process and return datetime from partitioned s3 path

handles the following formats:
- year=YYYY/month=MM/day=DD/hour=HH
- year=YYYY/month=MM/day=DD
- timestamp=DDDDDDDDDD

:return datetime(tz=UTC):
"""
try:
# handle gtfs-rt paths
Expand Down Expand Up @@ -696,9 +673,7 @@ def read_parquet(
read_columns = list(set(ds.schema.names) & set(columns))
table = ds.to_table(columns=read_columns)
for null_column in set(columns).difference(ds.schema.names):
table = table.append_column(
null_column, pa.nulls(table.num_rows)
)
table = table.append_column(null_column, pa.nulls(table.num_rows))

df = table.to_pandas(self_destruct=True)
break
Expand Down
Loading
Loading