Skip to content

Commit

Permalink
Refactor test server code for thread safety
Browse files Browse the repository at this point in the history
Moved serverProps outside goroutines to improve code readability and maintainability. Added a RWMutex to serverProps to ensure thread-safe access to EchoBuffer, preventing race conditions during concurrent writes.
  • Loading branch information
wneessen committed Nov 11, 2024
1 parent f7bdd8f commit 800c266
Showing 1 changed file with 92 additions and 68 deletions.
160 changes: 92 additions & 68 deletions smtp/smtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"net"
"os"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -2235,17 +2236,17 @@ func TestClient_Mail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-8BITMIME\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) {
if err := simpleSMTPServer(ctx, t, &serverProps{
EchoBuffer: buf,
FeatureSet: featureSet,
ListenPort: serverPort,
},
); err != nil {
props := &serverProps{
EchoBuffer: echoBuffer,
FeatureSet: featureSet,
ListenPort: serverPort,
}
go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}(echoBuffer)
}()
time.Sleep(time.Millisecond * 30)

client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
Expand All @@ -2261,7 +2262,9 @@ func TestClient_Mail(t *testing.T) {
t.Errorf("failed to set mail from address: %s", err)
}
expected := "MAIL FROM:<[email protected]> BODY=8BITMIME"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[5], expected) {
t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5])
}
Expand All @@ -2273,17 +2276,17 @@ func TestClient_Mail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-SMTPUTF8\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) {
if err := simpleSMTPServer(ctx, t, &serverProps{
EchoBuffer: buf,
FeatureSet: featureSet,
ListenPort: serverPort,
},
); err != nil {
props := &serverProps{
EchoBuffer: echoBuffer,
FeatureSet: featureSet,
ListenPort: serverPort,
}
go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}(echoBuffer)
}()
time.Sleep(time.Millisecond * 30)

client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
Expand All @@ -2299,7 +2302,9 @@ func TestClient_Mail(t *testing.T) {
t.Errorf("failed to set mail from address: %s", err)
}
expected := "MAIL FROM:<[email protected]> SMTPUTF8"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[5], expected) {
t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5])
}
Expand All @@ -2311,17 +2316,17 @@ func TestClient_Mail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-SMTPUTF8\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) {
if err := simpleSMTPServer(ctx, t, &serverProps{
EchoBuffer: buf,
FeatureSet: featureSet,
ListenPort: serverPort,
},
); err != nil {
props := &serverProps{
EchoBuffer: echoBuffer,
FeatureSet: featureSet,
ListenPort: serverPort,
}
go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}(echoBuffer)
}()
time.Sleep(time.Millisecond * 30)

client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
Expand All @@ -2337,7 +2342,9 @@ func TestClient_Mail(t *testing.T) {
t.Errorf("failed to set mail from address: %s", err)
}
expected := "MAIL FROM:<valid-from+📧@domain.tld> SMTPUTF8"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[5], expected) {
t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5])
}
Expand All @@ -2349,17 +2356,17 @@ func TestClient_Mail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-DSN\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) {
if err := simpleSMTPServer(ctx, t, &serverProps{
EchoBuffer: buf,
FeatureSet: featureSet,
ListenPort: serverPort,
},
); err != nil {
props := &serverProps{
EchoBuffer: echoBuffer,
FeatureSet: featureSet,
ListenPort: serverPort,
}
go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}(echoBuffer)
}()
time.Sleep(time.Millisecond * 30)

client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
Expand All @@ -2376,7 +2383,9 @@ func TestClient_Mail(t *testing.T) {
t.Errorf("failed to set mail from address: %s", err)
}
expected := "MAIL FROM:<[email protected]> RET=FULL"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[5], expected) {
t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5])
}
Expand All @@ -2388,17 +2397,17 @@ func TestClient_Mail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-DSN\r\n250-8BITMIME\r\n250-SMTPUTF8\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) {
if err := simpleSMTPServer(ctx, t, &serverProps{
EchoBuffer: buf,
FeatureSet: featureSet,
ListenPort: serverPort,
},
); err != nil {
props := &serverProps{
EchoBuffer: echoBuffer,
FeatureSet: featureSet,
ListenPort: serverPort,
}
go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}(echoBuffer)
}()
time.Sleep(time.Millisecond * 30)

