Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove quotes from keyspace names before sending them back in set_keyspace result messages #129

Merged
merged 6 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions parser/lexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,15 @@ func TestLexerIdentifiers(t *testing.T) {
{`"system"`, tkIdentifier, "system"},
{`"system"`, tkIdentifier, "system"},
{`"System"`, tkIdentifier, "System"},
{`""""`, tkIdentifier, "\""},
{`""""""`, tkIdentifier, "\"\""},
{`"A"""""`, tkIdentifier, "A\"\""},
{`"""A"""`, tkIdentifier, "\"A\""},
{`"""""A"`, tkIdentifier, "\"\"A"},
// below test verify correct escaping double quote character as per CQL definition:
// identifier ::= unquoted_identifier | quoted_identifier
// unquoted_identifier ::= re('[a-zA-Z][link:[a-zA-Z0-9]]*')
// quoted_identifier ::= '"' (any character where " can appear if doubled)+ '"'
{`""""`, tkIdentifier, "\""}, // outermost quotes indicate quoted string, inner two double quotes shall be treated as single quote
{`""""""`, tkIdentifier, "\"\""}, // same as above, but 4 inner quotes result in 2 quotes
{`"A"""""`, tkIdentifier, "A\"\""}, // outermost quotes indicate quoted string, 4 quotes after A result in 2 quotes
{`"""A"""`, tkIdentifier, "\"A\""}, // outermost quotes indicate quoted string, 2 quotes before and after A result in single quotes
{`"""""A"`, tkIdentifier, "\"\"A"}, // analogical to previous tests
{`";`, tkInvalid, ""},
{`"""`, tkIdentifier, ""},
}
Expand Down
18 changes: 14 additions & 4 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,18 +667,17 @@ func (c *client) handleExecute(raw *frame.RawFrame, msg *partialExecute, customP
}

func (c *client) handleQuery(raw *frame.RawFrame, msg *partialQuery, customPayload map[string][]byte) {
c.proxy.logger.Debug("handling query", zap.String("query", msg.query), zap.Int16("stream", raw.Header.StreamId))

handled, stmt, err := parser.IsQueryHandled(parser.IdentifierFromString(c.keyspace), msg.query)

if handled {
c.proxy.logger.Debug("Query handled by proxy", zap.String("query", msg.query), zap.Int16("stream", raw.Header.StreamId))
if err != nil {
c.proxy.logger.Error("error parsing query to see if it's handled", zap.Error(err))
c.send(raw.Header, &message.Invalid{ErrorMessage: err.Error()})
} else {
c.interceptSystemQuery(raw.Header, stmt)
}
} else {
c.proxy.logger.Debug("Query not handled by proxy, forwarding", zap.String("query", msg.query), zap.Int16("stream", raw.Header.StreamId))
c.execute(raw, c.getDefaultIdempotency(customPayload), c.keyspace, msg)
}
}
Expand Down Expand Up @@ -813,9 +812,20 @@ func (c *client) interceptSystemQuery(hdr *frame.Header, stmt interface{}) {
}
case *parser.UseStatement:
if _, err := c.proxy.maybeCreateSession(hdr.Version, s.Keyspace); err != nil {
c.send(hdr, &message.ServerError{ErrorMessage: "Proxy unable to create new session for keyspace"})
errMsg := "Proxy unable to create new session for keyspace"
var cqlError *proxycore.CqlError
if errors.As(err, &cqlError) {
// copy detailed error reason from downstream message
errMsg = cqlError.Message.GetErrorMessage()
}
c.send(hdr, &message.ServerError{ErrorMessage: errMsg})
} else {
c.keyspace = s.Keyspace
lukasz-antoniak marked this conversation as resolved.
Show resolved Hide resolved
// We might have received a quoted keyspace name in the UseStatement so remove any
// quotes before sending back this result message. This keeps us consistent with
// how Cassandra implements the same functionality and avoids any issues with
// drivers sending follow-on "USE" requests after wrapping the keyspace name in
// quotes.
ks := parser.IdentifierFromString(s.Keyspace)
c.send(hdr, &message.SetKeyspaceResult{Keyspace: ks.ID()})
}
Expand Down
45 changes: 40 additions & 5 deletions proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,48 @@ func TestProxy_UseKeyspace(t *testing.T) {

cl := connectTestClient(t, ctx, proxyContactPoint)

resp, err := cl.SendAndReceive(ctx, frame.NewFrame(primitive.ProtocolVersion4, 0, &message.Query{Query: "USE system"}))
testKeyspaces := []string{"system", "\"system\""}
for _, testKeyspace := range testKeyspaces {

resp, err := cl.SendAndReceive(ctx, frame.NewFrame(primitive.ProtocolVersion4, 0, &message.Query{Query: "USE " + testKeyspace}))
require.NoError(t, err)

assert.Equal(t, primitive.OpCodeResult, resp.Header.OpCode)
res, ok := resp.Body.Message.(*message.SetKeyspaceResult)
require.True(t, ok, "expected set keyspace result")
assert.Equal(t, "system", res.Keyspace)
}
}

func TestProxy_UseKeyspace_Error(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
tester, proxyContactPoint, err := setupProxyTest(ctx, 1, proxycore.MockRequestHandlers{
primitive.OpCodeQuery: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message {
qry := frm.Body.Message.(*message.Query)
if qry.Query == "USE non_existing" {
return &message.ServerError{
ErrorMessage: "Keyspace 'non_existing' does not exist",
}
}
return cl.InterceptQuery(frm.Header, frm.Body.Message.(*message.Query))
}})
defer func() {
cancel()
tester.shutdown()
}()
require.NoError(t, err)

assert.Equal(t, primitive.OpCodeResult, resp.Header.OpCode)
res, ok := resp.Body.Message.(*message.SetKeyspaceResult)
require.True(t, ok, "expected set keyspace result")
assert.Equal(t, "system", res.Keyspace)
cl := connectTestClient(t, ctx, proxyContactPoint)

resp, err := cl.SendAndReceive(ctx, frame.NewFrame(primitive.ProtocolVersion4, 0, &message.Query{Query: "USE non_existing"}))
require.NoError(t, err)

assert.Equal(t, primitive.OpCodeError, resp.Header.OpCode)
res, ok := resp.Body.Message.(*message.ServerError)
require.True(t, ok)
// make sure that CQL Proxy returns the same error of 'USE keyspace' command
// as backend C* cluster has and does not wrap it inside a custom one
assert.Equal(t, "Keyspace 'non_existing' does not exist", res.ErrorMessage)
}

func TestProxy_NegotiateProtocolV5(t *testing.T) {
Expand Down
Loading