diff --git a/src/macaw/core.clj b/src/macaw/core.clj index 1cecb49..ff07766 100644 --- a/src/macaw/core.clj +++ b/src/macaw/core.clj @@ -1,6 +1,7 @@ (ns macaw.core (:require [macaw.rewrite :as rewrite] + [macaw.util :as u] [macaw.walk :as mw]) (:import (net.sf.jsqlparser.expression Alias) @@ -33,19 +34,33 @@ :tables #{} :table-wildcards #{}})) +(defn- make-table [^Table t] + (merge + {:table (.getName t)} + (when-let [s (.getSchemaName t)] + {:schema s}))) + +(defn- make-column [aliases ^Column c] + (merge + {:column (.getColumnName c)} + (when-let [t (.getTable c)] + (or + (get aliases (.getName t)) + (make-table t))))) + (defn- alias-mapping [^Table table] (when-let [^Alias table-alias (.getAlias table)] - [(.getName table-alias) (.getName table)])) + [(.getName table-alias) (make-table table)])) (defn- resolve-table-name "JSQLParser can't tell whether the `f` in `select f.*` refers to a real table or an alias. Therefore, we have to disambiguate them based on our own map of aliases->table names. So this function will return the real name of the table referenced in a table-wildcard (as far as can be determined from the query)." - [alias->name ^AllTableColumns atc] + [alias->table name->table ^AllTableColumns atc] (let [table-name (-> atc .getTable .getName)] - (or (alias->name table-name) - table-name))) + (or (alias->table table-name) + (name->table table-name)))) (defn- update-components [f components] @@ -60,12 +75,15 @@ (let [{:keys [columns has-wildcard? mutation-commands tables table-wildcards]} (query->raw-components parsed-query) - aliases (into {} (map (comp alias-mapping :component) tables))] - {:columns (into #{} (update-components #(.getColumnName ^Column %) columns)) + aliases (into {} (map #(-> % :component alias-mapping) tables)) + tables (->> (update-components make-table tables) + (u/group-with #(-> % :component :table) + (fn [a b] (if (:schema a) a b))))] + {:columns (into #{} (update-components (partial make-column aliases) columns)) :has-wildcard? (into #{} has-wildcard?) :mutation-commands (into #{} mutation-commands) - :tables (into #{} (update-components #(.getName ^Table %) tables)) - :table-wildcards (into #{} (update-components (partial resolve-table-name aliases) table-wildcards))})) + :tables (into #{} (vals tables)) + :table-wildcards (into #{} (update-components (partial resolve-table-name aliases tables) table-wildcards))})) (defn parsed-query "Main entry point: takes a string query and returns a `Statement` object that can be handled by the other functions." diff --git a/src/macaw/util.clj b/src/macaw/util.clj new file mode 100644 index 0000000..6b47140 --- /dev/null +++ b/src/macaw/util.clj @@ -0,0 +1,14 @@ +(ns macaw.util) + +(defn group-with + "Generalized `group-by`, where you can supply your own returning function (instead of usual `conj`). + + https://ask.clojure.org/index.php/12319/can-group-by-be-generalized" + [kf rf coll] + (persistent! + (reduce + (fn [ret x] + (let [k (kf x)] + (assoc! ret k (rf (get ret k) x)))) + (transient {}) + coll))) diff --git a/test/macaw/core_test.clj b/test/macaw/core_test.clj index c78b196..53c9d53 100644 --- a/test/macaw/core_test.clj +++ b/test/macaw/core_test.clj @@ -28,21 +28,21 @@ (deftest query->tables-test (testing "Simple queries" - (is (= #{"core_user"} + (is (= #{{:table "core_user"}} (tables "SELECT * FROM core_user;"))) - (is (= #{"core_user"} + (is (= #{{:table "core_user"}} (tables "SELECT id, email FROM core_user;")))) (testing "With a schema (Postgres)" ;; TODO: only run this against supported DBs - ;; It strips the schema - (is (= #{"core_user"} + (is (= #{{:table "core_user" :schema "the_schema_name"}} (tables "SELECT * FROM the_schema_name.core_user;")))) (testing "Sub-selects" - (is (= #{"core_user"} + (is (= #{{:table "core_user"}} (tables "SELECT * FROM (SELECT DISTINCT email FROM core_user) q;"))))) (deftest tables-with-complex-aliases-issue-14-test (testing "With an alias that is also a table name" - (is (= #{"user" "user2_final"} + (is (= #{{:table "user"} + {:table "user2_final"}} (tables "SELECT legacy_user.id AS old_id, user.id AS new_id @@ -63,11 +63,18 @@ (deftest query->columns-test (testing "Simple queries" - (is (= #{"foo" "bar" "id" "quux_id"} + (is (= #{{:column "foo"} + {:column "bar"} + {:column "id" :table "quux"} + {:column "quux_id" :table "baz"}} (columns "SELECT foo, bar FROM baz INNER JOIN quux ON quux.id = baz.quux_id")))) (testing "'group by' columns present" - (is (= #{"id" "user_id"} - (columns "SELECT id FROM orders GROUP BY user_id"))))) + (is (= #{{:column "id"} + {:column "user_id"}} + (columns "SELECT id FROM orders GROUP BY user_id")))) + (testing "table alias present" + (is (= #{{:column "id" :table "orders" :schema "public"}} + (columns "SELECT o.id FROM public.orders o"))))) (deftest mutation-test (is (= #{"alter-sequence"} @@ -136,7 +143,7 @@ (deftest alias-inclusion-test (testing "Aliases are not included" - (is (= #{"orders" "foo"} + (is (= #{{:table "orders"} {:table "foo"}} (tables "SELECT id, o.id FROM orders o JOIN foo ON orders.id = foo.order_id"))))) (deftest resolve-columns-test @@ -150,54 +157,55 @@ (is (true? (has-wildcard? "SELECT id, * FROM orders JOIN foo ON orders.id = foo.order_id")))) (deftest table-wildcard-test-without-aliases - (is (= #{"orders"} + (is (= #{{:component {:table "orders"} :context ["FROM" "SELECT"]}} (table-wcs "SELECT orders.* FROM orders JOIN foo ON orders.id = foo.order_id"))) - (is (= #{"foo"} - (table-wcs "SELECT foo.* FROM orders JOIN foo ON orders.id = foo.order_id")))) + (is (= #{{:component {:table "foo" :schema "public"} :context ["FROM" "JOIN" "SELECT"]}} + (table-wcs "SELECT foo.* FROM orders JOIN public.foo f ON orders.id = foo.order_id")))) (deftest table-star-test-with-aliases - (is (= #{"orders"} + (is (= #{{:table "orders"}} (table-wcs "SELECT o.* FROM orders o JOIN foo ON orders.id = foo.order_id"))) - (is (= #{"foo"} + (is (= #{{:table "foo"}} (table-wcs "SELECT f.* FROM orders o JOIN foo f ON orders.id = foo.order_id")))) (deftest context-test (testing "Sub-select with outer wildcard" (is (= {:columns - #{{:component "total", :context ["SELECT" "SUB_SELECT" "FROM" "SELECT"]} - {:component "id", :context ["SELECT" "SUB_SELECT" "FROM" "SELECT"]} - {:component "total", :context ["WHERE" "JOIN" "FROM" "SELECT"]}}, + #{{:component {:column "total"}, :context ["SELECT" "SUB_SELECT" "FROM" "SELECT"]} + {:component {:column "id"}, :context ["SELECT" "SUB_SELECT" "FROM" "SELECT"]} + {:component {:column "total"}, :context ["WHERE" "JOIN" "FROM" "SELECT"]}}, :has-wildcard? #{{:component true, :context ["SELECT"]}}, :mutation-commands #{}, - :tables #{{:component "orders", :context ["FROM" "SELECT" "SUB_SELECT" "FROM" "SELECT"]}}, + :tables #{{:component {:table "orders"}, :context ["FROM" "SELECT" "SUB_SELECT" "FROM" "SELECT"]}}, :table-wildcards #{}} (components "SELECT * FROM (SELECT id, total FROM orders) WHERE total > 10")))) (testing "Sub-select with inner wildcard" (is (= {:columns - #{{:component "id", :context ["SELECT"]} - {:component "total", :context ["SELECT"]} - {:component "total", :context ["WHERE" "JOIN" "FROM" "SELECT"]}}, + #{{:component {:column "id"}, :context ["SELECT"]} + {:component {:column "total"}, :context ["SELECT"]} + {:component {:column "total"}, :context ["WHERE" "JOIN" "FROM" "SELECT"]}}, :has-wildcard? #{{:component true, :context ["SELECT" "SUB_SELECT" "FROM" "SELECT"]}}, :mutation-commands #{}, - :tables #{{:component "orders", :context ["FROM" "SELECT" "SUB_SELECT" "FROM" "SELECT"]}}, + :tables #{{:component {:table "orders"}, :context ["FROM" "SELECT" "SUB_SELECT" "FROM" "SELECT"]}}, :table-wildcards #{}} (components "SELECT id, total FROM (SELECT * FROM orders) WHERE total > 10")))) (testing "Sub-select with dual wildcards" - (is (= {:columns #{{:component "total", :context ["WHERE" "JOIN" "FROM" "SELECT"]}}, + (is (= {:columns #{{:component {:column "total"}, :context ["WHERE" "JOIN" "FROM" "SELECT"]}}, :has-wildcard? #{{:component true, :context ["SELECT" "SUB_SELECT" "FROM" "SELECT"]} {:component true, :context ["SELECT"]}}, :mutation-commands #{}, - :tables #{{:component "orders", :context ["FROM" "SELECT" "SUB_SELECT" "FROM" "SELECT"]}}, + :tables #{{:component {:table "orders"}, :context ["FROM" "SELECT" "SUB_SELECT" "FROM" "SELECT"]}}, :table-wildcards #{}} (components "SELECT * FROM (SELECT * FROM orders) WHERE total > 10")))) (testing "Join; table wildcard" - (is (= {:columns #{{:component "order_id", :context ["JOIN" "SELECT"]} - {:component "id", :context ["JOIN" "SELECT"]}}, + (is (= {:columns #{{:component {:column "order_id" :table "foo"}, :context ["JOIN" "SELECT"]} + {:component {:column "id" :table "orders"}, :context ["JOIN" "SELECT"]}}, :has-wildcard? #{}, :mutation-commands #{}, - :tables #{{:component "foo", :context ["FROM" "JOIN" "SELECT"]} {:component "orders", :context ["FROM" "SELECT"]}}, - :table-wildcards #{{:component "orders", :context ["SELECT"]}}} + :tables #{{:component {:table "foo"}, :context ["FROM" "JOIN" "SELECT"]} + {:component {:table "orders"}, :context ["FROM" "SELECT"]}}, + :table-wildcards #{{:component {:table "orders"}, :context ["SELECT"]}}} (components "SELECT o.* FROM orders o JOIN foo ON orders.id = foo.order_id"))))) (defn test-replacement [before replacements after]