Skip to content

Commit

Permalink
now supporting server and apitoken in mcrit-cli
Browse files Browse the repository at this point in the history
  • Loading branch information
danielplohmann committed Feb 26, 2025
1 parent 3c4aeab commit 4c6b42d
Showing 1 changed file with 31 additions and 33 deletions.
64 changes: 31 additions & 33 deletions mcrit/client/McritConsole.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def __init__(self) -> None:
subparsers = self.parser.add_subparsers(dest="command")
self.parser_client = subparsers.add_parser("client")
subparser_client = self.parser_client.add_subparsers(dest="client_command")
self.parser_client.add_argument("--server", type=str, default=None, help="The MCRIT server to connect to.")
self.parser_client.add_argument("--apitoken", type=str, default=None, help="API token to use for the connection.")
client_status = subparser_client.add_parser("status")
# client submit
client_submit = subparser_client.add_parser("submit", help="Various ways of file submission incl. disassembly using SMDA if needed.")
Expand Down Expand Up @@ -168,6 +170,7 @@ def __init__(self) -> None:

def run(self):
ARGS = self.parser.parse_args()
self.client = McritClient(ARGS.server, ARGS.apitoken)
if ARGS.client_command == "status":
self._handle_status(ARGS)
elif ARGS.client_command == "import":
Expand All @@ -188,56 +191,56 @@ def run(self):
print(self.parser_client.print_help())

def _handle_status(self, args):
client = McritClient()
result = client.getStatus(with_pichash=False)
print(f"DB: {result['status']['storage_type']} - {result['status']['db_state']} | {result['status']['db_timestamp']}")
print(f"Families: {result['status']['num_families']}")
print(f"Samples: {result['status']['num_samples']}")
print(f"Functions: {result['status']['num_functions']}")
result = self.client.getStatus(with_pichash=False)
if result:
print(f"DB: {result['status']['storage_type']} - {result['status']['db_state']} | {result['status']['db_timestamp']}")
print(f"Families: {result['status']['num_families']}")
print(f"Samples: {result['status']['num_samples']}")
print(f"Functions: {result['status']['num_functions']}")
else:
print("Failed to retrieve status from the server.")

def _handle_import(self, args):
client = McritClient()
if os.path.isfile(args.filepath):
with open(args.filepath) as fin:
result = client.addImportData(json.load(fin))
result = self.client.addImportData(json.load(fin))
print(result)
else:
print("Your <filepath> does not exist.")

def _handle_export(self, args):
client = McritClient()
sample_ids = None
if args.sample_ids is not None:
sample_ids = [int(i) for i in args.sample_ids.split(",")]
result = client.getExportData(sample_ids=sample_ids)
result = self.client.getExportData(sample_ids=sample_ids)
with open(args.filepath, "w") as fout:
json.dump(result, fout, indent=1)
print(f"wrote export to {args.filepath}.")

def _handle_search(self, args):
client = McritClient()
result = client.search_families(args.search_term)
result = self.client.search_families(args.search_term)
if result["search_results"]:
print("Family Search Results")
for family_id, entry in result["search_results"].items():
family_entry = FamilyEntry.fromDict(entry)
print(f"{family_entry}")
print("*" * 20)
result = client.search_samples(args.search_term)
result = self.client.search_samples(args.search_term)
if result["search_results"]:
print("Sample Search Results")
for sample_id, entry in result["search_results"].items():
sample_entry = SampleEntry.fromDict(entry)
print(f"{sample_entry}")
print("*" * 20)
result = client.search_functions(args.search_term)
result = self.client.search_functions(args.search_term)
if result["search_results"]:
print("Function Search Results")
for function_id, entry in result["search_results"].items():
function_entry = FunctionEntry.fromDict(entry)
print(f"{function_entry}")

