diff --git a/src/main/scala/io/joern/scanners/c/FileOpRace.scala b/src/main/scala/io/joern/scanners/c/FileOpRace.scala index ab29af1..bcf12a6 100644 --- a/src/main/scala/io/joern/scanners/c/FileOpRace.scala +++ b/src/main/scala/io/joern/scanners/c/FileOpRace.scala @@ -25,62 +25,43 @@ object FileOpRace extends QueryBundle { |""".stripMargin, score = 3.0, withStrRep({ cpg => - val firstParam = Set( - "open", - "fopen", - "creat", - "access", - "chmod", - "readlink", - "chown", - "lchown", - "stat", - "lstat", - "unlink", - "rmdir", - "mkdir", - "mknod", - "mkfifo", - "chdir", - "link", - "rename" + val operations: Map[String, Seq[Integer]] = Map( + "access" -> Seq(1), + "chdir" -> Seq(1), + "chmod" -> Seq(1), + "chown" -> Seq(1), + "creat" -> Seq(1), + "faccessat" -> Seq(2), + "fchmodat" -> Seq(2), + "fopen" -> Seq(1), + "fstatat" -> Seq(2), + "lchown" -> Seq(1), + "linkat" -> Seq(2, 4), + "link" -> Seq(1, 2), + "lstat" -> Seq(1), + "mkdirat" -> Seq(2), + "mkdir" -> Seq(1), + "mkfifoat" -> Seq(2), + "mkfifo" -> Seq(1), + "mknodat" -> Seq(2), + "mknod" -> Seq(1), + "openat" -> Seq(2), + "open" -> Seq(1), + "readlinkat" -> Seq(2), + "readlink" -> Seq(1), + "renameat" -> Seq(2, 4), + "rename" -> Seq(1, 2), + "rmdir" -> Seq(1), + "stat" -> Seq(1), + "unlinkat" -> Seq(2), + "unlink" -> Seq(1), ) - val secondParam = Set( - "openat", - "fstatat", - "fchmodat", - "readlinkat", - "unlinkat", - "mkdirat", - "mknodat", - "mkfifoat", - "faccessat", - "link", - "rename", - "linkat", - "renameat" - ) - val fourthParam = Set("linkat", "renameat") - - val anyParam = firstParam ++ secondParam ++ fourthParam def fileCalls(calls: Traversal[Call]) = - calls.nameExact(anyParam.toSeq: _*) + calls.nameExact(operations.keys.toSeq: _*) - def fileArgs(c: Call) = { - val res = Traversal.newBuilder[Expression] - // note some functions are in multiple setts because they take multiple paths - if (firstParam.contains(c.name)) { - res.addOne(c.argument(1)) - } - if (secondParam.contains(c.name)) { - res.addOne(c.argument(2)) - } - if (fourthParam.contains(c.name)) { - res.addOne(c.argument(4)) - } - res.result().whereNot(_.isLiteral) - } + def fileArgs(c: Call) = + c.argument.whereNot(_.isLiteral).argumentIndex(operations(c.name): _*) fileCalls(cpg.call) .filter(call => {