diff --git a/core/src/main/java/com/alibaba/druid/sql/repository/SchemaResolveVisitorFactory.java b/core/src/main/java/com/alibaba/druid/sql/repository/SchemaResolveVisitorFactory.java index dfe546f50d..440a5416c9 100644 --- a/core/src/main/java/com/alibaba/druid/sql/repository/SchemaResolveVisitorFactory.java +++ b/core/src/main/java/com/alibaba/druid/sql/repository/SchemaResolveVisitorFactory.java @@ -1016,6 +1016,13 @@ static void resolve(SchemaResolveVisitor visitor, SQLInsertStatement x) { visitor.visit(query); } + if (x instanceof OracleInsertStatement) { + SQLObject returning = ((OracleInsertStatement) x).getReturning(); + if (returning != null) { + returning.accept(visitor); + } + } + visitor.popContext(); } diff --git a/core/src/main/java/com/alibaba/druid/sql/visitor/SchemaStatVisitor.java b/core/src/main/java/com/alibaba/druid/sql/visitor/SchemaStatVisitor.java index 85e53d4953..455a708c35 100644 --- a/core/src/main/java/com/alibaba/druid/sql/visitor/SchemaStatVisitor.java +++ b/core/src/main/java/com/alibaba/druid/sql/visitor/SchemaStatVisitor.java @@ -25,6 +25,7 @@ import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlExpr; import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter; import com.alibaba.druid.sql.dialect.oracle.ast.expr.OracleExpr; +import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleInsertStatement; import com.alibaba.druid.sql.dialect.oracle.visitor.OracleASTVisitorAdapter; import com.alibaba.druid.sql.dialect.postgresql.visitor.PGASTVisitorAdapter; import com.alibaba.druid.sql.repository.SchemaObject; @@ -1183,6 +1184,9 @@ public boolean visit(SQLInsertStatement x) { accept(x.getColumns()); accept(x.getQuery()); + if (x instanceof OracleInsertStatement) { + accept(((OracleInsertStatement) x).getReturning()); + } return false; } @@ -2147,6 +2151,7 @@ public boolean visit(SQLMethodInvokeExpr x) { this.functions.add(x); accept(x.getArguments()); + accept(x.getFrom()); return false; } diff --git a/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/insert/OracleInsertTest2.java b/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/insert/OracleInsertTest2.java index d350aa9e52..b1625e0c70 100644 --- a/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/insert/OracleInsertTest2.java +++ b/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/insert/OracleInsertTest2.java @@ -61,4 +61,67 @@ public void test_0() throws Exception { Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "job_id"))); Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "salary"))); } + + public void test_1() throws Exception { + String sql = "BEGIN\n" + + "\tINSERT INTO employees (first_name, last_name, job_title)\n" + + "\tVALUES (?, ?, ?)\n" + + "\tRETURNING employee_id INTO ?;\n" + + "\tCOMMIT;\n" + + "END;"; + OracleStatementParser parser = new OracleStatementParser(sql); + List statementList = parser.parseStatementList(); + SQLStatement statemen = statementList.get(0); + print(statementList); + + Assert.assertEquals(1, statementList.size()); + + OracleSchemaStatVisitor visitor = new OracleSchemaStatVisitor(); + statemen.accept(visitor); + + System.out.println("Tables : " + visitor.getTables()); + System.out.println("fields : " + visitor.getColumns()); + System.out.println("coditions : " + visitor.getConditions()); + System.out.println("relationships : " + visitor.getRelationships()); + + Assert.assertTrue(visitor.getTables().containsKey(new TableStat.Name("employees"))); + + Assert.assertEquals(1, visitor.getTables().size()); + Assert.assertEquals(4, visitor.getColumns().size()); + + Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "first_name"))); + Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "last_name"))); + Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "job_title"))); + Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "employee_id"))); + } + + public void test_2() throws Exception { + String sql = "SELECT employee_id, TO_CHAR(TRIM(LEADING 0 FROM hire_date))\n" + + "FROM employees\n" + + "WHERE department_id = 60\n" + + "ORDER BY employee_id"; + OracleStatementParser parser = new OracleStatementParser(sql); + List statementList = parser.parseStatementList(); + SQLStatement statemen = statementList.get(0); + print(statementList); + + Assert.assertEquals(1, statementList.size()); + + OracleSchemaStatVisitor visitor = new OracleSchemaStatVisitor(); + statemen.accept(visitor); + + System.out.println("Tables : " + visitor.getTables()); + System.out.println("fields : " + visitor.getColumns()); + System.out.println("coditions : " + visitor.getConditions()); + System.out.println("relationships : " + visitor.getRelationships()); + + Assert.assertTrue(visitor.getTables().containsKey(new TableStat.Name("employees"))); + + Assert.assertEquals(1, visitor.getTables().size()); + Assert.assertEquals(3, visitor.getColumns().size()); + + Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "employee_id"))); + Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "hire_date"))); + Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "department_id"))); + } } diff --git a/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/select/OracleSelectTest57.java b/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/select/OracleSelectTest57.java index 33115cbae8..9237fe0ed9 100644 --- a/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/select/OracleSelectTest57.java +++ b/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/select/OracleSelectTest57.java @@ -26,11 +26,10 @@ public class OracleSelectTest57 extends OracleTest { public void test_0() throws Exception { - String sql = // - "SELECT TRIM(BOTH FROM EUCD) AS \"value\",NTLANG1 AS \"text\" " // + String sql = "SELECT TRIM(BOTH FROM EUCD) AS \"value\",NTLANG1 AS \"text\" " + " FROM T_HT_WREM_ENUMLANG_D" + " WHERE TYPE=?" - + " ORDER BY \"value\" ASC"; // + + " ORDER BY \"value\" ASC"; OracleStatementParser parser = new OracleStatementParser(sql); List statementList = parser.parseStatementList(); @@ -49,8 +48,7 @@ public void test_0() throws Exception { System.out.println("orderBy : " + visitor.getOrderByColumns()); Assert.assertEquals(1, visitor.getTables().size()); - - Assert.assertEquals(2, visitor.getColumns().size()); + Assert.assertEquals(3, visitor.getColumns().size()); { String text = SQLUtils.toOracleString(stmt); @@ -73,4 +71,27 @@ public void test_0() throws Exception { // Assert.assertTrue(visitor.getOrderByColumns().contains(new TableStat.Column("employees", "last_name"))); } + + public void test_1() throws Exception { + String sql = "SELECT TRIM(BOTH 'x' FROM 'xJohnxx') FROM dual"; + + OracleStatementParser parser = new OracleStatementParser(sql); + List statementList = parser.parseStatementList(); + SQLStatement stmt = statementList.get(0); + print(statementList); + + Assert.assertEquals(1, statementList.size()); + + OracleSchemaStatVisitor visitor = new OracleSchemaStatVisitor(); + stmt.accept(visitor); + + System.out.println("Tables : " + visitor.getTables()); + System.out.println("fields : " + visitor.getColumns()); + System.out.println("coditions : " + visitor.getConditions()); + System.out.println("relationships : " + visitor.getRelationships()); + System.out.println("orderBy : " + visitor.getOrderByColumns()); + + Assert.assertEquals(0, visitor.getTables().size()); + Assert.assertEquals(0, visitor.getColumns().size()); + } }