Skip to content

Commit

Permalink
Allow direct port connection in Go SDK (#325)
Browse files Browse the repository at this point in the history
* Allow direct port connection in Go SDK

* Update package version
  • Loading branch information
dmgardiner25 authored Oct 5, 2023
1 parent 26a9259 commit 97233d2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 20 deletions.
58 changes: 39 additions & 19 deletions go/tunnels/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,33 +136,26 @@ func (c *Client) Connect(ctx context.Context, hostID string) error {
return nil
}

// Opens a stream connected to a remote port for clients which cannot or do not want to forward local TCP ports.
// Returns a readWriteCloser which can be used to read and write to the remote port.
// Set AcceptLocalConnectionsForForwardedPorts to false in ConnectAsync to ensure TCP listeners are not created
// This will return an error if the port is not yet forwarded,
// the caller should first call WaitForForwardedPort.
func (c *Client) ConnectToForwardedPort(ctx context.Context, listenerIn *net.TCPListener, port uint16) error {
// ConnectListenerToForwardedPort opens a stream to a remote port and connects it to a given listener.
//
// Ensure that the port is already forwarded before calling this function
// by calling WaitForForwardedPort. Otherwise, this will return an error.
//
// Set acceptLocalConnectionsForForwardedPorts to false when creating the client to ensure
// TCP listeners are not created for all ports automatically when the client connects.
func (c *Client) ConnectListenerToForwardedPort(ctx context.Context, listenerIn net.Listener, port uint16) error {
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting
// stuck in case of concurrent or sequential errors.
select {
case errc <- err:
default:
}
}

go func() {
for {
conn, err := listenerIn.AcceptTCP()
conn, err := listenerIn.Accept()
if err != nil {
sendError(err)
sendError(err, errc)
return
}

go func() {
if err := c.handleConnection(ctx, conn, port); err != nil {
sendError(err)
if err := c.ConnectToForwardedPort(ctx, conn, port); err != nil {
sendError(err, errc)
}
}()
}
Expand All @@ -171,6 +164,24 @@ func (c *Client) ConnectToForwardedPort(ctx context.Context, listenerIn *net.TCP
return awaitError(ctx, errc)
}

// ConnectToForwardedPort opens a stream to a remote port and connects it to a given connection.
//
// Ensure that the port is already forwarded before calling this function
// by calling WaitForForwardedPort. Otherwise, this will return an error.
//
// Set acceptLocalConnectionsForForwardedPorts to false when creating the client to ensure
// TCP listeners are not created for all ports automatically when the client connects.
func (c *Client) ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, port uint16) error {
errc := make(chan error, 1)
go func() {
if err := c.handleConnection(ctx, conn, port); err != nil {
sendError(err, errc)
}
}()

return awaitError(ctx, errc)
}

// WaitForForwardedPort waits for the specified port to be forwarded.
// It is common practice to call this function before ConnectToForwardedPort.
func (c *Client) WaitForForwardedPort(ctx context.Context, port uint16) error {
Expand Down Expand Up @@ -207,6 +218,15 @@ func (c *Client) RefreshPorts(ctx context.Context) error {
return err
}

func sendError(err error, errc chan error) {
// Use non-blocking send, to avoid goroutines getting
// stuck in case of concurrent or sequential errors.
select {
case errc <- err:
default:
}
}

func awaitError(ctx context.Context, errc chan error) error {
select {
case err := <-errc:
Expand Down
17 changes: 16 additions & 1 deletion go/tunnels/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,22 @@ func TestPortForwarding(t *testing.T) {
done <- fmt.Errorf("wait for forwarded port failed: %v", err)
return
}
err = c.ConnectToForwardedPort(ctx, listen, streamPort)

// Test connecting with a listener
err = c.ConnectListenerToForwardedPort(ctx, listen, streamPort)
if err != nil {
done <- fmt.Errorf("connect to forwarded port failed: %v", err)
return
}

// Connect to the listener and and test connecting with the given connection
conn, err := listen.Accept()
if err != nil {
done <- fmt.Errorf("accept connection failed: %v", err)
return
}

err = c.ConnectToForwardedPort(ctx, conn, streamPort)
if err != nil {
done <- fmt.Errorf("connect to forwarded port failed: %v", err)
return
Expand Down

0 comments on commit 97233d2

Please sign in to comment.