Skip to content

Commit

Permalink
Merge pull request #9 from stealthrocket/cleanup-request-on-failure-o…
Browse files Browse the repository at this point in the history
…r-shutdown

Notify API if a response cannot be generated
  • Loading branch information
chriso authored Apr 16, 2024
2 parents c815406 + 36ffd0d commit cef264b
Showing 1 changed file with 89 additions and 31 deletions.
120 changes: 89 additions & 31 deletions cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"os/signal"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
Expand All @@ -27,7 +28,10 @@ var (

const defaultEndpoint = "localhost:8000"

const pollTimeout = 30 * time.Second
const (
pollTimeout = 30 * time.Second
cleanupTimeout = 5 * time.Second
)

var httpClient = &http.Client{
Transport: http.DefaultTransport,
Expand Down Expand Up @@ -89,17 +93,21 @@ Run 'dispatch help run' to learn about Dispatch sessions.`, BridgeSession)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var wg sync.WaitGroup

// Setup signal handler.
signals := make(chan os.Signal, 2)
signal.Notify(signals, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
var signaled bool
wg.Add(1)
go func() {
defer wg.Done()

for {
select {
case <-ctx.Done():
return
case <-signals:
cancel()
if !signaled {
signaled = true
_ = cmd.Process.Signal(syscall.SIGTERM)
Expand All @@ -110,11 +118,18 @@ Run 'dispatch help run' to learn about Dispatch sessions.`, BridgeSession)
}
}()

bridgeSessionURL := fmt.Sprintf("%s/sessions/%s", DispatchBridgeUrl, BridgeSession)

