diff --git a/files/galaxy_jwd.py b/files/galaxy_jwd.py index 8f79145..99a5831 100644 --- a/files/galaxy_jwd.py +++ b/files/galaxy_jwd.py @@ -165,9 +165,7 @@ def main(): backends = parse_object_store(object_store_conf) # Add pulsar staging directory (runner: pulsar_embedded) to backends - backends["pulsar_embedded"] = get_pulsar_staging_dir( - galaxy_pulsar_app_conf - ) + backends["pulsar_embedded"] = get_pulsar_staging_dir(galaxy_pulsar_app_conf) # Connect to Galaxy database db = Database( @@ -181,9 +179,7 @@ def main(): if args.operation == "get": job_id = args.job_id object_store_id, job_runner_name = db.get_job_info(job_id) - jwd_path = decode_path( - job_id, [object_store_id], backends, job_runner_name - ) + jwd_path = decode_path(job_id, [object_store_id], backends, job_runner_name) # Check if jwd_path: @@ -200,8 +196,7 @@ def main(): # Check if the given Galaxy log directory exists if not os.path.isdir(galaxy_log_dir): raise ValueError( - f"The given Galaxy log directory {galaxy_log_dir} does not" - f"exist" + f"The given Galaxy log directory {galaxy_log_dir} does not" f"exist" ) # Set variables @@ -299,7 +294,7 @@ def parse_object_store(object_store_conf: str) -> dict: """ if object_store_conf.endswith(".xml"): return parse_object_store_xml(object_store_conf) - if object_store_conf.split('.')[-1] in ('yml', 'yaml'): + if object_store_conf.split(".")[-1] in ("yml", "yaml"): return parse_object_store_yaml(object_store_conf) raise ValueError("Invalid object store configuration file extension") @@ -321,11 +316,11 @@ def parse_object_store_yaml(object_store_conf: str) -> dict: with open(object_store_conf, "r") as f: data = yaml.safe_load(f) backends = {} - for backend in data['backends']: - backend_id = backend['id'] + for backend in data["backends"]: + backend_id = backend["id"] backends[backend_id] = {} # Get the extra_dir's path for each backend if type is "job_work" - if 'extra_dirs' in backend: + if "extra_dirs" in backend: for extra_dir in backend["extra_dirs"]: if extra_dir.get("type") == "job_work": backends[backend_id] = extra_dir["path"] @@ -390,8 +385,7 @@ def decode_path( jwd_path = f"{backends_dict['pulsar_embedded']}/{job_id}" else: jwd_path = ( - f"{backends_dict[metadata[0]]}/" - f"0{job_id[0:2]}/{job_id[2:5]}/{job_id}" + f"{backends_dict[metadata[0]]}/" f"0{job_id[0:2]}/{job_id[2:5]}/{job_id}" ) # Validate that the path is a JWD diff --git a/files/walle.py b/files/walle.py index 33a2cca..9fa4258 100644 --- a/files/walle.py +++ b/files/walle.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) CHECKSUM_FILE_ENV = "MALWARE_LIB" -GXADMIN_PATH = os.getenv('GXADMIN_PATH', '/usr/local/bin/gxadmin') +GXADMIN_PATH = os.getenv("GXADMIN_PATH", "/usr/local/bin/gxadmin") CURRENT_TIME = int(time.time()) @@ -244,9 +244,7 @@ def all_files_in_dir(dir: pathlib.Path, args) -> [pathlib.Path]: continue file = pathlib.Path(os.path.join(root, filename)) file_stat = file.stat() - if not file_in_size_range( - file_stat, args.min_size, args.max_size - ): + if not file_in_size_range(file_stat, args.min_size, args.max_size): logger.debug(f"File {file} not in size range") elif not file_accessed_in_range(file_stat, args.since): logger.debug(f"File {file} not in access date range") @@ -310,7 +308,9 @@ def scan_file_for_malware( sha1 = None for malware in lib: if malware.crc32 == crc32: - logger.debug(f"File {file} CRC32 matches {malware.program} {malware.version}") + logger.debug( + f"File {file} CRC32 matches {malware.program} {malware.version}" + ) if sha1 is None: sha1 = digest_file_sha1(chunksize, file) if malware.sha1 == sha1: @@ -323,9 +323,11 @@ def scan_file_for_malware( def report_matching_malware(job: Job, malware: Malware, path: pathlib.Path) -> str: - return (f"Job user: {job.user_name} Job ID: {job.galaxy_id}" - f"{malware.malware_class} {malware.program} {malware.version}" - f" {path}") + return ( + f"Job user: {job.user_name} Job ID: {job.galaxy_id}" + f"{malware.malware_class} {malware.program} {malware.version}" + f" {path}" + ) def construct_malware_list(malware_yaml: dict) -> [Malware]: @@ -488,27 +490,23 @@ def kill_job(job: Job, debug=False): serial_args = [ [ GXADMIN_PATH, - 'mutate', - 'fail-job', + "mutate", + "fail-job", str(job.galaxy_id), - '--commit', + "--commit", ], [ GXADMIN_PATH, - 'mutate', - 'fail-terminal-datasets', - '--commit', + "mutate", + "fail-terminal-datasets", + "--commit", ], ] for args in serial_args: if debug: logger.debug(f"COMMAND: {' '.join(args)}") try: - result = subprocess.run( - args, - check=True, - capture_output=True, - text=True) + result = subprocess.run(args, check=True, capture_output=True, text=True) if debug: if result.stdout: logger.debug(f"COMMAND STDOUT:\n{result.stdout}") @@ -526,7 +524,7 @@ def main(): args = make_parser().parse_args() logging.basicConfig( level=logging.DEBUG if args.debug else logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(levelname)s - %(message)s", ) logger.info("Starting scan...") jwd_getter = JWDGetter()