Skip to content

Commit

Permalink
Support column statistics for RETURNING in INSERT statement
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhengguanLi committed Nov 22, 2024
1 parent 8d7811a commit 4bb4720
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,13 @@ static void resolve(SchemaResolveVisitor visitor, SQLInsertStatement x) {
visitor.visit(query);
}

if (x instanceof OracleInsertStatement) {
SQLObject returning = ((OracleInsertStatement) x).getReturning();
if (returning != null) {
returning.accept(visitor);
}
}

visitor.popContext();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlExpr;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter;
import com.alibaba.druid.sql.dialect.oracle.ast.expr.OracleExpr;
import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleInsertStatement;
import com.alibaba.druid.sql.dialect.oracle.visitor.OracleASTVisitorAdapter;
import com.alibaba.druid.sql.dialect.postgresql.visitor.PGASTVisitorAdapter;
import com.alibaba.druid.sql.repository.SchemaObject;
Expand Down Expand Up @@ -1184,6 +1185,9 @@ public boolean visit(SQLInsertStatement x) {

accept(x.getColumns());
accept(x.getQuery());
if (x instanceof OracleInsertStatement) {
accept(((OracleInsertStatement) x).getReturning());
}

return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,36 @@ public void test_0() throws Exception {
Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "salary")));
}

public void test_1() throws Exception {
String sql = "BEGIN\n" +
"\tINSERT INTO employees (first_name, last_name, job_title)\n" +
"\tVALUES (?, ?, ?)\n" +
"\tRETURNING employee_id INTO ?;\n" +
"\tCOMMIT;\n" +
"END;";
OracleStatementParser parser = new OracleStatementParser(sql);
List<SQLStatement> statementList = parser.parseStatementList();
SQLStatement statemen = statementList.get(0);
print(statementList);

Assert.assertEquals(1, statementList.size());

OracleSchemaStatVisitor visitor = new OracleSchemaStatVisitor();
statemen.accept(visitor);

System.out.println("Tables : " + visitor.getTables());
System.out.println("fields : " + visitor.getColumns());
System.out.println("coditions : " + visitor.getConditions());
System.out.println("relationships : " + visitor.getRelationships());

Assert.assertTrue(visitor.getTables().containsKey(new TableStat.Name("employees")));

Assert.assertEquals(1, visitor.getTables().size());
Assert.assertEquals(4, visitor.getColumns().size());

Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "first_name")));
Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "last_name")));
Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "job_title")));
Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "employee_id")));
}
}

0 comments on commit 4bb4720

Please sign in to comment.