Skip to content

Commit

Permalink
fix black requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
KatrionaGoldmann committed Dec 4, 2024
1 parent b237c51 commit aa99b2c
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 114 deletions.
44 changes: 32 additions & 12 deletions 01_print_deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,17 @@ def count_files(s3_client, bucket_name, prefix):
return count


def print_deployments(aws_credentials, include_inactive=False, subset_countries=None, print_image_count=True):
def print_deployments(
aws_credentials,
include_inactive=False,
subset_countries=None,
print_image_count=True,
):
"""Print deployment details, optionally filtering by country or active status."""
username, password = aws_credentials["UKCEH_username"], aws_credentials["UKCEH_password"]
username, password = (
aws_credentials["UKCEH_username"],
aws_credentials["UKCEH_password"],
)
deployments = get_deployments(username, password)

# Filter active deployments if not including inactive
Expand All @@ -76,16 +84,22 @@ def print_deployments(aws_credentials, include_inactive=False, subset_countries=

# Print deployments for each country
for country in all_countries:
country_deployments = [d for d in deployments if d["country"].title() == country]
country_deployments = [
d for d in deployments if d["country"].title() == country
]
country_code = country_deployments[0]["country_code"].lower()
print(f"\n{country} ({country_code}) has {len(country_deployments)} deployments:")
print(
f"\n{country} ({country_code}) has {len(country_deployments)} deployments:"
)

total_images = 0
for dep in sorted(country_deployments, key=lambda d: d["deployment_id"]):
deployment_id = dep["deployment_id"]
location_name = dep["location_name"]
camera_id = dep["camera_id"]
print(f" - Deployment ID: {deployment_id}, Name: {location_name}, Camera ID: {camera_id}")
print(
f" - Deployment ID: {deployment_id}, Name: {location_name}, Camera ID: {camera_id}"
)

if print_image_count:
prefix = f"{deployment_id}/snapshot_images"
Expand All @@ -105,22 +119,28 @@ def print_deployments(aws_credentials, include_inactive=False, subset_countries=
description="Script for printing the deployments available on the Jasmin object store."
)
parser.add_argument(
"--include_inactive", action=argparse.BooleanOptionalAction,
default=False, help="Flag to include inactive deployments."
"--include_inactive",
action=argparse.BooleanOptionalAction,
default=False,
help="Flag to include inactive deployments.",
)
parser.add_argument(
"--print_image_count", action=argparse.BooleanOptionalAction,
default=False, help="Flag to print the number of images per deployment."
"--print_image_count",
action=argparse.BooleanOptionalAction,
default=False,
help="Flag to print the number of images per deployment.",
)
parser.add_argument(
"--subset_countries", nargs='+', default=None,
help="Optional list to subset for specific countries (e.g. --subset_countries 'Panama' 'Thailand')."
"--subset_countries",
nargs="+",
default=None,
help="Optional list to subset for specific countries (e.g. --subset_countries 'Panama' 'Thailand').",
)
args = parser.parse_args()

print_deployments(
aws_credentials,
args.include_inactive,
args.subset_countries,
args.print_image_count
args.print_image_count,
)
21 changes: 16 additions & 5 deletions 02_generate_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,27 @@ def main():
parser = argparse.ArgumentParser(
description="Generate a file containing S3 keys from a bucket."
)
parser.add_argument("--bucket", type=str, required=True, help="Name of the S3 bucket.")
parser.add_argument(
"--deployment_id", type=str, default="",
help="The deployment id to filter objects. If set to '' then all deployments are used. (default: '')"
"--bucket", type=str, required=True, help="Name of the S3 bucket."
)
parser.add_argument(
"--deployment_id",
type=str,
default="",
help="The deployment id to filter objects. If set to '' then all deployments are used. (default: '')",
)
parser.add_argument(
"--output_file",
type=str,
default="s3_keys.txt",
help="Output file to save S3 keys.",
)
parser.add_argument("--output_file", type=str, default="s3_keys.txt", help="Output file to save S3 keys.")
args = parser.parse_args()

# List keys from the specified S3 bucket and prefix
print(f"Listing keys from bucket '{args.bucket}' with deployment '{args.deployment_id}'...")
print(
f"Listing keys from bucket '{args.bucket}' with deployment '{args.deployment_id}'..."
)
keys = list_s3_keys(args.bucket, args.deployment_id)

# Save keys to the output file
Expand Down
41 changes: 28 additions & 13 deletions 03_pre_chop_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ def load_workload(input_file, file_extensions):
"""
Load workload from a file. Assumes each line contains an S3 key.
"""
with open(input_file, 'r') as f:
with open(input_file, "r") as f:
all_keys = [line.strip() for line in f.readlines()]

subset_keys = [x for x in all_keys if x.endswith(tuple(file_extensions))]

# remove corrupt keys
subset_keys = [x for x in subset_keys if not os.path.basename(x).startswith('$')]
subset_keys = [x for x in subset_keys if not os.path.basename(x).startswith("$")]

# remove keys uploaded from the recycle bin (legacy code)
subset_keys = [x for x in subset_keys if 'recycle' not in x]
subset_keys = [x for x in subset_keys if "recycle" not in x]
print(f"{len(subset_keys)} keys")
return subset_keys

Expand All @@ -28,7 +28,7 @@ def split_workload(keys, chunk_size):
"""
num_chunks = ceil(len(keys) / chunk_size)
chunks = {
str(i + 1): {"keys": keys[i * chunk_size: (i + 1) * chunk_size]}
str(i + 1): {"keys": keys[i * chunk_size : (i + 1) * chunk_size]}
for i in range(num_chunks)
}
print(f"{len(chunks)} chunks")
Expand All @@ -39,22 +39,37 @@ def save_chunks(chunks, output_file):
"""
Save chunks to a JSON file.
"""
with open(output_file, 'w') as f:
with open(output_file, "w") as f:
json.dump(chunks, f, indent=4)


def main():
parser = argparse.ArgumentParser(description="Pre-chop S3 workload into manageable chunks.")
parser = argparse.ArgumentParser(
description="Pre-chop S3 workload into manageable chunks."
)
parser.add_argument(
"--input_file",
type=str,
required=True,
help="Path to file containing S3 keys, one per line.",
)
parser.add_argument(
"--input_file", type=str, required=True,
help="Path to file containing S3 keys, one per line."
"--file_extensions",
type=str,
nargs="+",
required=True,
default="'jpg' 'jpeg'",
help="File extensions to be chuncked. If empty, all extensions used.",
)
parser.add_argument(
"--file_extensions", type=str, nargs='+',
required=True, default="'jpg' 'jpeg'",
help="File extensions to be chuncked. If empty, all extensions used.")
parser.add_argument("--chunk_size", type=int, default=100, help="Number of keys per chunk.")
parser.add_argument("--output_file", type=str, required=True, help="Path to save the output JSON file.")
"--chunk_size", type=int, default=100, help="Number of keys per chunk."
)
parser.add_argument(
"--output_file",
type=str,
required=True,
help="Path to save the output JSON file.",
)
args = parser.parse_args()

# Load the workload from the input file
Expand Down
78 changes: 49 additions & 29 deletions 04_process_chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def main(

client = initialise_session(credentials_file)

keys = chunks[chunk_id]['keys']
keys = chunks[chunk_id]["keys"]
download_and_analyse(
keys=keys,
output_dir=output_dir,
Expand All @@ -157,47 +157,67 @@ def main(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process a specific chunk of S3 keys.")
parser.add_argument(
"--chunk_id", required=True,
help="ID of the chunk to process (e.g., 0, 1, 2, 3)."
"--chunk_id",
required=True,
help="ID of the chunk to process (e.g., 0, 1, 2, 3).",
)
parser.add_argument(
"--json_file", required=True, help="Path to the JSON file with key chunks."
)
parser.add_argument(
"--output_dir", required=True, default="./data/",
help="Directory to save downloaded files and analysis results."
"--output_dir",
required=True,
default="./data/",
help="Directory to save downloaded files and analysis results.",
)
parser.add_argument("--bucket_name", required=True, help="Name of the S3 bucket.")
parser.add_argument(
"--bucket_name", required=True, help="Name of the S3 bucket."
"--credentials_file",
default="credentials.json",
help="Path to AWS credentials file.",
)
parser.add_argument(
"--credentials_file", default="credentials.json", help="Path to AWS credentials file."
"--remove_image", action="store_true", help="Remove images after processing."
)
parser.add_argument(
"--remove_image", action="store_true", help="Remove images after processing."
"--perform_inference", action="store_true", help="Enable inference."
)
parser.add_argument(
"--localisation_model_path",
type=str,
default="./models/v1_localizmodel_2021-08-17-12-06.pt",
help="Path to the localisation model weights.",
)
parser.add_argument(
"--binary_model_path",
type=str,
help="Path to the binary model weights.",
default="./models/moth-nonmoth-effv2b3_20220506_061527_30.pth",
)
parser.add_argument("--perform_inference", action="store_true", help="Enable inference.")
parser.add_argument(
"--localisation_model_path", type=str, default="./models/v1_localizmodel_2021-08-17-12-06.pt",
help="Path to the localisation model weights."
"--order_model_path",
type=str,
help="Path to the order model weights.",
default="./models/dhc_best_128.pth",
)
parser.add_argument(
"--binary_model_path", type=str, help="Path to the binary model weights.",
default="./models/moth-nonmoth-effv2b3_20220506_061527_30.pth"
"--order_labels", type=str, help="Path to the order labels file."
)
parser.add_argument(
"--order_model_path", type=str, help="Path to the order model weights.", default="./models/dhc_best_128.pth"
"--device",
type=str,
default="cpu",
help="Device to run inference on (e.g., cpu or cuda).",
)
parser.add_argument("--order_labels", type=str, help="Path to the order labels file.")
parser.add_argument(
"--device", type=str, default="cpu",
help="Device to run inference on (e.g., cpu or cuda)."
"--order_thresholds_path",
type=str,
default="./models/thresholdsTestTrain.csv",
help="Path to the order data thresholds file.",
)
parser.add_argument(
"--order_thresholds_path", type=str, default="./models/thresholdsTestTrain.csv",
help="Path to the order data thresholds file."
"--csv_file", default="results.csv", help="Path to save analysis results."
)
parser.add_argument("--csv_file", default="results.csv", help="Path to save analysis results.")

args = parser.parse_args()

Expand All @@ -216,10 +236,10 @@ def main(

models = load_models(
device,
getattr(args, 'localisation_model_path'),
getattr(args, 'binary_model_path'),
getattr(args, 'order_model_path'),
getattr(args, 'order_thresholds_path')
getattr(args, "localisation_model_path"),
getattr(args, "binary_model_path"),
getattr(args, "order_model_path"),
getattr(args, "order_thresholds_path"),
)

main(
Expand All @@ -230,11 +250,11 @@ def main(
credentials_file=args.credentials_file,
remove_image=args.remove_image,
perform_inference=args.perform_inference,
localisation_model=models['localisation_model'],
binary_model=models['classification_model'],
order_model=models['order_model'],
order_labels=models['order_model_labels'],
order_data_thresholds=models['order_model_thresholds'],
localisation_model=models["localisation_model"],
binary_model=models["classification_model"],
order_model=models["order_model"],
order_labels=models["order_model_labels"],
order_data_thresholds=models["order_model_thresholds"],
device=device,
csv_file=args.csv_file,
)
18 changes: 9 additions & 9 deletions utils/aws_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,34 +214,34 @@ def get_objects(
keys = []
for page in page_iterator:
if os.path.basename(page.get("Contents", [])[0]["Key"]).startswith("$"):
print(f'{page.get("Contents", [])[0]["Key"]} is suspected corrupt, skipping')
print(
f'{page.get("Contents", [])[0]["Key"]} is suspected corrupt, skipping'
)
continue

for obj in page.get("Contents", []):
keys.append(obj["Key"])

# don't rerun previously analysed images
results_df = pd.read_csv(csv_file, dtype=str)
run_images = [re.sub(r'^.*?dep', 'dep', x) for x in results_df['image_path']]
run_images = [re.sub(r"^.*?dep", "dep", x) for x in results_df["image_path"]]
keys = [x for x in keys if x not in run_images]

# Divide the keys among workers
chunks = [
keys[i: i + math.ceil(len(keys) / num_workers)]
for i in range(0, len(keys),
math.ceil(len(keys) / num_workers))
keys[i : i + math.ceil(len(keys) / num_workers)]
for i in range(0, len(keys), math.ceil(len(keys) / num_workers))
]

# Shared progress bar
results_file = os.path.basename(csv_file).replace('_results.csv', '')
results_file = os.path.basename(csv_file).replace("_results.csv", "")
progress_bar = tqdm.tqdm(
total=total_files,
desc=f"Download files for {results_file}"
total=total_files, desc=f"Download files for {results_file}"
)

def process_chunk(chunk):
for i in range(0, len(chunk), batch_size):
batch_keys = chunk[i: i + batch_size]
batch_keys = chunk[i : i + batch_size]
download_batch(
s3_client,
bucket_name,
Expand Down
Loading

0 comments on commit aa99b2c

Please sign in to comment.