diff --git a/src/occa/internal/bin/occa.hpp b/src/occa/internal/bin/occa.hpp index d5e20960f..3db55f925 100644 --- a/src/occa/internal/bin/occa.hpp +++ b/src/occa/internal/bin/occa.hpp @@ -6,6 +6,7 @@ namespace occa { namespace bin { cli::command buildOccaCommand(); + bool runTranslate(const json &args); } } diff --git a/src/occa/internal/utils/cli.cpp b/src/occa/internal/utils/cli.cpp index dd332ba67..e898ff59a 100644 --- a/src/occa/internal/utils/cli.cpp +++ b/src/occa/internal/utils/cli.cpp @@ -405,6 +405,7 @@ namespace occa { } occa::json parser::parseArgs(const strVector &args_, + const std::vector &commands, const bool supressErrors) { strVector args = splitShortOptionArgs(args_); const int argc = (int) args.size(); @@ -445,7 +446,12 @@ namespace occa { // No option if (!opt) { - checkOptions = (arg == "=="); + for (auto cmd : commands) { + if (arg == cmd.name) { + checkOptions = 0; + break; + } + } jArguments += arg; continue; } @@ -779,7 +785,7 @@ namespace occa { const bool hasCommands = commands.size(); - json parsedArgs = parseArgs(shellArgs, supressErrors); + json parsedArgs = parseArgs(shellArgs, commands, supressErrors); lastCommandArgs = parsedArgs; json &jArguments = parsedArgs["arguments"]; diff --git a/src/occa/internal/utils/cli.hpp b/src/occa/internal/utils/cli.hpp index 0ead626ad..cf286d1d0 100644 --- a/src/occa/internal/utils/cli.hpp +++ b/src/occa/internal/utils/cli.hpp @@ -120,6 +120,8 @@ namespace occa { }; //================================== + class command; + //---[ Parser ]--------------------- class parser : public printable { public: @@ -157,6 +159,7 @@ namespace occa { occa::json parseArgs(const int argc, const char **argv); occa::json parseArgs(const strVector &args_, + const std::vector &commands = {}, const bool supressErrors = false); bool hasCustomHelpOption(); diff --git a/tests/src/internal/utils/cli.cpp b/tests/src/internal/utils/cli.cpp index 991ba3405..6a3ef6a57 100644 --- a/tests/src/internal/utils/cli.cpp +++ b/tests/src/internal/utils/cli.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include void testPretty(); void testOption(); @@ -231,4 +233,33 @@ void testParser() { } void testCommand() { + occa::cli::command translateCommand; + translateCommand + .withName("translate") + .withCallback(occa::bin::runTranslate) + .withDescription("Translate kernels") + .addOption(occa::cli::option('m', "mode", + "Output mode (Default: Serial)") + .withArg() + .expandsFunction([&](const occa::json &args) { + occa::strVector suggestions; + for (auto &it : occa::getModeMap()) { + suggestions.push_back(it.second->name()); + } + return suggestions; + })) + .addOption(occa::cli::option('v', "verbose", + "Verbose output")) + .addArgument(occa::cli::argument("FILE", + "An .okl file") + .isRequired() + .expandsFiles()); + + const auto &options = translateCommand.options; + ASSERT_EQ(options.size(), (unsigned long)2); + + const int argc = 4; + const char *argv[] = {"translate", "addVectors.okl", "-m", "cuda"}; + const auto json = translateCommand.parseArgs(argc, (const char **)&argv); + ASSERT_EQ(json["options"]["mode"], "cuda"); }