From 5ff026eecc2f8122061ef094dd36a2ef426168d4 Mon Sep 17 00:00:00 2001 From: yutopp Date: Thu, 20 Jul 2023 02:45:31 +0900 Subject: [PATCH] Support DialOption --- client.go | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index 1c79acd..12fff8e 100644 --- a/client.go +++ b/client.go @@ -8,21 +8,40 @@ package rtmp import ( + "context" "net" "github.com/pkg/errors" ) -func Dial(protocol, addr string, config *ConnConfig) (*ClientConn, error) { - return DialWithDialer(&net.Dialer{}, protocol, addr, config) +type dialOptions struct { + dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) } -func DialWithDialer(dialer *net.Dialer, protocol, addr string, config *ConnConfig) (*ClientConn, error) { +func WithContextDialer(dialFunc func(context.Context, string, string) (net.Conn, error)) DialOption { + return func(o *dialOptions) { + o.dialFunc = dialFunc + } +} + +type DialOption func(*dialOptions) + +func Dial(protocol, addr string, config *ConnConfig, opts ...DialOption) (*ClientConn, error) { + opt := &dialOptions{ + dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, network, addr) + }, + } + for _, o := range opts { + o(opt) + } + if protocol != "rtmp" { return nil, errors.Errorf("Unknown protocol: %s", protocol) } - rwc, err := dialer.Dial("tcp", addr) + // TODO: support ctx + rwc, err := opt.dialFunc(context.Background(), "tcp", addr) if err != nil { return nil, err } @@ -30,6 +49,10 @@ func DialWithDialer(dialer *net.Dialer, protocol, addr string, config *ConnConfi return newClientConnWithSetup(rwc, config) } +func DialWithDialer(dialer *net.Dialer, protocol, addr string, config *ConnConfig) (*ClientConn, error) { + return Dial(protocol, addr, config, WithContextDialer(dialer.DialContext)) +} + func makeValidAddr(addr string) (string, error) { host, port, err := net.SplitHostPort(addr) if err != nil {