From 0c63c1c95b6852eb04ad35c1c946899432a0641f Mon Sep 17 00:00:00 2001 From: Duncan Ogilvie Date: Fri, 11 Oct 2024 22:21:24 +0200 Subject: [PATCH] Add basic argument parser for testing the obfuscator --- obfuscator/include/obfuscator/args.hpp | 274 +++++++++++++++++++++++++ obfuscator/src/obfuscate.cpp | 42 +++- 2 files changed, 309 insertions(+), 7 deletions(-) create mode 100644 obfuscator/include/obfuscator/args.hpp diff --git a/obfuscator/include/obfuscator/args.hpp b/obfuscator/include/obfuscator/args.hpp new file mode 100644 index 0000000..92a9069 --- /dev/null +++ b/obfuscator/include/obfuscator/args.hpp @@ -0,0 +1,274 @@ +#pragma once + +#include +#include +#include +#include + +class ArgumentParser +{ + protected: + void addPositional(const std::string& name, std::string& value, const std::string& help, bool required = false) + { + auto fn = [this, &value]() + { + value = arg; + }; + positionalArgs.push_back(Arg{name, help, required, fn}); + } + + void addString(const std::string& flagname, std::string& value, const std::string& help, bool required = false) + { + auto fn = [this, flagname, &value] + { + if (arg.substr(0, flagname.length()) == flagname) + { + if (arg.length() == flagname.length()) + { + // -flagname + if (i + 1 >= argc) + { + throw std::runtime_error("missing value for '" + flagname + "' argument"); + } + value = argv[++i]; + if (value.empty()) + { + throw std::runtime_error("empty value for '" + flagname + "' argument"); + } + markExtracted(flagname); + } + else if (arg[flagname.length()] == '=') + { + // -flagname= + value = arg.substr(flagname.length() + 1); + markExtracted(flagname); + } + } + }; + flagArgs.push_back(Arg{flagname, help, required, fn}); + } + + void addBool(const std::string& flagname, bool& value, const std::string& help, bool required = false) + { + auto fn = [this, flagname, &value] + { + if (arg.substr(0, flagname.length()) == flagname) + { + if (arg.length() == flagname.length()) + { + // -flagname + value = true; + markExtracted(flagname); + } + else if (arg[flagname.length()] == '=') + { + // -flagname= + auto strValue = arg.substr(flagname.length() + 1); + if (strValue.empty()) + { + throw std::runtime_error("empty value for '" + flagname + "' argument"); + } + value = strValue == "1" || strValue == "true"; + markExtracted(flagname); + } + } + }; + flagArgs.push_back(Arg{flagname, help, required, fn}); + } + + public: + explicit ArgumentParser(std::string description) : description(std::move(description)) + { + } + + virtual ~ArgumentParser() = default; + ArgumentParser(const ArgumentParser&) = delete; + ArgumentParser& operator=(const ArgumentParser&) = delete; + ArgumentParser(ArgumentParser&&) = delete; + ArgumentParser& operator=(ArgumentParser&&) = delete; + + void parse(int argc, char** argv) + { + this->argc = argc; + this->argv = argv; + bool seenRequired = false; + for (const auto& positionalArg : positionalArgs) + { + if (positionalArg.name.empty()) + { + throw std::runtime_error("cannot add positional argument without name"); + } + if (!positionalArg.required) + { + if (seenRequired) + { + throw std::runtime_error("cannot add required positional argument after an optional one"); + } + } + else + { + seenRequired = true; + } + } + for (const auto& flagArg : flagArgs) + { + if (flagArg.name.empty()) + { + throw std::runtime_error("cannot add argument without name"); + } + if (flagArg.name[0] != '-') + { + throw std::runtime_error("invalid argument name '" + flagArg.name + "'"); + } + } + size_t positionalIndex = 0; + for (i = 1; i < argc; i++) + { + arg = std::string(argv[i]); + if (arg.empty()) + { + continue; + } + if (arg[0] == '-') + { + didExtract = false; + for (const auto& flag : flagArgs) + { + flag.fn(); + } + if (!didExtract) + { + throw std::runtime_error("unknown argument '" + arg + "'"); + } + } + else + { + if (positionalIndex + 1 > positionalArgs.size()) + { + throw std::runtime_error("unexpected positional argument '" + arg + "'"); + } + const auto& positionalArg = positionalArgs[positionalIndex++]; + if (positionalArg.name[0] == '-') + { + markExtracted(positionalArg.name); + } + positionalArg.fn(); + } + } + for (const auto& flagArg : flagArgs) + { + if (!flagArg.required) + { + continue; + } + if (!flagsExtracted.contains(flagArg.name)) + { + throw std::runtime_error("required argument '" + flagArg.name + "' missing"); + } + } + for (size_t i = positionalIndex; i < positionalArgs.size(); i++) + { + const auto& positionalArg = positionalArgs[i]; + if (positionalArg.required) + { + if (flagsExtracted.contains(positionalArg.name)) + { + continue; + } + throw std::runtime_error("required positional argument missing"); + } + } + } + + [[nodiscard]] std::string helpStr() const + { + std::string help; + help += " "; + help += argv[0]; + help += " {OPTIONS}"; + + for (const auto& positionalArg : positionalArgs) + { + help += " "; + if (!positionalArg.required) + { + help += '['; + } + if (positionalArg.name[0] == '-') + { + help += "[" + positionalArg.name + "]"; + help += " "; + } + else + { + help += positionalArg.name; + } + if (!positionalArg.required) + { + help += ']'; + } + } + help += '\n'; + + if (!description.empty()) + { + help += "\n "; + help += description; + help += "\n\n"; + } + + help += " OPTIONS:\n"; + + size_t maxLen = 0; + for (const auto& flagArg : flagArgs) + { + if (flagArg.name.size() > maxLen) + { + maxLen = flagArg.name.size(); + } + } + for (const auto& flagArg : flagArgs) + { + help += "\n "; + help += flagArg.name; + for (size_t i = 0; i < maxLen - flagArg.name.size(); i++) + { + help += ' '; + } + help += " "; + help += flagArg.help; + } + + return help; + } + + private: + struct Arg + { + std::string name; + std::string help; + bool required = false; + std::function fn; + }; + + std::string description; + std::vector positionalArgs; + std::vector flagArgs; + + int i = 1; + int argc = 0; + char** argv = nullptr; + bool didExtract = false; + std::string arg; + std::unordered_set flagsExtracted; + + void markExtracted(const std::string& flagname) + { + didExtract = true; + if (flagsExtracted.contains(flagname)) + { + throw std::runtime_error("duplicate value for '" + flagname + "' argument"); + } + flagsExtracted.insert(flagname); + } +}; diff --git a/obfuscator/src/obfuscate.cpp b/obfuscator/src/obfuscate.cpp index 4218ab2..db4fff4 100644 --- a/obfuscator/src/obfuscate.cpp +++ b/obfuscator/src/obfuscate.cpp @@ -142,16 +142,44 @@ static bool riscvm_handle_syscall(vm::riscvm* self, uint64_t code, uint64_t* res #endif // _WIN32 -int main(int argc, char** argv) +#include + +struct Arguments : ArgumentParser { - if (argc < 2) + std::string input; + std::string output; + std::string payload; + bool help; + + Arguments(int argc, char** argv) : ArgumentParser("Obfuscates the riscvm_run function") { - puts("Usage: obfuscator riscvm.exe [payload.bin]"); - return EXIT_FAILURE; + addPositional("input", input, "Input PE file to obfuscate", true); + addString("-output", output, "Obfuscated function output"); + addString("-payload", payload, "Payload to execute (Windows only)"); + addBool("-help", help, "Prints this help message"); + try + { + parse(argc, argv); + } + catch (const std::exception& e) + { + printf("Error: %s\n\nHelp:\n%s\n", e.what(), helpStr().c_str()); + std::exit(help ? EXIT_SUCCESS : EXIT_FAILURE); + } + if (help) + { + puts(helpStr().c_str()); + std::exit(EXIT_SUCCESS); + } } +}; + +int main(int argc, char** argv) +{ + Arguments args(argc, argv); std::vector pe; - if (!loadFile(argv[1], pe)) + if (!loadFile(args.input, pe)) { puts("Failed to load the executable."); return EXIT_FAILURE; @@ -226,10 +254,10 @@ int main(int argc, char** argv) __debugbreak(); // Run the payload if specified on the command line - if (argc > 2) + if (!args.payload.empty()) { std::vector payload; - if (!loadFile(argv[2], payload)) + if (!loadFile(args.payload, payload)) { puts("Failed to load the payload."); return EXIT_FAILURE;