diff --git a/config-example.yaml b/config-example.yaml index a8c250310f..95c080f43c 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -40,12 +40,15 @@ grpc_listen_addr: 127.0.0.1:50443 # are doing. grpc_allow_insecure: false -# The Access-Control-Allow-Origin header specifies which origins are allowed to access resources. +# The allow_origins list will allow you to set the Access-Control-Allow-Origin header to the origin in the list. +# This will allow you to enable cors and set headscale without a reverse proxy. +# Multiple origins can be set in the allow_origins list. # Options: -# - "*" to allow access from any origin (not recommended for sensitive data). -# - "http://example.com" to only allow access from a specific origin. -# - "" to disable Cross-Origin Resource Sharing (CORS). -access_control_allow_origin: "" +# - "*" is disabled (due to security risks). +# - "https://example.com" to only allow access from a specific origin. +# - "https://example.com:1234" to allow access from a specific origin with a port. +cors: + allow_origins: [] # The Noise section includes specific configuration for the # TS2021 Noise protocol diff --git a/hscontrol/app.go b/hscontrol/app.go index 72ac1693d4..dc735cab1f 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -455,18 +455,63 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error { return os.Remove(h.cfg.UnixSocket) } +// corsHeaderMiddleware will add an "Access-Control-Allow-Origin" to enable CORS. func (h *Headscale) corsHeadersMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", h.cfg.AccessControlAllowOrigins) + // skip disabled CORS endpoints + if !h.enabledCorsRoutes(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + + origin := r.Header.Get("Origin") + // we compare origin from the allowed Origins list. Then add the header with origin + for _, allowedOrigin := range h.cfg.AllowedOrigins.Origins { + if allowedOrigin == origin { + w.Header().Set("Vary", "Origin") + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + break + } + } next.ServeHTTP(w, r) }) } +func (h *Headscale) enabledCorsRoutes(routerPath string) bool { + // enable all api endpoints + if strings.HasPrefix(routerPath, "/api/") { + return true + } + + // A list of enabled CORS endpoints + enabledRoutes := []string{ + "/health", + "/key", + "/register/{registration_id}", + "/oidc/callback", + "/verify", + "/derp", + "/derp/probe", + "/derp/latency-check", + "/bootstrap-dns", + "/machine/register", + "/machine/map", + } + + for _, routes := range enabledRoutes { + if routes == routerPath { + return true + } + } + + return false +} + func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { router := mux.NewRouter() router.Use(prometheusMiddleware) - if h.cfg.AccessControlAllowOrigins != "" { + if len(h.cfg.AllowedOrigins.Origins) != 0 { router.Use(h.corsHeadersMiddleware) } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 080d78fe55..00111fd672 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -66,7 +66,7 @@ type Config struct { Log LogConfig DisableUpdateCheck bool - AccessControlAllowOrigins string + AllowedOrigins CorsConfig Database DatabaseConfig @@ -210,6 +210,10 @@ type LogTailConfig struct { Enabled bool } +type CorsConfig struct { + Origins []string +} + type CLIConfig struct { Address string APIKey string @@ -534,6 +538,14 @@ func logtailConfig() LogTailConfig { } } +func corsConfig() CorsConfig { + allowedOrigins := viper.GetStringSlice("cors.allowed_origins") + + return CorsConfig{ + Origins: allowedOrigins, + } +} + func policyConfig() PolicyConfig { policyPath := viper.GetString("policy.path") policyMode := viper.GetString("policy.mode") @@ -907,7 +919,7 @@ func LoadServerConfig() (*Config, error) { GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"), DisableUpdateCheck: false, - AccessControlAllowOrigins: viper.GetString("access_control_allow_origin"), + AllowedOrigins: corsConfig(), PrefixV4: prefix4, PrefixV6: prefix6,