client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
Expand All @@ -2415,7 +2424,9 @@ func TestClient_Mail(t *testing.T) {
t.Errorf("failed to set mail from address: %s", err)
}
expected := "MAIL FROM:<[email protected]> BODY=8BITMIME SMTPUTF8 RET=FULL"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[7], expected) {
t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[7])
}
Expand Down Expand Up @@ -2490,17 +2501,17 @@ func TestClient_Rcpt(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-DSN\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) {
if err := simpleSMTPServer(ctx, t, &serverProps{
EchoBuffer: buf,
FeatureSet: featureSet,
ListenPort: serverPort,
},
); err != nil {
props := &serverProps{
EchoBuffer: echoBuffer,
FeatureSet: featureSet,
ListenPort: serverPort,
}
go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}(echoBuffer)
}()
time.Sleep(time.Millisecond * 30)
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
if err != nil {
Expand All @@ -2519,7 +2530,9 @@ func TestClient_Rcpt(t *testing.T) {
t.Error("recpient address with newlines should fail")
}
expected := "RCPT TO:<[email protected]> NOTIFY=SUCCESS"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[5], expected) {
t.Errorf("expected rcpt to command to be %q, but sent %q", expected, resp[5])
}
Expand Down Expand Up @@ -2782,17 +2795,17 @@ func TestSendMail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) {
if err := simpleSMTPServer(ctx, t, &serverProps{
EchoBuffer: buf,
FeatureSet: featureSet,
ListenPort: serverPort,
},
); err != nil {
props := &serverProps{
EchoBuffer: echoBuffer,
FeatureSet: featureSet,
ListenPort: serverPort,
}
go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}(echoBuffer)
}()
time.Sleep(time.Millisecond * 30)
addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort)
testHookStartTLS = func(config *tls.Config) {
Expand All @@ -2806,7 +2819,9 @@ func TestSendMail(t *testing.T) {
[]byte("test message")); err != nil {
t.Fatalf("failed to send mail: %s", err)
}
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if len(resp)-1 != len(want) {
t.Fatalf("expected %d lines, but got %d", len(want), len(resp))
}
Expand Down Expand Up @@ -2857,17 +2872,17 @@ func TestSendMail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) {
if err := simpleSMTPServer(ctx, t, &serverProps{
EchoBuffer: buf,
FeatureSet: featureSet,
ListenPort: serverPort,
},
); err != nil {
props := &serverProps{
EchoBuffer: echoBuffer,
FeatureSet: featureSet,
ListenPort: serverPort,
}
go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}(echoBuffer)
}()
time.Sleep(time.Millisecond * 30)
addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort)
testHookStartTLS = func(config *tls.Config) {
Expand All @@ -2887,7 +2902,9 @@ Goodbye.`)
if err := SendMail(addr, auth, "[email protected]", []string{"[email protected]"}, message); err != nil {
t.Fatalf("failed to send mail: %s", err)
}
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if len(resp)-1 != len(want) {
t.Errorf("expected %d lines, but got %d", len(want), len(resp))
}
Expand Down Expand Up @@ -3542,6 +3559,7 @@ func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "

// serverProps represents the configuration properties for the SMTP server.
type serverProps struct {
BufferMutex sync.RWMutex
EchoBuffer io.Writer
FailOnAuth bool
FailOnDataInit bool
Expand Down Expand Up @@ -3640,9 +3658,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
t.Logf("failed to write line: %s", err)
}
if props.EchoBuffer != nil {
if _, err := props.EchoBuffer.Write([]byte(data + "\r\n")); err != nil {
t.Errorf("failed write to echo buffer: %s", err)
props.BufferMutex.Lock()
if _, berr := props.EchoBuffer.Write([]byte(data + "\r\n")); berr != nil {
t.Errorf("failed write to echo buffer: %s", berr)
}
props.BufferMutex.Unlock()
}
_ = writer.Flush()
}
Expand All @@ -3665,9 +3685,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
}
time.Sleep(time.Millisecond)
if props.EchoBuffer != nil {
props.BufferMutex.Lock()
if _, berr := props.EchoBuffer.Write([]byte(data)); berr != nil {
t.Errorf("failed write to echo buffer: %s", berr)
}
props.BufferMutex.Unlock()
}

var datastring string
Expand Down Expand Up @@ -3768,9 +3790,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
break
}
if props.EchoBuffer != nil {
if _, err := props.EchoBuffer.Write([]byte(ddata)); err != nil {
t.Errorf("failed write to echo buffer: %s", err)
props.BufferMutex.Lock()
if _, berr := props.EchoBuffer.Write([]byte(ddata)); berr != nil {
t.Errorf("failed write to echo buffer: %s", berr)
}
props.BufferMutex.Unlock()
}
ddata = strings.TrimSpace(ddata)
if ddata == "." {
Expand Down

0 comments on commit 800c266

Please sign in to comment.