diff --git a/mcrit/client/McritConsole.py b/mcrit/client/McritConsole.py index 360f58d..2ebefad 100644 --- a/mcrit/client/McritConsole.py +++ b/mcrit/client/McritConsole.py @@ -137,6 +137,8 @@ def __init__(self) -> None: subparsers = self.parser.add_subparsers(dest="command") self.parser_client = subparsers.add_parser("client") subparser_client = self.parser_client.add_subparsers(dest="client_command") + self.parser_client.add_argument("--server", type=str, default=None, help="The MCRIT server to connect to.") + self.parser_client.add_argument("--apitoken", type=str, default=None, help="API token to use for the connection.") client_status = subparser_client.add_parser("status") # client submit client_submit = subparser_client.add_parser("submit", help="Various ways of file submission incl. disassembly using SMDA if needed.") @@ -168,6 +170,7 @@ def __init__(self) -> None: def run(self): ARGS = self.parser.parse_args() + self.client = McritClient(ARGS.server, ARGS.apitoken) if ARGS.client_command == "status": self._handle_status(ARGS) elif ARGS.client_command == "import": @@ -188,47 +191,48 @@ def run(self): print(self.parser_client.print_help()) def _handle_status(self, args): - client = McritClient() - result = client.getStatus(with_pichash=False) - print(f"DB: {result['status']['storage_type']} - {result['status']['db_state']} | {result['status']['db_timestamp']}") - print(f"Families: {result['status']['num_families']}") - print(f"Samples: {result['status']['num_samples']}") - print(f"Functions: {result['status']['num_functions']}") + result = self.client.getStatus(with_pichash=False) + if result: + print(f"DB: {result['status']['storage_type']} - {result['status']['db_state']} | {result['status']['db_timestamp']}") + print(f"Families: {result['status']['num_families']}") + print(f"Samples: {result['status']['num_samples']}") + print(f"Functions: {result['status']['num_functions']}") + else: + print("Failed to retrieve status from the server.") def _handle_import(self, args): - client = McritClient() if os.path.isfile(args.filepath): with open(args.filepath) as fin: - result = client.addImportData(json.load(fin)) + result = self.client.addImportData(json.load(fin)) print(result) + else: + print("Your does not exist.") def _handle_export(self, args): - client = McritClient() sample_ids = None if args.sample_ids is not None: sample_ids = [int(i) for i in args.sample_ids.split(",")] - result = client.getExportData(sample_ids=sample_ids) + result = self.client.getExportData(sample_ids=sample_ids) with open(args.filepath, "w") as fout: json.dump(result, fout, indent=1) print(f"wrote export to {args.filepath}.") def _handle_search(self, args): - client = McritClient() - result = client.search_families(args.search_term) + result = self.client.search_families(args.search_term) if result["search_results"]: print("Family Search Results") for family_id, entry in result["search_results"].items(): family_entry = FamilyEntry.fromDict(entry) print(f"{family_entry}") print("*" * 20) - result = client.search_samples(args.search_term) + result = self.client.search_samples(args.search_term) if result["search_results"]: print("Sample Search Results") for sample_id, entry in result["search_results"].items(): sample_entry = SampleEntry.fromDict(entry) print(f"{sample_entry}") print("*" * 20) - result = client.search_functions(args.search_term) + result = self.client.search_functions(args.search_term) if result["search_results"]: print("Function Search Results") for function_id, entry in result["search_results"].items(): @@ -236,8 +240,7 @@ def _handle_search(self, args): print(f"{function_entry}") def _handle_queue(self, args): - client = McritClient() - result = client.getQueueData(filter=args.filter) + result = self.client.getQueueData(filter=args.filter) for entry in result: job_id = entry.job_id result_id = entry.result @@ -272,17 +275,16 @@ def _handle_query(self, args): if args.output is not None and (not os.path.exists(args.output) or not os.path.isdir(args.output)): print("Your is not a directory or does not exist.") return - client = McritClient() if args.smda: smda_report = SmdaReport.fromFile(args.filepath) - job_id = client.requestMatchesForSmdaReport(smda_report) + job_id = self.client.requestMatchesForSmdaReport(smda_report) else: if base_addr: - job_id = client.requestMatchesForMappedBinary(readFileContent(args.filepath), disassemble_locally=False, base_address=base_addr) + job_id = self.client.requestMatchesForMappedBinary(readFileContent(args.filepath), disassemble_locally=False, base_address=base_addr) else: - job_id = client.requestMatchesForUnmappedBinary(readFileContent(args.filepath), disassemble_locally=False) + job_id = self.client.requestMatchesForUnmappedBinary(readFileContent(args.filepath), disassemble_locally=False) print(f"Started job: {job_id}, waiting for result...") - compact_result_dict = client.awaitResult(job_id, sleep_time=2, compact=True) + compact_result_dict = self.client.awaitResult(job_id, sleep_time=2, compact=True) if args.output is not None: with open(args.output + os.sep + f"{job_id}.json", "w") as fout: json.dump(compact_result_dict, fout, indent=1) @@ -329,19 +331,17 @@ def _handle_submit(self, args): self._handle_submit_malpedia(args) def _handle_submit_file(self, args): - client = McritClient() sample_sha256 = sha256(readFileContent(args.filepath)) - if client.getSampleBySha256(sample_sha256): + if self.client.getSampleBySha256(sample_sha256): print(f"SKIPPING: {args.filepath} - already in MCRIT.") return smda_report = getSmdaReportFromFilepath(args, args.filepath) if smda_report: print(smda_report) - client.addReport(smda_report) + self.client.addReport(smda_report) def _handle_submit_dir(self, args): - client = McritClient() - mcrit_samples = client.getSamples() + mcrit_samples = self.client.getSamples() mcrit_samples_by_sha256 = {} for sample_id, sample in mcrit_samples.items(): mcrit_samples_by_sha256[sample.sha256] = sample @@ -354,11 +354,10 @@ def _handle_submit_dir(self, args): smda_report = getSmdaReportFromFilepath(args, filepath) if smda_report: print(smda_report) - client.addReport(smda_report) + self.client.addReport(smda_report) def _handle_submit_recursive(self, args): - client = McritClient() - mcrit_samples = client.getSamples() + mcrit_samples = self.client.getSamples() mcrit_samples_by_sha256 = {} for sample_id, sample in mcrit_samples.items(): mcrit_samples_by_sha256[sample.sha256] = sample @@ -376,7 +375,7 @@ def _handle_submit_recursive(self, args): smda_report.version = getSampleVersion(folder_relative_path, smda_report.family) print(filepath) print(smda_report) - client.addReport(smda_report) + self.client.addReport(smda_report) def _handle_submit_malpedia(self, args): # verify that we have a malpedia root @@ -388,9 +387,8 @@ def _handle_submit_malpedia(self, args): if not "malpedia.bib" in os.listdir(malpedia_root): print(f"Error: 'malpedia.bib' is missing in that folder, are you sure you are poniting to a Malpedia repository?") return - client = McritClient() # get current status of all samples in MCRIT - mcrit_samples = client.getSamples() + mcrit_samples = self.client.getSamples() mcrit_samples_by_filename = {} for sample_id, sample in mcrit_samples.items(): # verify that filename has malpedia format (starts with sha256 and _unpacked/_dump) @@ -421,7 +419,7 @@ def _handle_submit_malpedia(self, args): smda_report.version = malpedia_version print(malpedia_filepath) print(smda_report) - client.addReport(smda_report) + self.client.addReport(smda_report) # warn about files that appear deleted because not present in Malpedia but in MCRIT (based on name schema) for filename, mcrit_sample in mcrit_samples_by_filename.items(): if self._isMalpediaFilename(filename) and filename not in malpedia_samples_by_filename: