Skip to content

Commit 6f9a67b

Browse files
committed
make it work
1 parent d169b57 commit 6f9a67b

File tree

5 files changed

+40
-32
lines changed

5 files changed

+40
-32
lines changed

client/auth.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func (c *Conn) readInitialHandshake() error {
9292
pos += 2
9393

9494
// The upper 2 bytes of the Capabilities Flags
95-
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
95+
c.capability |= uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
9696
pos += 2
9797

9898
// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
@@ -209,10 +209,8 @@ func (c *Conn) writeAuthHandshake() error {
209209

210210
// Set default client capabilities that reflect the abilities of this library
211211
capability := mysql.CLIENT_PROTOCOL_41 | mysql.CLIENT_SECURE_CONNECTION |
212-
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH
213-
// Adjust client capability flags based on server support
214-
capability |= c.capability & mysql.CLIENT_LONG_FLAG
215-
capability |= c.capability & mysql.CLIENT_QUERY_ATTRIBUTES
212+
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH |
213+
mysql.CLIENT_LONG_FLAG | mysql.CLIENT_QUERY_ATTRIBUTES
216214
// Adjust client capability flags on specific client requests
217215
// Only flags that would make any sense setting and aren't handled elsewhere
218216
// in the library are supported here
@@ -275,6 +273,7 @@ func (c *Conn) writeAuthHandshake() error {
275273
data := make([]byte, length+4)
276274

277275
// capability [32 bit]
276+
c.capability &= capability
278277
data[4] = byte(capability)
279278
data[5] = byte(capability >> 8)
280279
data[6] = byte(capability >> 16)

client/conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ func (c *Conn) UnsetCapability(cap uint32) {
252252

253253
// HasCapability returns true if the connection has the specific capability
254254
func (c *Conn) HasCapability(cap uint32) bool {
255-
return c.ccaps&cap > 0
255+
return c.ccaps&cap != 0
256256
}
257257

258258
// UseSSL: use default SSL

client/resp.go

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,6 @@ import (
1414
"github.com/go-mysql-org/go-mysql/utils"
1515
)
1616

17-
func (c *Conn) readUntilEOF() (err error) {
18-
var data []byte
19-
20-
for {
21-
data, err = c.ReadPacket()
22-
if err != nil {
23-
return err
24-
}
25-
26-
// EOF Packet
27-
if c.isEOFPacket(data) {
28-
return err
29-
}
30-
}
31-
}
32-
3317
func (c *Conn) isEOFPacket(data []byte) bool {
3418
return data[0] == mysql.EOF_HEADER && len(data) <= 5
3519
}
@@ -357,7 +341,7 @@ func (c *Conn) readResultColumns(result *mysql.Result) (err error) {
357341
result.FieldNames[utils.ByteSliceToString(result.Fields[i].Name)] = i
358342
}
359343

360-
if !c.HasCapability(mysql.CLIENT_DEPRECATE_EOF) {
344+
if c.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
361345
// EOF Packet
362346
rawPkgLen := len(result.RawPkg)
363347
result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
@@ -394,7 +378,7 @@ func (c *Conn) readResultRows(result *mysql.Result, isBinary bool) (err error) {
394378
data = result.RawPkg[rawPkgLen:]
395379

396380
if data[0] == mysql.EOF_HEADER && len(data) <= 0xffffff {
397-
if c.HasCapability(mysql.CLIENT_DEPRECATE_EOF) {
381+
if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 {
398382
// Treat like OK
399383
affectedRows, _, n := mysql.LengthEncodedInt(data[1:])
400384
insertId, _, m := mysql.LengthEncodedInt(data[1+n:])
@@ -446,9 +430,16 @@ func (c *Conn) readResultRowsStreaming(result *mysql.Result, isBinary bool, perR
446430
return err
447431
}
448432

449-
// EOF Packet
450-
if c.isEOFPacket(data) {
451-
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
433+
if data[0] == mysql.EOF_HEADER && len(data) <= 0xffffff {
434+
if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 {
435+
// Treat like OK
436+
affectedRows, _, n := mysql.LengthEncodedInt(data[1:])
437+
insertId, _, m := mysql.LengthEncodedInt(data[1+n:])
438+
result.Status = binary.LittleEndian.Uint16(data[1+n+m:])
439+
result.AffectedRows = affectedRows
440+
result.InsertId = insertId
441+
c.status = result.Status
442+
} else if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
452443
result.Warnings = binary.LittleEndian.Uint16(data[1:])
453444
// todo add strict_mode, warning will be treat as error
454445
result.Status = binary.LittleEndian.Uint16(data[3:])

client/stmt.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,33 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
275275
}
276276

277277
if s.params > 0 {
278-
if err := s.conn.readUntilEOF(); err != nil {
279-
return nil, errors.Trace(err)
278+
for range s.params {
279+
if _, err := s.conn.ReadPacket(); err != nil {
280+
return nil, errors.Trace(err)
281+
}
282+
}
283+
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
284+
if packet, err := s.conn.ReadPacket(); err != nil {
285+
return nil, errors.Trace(err)
286+
} else if c.isEOFPacket(packet) {
287+
return nil, mysql.ErrMalformPacket
288+
}
280289
}
281290
}
282291

283292
if s.columns > 0 {
284-
if err := s.conn.readUntilEOF(); err != nil {
285-
return nil, errors.Trace(err)
293+
// TODO process when CLIENT_CACHE_METADATA enabled
294+
for range s.columns {
295+
if _, err := s.conn.ReadPacket(); err != nil {
296+
return nil, errors.Trace(err)
297+
}
298+
}
299+
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
300+
if packet, err := s.conn.ReadPacket(); err != nil {
301+
return nil, errors.Trace(err)
302+
} else if c.isEOFPacket(packet) {
303+
return nil, mysql.ErrMalformPacket
304+
}
286305
}
287306
}
288307

replication/row_event.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1082,7 +1082,6 @@ func (e *RowsEvent) DecodeData(pos int, data []byte) (err2 error) {
10821082
if e.compressed {
10831083
data, err2 = mysql.DecompressMariadbData(data[pos:])
10841084
if err2 != nil {
1085-
//nolint:nakedret
10861085
return err2
10871086
}
10881087
pos = 0

0 commit comments

Comments
 (0)