Skip to content

Commit

Permalink
Merge pull request #13 from golistic/bug/11-parsedns
Browse files Browse the repository at this point in the history
Fix issues with parsing DSN query part
  • Loading branch information
geertjanvdk authored Feb 14, 2023
2 parents 706b361 + d981ca6 commit c07b25d
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 31 deletions.
16 changes: 13 additions & 3 deletions datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/golistic/xstrings"
)

var reDSN = regexp.MustCompile(`(.*?)(?::(.*?))?@(\w+)\((.*?)\)(?:/(\w+))?(\?)?(?:/?(.*))?`)
var reDSN = regexp.MustCompile(`(.*?)(?::(.*?))?@(\w+)\((.*?)\)(?:/(\w+))?/?(\?)?(.*)?`)

type DataSource struct {
Driver string
Expand All @@ -24,8 +24,8 @@ type DataSource struct {
UseTLS bool
}

// ParseDNS parsers the name as a data source name (DSN).
func ParseDNS(name string) (*DataSource, error) {
// ParseDSN parsers the name as a data source name (DSN).
func ParseDSN(name string) (*DataSource, error) {
errMsg := "invalid data source name (%w)"

m := reDSN.FindAllStringSubmatch(name, -1)
Expand Down Expand Up @@ -68,5 +68,15 @@ func (d *DataSource) String() string {
} else {
n += "/"
}

var queryParts []string
if d.UseTLS {
queryParts = append(queryParts, "useTLS=true")
}

if len(queryParts) > 0 {
n += "?" + strings.Join(queryParts, "&")
}

return n
}
79 changes: 52 additions & 27 deletions datasource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
package pxmysql

import (
"strings"
"testing"

"github.com/geertjanvdk/xkit/xt"
)

func TestParseDNS(t *testing.T) {
func TestParseDSN(t *testing.T) {
t.Run("parse query string", func(t *testing.T) {
dsn := "scott:tiger@tcp(127.0.0.1:33060)/test?useTLS=true"
exp := &DataSource{
Expand All @@ -20,9 +21,13 @@ func TestParseDNS(t *testing.T) {
UseTLS: true,
}

have, err := ParseDNS(dsn)
have, err := ParseDSN(dsn)
xt.OK(t, err)
xt.Eq(t, exp, have)

t.Run("using String-method, query part must be included", func(t *testing.T) {
xt.Assert(t, strings.Contains(have.String(), "?useTLS=true"))
})
})

t.Run("no query string provided", func(t *testing.T) {
Expand All @@ -36,41 +41,61 @@ func TestParseDNS(t *testing.T) {
UseTLS: false,
}

have, err := ParseDNS(dsn)
have, err := ParseDSN(dsn)
xt.OK(t, err)
xt.Eq(t, exp, have)

t.Run("using String-method, with useTLS false, it is not included", func(t *testing.T) {
xt.Assert(t, !strings.Contains(have.String(), "?useTLS="))
})
})

t.Run("no default schema with query string", func(t *testing.T) {
dsn := "scott:tiger@tcp(127.0.0.1:33060)?useTLS=true"
exp := &DataSource{
User: "scott",
Password: "tiger",
Protocol: "tcp",
Address: "127.0.0.1:33060",
Schema: "",
UseTLS: true,
var cases = map[string]string{
"without slash": "scott:tiger@tcp(127.0.0.1:33060)/?useTLS=true",
"with slash": "scott:tiger@tcp(127.0.0.1:33060)?useTLS=true",
}

have, err := ParseDNS(dsn)
xt.OK(t, err)
xt.Eq(t, exp, have)
for name, dsn := range cases {
t.Run(name, func(t *testing.T) {
exp := &DataSource{
User: "scott",
Password: "tiger",
Protocol: "tcp",
Address: "127.0.0.1:33060",
Schema: "",
UseTLS: true,
}

have, err := ParseDSN(dsn)
xt.OK(t, err)
xt.Eq(t, exp, have)
})
}
})

t.Run("no default schema without query string", func(t *testing.T) {
dsn := "scott:tiger@tcp(127.0.0.1:33060)"
exp := &DataSource{
User: "scott",
Password: "tiger",
Protocol: "tcp",
Address: "127.0.0.1:33060",
Schema: "",
UseTLS: false,
var cases = map[string]string{
"without slash": "scott:tiger@tcp(127.0.0.1:33060)/",
"with slash": "scott:tiger@tcp(127.0.0.1:33060)",
}

have, err := ParseDNS(dsn)
xt.OK(t, err)
xt.Eq(t, exp, have)
for name, dsn := range cases {
t.Run(name, func(t *testing.T) {
exp := &DataSource{
User: "scott",
Password: "tiger",
Protocol: "tcp",
Address: "127.0.0.1:33060",
Schema: "",
UseTLS: false,
}

have, err := ParseDSN(dsn)
xt.OK(t, err)
xt.Eq(t, exp, have)
})
}
})

t.Run("no password", func(t *testing.T) {
Expand All @@ -84,7 +109,7 @@ func TestParseDNS(t *testing.T) {
UseTLS: false,
}

have, err := ParseDNS(dsn)
have, err := ParseDSN(dsn)
xt.OK(t, err)
xt.Eq(t, exp, have)
})
Expand All @@ -100,7 +125,7 @@ func TestParseDNS(t *testing.T) {
UseTLS: false,
}

have, err := ParseDNS(dsn)
have, err := ParseDSN(dsn)
xt.OK(t, err)
xt.Eq(t, exp, have)
})
Expand Down
2 changes: 1 addition & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
// to the MySQL database using MySQL X Protocol.
// This will be used instead of the Open-method (which actually uses this method).
func (d *Driver) OpenConnector(name string) (driver.Connector, error) {
ds, err := ParseDNS(name)
ds, err := ParseDSN(name)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit c07b25d

Please sign in to comment.