// Poll for work in the background.
var successfulPolls int64
wg.Add(1)
go func() {
for {
if err := poll(ctx, httpClient); err != nil {
defer wg.Done()

for ctx.Err() == nil {
// Fetch a request from the API.
requestID, res, err := poll(ctx, httpClient, bridgeSessionURL)
if err != nil {
if ctx.Err() != nil {
return
}
Expand All @@ -126,13 +141,42 @@ Run 'dispatch help run' to learn about Dispatch sessions.`, BridgeSession)
slog.Warn(err.Error())
}
time.Sleep(1 * time.Second)
continue
}

atomic.AddInt64(&successfulPolls, +1)

// Asynchronously send the request to invoke a function to
// the local application.
wg.Add(1)
go func() {
defer wg.Done()

err := invoke(ctx, httpClient, bridgeSessionURL, requestID, res)
if err != nil {
if ctx.Err() == nil {
slog.Warn(err.Error())
}

// Notify upstream if we're unable to generate a response,
// either because the local application can't be contacted,
// is misbehaving, or a shutdown sequence has been initiated.
ctx, cancel := context.WithTimeout(context.Background(), cleanupTimeout)
defer cancel()
if err := cleanup(ctx, httpClient, bridgeSessionURL, requestID); err != nil {
slog.Debug(err.Error())
}
}
}()
}
}()

err := cmd.Run()

// Cancel the context and wait for all goroutines to return.
cancel()
wg.Wait()

// If the command was halted by a signal rather than some other error,
// assume that the command invocation succeeded and that the user may
// want to resume this session.
Expand All @@ -158,53 +202,45 @@ Run 'dispatch help run' to learn about Dispatch sessions.`, BridgeSession)
return cmd
}

func poll(ctx context.Context, client *http.Client) error {
// Fetch a request from the API.
bridgeSessionURL := fmt.Sprintf("%s/sessions/%s", DispatchBridgeUrl, BridgeSession)
slog.Debug("getting request from API", "url", bridgeSessionURL)
func poll(ctx context.Context, client *http.Client, url string) (string, *http.Response, error) {
slog.Debug("getting request from API", "url", url)

bridgeGetReq, err := http.NewRequestWithContext(ctx, "GET", bridgeSessionURL, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
panic(err)
}
bridgeGetReq.Header.Add("Authorization", "Bearer "+DispatchApiKey)
bridgeGetReq.Header.Add("Connect-Timeout-Ms", strconv.FormatInt(pollTimeout.Milliseconds(), 10))
req.Header.Add("Authorization", "Bearer "+DispatchApiKey)
req.Header.Add("Request-Timeout", strconv.FormatInt(int64(pollTimeout.Seconds()), 10))
if DispatchBridgeHostHeader != "" {
bridgeGetReq.Host = DispatchBridgeHostHeader
req.Host = DispatchBridgeHostHeader
}

bridgeGetRes, err := client.Do(bridgeGetReq)
res, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to contact Dispatch API (%s): %v", DispatchBridgeUrl, err)
return "", nil, fmt.Errorf("failed to contact Dispatch API (%s): %v", DispatchBridgeUrl, err)
}
if bridgeGetRes.StatusCode != http.StatusOK {
bridgeGetRes.Body.Close()
if res.StatusCode != http.StatusOK {
res.Body.Close()

switch bridgeGetRes.StatusCode {
switch res.StatusCode {
case http.StatusUnauthorized:
return authError{}
return "", nil, authError{}
case http.StatusGatewayTimeout:
// A 504 is expected when long polling and no requests
// are available. Return a nil in this case and let the
// caller try again.
return nil
return "", nil, nil
default:
return fmt.Errorf("failed to contact Dispatch API (%s): response code %d", DispatchBridgeUrl, bridgeGetRes.StatusCode)
return "", nil, fmt.Errorf("failed to contact Dispatch API (%s): response code %d", DispatchBridgeUrl, res.StatusCode)
}
}

requestID := bridgeGetRes.Header.Get("X-Request-Id")
requestID := res.Header.Get("X-Request-Id")

go func() {
if err := invoke(ctx, client, bridgeSessionURL, requestID, bridgeGetRes); err != nil {
slog.Warn(err.Error())
}
}()

return nil
return requestID, res, nil
}

func invoke(ctx context.Context, client *http.Client, bridgeSessionURL, requestID string, bridgeGetRes *http.Response) error {
func invoke(ctx context.Context, client *http.Client, url, requestID string, bridgeGetRes *http.Response) error {
defer bridgeGetRes.Body.Close()

slog.Debug("sending request from Dispatch API to local application", "endpoint", LocalEndpoint, "request_id", requestID)
Expand Down Expand Up @@ -243,7 +279,7 @@ func invoke(ctx context.Context, client *http.Client, bridgeSessionURL, requestI
slog.Debug("sending local application's response to Dispatch API", "request_id", requestID)

// Send the response back to the API.
bridgePostReq, err := http.NewRequestWithContext(ctx, "POST", bridgeSessionURL, bufio.NewReader(&bufferedEndpointRes))
bridgePostReq, err := http.NewRequestWithContext(ctx, "POST", url, bufio.NewReader(&bufferedEndpointRes))
if err != nil {
panic(err)
}
Expand All @@ -256,8 +292,30 @@ func invoke(ctx context.Context, client *http.Client, bridgeSessionURL, requestI
if err != nil {
return fmt.Errorf("failed to contact Dispatch API: %v", err)
}
if bridgePostRes.StatusCode != 202 {
if bridgePostRes.StatusCode != http.StatusAccepted {
return fmt.Errorf("failed to contact Dispatch API: response code %d", bridgePostRes.StatusCode)
}
return nil
}

func cleanup(ctx context.Context, client *http.Client, url, requestID string) error {
slog.Debug("cleaning up request", "request_id", requestID)

req, err := http.NewRequestWithContext(ctx, "DELETE", url, nil)
if err != nil {
panic(err)
}
req.Header.Add("Authorization", "Bearer "+DispatchApiKey)
req.Header.Add("X-Request-ID", requestID)
if DispatchBridgeHostHeader != "" {
req.Host = DispatchBridgeHostHeader
}
res, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to contact Dispatch API: %v", err)
}
if res.StatusCode != http.StatusOK {
return fmt.Errorf("failed to contact Dispatch API: response code %d", res.StatusCode)
}
return nil
}

0 comments on commit cef264b

Please sign in to comment.