def _handle_queue(self, args):
client = McritClient()
result = client.getQueueData(filter=args.filter)
result = self.client.getQueueData(filter=args.filter)
for entry in result:
job_id = entry.job_id
result_id = entry.result
Expand Down Expand Up @@ -272,17 +275,16 @@ def _handle_query(self, args):
if args.output is not None and (not os.path.exists(args.output) or not os.path.isdir(args.output)):
print("Your <output> is not a directory or does not exist.")
return
client = McritClient()
if args.smda:
smda_report = SmdaReport.fromFile(args.filepath)
job_id = client.requestMatchesForSmdaReport(smda_report)
job_id = self.client.requestMatchesForSmdaReport(smda_report)
else:
if base_addr:
job_id = client.requestMatchesForMappedBinary(readFileContent(args.filepath), disassemble_locally=False, base_address=base_addr)
job_id = self.client.requestMatchesForMappedBinary(readFileContent(args.filepath), disassemble_locally=False, base_address=base_addr)
else:
job_id = client.requestMatchesForUnmappedBinary(readFileContent(args.filepath), disassemble_locally=False)
job_id = self.client.requestMatchesForUnmappedBinary(readFileContent(args.filepath), disassemble_locally=False)
print(f"Started job: {job_id}, waiting for result...")
compact_result_dict = client.awaitResult(job_id, sleep_time=2, compact=True)
compact_result_dict = self.client.awaitResult(job_id, sleep_time=2, compact=True)
if args.output is not None:
with open(args.output + os.sep + f"{job_id}.json", "w") as fout:
json.dump(compact_result_dict, fout, indent=1)
Expand Down Expand Up @@ -329,19 +331,17 @@ def _handle_submit(self, args):
self._handle_submit_malpedia(args)

def _handle_submit_file(self, args):
client = McritClient()
sample_sha256 = sha256(readFileContent(args.filepath))
if client.getSampleBySha256(sample_sha256):
if self.client.getSampleBySha256(sample_sha256):
print(f"SKIPPING: {args.filepath} - already in MCRIT.")
return
smda_report = getSmdaReportFromFilepath(args, args.filepath)
if smda_report:
print(smda_report)
client.addReport(smda_report)
self.client.addReport(smda_report)

def _handle_submit_dir(self, args):
client = McritClient()
mcrit_samples = client.getSamples()
mcrit_samples = self.client.getSamples()
mcrit_samples_by_sha256 = {}
for sample_id, sample in mcrit_samples.items():
mcrit_samples_by_sha256[sample.sha256] = sample
Expand All @@ -354,11 +354,10 @@ def _handle_submit_dir(self, args):
smda_report = getSmdaReportFromFilepath(args, filepath)
if smda_report:
print(smda_report)
client.addReport(smda_report)
self.client.addReport(smda_report)

def _handle_submit_recursive(self, args):
client = McritClient()
mcrit_samples = client.getSamples()
mcrit_samples = self.client.getSamples()
mcrit_samples_by_sha256 = {}
for sample_id, sample in mcrit_samples.items():
mcrit_samples_by_sha256[sample.sha256] = sample
Expand All @@ -376,7 +375,7 @@ def _handle_submit_recursive(self, args):
smda_report.version = getSampleVersion(folder_relative_path, smda_report.family)
print(filepath)
print(smda_report)
client.addReport(smda_report)
self.client.addReport(smda_report)

def _handle_submit_malpedia(self, args):
# verify that we have a malpedia root
Expand All @@ -388,9 +387,8 @@ def _handle_submit_malpedia(self, args):
if not "malpedia.bib" in os.listdir(malpedia_root):
print(f"Error: 'malpedia.bib' is missing in that folder, are you sure you are poniting to a Malpedia repository?")
return
client = McritClient()
# get current status of all samples in MCRIT
mcrit_samples = client.getSamples()
mcrit_samples = self.client.getSamples()
mcrit_samples_by_filename = {}
for sample_id, sample in mcrit_samples.items():
# verify that filename has malpedia format (starts with sha256 and _unpacked/_dump)
Expand Down Expand Up @@ -421,7 +419,7 @@ def _handle_submit_malpedia(self, args):
smda_report.version = malpedia_version
print(malpedia_filepath)
print(smda_report)
client.addReport(smda_report)
self.client.addReport(smda_report)
# warn about files that appear deleted because not present in Malpedia but in MCRIT (based on name schema)
for filename, mcrit_sample in mcrit_samples_by_filename.items():
if self._isMalpediaFilename(filename) and filename not in malpedia_samples_by_filename:
Expand Down

0 comments on commit 4c6b42d

Please sign in to comment.