diff --git a/client.go b/client.go index a7dead86..028e277a 100644 --- a/client.go +++ b/client.go @@ -142,9 +142,6 @@ type ( // host is the hostname of the SMTP server we are connecting to. host string - // isEncrypted indicates wether the Client connection is encrypted or not. - isEncrypted bool - // logAuthData indicates whether authentication-related data should be logged. logAuthData bool @@ -806,7 +803,7 @@ func (c *Client) SetSSLPort(ssl bool, fallback bool) { } // SetDebugLog sets or overrides whether the Client is using debug logging. The debug logger will log incoming -// and outgoing communication between the Client and the server to os.Stderr. +// and outgoing communication between the client and the server to log.Logger that is defined on the Client. // // Note: The SMTP communication might include unencrypted authentication data, depending on whether you are using // SMTP authentication and the type of authentication mechanism. This could pose a data protection risk. Use @@ -815,9 +812,26 @@ func (c *Client) SetSSLPort(ssl bool, fallback bool) { // Parameters: // - val: A boolean value indicating whether to enable (true) or disable (false) debug logging. func (c *Client) SetDebugLog(val bool) { + c.SetDebugLogWithSMTPClient(c.smtpClient, val) +} + +// SetDebugLogWithSMTPClient sets or overrides whether the provided smtp.Client is using debug logging. +// The debug logger will log incoming and outgoing communication between the client and the server to +// log.Logger that is defined on the Client. +// +// Note: The SMTP communication might include unencrypted authentication data, depending on whether you are using +// SMTP authentication and the type of authentication mechanism. This could pose a data protection risk. Use +// debug logging with caution. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// - val: A boolean value indicating whether to enable (true) or disable (false) debug logging. +func (c *Client) SetDebugLogWithSMTPClient(client *smtp.Client, val bool) { + c.mutex.Lock() + defer c.mutex.Unlock() c.useDebugLog = val - if c.smtpClient != nil { - c.smtpClient.SetDebugLog(val) + if client != nil { + client.SetDebugLog(val) } } @@ -830,9 +844,24 @@ func (c *Client) SetDebugLog(val bool) { // Parameters: // - logger: A logger that satisfies the log.Logger interface to be set for the Client. func (c *Client) SetLogger(logger log.Logger) { + c.SetLoggerWithSMTPClient(c.smtpClient, logger) +} + +// SetLoggerWithSMTPClient sets or overrides the custom logger currently used by the provided smtp.Client. +// The logger must satisfy the log.Logger interface and is only utilized when debug logging is enabled on +// the provided smtp.Client. +// +// By default, log.Stdlog is used if no custom logger is provided. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// - logger: A logger that satisfies the log.Logger interface to be set for the Client. +func (c *Client) SetLoggerWithSMTPClient(client *smtp.Client, logger log.Logger) { + c.mutex.Lock() + defer c.mutex.Unlock() c.logger = logger - if c.smtpClient != nil { - c.smtpClient.SetLogger(logger) + if client != nil { + client.SetLogger(logger) } } @@ -926,67 +955,91 @@ func (c *Client) SetLogAuthData(logAuth bool) { // SMTP server. // // Parameters: -// - dialCtx: The context.Context used to control the connection timeout and cancellation. +// - ctxDial: The context.Context used to control the connection timeout and cancellation. // // Returns: // - An error if the connection to the SMTP server fails or any subsequent command fails. -func (c *Client) DialWithContext(dialCtx context.Context) error { +func (c *Client) DialWithContext(ctxDial context.Context) error { + client, err := c.DialToSMTPClientWithContext(ctxDial) + if err != nil { + return err + } c.mutex.Lock() - defer c.mutex.Unlock() + c.smtpClient = client + c.mutex.Unlock() + return nil +} + +// DialToSMTPClientWithContext establishes and configures a smtp.Client connection using +// the provided context. +// +// This function uses the provided context to manage the connection deadline and cancellation. +// It dials the SMTP server using the Client's configured DialContextFunc or a default dialer. +// If SSL is enabled, it uses a TLS connection. After successfully connecting, it initializes +// an smtp.Client, sends the HELO/EHLO command, and optionally performs STARTTLS and SMTP AUTH +// based on the Client's configuration. Debug and authentication logging are enabled if +// configured. +// +// Parameters: +// - ctxDial: The context used to control the connection timeout and cancellation. +// +// Returns: +// - A pointer to the initialized smtp.Client. +// - An error if the connection fails, the smtp.Client cannot be created, or any subsequent commands fail. +func (c *Client) DialToSMTPClientWithContext(ctxDial context.Context) (*smtp.Client, error) { + c.mutex.RLock() + defer c.mutex.RUnlock() - ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) + ctx, cancel := context.WithDeadline(ctxDial, time.Now().Add(c.connTimeout)) defer cancel() + isEncrypted := false + dialContextFunc := c.dialContextFunc if c.dialContextFunc == nil { netDialer := net.Dialer{} - c.dialContextFunc = netDialer.DialContext - + dialContextFunc = netDialer.DialContext if c.useSSL { tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig} - c.isEncrypted = true - c.dialContextFunc = tlsDialer.DialContext + isEncrypted = true + dialContextFunc = tlsDialer.DialContext } } - connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr()) + connection, err := dialContextFunc(ctx, "tcp", c.ServerAddr()) if err != nil && c.fallbackPort != 0 { // TODO: should we somehow log or append the previous error? - connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) + connection, err = dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) } if err != nil { - return err + return nil, err } client, err := smtp.NewClient(connection, c.host) if err != nil { - return err - } - if client == nil { - return fmt.Errorf("SMTP client is nil") + return nil, err } - c.smtpClient = client if c.logger != nil { - c.smtpClient.SetLogger(c.logger) + client.SetLogger(c.logger) } if c.useDebugLog { - c.smtpClient.SetDebugLog(true) + client.SetDebugLog(true) } if c.logAuthData { - c.smtpClient.SetLogAuthData() + client.SetLogAuthData() } - if err = c.smtpClient.Hello(c.helo); err != nil { - return err + if err = client.Hello(c.helo); err != nil { + return nil, err } - if err = c.tls(); err != nil { - return err + if err = c.tls(client, &isEncrypted); err != nil { + return nil, err } - if err = c.auth(); err != nil { - return err + if err = c.auth(client, isEncrypted); err != nil { + return nil, err } - return nil + return client, nil } // Close terminates the connection to the SMTP server, returning an error if the disconnection @@ -999,10 +1052,27 @@ func (c *Client) DialWithContext(dialCtx context.Context) error { // Returns: // - An error if the disconnection fails; otherwise, returns nil. func (c *Client) Close() error { - if c.smtpClient == nil || !c.smtpClient.HasConnection() { + return c.CloseWithSMTPClient(c.smtpClient) +} + +// CloseWithSMTPClient terminates the connection of the provided smtp.Client to the SMTP server, +// returning an error if the disconnection fails. If the connection is already closed, this +// method is a no-op and disregards any error. +// +// This function checks if the smtp.Client connection is active. If not, it simply returns +// without any action. If the connection is active, it attempts to gracefully close the +// connection using the Quit method. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// +// Returns: +// - An error if the disconnection fails; otherwise, returns nil. +func (c *Client) CloseWithSMTPClient(client *smtp.Client) error { + if client == nil || !client.HasConnection() { return nil } - if err := c.smtpClient.Quit(); err != nil { + if err := client.Quit(); err != nil { return fmt.Errorf("failed to close SMTP client: %w", err) } @@ -1016,12 +1086,29 @@ func (c *Client) Close() error { // the command fails, an error is returned. // // Returns: -// - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil. +// - An error if the connection check fails or if sending the RSET command fails; +// otherwise, returns nil. func (c *Client) Reset() error { - if err := c.checkConn(); err != nil { + return c.ResetWithSMTPClient(c.smtpClient) +} + +// ResetWithSMTPClient sends an SMTP RSET command to the provided smtp.Client, to reset +// the state of the current SMTP session. +// +// This method checks the connection to the SMTP server and, if the connection is valid, +// it sends an RSET command to reset the session state. If the connection is invalid or +// the command fails, an error is returned. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// +// Returns: +// - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil. +func (c *Client) ResetWithSMTPClient(client *smtp.Client) error { + if err := c.checkConn(client); err != nil { return err } - if err := c.smtpClient.Reset(); err != nil { + if err := client.Reset(); err != nil { return fmt.Errorf("failed to send RSET to SMTP client: %w", err) } @@ -1061,24 +1148,47 @@ func (c *Client) DialAndSend(messages ...*Msg) error { // - An error if the connection fails, if sending the messages fails, or if closing the // connection fails; otherwise, returns nil. func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) error { - c.sendMutex.Lock() - defer c.sendMutex.Unlock() - if err := c.DialWithContext(ctx); err != nil { + client, err := c.DialToSMTPClientWithContext(ctx) + if err != nil { return fmt.Errorf("dial failed: %w", err) } defer func() { - _ = c.Close() + _ = c.CloseWithSMTPClient(client) }() - if err := c.Send(messages...); err != nil { + if err = c.SendWithSMTPClient(client, messages...); err != nil { return fmt.Errorf("send failed: %w", err) } - if err := c.Close(); err != nil { + if err = c.CloseWithSMTPClient(client); err != nil { return fmt.Errorf("failed to close connection: %w", err) } return nil } +// Send attempts to send one or more Msg using the SMTP client that is assigned to the Client. +// If the Client has no active connection to the server, Send will fail with an error. For +// each of the provided Msg, it will associate a SendError with the Msg in case of a +// transmission or delivery error. +// +// This method first checks for an active connection to the SMTP server. If the connection is +// not valid, it returns a SendError. It then iterates over the provided messages, attempting +// to send each one. If an error occurs during sending, the method records the error and +// associates it with the corresponding Msg. If multiple errors are encountered, it aggregates +// them into a single SendError to be returned. +// +// Parameters: +// - client: A pointer to the smtp.Client that holds the connection to the SMTP server +// - messages: A variadic list of pointers to Msg objects to be sent. +// +// Returns: +// - An error that represents the sending result, which may include multiple SendErrors if +// any occurred; otherwise, returns nil. +func (c *Client) Send(messages ...*Msg) (returnErr error) { + c.sendMutex.Lock() + defer c.sendMutex.Unlock() + return c.SendWithSMTPClient(c.smtpClient, messages...) +} + // auth attempts to authenticate the client using SMTP AUTH mechanisms. It checks the connection, // determines the supported authentication methods, and applies the appropriate authentication // type. An error is returned if authentication fails. @@ -1098,16 +1208,17 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e // Returns: // - An error if the connection check fails, if no supported authentication method is found, // or if the authentication process fails. -func (c *Client) auth() error { +func (c *Client) auth(client *smtp.Client, isEnc bool) error { + var smtpAuth smtp.Auth if c.smtpAuth == nil && c.smtpAuthType != SMTPAuthNoAuth { - hasSMTPAuth, smtpAuthType := c.smtpClient.Extension("AUTH") + hasSMTPAuth, smtpAuthType := client.Extension("AUTH") if !hasSMTPAuth { return fmt.Errorf("server does not support SMTP AUTH") } authType := c.smtpAuthType if c.smtpAuthType == SMTPAuthAutoDiscover { - discoveredType, err := c.authTypeAutoDiscover(smtpAuthType) + discoveredType, err := c.authTypeAutoDiscover(smtpAuthType, isEnc) if err != nil { return err } @@ -1119,74 +1230,74 @@ func (c *Client) auth() error { if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) { return ErrPlainAuthNotSupported } - c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false) + smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false) case SMTPAuthPlainNoEnc: if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) { return ErrPlainAuthNotSupported } - c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true) + smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true) case SMTPAuthLogin: if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) { return ErrLoginAuthNotSupported } - c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false) + smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false) case SMTPAuthLoginNoEnc: if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) { return ErrLoginAuthNotSupported } - c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true) + smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true) case SMTPAuthCramMD5: if !strings.Contains(smtpAuthType, string(SMTPAuthCramMD5)) { return ErrCramMD5AuthNotSupported } - c.smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass) + smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass) case SMTPAuthXOAUTH2: if !strings.Contains(smtpAuthType, string(SMTPAuthXOAUTH2)) { return ErrXOauth2AuthNotSupported } - c.smtpAuth = smtp.XOAuth2Auth(c.user, c.pass) + smtpAuth = smtp.XOAuth2Auth(c.user, c.pass) case SMTPAuthSCRAMSHA1: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1)) { return ErrSCRAMSHA1AuthNotSupported } - c.smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass) + smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass) case SMTPAuthSCRAMSHA256: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256)) { return ErrSCRAMSHA256AuthNotSupported } - c.smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass) + smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass) case SMTPAuthSCRAMSHA1PLUS: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1PLUS)) { return ErrSCRAMSHA1PLUSAuthNotSupported } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { return err } - c.smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState) + smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState) case SMTPAuthSCRAMSHA256PLUS: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256PLUS)) { return ErrSCRAMSHA256PLUSAuthNotSupported } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { return err } - c.smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState) + smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState) default: return fmt.Errorf("unsupported SMTP AUTH type %q", c.smtpAuthType) } } - if c.smtpAuth != nil { - if err := c.smtpClient.Auth(c.smtpAuth); err != nil { + if smtpAuth != nil { + if err := client.Auth(smtpAuth); err != nil { return fmt.Errorf("SMTP AUTH failed: %w", err) } } return nil } -func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) { +func (c *Client) authTypeAutoDiscover(supported string, isEnc bool) (SMTPAuthType, error) { if supported == "" { return "", ErrNoSupportedAuthDiscovered } @@ -1194,7 +1305,7 @@ func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) { SMTPAuthSCRAMSHA256PLUS, SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1PLUS, SMTPAuthSCRAMSHA1, SMTPAuthXOAUTH2, SMTPAuthCramMD5, SMTPAuthPlain, SMTPAuthLogin, } - if !c.isEncrypted { + if !isEnc { preferList = []SMTPAuthType{SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1, SMTPAuthXOAUTH2, SMTPAuthCramMD5} } mechs := strings.Split(supported, " ") @@ -1231,13 +1342,13 @@ func sliceContains(slice []string, item string) bool { // // Returns: // - An error if any part of the sending process fails; otherwise, returns nil. -func (c *Client) sendSingleMsg(message *Msg) error { - c.mutex.Lock() - defer c.mutex.Unlock() - escSupport, _ := c.smtpClient.Extension("ENHANCEDSTATUSCODES") +func (c *Client) sendSingleMsg(client *smtp.Client, message *Msg) error { + c.mutex.RLock() + defer c.mutex.RUnlock() + escSupport, _ := client.Extension("ENHANCEDSTATUSCODES") if message.encoding == NoEncoding { - if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok { + if ok, _ := client.Extension("8BITMIME"); !ok { return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message} } } @@ -1260,16 +1371,16 @@ func (c *Client) sendSingleMsg(message *Msg) error { if c.requestDSN { if c.dsnReturnType != "" { - c.smtpClient.SetDSNMailReturnOption(string(c.dsnReturnType)) + client.SetDSNMailReturnOption(string(c.dsnReturnType)) } } - if err = c.smtpClient.Mail(from); err != nil { + if err = client.Mail(from); err != nil { retError := &SendError{ Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err), affectedMsg: message, errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport), } - if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil { + if resetSendErr := client.Reset(); resetSendErr != nil { retError.errlist = append(retError.errlist, resetSendErr) } return retError @@ -1279,9 +1390,9 @@ func (c *Client) sendSingleMsg(message *Msg) error { rcptSendErr.errlist = make([]error, 0) rcptSendErr.rcpt = make([]string, 0) rcptNotifyOpt := strings.Join(c.dsnRcptNotifyType, ",") - c.smtpClient.SetDSNRcptNotifyOption(rcptNotifyOpt) + client.SetDSNRcptNotifyOption(rcptNotifyOpt) for _, rcpt := range rcpts { - if err = c.smtpClient.Rcpt(rcpt); err != nil { + if err = client.Rcpt(rcpt); err != nil { rcptSendErr.Reason = ErrSMTPRcptTo rcptSendErr.errlist = append(rcptSendErr.errlist, err) rcptSendErr.rcpt = append(rcptSendErr.rcpt, rcpt) @@ -1292,12 +1403,12 @@ func (c *Client) sendSingleMsg(message *Msg) error { } } if hasError { - if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil { + if resetSendErr := client.Reset(); resetSendErr != nil { rcptSendErr.errlist = append(rcptSendErr.errlist, resetSendErr) } return rcptSendErr } - writer, err := c.smtpClient.Data() + writer, err := client.Data() if err != nil { return &SendError{ Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err), @@ -1322,7 +1433,7 @@ func (c *Client) sendSingleMsg(message *Msg) error { } message.isDelivered = true - if err = c.Reset(); err != nil { + if err = c.ResetWithSMTPClient(client); err != nil { return &SendError{ Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err), affectedMsg: message, errcode: errorCode(err), @@ -1344,21 +1455,24 @@ func (c *Client) sendSingleMsg(message *Msg) error { // Returns: // - An error if there is no active connection, if the NOOP command fails, or if extending // the deadline fails; otherwise, returns nil. -func (c *Client) checkConn() error { - if c.smtpClient == nil { +func (c *Client) checkConn(client *smtp.Client) error { + if client == nil { return ErrNoActiveConnection } - if !c.smtpClient.HasConnection() { + if !client.HasConnection() { return ErrNoActiveConnection } - if !c.noNoop { - if err := c.smtpClient.Noop(); err != nil { + c.mutex.RLock() + noNoop := c.noNoop + c.mutex.RUnlock() + if !noNoop { + if err := client.Noop(); err != nil { return ErrNoActiveConnection } } - if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil { + if err := client.UpdateDeadline(c.connTimeout); err != nil { return ErrDeadlineExtendFailed } return nil @@ -1405,10 +1519,10 @@ func (c *Client) setDefaultHelo() error { // Returns: // - An error if there is no active connection, if STARTTLS is required but not supported, // or if there are issues during the TLS handshake; otherwise, returns nil. -func (c *Client) tls() error { +func (c *Client) tls(client *smtp.Client, isEnc *bool) error { if !c.useSSL && c.tlspolicy != NoTLS { hasStartTLS := false - extension, _ := c.smtpClient.Extension("STARTTLS") + extension, _ := client.Extension("STARTTLS") if c.tlspolicy == TLSMandatory { hasStartTLS = true if !extension { @@ -1422,21 +1536,21 @@ func (c *Client) tls() error { } } if hasStartTLS { - if err := c.smtpClient.StartTLS(c.tlsconfig); err != nil { + if err := client.StartTLS(c.tlsconfig); err != nil { return err } } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { switch { case errors.Is(err, smtp.ErrNonTLSConnection): - c.isEncrypted = false + *isEnc = false return nil default: return fmt.Errorf("failed to get TLS connection state: %w", err) } } - c.isEncrypted = tlsConnState.HandshakeComplete + *isEnc = tlsConnState.HandshakeComplete } return nil } diff --git a/client_119.go b/client_119.go index a28f747d..96d18a7d 100644 --- a/client_119.go +++ b/client_119.go @@ -7,12 +7,16 @@ package mail -import "errors" +import ( + "errors" -// Send attempts to send one or more Msg using the Client connection to the SMTP server. -// If the Client has no active connection to the server, Send will fail with an error. For each -// of the provided Msg, it will associate a SendError with the Msg in case of a transmission -// or delivery error. + "github.com/wneessen/go-mail/smtp" +) + +// SendWithSMTPClient attempts to send one or more Msg using a provided smtp.Client with an +// established connection to the SMTP server. If the smtp.Client has no active connection to +// the server, SendWithSMTPClient will fail with an error. For each of the provided Msg, it +// will associate a SendError with the Msg in case of a transmission or delivery error. // // This method first checks for an active connection to the SMTP server. If the connection is // not valid, it returns a SendError. It then iterates over the provided messages, attempting @@ -21,17 +25,18 @@ import "errors" // them into a single SendError to be returned. // // Parameters: +// - client: A pointer to the smtp.Client that holds the connection to the SMTP server // - messages: A variadic list of pointers to Msg objects to be sent. // // Returns: // - An error that represents the sending result, which may include multiple SendErrors if // any occurred; otherwise, returns nil. -func (c *Client) Send(messages ...*Msg) error { +func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) error { escSupport := false - if c.smtpClient != nil { - escSupport, _ = c.smtpClient.Extension("ENHANCEDSTATUSCODES") + if client != nil { + escSupport, _ = client.Extension("ENHANCEDSTATUSCODES") } - if err := c.checkConn(); err != nil { + if err := c.checkConn(client); err != nil { return &SendError{ Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err), errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport), @@ -39,7 +44,7 @@ func (c *Client) Send(messages ...*Msg) error { } var errs []*SendError for id, message := range messages { - if sendErr := c.sendSingleMsg(message); sendErr != nil { + if sendErr := c.sendSingleMsg(client, message); sendErr != nil { messages[id].sendError = sendErr var msgSendErr *SendError diff --git a/client_120.go b/client_120.go index 67c5b5e9..622e149b 100644 --- a/client_120.go +++ b/client_120.go @@ -9,29 +9,34 @@ package mail import ( "errors" + + "github.com/wneessen/go-mail/smtp" ) -// Send attempts to send one or more Msg using the Client connection to the SMTP server. -// If the Client has no active connection to the server, Send will fail with an error. For each -// of the provided Msg, it will associate a SendError with the Msg in case of a transmission -// or delivery error. +// SendWithSMTPClient attempts to send one or more Msg using a provided smtp.Client with an +// established connection to the SMTP server. If the smtp.Client has no active connection to +// the server, SendWithSMTPClient will fail with an error. For each of the provided Msg, it +// will associate a SendError with the Msg in case of a transmission or delivery error. // // This method first checks for an active connection to the SMTP server. If the connection is -// not valid, it returns an error wrapped in a SendError. It then iterates over the provided -// messages, attempting to send each one. If an error occurs during sending, the method records -// the error and associates it with the corresponding Msg. +// not valid, it returns a SendError. It then iterates over the provided messages, attempting +// to send each one. If an error occurs during sending, the method records the error and +// associates it with the corresponding Msg. If multiple errors are encountered, it aggregates +// them into a single SendError to be returned. // // Parameters: +// - client: A pointer to the smtp.Client that holds the connection to the SMTP server // - messages: A variadic list of pointers to Msg objects to be sent. // // Returns: -// - An error that aggregates any SendErrors encountered during the sending process; otherwise, returns nil. -func (c *Client) Send(messages ...*Msg) (returnErr error) { +// - An error that represents the sending result, which may include multiple SendErrors if +// any occurred; otherwise, returns nil. +func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) (returnErr error) { escSupport := false - if c.smtpClient != nil { - escSupport, _ = c.smtpClient.Extension("ENHANCEDSTATUSCODES") + if client != nil { + escSupport, _ = client.Extension("ENHANCEDSTATUSCODES") } - if err := c.checkConn(); err != nil { + if err := c.checkConn(client); err != nil { returnErr = &SendError{ Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err), errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport), @@ -45,7 +50,7 @@ func (c *Client) Send(messages ...*Msg) (returnErr error) { }() for id, message := range messages { - if sendErr := c.sendSingleMsg(message); sendErr != nil { + if sendErr := c.sendSingleMsg(client, message); sendErr != nil { messages[id].sendError = sendErr errs = append(errs, sendErr) } diff --git a/client_test.go b/client_test.go index f6d5161e..195a797e 100644 --- a/client_test.go +++ b/client_test.go @@ -1772,11 +1772,8 @@ func TestClient_DialWithContext(t *testing.T) { t.Errorf("failed to close the client: %s", err) } }) - if client.smtpClient == nil { - t.Errorf("client with invalid HELO should still have a smtp client, got nil") - } - if !client.smtpClient.HasConnection() { - t.Errorf("client with invalid HELO should still have a smtp client connection, got nil") + if client.smtpClient != nil { + t.Error("client with invalid HELO should not have a smtp client") } }) t.Run("fail on base port and fallback", func(t *testing.T) { @@ -1825,11 +1822,8 @@ func TestClient_DialWithContext(t *testing.T) { if err = client.DialWithContext(ctxDial); err == nil { t.Fatalf("connection was supposed to fail, but didn't") } - if client.smtpClient == nil { - t.Fatalf("client has no smtp client") - } - if !client.smtpClient.HasConnection() { - t.Errorf("client has no connection") + if client.smtpClient != nil { + t.Fatalf("client is not supposed to have a smtp client") } }) t.Run("connect with failing auth", func(t *testing.T) { @@ -2297,6 +2291,7 @@ func TestClient_DialAndSendWithContext(t *testing.T) { t.Errorf("client was supposed to fail on dial") } }) + // https://github.com/wneessen/go-mail/issues/380 t.Run("concurrent sending via DialAndSendWithContext", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2336,6 +2331,44 @@ func TestClient_DialAndSendWithContext(t *testing.T) { } wg.Wait() }) + // https://github.com/wneessen/go-mail/issues/385 + t.Run("concurrent sending via DialAndSendWithContext on receiver func", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + sender := testSender{client} + + ctxDial := context.Background() + wg := sync.WaitGroup{} + for i := 0; i < 5; i++ { + wg.Add(1) + msg := testMessage(t) + go func() { + defer wg.Done() + if goroutineErr := sender.Send(ctxDial, msg); goroutineErr != nil { + t.Errorf("failed to send message: %s", goroutineErr) + } + }() + } + wg.Wait() + }) } func TestClient_auth(t *testing.T) { @@ -2574,8 +2607,8 @@ func TestClient_authTypeAutoDiscover(t *testing.T) { } for _, tt := range tests { t.Run("AutoDiscover selects the strongest auth type: "+string(tt.expect), func(t *testing.T) { - client := &Client{smtpAuthType: SMTPAuthAutoDiscover, isEncrypted: tt.tls} - authType, err := client.authTypeAutoDiscover(tt.supported) + client := &Client{smtpAuthType: SMTPAuthAutoDiscover} + authType, err := client.authTypeAutoDiscover(tt.supported, tt.tls) if err != nil && !tt.shouldFail { t.Fatalf("failed to auto discover auth type: %s", err) } @@ -2709,6 +2742,82 @@ func TestClient_Send(t *testing.T) { }) } +func TestClient_DialToSMTPClientWithContext(t *testing.T) { + t.Run("establish a new client connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + smtpClient, err := client.DialToSMTPClientWithContext(ctxDial) + if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + t.Skip("failed to connect to the test server due to timeout") + } + t.Fatalf("failed to connect to test server: %s", err) + } + t.Cleanup(func() { + if err := client.CloseWithSMTPClient(smtpClient); err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + t.Skip("failed to close the test server connection due to timeout") + } + t.Errorf("failed to close client: %s", err) + } + }) + if smtpClient == nil { + t.Fatal("expected SMTP client, got nil") + } + if !smtpClient.HasConnection() { + t.Fatal("expected connection on smtp client") + } + if ok, _ := smtpClient.Extension("DSN"); !ok { + t.Error("expected DSN extension but it was not found") + } + }) + t.Run("dial to SMTP server fails on first client writeFile", func(t *testing.T) { + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + failReadWriteSeekCloser{}, + failReadWriteSeekCloser{}, + } + + ctxDial, cancelDial := context.WithTimeout(context.Background(), time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithDialContextFunc(getFakeDialFunc(fake))) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + _, err = client.DialToSMTPClientWithContext(ctxDial) + if err == nil { + t.Fatal("expected connection to fake to fail") + } + }) +} + func TestClient_sendSingleMsg(t *testing.T) { t.Run("connect and send email", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -2748,7 +2857,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err != nil { + if err = client.sendSingleMsg(client.smtpClient, message); err != nil { t.Errorf("failed to send message: %s", err) } }) @@ -2791,7 +2900,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } }) @@ -2836,7 +2945,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2886,7 +2995,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2936,7 +3045,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2986,7 +3095,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err != nil { + if err = client.sendSingleMsg(client.smtpClient, message); err != nil { t.Errorf("failed to send message: %s", err) } }) @@ -3029,7 +3138,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3080,7 +3189,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3131,7 +3240,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3181,7 +3290,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3231,7 +3340,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3281,7 +3390,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Error("expected mail delivery to fail") } var sendErr *SendError @@ -3334,7 +3443,7 @@ func TestClient_checkConn(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.checkConn(); err != nil { + if err = client.checkConn(client.smtpClient); err != nil { t.Errorf("failed to check connection: %s", err) } }) @@ -3375,7 +3484,7 @@ func TestClient_checkConn(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.checkConn(); err == nil { + if err = client.checkConn(client.smtpClient); err == nil { t.Errorf("client should have failed on connection check") } if !errors.Is(err, ErrNoActiveConnection) { @@ -3387,7 +3496,7 @@ func TestClient_checkConn(t *testing.T) { if err != nil { t.Fatalf("failed to create new client: %s", err) } - if err = client.checkConn(); err == nil { + if err = client.checkConn(client.smtpClient); err == nil { t.Errorf("client should have failed on connection check") } if !errors.Is(err, ErrNoActiveConnection) { @@ -3611,24 +3720,20 @@ func TestClient_XOAuth2OnFaker(t *testing.T) { } if err = c.DialWithContext(context.Background()); err == nil { t.Fatal("expected dial error got nil") - } else { - if !errors.Is(err, ErrXOauth2AuthNotSupported) { - t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) - } + } + if !errors.Is(err, ErrXOauth2AuthNotSupported) { + t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) } if err = c.Close(); err != nil { t.Fatalf("disconnect from test server failed: %v", err) } client := strings.Split(wrote.String(), "\r\n") - if len(client) != 3 { - t.Fatalf("unexpected number of client requests got %d; want 3", len(client)) + if len(client) != 2 { + t.Fatalf("unexpected number of client requests got %d; want 2", len(client)) } if !strings.HasPrefix(client[0], "EHLO") { t.Fatalf("expected EHLO, got %q", client[0]) } - if client[1] != "QUIT" { - t.Fatalf("expected QUIT, got %q", client[3]) - } }) } @@ -3652,6 +3757,17 @@ func (f faker) SetDeadline(time.Time) error { return nil } func (f faker) SetReadDeadline(time.Time) error { return nil } func (f faker) SetWriteDeadline(time.Time) error { return nil } +type testSender struct { + client *Client +} + +func (t *testSender) Send(ctx context.Context, m *Msg) error { + if err := t.client.DialAndSendWithContext(ctx, m); err != nil { + return fmt.Errorf("failed to dial and send mail: %w", err) + } + return nil +} + // parseJSONLog parses a JSON encoded log from the provided buffer and returns a slice of logLine structs. // In case of a decode error, it reports the error to the testing framework. func parseJSONLog(t *testing.T, buf *bytes.Buffer) logData { diff --git a/smtp/smtp.go b/smtp/smtp.go index 8a278ee7..24a079f6 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -601,12 +601,16 @@ func (c *Client) SetLogAuthData() { // SetDSNMailReturnOption sets the DSN mail return option for the Mail method func (c *Client) SetDSNMailReturnOption(d string) { + c.mutex.Lock() c.dsnmrtype = d + c.mutex.Unlock() } // SetDSNRcptNotifyOption sets the DSN recipient notify option for the Mail method func (c *Client) SetDSNRcptNotifyOption(d string) { + c.mutex.Lock() c.dsnrntype = d + c.mutex.Unlock() } // HasConnection checks if the client has an active connection.