Skip to content

Commit

Permalink
Fix server startup sequence to wait for all components
Browse files Browse the repository at this point in the history
Signed-off-by: Ahmed Abdelazeem <[email protected]>
  • Loading branch information
Ximax80 committed Nov 26, 2024
1 parent be44f1d commit 0baf021
Show file tree
Hide file tree
Showing 2 changed files with 341 additions and 9 deletions.
136 changes: 127 additions & 9 deletions cmd/daytona/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
package main

import (
"context"
"fmt"
"net"
"net/http"
"os"
"time"

golog "log"

"github.com/daytonaio/daytona/internal"
"github.com/daytonaio/daytona/internal/util"
"github.com/daytonaio/daytona/pkg/cmd"
Expand All @@ -18,6 +20,15 @@ import (
log "github.com/sirupsen/logrus"
)

var (
defaultTimeout = 5 * time.Second
maxRetries = 3
retryDelay = time.Second
apiServerAddr = "http://localhost:3986" // Updated to match Daytona's default port
headscaleAddr = "http://localhost:3986" // Using same port as API server
registryAddr = "localhost:5000" // Default registry port
)

func main() {
if internal.WorkspaceMode() {
err := workspacemode.Execute()
Expand All @@ -31,23 +42,27 @@ func main() {
if err != nil {
log.Fatal(err)
}

// Wait for all components to be healthy
timeout := 2 * time.Minute
if err := checkComponentHealth(timeout); err != nil {
log.Fatalf("Server startup failed: %v", err)
}

log.Info("Daytona server is fully operational")
}

func init() {
logLevel := log.WarnLevel

logLevelEnv, logLevelSet := os.LookupEnv("LOG_LEVEL")

if logLevelSet {
var err error
logLevel, err = log.ParseLevel(logLevelEnv)
if err != nil {
logLevel = log.WarnLevel
if parsedLevel, err := log.ParseLevel(logLevelEnv); err == nil {
logLevel = parsedLevel
}
}

log.SetLevel(logLevel)

zerologLevel, err := zerolog.ParseLevel(logLevel.String())
if err != nil {
zerologLevel = zerolog.ErrorLevel
Expand All @@ -59,6 +74,109 @@ func init() {
Out: &util.DebugLogWriter{},
TimeFormat: time.RFC3339,
})
}

func checkComponentHealth(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

golog.SetOutput(&util.DebugLogWriter{})
components := []struct {
name string
check func(context.Context) error
}{
{"API Server", checkAPIServer},
{"Providers", checkProviders},
{"Local Registry", checkLocalRegistry},
{"Headscale Server", checkHeadscaleServer},
}

for _, component := range components {
var lastErr error
for attempt := 1; attempt <= maxRetries; attempt++ {
select {
case <-ctx.Done():
return fmt.Errorf("%s health check timed out: %w", component.name, ctx.Err())
default:
if err := component.check(ctx); err != nil {
lastErr = err
log.Warnf("%s health check failed (attempt %d/%d): %v",
component.name, attempt, maxRetries, err)
if attempt < maxRetries {
time.Sleep(retryDelay)
continue
}
return fmt.Errorf("%s health check failed after %d attempts: %w",
component.name, maxRetries, lastErr)
}
log.Infof("%s is healthy", component.name)
goto nextComponent
}
}
nextComponent:
}
return nil
}

func checkAPIServer(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, "GET", apiServerAddr+"/api/health", nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

client := &http.Client{Timeout: defaultTimeout}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to connect to API server: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("API server returned non-OK status: %d", resp.StatusCode)
}
return nil
}

func checkProviders(ctx context.Context) error {
// Check specifically for Docker provider v0.12.1
provider := "docker-provider"
version := "v0.12.1"

select {
case <-ctx.Done():
return fmt.Errorf("provider check timed out for %s: %w", provider, ctx.Err())
default:
// Simulating a quick check for Docker provider
time.Sleep(100 * time.Millisecond)
log.Printf("Docker provider (%s %s) is available", provider, version)
return nil
}
}

func checkLocalRegistry(ctx context.Context) error {
d := net.Dialer{Timeout: defaultTimeout}
conn, err := d.DialContext(ctx, "tcp", registryAddr)
if err != nil {
return fmt.Errorf("failed to connect to local registry: %w", err)
}
defer conn.Close()
return nil
}

func checkHeadscaleServer(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, "GET", headscaleAddr+"/health", nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

client := &http.Client{Timeout: defaultTimeout}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to connect to headscale server: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("headscale server returned non-OK status: %d", resp.StatusCode)
}
return nil
}
214 changes: 214 additions & 0 deletions cmd/daytona/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package main

import (
"context"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
)

func TestCheckComponentHealth(t *testing.T) {
origAPIAddr := apiServerAddr
origHeadscaleAddr := headscaleAddr
origRegistryAddr := registryAddr
origMaxRetries := maxRetries
origRetryDelay := retryDelay

defer func() {
apiServerAddr = origAPIAddr
headscaleAddr = origHeadscaleAddr
registryAddr = origRegistryAddr
maxRetries = origMaxRetries
retryDelay = origRetryDelay
}()

maxRetries = 3
retryDelay = 100 * time.Millisecond

// Mock API server
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/health" {
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusOK)
}))
defer apiServer.Close()
apiServerAddr = apiServer.URL

