diff --git a/pg_query.h b/pg_query.h index 7f34c41b..925a57c7 100644 --- a/pg_query.h +++ b/pg_query.h @@ -65,9 +65,20 @@ typedef struct { PgQueryError* error; } PgQueryFingerprintResult; +typedef struct { + int location; /* start offset in query text */ + int length; /* length in bytes, or -1 to ignore */ + int param_id; /* Param id to use - if negative prefix, need to abs(..) and add highest_extern_param_id */ + int token; /* constant token type as reported by lexer */ + char *val; /* constant value */ +} PgQueryNormalizeConstLocation; + typedef struct { char* normalized_query; PgQueryError* error; + PgQueryNormalizeConstLocation *clocations; + int clocations_count; + int highest_extern_param_id; } PgQueryNormalizeResult; // Postgres parser options (parse mode and GUCs that affect parsing) diff --git a/src/pg_query_normalize.c b/src/pg_query_normalize.c index 460493fd..19a1fdbb 100644 --- a/src/pg_query_normalize.c +++ b/src/pg_query_normalize.c @@ -2,6 +2,7 @@ #include "pg_query_internal.h" #include "pg_query_fingerprint.h" +#include "gramparse.h" #include "parser/parser.h" #include "parser/scanner.h" #include "parser/scansup.h" @@ -9,6 +10,10 @@ #include "nodes/nodeFuncs.h" #include "pg_query_outfuncs.h" +#include "postgres/include/parser/scanner.h" + +#include +#include /* * Struct for tracking locations/lengths of constants during normalization @@ -18,6 +23,8 @@ typedef struct pgssLocationLen int location; /* start offset in query text */ int length; /* length in bytes, or -1 to ignore */ int param_id; /* Param id to use - if negative prefix, need to abs(..) and add highest_extern_param_id */ + char *val; /* constant value */ + int token; /* token type as reported by the lexer */ } pgssLocationLen; /* @@ -107,9 +114,9 @@ fill_in_constant_lengths(pgssConstLocations *jstate, const char *query) { pgssLocationLen *locs; core_yyscan_t yyscanner; - core_yy_extra_type yyextra; - core_YYSTYPE yylval; - YYLTYPE yylloc; + base_yy_extra_type yyextra; + YYSTYPE yylval; + YYLTYPE yylloc = 0; int last_loc = -1; int i; @@ -124,10 +131,12 @@ fill_in_constant_lengths(pgssConstLocations *jstate, const char *query) /* initialize the flex scanner --- should match raw_parser() */ yyscanner = scanner_init(query, - &yyextra, + &yyextra.core_yy_extra, &ScanKeywords, ScanKeywordTokens); + yyextra.have_lookahead = false; + /* Search for each constant, in sequence */ for (i = 0; i < jstate->clocations_count; i++) { @@ -142,7 +151,7 @@ fill_in_constant_lengths(pgssConstLocations *jstate, const char *query) /* Lex tokens until we find the desired constant */ for (;;) { - tok = core_yylex(&yylval, &yylloc, yyscanner); + tok = base_yylex(&yylval, &yylloc, yyscanner); /* We should not hit end-of-string, but if we do, behave sanely */ if (tok == 0) @@ -154,6 +163,8 @@ fill_in_constant_lengths(pgssConstLocations *jstate, const char *query) */ if (yylloc >= loc) { + bool negative = false; + if (query[loc] == '-') { /* @@ -168,29 +179,37 @@ fill_in_constant_lengths(pgssConstLocations *jstate, const char *query) * where bar = 1" and "select * from foo where bar = -2" * will have identical normalized query strings. */ - tok = core_yylex(&yylval, &yylloc, yyscanner); + tok = base_yylex(&yylval, &yylloc, yyscanner); if (tok == 0) break; /* out of inner for-loop */ + negative = true; } /* * We now rely on the assumption that flex has placed a zero * byte after the text of the current token in scanbuf. */ - locs[i].length = (int) strlen(yyextra.scanbuf + loc); + locs[i].length = (int) strlen(yyextra.core_yy_extra.scanbuf + loc); + locs[i].token = tok; - /* Quoted string with Unicode escapes - * - * The lexer consumes trailing whitespace in order to find UESCAPE, but if there - * is no UESCAPE it has still consumed it - don't include it in constant length. - */ - if (locs[i].length > 4 && /* U&'' */ - (yyextra.scanbuf[loc] == 'u' || yyextra.scanbuf[loc] == 'U') && - yyextra.scanbuf[loc + 1] == '&' && yyextra.scanbuf[loc + 2] == '\'') + if (tok == SCONST || tok == FCONST || tok == BCONST || tok == XCONST) { - int j = locs[i].length - 1; /* Skip the \0 */ - for (; j >= 0 && scanner_isspace(yyextra.scanbuf[loc + j]); j--) {} - locs[i].length = j + 1; /* Count the \0 */ + locs[i].val = palloc(strlen(yylval.core_yystype.str) + 1); + strcpy(locs[i].val, yylval.core_yystype.str); + } + else if (tok == ICONST) + { + int val = yylval.core_yystype.ival; + /* Maximum number of digits in 32-bit int is 10 */ + int buf_size = 10 + 1; + if (negative) + { + buf_size += 1; + val = -val; + } + + locs[i].val = (char *)palloc(buf_size * sizeof(char)); + snprintf(locs[i].val, buf_size, "%d", val); } break; /* out of inner for-loop */ @@ -322,6 +341,8 @@ static void RecordConstLocation(pgssConstLocations *jstate, int location) jstate->clocations[jstate->clocations_count].length = -1; /* by default we assume that we need a new param ref */ jstate->clocations[jstate->clocations_count].param_id = - jstate->highest_normalize_param_id; + jstate->clocations[jstate->clocations_count].val = NULL; + jstate->clocations[jstate->clocations_count].token = 0; jstate->highest_normalize_param_id++; /* record param ref number if requested */ if (jstate->param_refs != NULL) { @@ -599,6 +620,7 @@ PgQueryNormalizeResult pg_query_normalize_ext(const char* input, bool normalize_ List *tree; pgssConstLocations jstate; int query_len; + int i; /* Parse query */ tree = raw_parser(input, RAW_PARSE_DEFAULT); @@ -624,6 +646,28 @@ PgQueryNormalizeResult pg_query_normalize_ext(const char* input, bool normalize_ /* Normalize query */ result.normalized_query = strdup(generate_normalized_query(&jstate, 0, &query_len, PG_UTF8)); + + /* Report constant locations */ + result.clocations_count = jstate.clocations_count; + if (result.clocations_count > 0) + { + result.clocations = (PgQueryNormalizeConstLocation *) + malloc(result.clocations_count * sizeof(PgQueryNormalizeConstLocation)); + + for (i = 0; i < result.clocations_count; i++) + { + pgssLocationLen jloc = jstate.clocations[i]; + result.clocations[i].location = jloc.location; + result.clocations[i].length = jloc.length; + result.clocations[i].param_id = jloc.param_id; + if (jloc.val != NULL) + result.clocations[i].val = strdup(jloc.val); + else + result.clocations[i].val = NULL; + result.clocations[i].token = jloc.token; + } + } + result.highest_extern_param_id = jstate.highest_extern_param_id; } PG_CATCH(); { @@ -664,12 +708,25 @@ PgQueryNormalizeResult pg_query_normalize_utility(const char* input) void pg_query_free_normalize_result(PgQueryNormalizeResult result) { - if (result.error) { + if (result.error) + { free(result.error->message); free(result.error->filename); free(result.error->funcname); free(result.error); + result.error = NULL; + } + + if (result.clocations) + { + int i; + for (i = 0; i < result.clocations_count; i++) + if (result.clocations[i].val != NULL) + free(result.clocations[i].val); + free(result.clocations); + result.clocations = NULL; } free(result.normalized_query); + result.normalized_query = NULL; } diff --git a/test/normalize_tests.c b/test/normalize_tests.c index 8f49bb14..9d200531 100644 --- a/test/normalize_tests.c +++ b/test/normalize_tests.c @@ -48,6 +48,8 @@ const char* tests[] = { "CLOSE cursor_a", "SELECT 1; ALTER USER a WITH PASSWORD 'b'", "SELECT $1; ALTER USER a WITH PASSWORD $2", + "SELECT U&'d!0061t!+000061' UESCAPE '!', -2147483647, -2147483648, x'beef', b'010101'", + "SELECT $1, $2, $3, $4, $5", }; size_t testsLength = __LINE__ - 7;