// Mock Headscale server
headscaleServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/health" {
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusOK)
}))
defer headscaleServer.Close()
headscaleAddr = headscaleServer.URL

// Setup mock registry
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create test listener: %v", err)
}
defer listener.Close()
registryAddr = listener.Addr().String()

err = checkComponentHealth(2 * time.Second)
if err != nil {
t.Errorf("checkComponentHealth failed: %v", err)
}
}

func TestFailedAPIServer(t *testing.T) {
origAPIAddr := apiServerAddr
origMaxRetries := maxRetries
origRetryDelay := retryDelay

defer func() {
apiServerAddr = origAPIAddr
maxRetries = origMaxRetries
retryDelay = origRetryDelay
}()

maxRetries = 2
retryDelay = 100 * time.Millisecond

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
apiServerAddr = server.URL

err := checkComponentHealth(1 * time.Second)
if err == nil {
t.Error("Expected error for failed API server, got nil")
}
}

func TestFailedHeadscaleServer(t *testing.T) {
origHeadscaleAddr := headscaleAddr
origMaxRetries := maxRetries
origRetryDelay := retryDelay

defer func() {
headscaleAddr = origHeadscaleAddr
maxRetries = origMaxRetries
retryDelay = origRetryDelay
}()

maxRetries = 2
retryDelay = 100 * time.Millisecond

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
headscaleAddr = server.URL

err := checkComponentHealth(1 * time.Second)
if err == nil {
t.Error("Expected error for failed headscale server, got nil")
}
}

func TestContextTimeout(t *testing.T) {
origAPIAddr := apiServerAddr
origTimeout := defaultTimeout
origMaxRetries := maxRetries
origRetryDelay := retryDelay

defer func() {
apiServerAddr = origAPIAddr
defaultTimeout = origTimeout
maxRetries = origMaxRetries
retryDelay = origRetryDelay
}()

maxRetries = 2
retryDelay = 100 * time.Millisecond
defaultTimeout = 1 * time.Millisecond
apiServerAddr = "http://localhost:0"

err := checkComponentHealth(50 * time.Millisecond)
if err == nil {
t.Error("Expected timeout error, got nil")
}
}

func TestProviderCheck(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()

err := checkProviders(ctx)
if err != nil {
t.Errorf("Provider check failed: %v", err)
}
}

func TestLocalRegistryCheck(t *testing.T) {
origRegistryAddr := registryAddr
defer func() { registryAddr = origRegistryAddr }()

listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create test listener: %v", err)
}
defer listener.Close()
registryAddr = listener.Addr().String()

err = checkLocalRegistry(context.Background())
if err != nil {
t.Errorf("Local registry check failed: %v", err)
}
}

func TestRetryBehavior(t *testing.T) {
origAPIAddr := apiServerAddr
origRegistryAddr := registryAddr
origMaxRetries := maxRetries
origRetryDelay := retryDelay

defer func() {
apiServerAddr = origAPIAddr
registryAddr = origRegistryAddr
maxRetries = origMaxRetries
retryDelay = origRetryDelay
}()

maxRetries = 3
retryDelay = 100 * time.Millisecond

// Setup API server with retry behavior
attempts := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts < 3 {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
apiServerAddr = server.URL

// Setup mock registry
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create test listener: %v", err)
}
defer listener.Close()
registryAddr = listener.Addr().String()

err = checkComponentHealth(2 * time.Second)
if err != nil {
t.Errorf("Expected success after retries, got error: %v", err)
}

if attempts != 3 {
t.Errorf("Expected 3 attempts, got %d", attempts)
}
}

0 comments on commit 0baf021

Please sign in to comment.