diff --git a/pkg/visor/cmd.go b/pkg/visor/cmd.go index 97650c04c..d212d95fb 100644 --- a/pkg/visor/cmd.go +++ b/pkg/visor/cmd.go @@ -45,6 +45,7 @@ var ( logTag string hiddenflags []string all bool + useCsrf bool pkg bool usr bool localIPs []net.IP // nolint:unused @@ -133,6 +134,7 @@ func init() { RootCmd.Flags().BoolVar(&isForceColor, "forcecolor", false, "force color logging when out is not STDOUT") hiddenflags = append(hiddenflags, "forcecolor") RootCmd.Flags().BoolVar(&all, "all", false, "show all flags") + RootCmd.Flags().BoolVar(&useCsrf, "csrf", true, "Request a CSRF token for sensitive hypervisor API requests") for _, j := range hiddenflags { RootCmd.Flags().MarkHidden(j) //nolint } diff --git a/pkg/visor/csrf.go b/pkg/visor/csrf.go new file mode 100644 index 000000000..f05580ffb --- /dev/null +++ b/pkg/visor/csrf.go @@ -0,0 +1,108 @@ +// Package visor pkg/visor/hypervisor.go +package visor + +import ( + "time" + + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "strings" + + "github.com/skycoin/skycoin/src/cipher" +) + +const ( + // CSRFHeaderName is the name of the CSRF header + CSRFHeaderName = "X-CSRF-Token" + + // CSRFMaxAge is the lifetime of a CSRF token in seconds + CSRFMaxAge = time.Second * 30 + + csrfSecretLength = 64 + + csrfNonceLength = 64 +) + +var ( + // ErrCSRFInvalid is returned when the the CSRF token is in invalid format + ErrCSRFInvalid = errors.New("invalid CSRF token") + // ErrCSRFExpired is returned when the csrf token has expired + ErrCSRFExpired = errors.New("csrf token expired") +) + +var csrfSecretKey []byte + +func init() { + csrfSecretKey = cipher.RandByte(csrfSecretLength) +} + +// CSRFToken csrf token +type CSRFToken struct { + Nonce []byte + ExpiresAt time.Time +} + +// newCSRFToken generates a new CSRF Token +func newCSRFToken() (string, error) { + token := &CSRFToken{ + Nonce: cipher.RandByte(csrfNonceLength), + ExpiresAt: time.Now().Add(CSRFMaxAge), + } + + tokenJSON, err := json.Marshal(token) + if err != nil { + return "", err + } + + h := hmac.New(sha256.New, csrfSecretKey) + _, err = h.Write([]byte(tokenJSON)) + if err != nil { + return "", err + } + + sig := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + + signingString := base64.RawURLEncoding.EncodeToString(tokenJSON) + + return strings.Join([]string{signingString, sig}, "."), nil +} + +// verifyCSRFToken checks validity of the given token +func verifyCSRFToken(headerToken string) error { + tokenParts := strings.Split(headerToken, ".") + if len(tokenParts) != 2 { + return ErrCSRFInvalid + } + + signingString, err := base64.RawURLEncoding.DecodeString(tokenParts[0]) + if err != nil { + return err + } + + h := hmac.New(sha256.New, csrfSecretKey) + _, err = h.Write([]byte(signingString)) + if err != nil { + return err + } + + sig := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + + if sig != tokenParts[1] { + return ErrCSRFInvalid + } + + var csrfToken CSRFToken + err = json.Unmarshal(signingString, &csrfToken) + if err != nil { + return err + } + + if time.Now().After(csrfToken.ExpiresAt) { + return ErrCSRFExpired + } + + return nil +} diff --git a/pkg/visor/hypervisor.go b/pkg/visor/hypervisor.go index a797cb87c..7a7016c63 100644 --- a/pkg/visor/hypervisor.go +++ b/pkg/visor/hypervisor.go @@ -215,6 +215,8 @@ func (hv *Hypervisor) makeMux() chi.Router { r.Get("/ping", hv.getPong()) + r.Get("/csrf", hv.getCsrf()) + if hv.c.EnableAuth { r.Group(func(r chi.Router) { r.Post("/create-account", hv.users.CreateAccount()) @@ -299,6 +301,29 @@ func (hv *Hypervisor) getPong() http.HandlerFunc { } } +// Csrf provides a temporal security token. +type Csrf struct { + Token string `json:"csrf_token"` +} + +func (hv *Hypervisor) getCsrf() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if useCsrf { + token, err := newCSRFToken() + if err != nil { + httputil.WriteJSON(w, r, http.StatusInternalServerError, err) + return + } + + httputil.WriteJSON(w, r, http.StatusOK, Csrf{ + Token: token, + }) + } else { + httputil.WriteJSON(w, r, http.StatusOK, Csrf{Token: ""}) + } + } +} + // About provides info about the hypervisor. type About struct { PubKey cipher.PubKey `json:"public_key"` // The hypervisor's public key. @@ -1346,6 +1371,21 @@ func (hv *Hypervisor) visorCtx(w http.ResponseWriter, r *http.Request) (*httpCtx return nil, false } + if useCsrf && (r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE") { + csrfToken := r.Header.Get(CSRFHeaderName) + if csrfToken == "" { + errMsg := fmt.Errorf("no csrf token for %s request", r.Method) + httputil.WriteJSON(w, r, http.StatusForbidden, errMsg) + return nil, false + } + + err = verifyCSRFToken(csrfToken) + if err != nil { + httputil.WriteJSON(w, r, http.StatusForbidden, err) + return nil, false + } + } + if pk != hv.c.PK { v, ok := hv.visorConn(pk) diff --git a/static/skywire-manager-src/src/app/services/api.service.ts b/static/skywire-manager-src/src/app/services/api.service.ts index 9e5c753cc..40931c74e 100644 --- a/static/skywire-manager-src/src/app/services/api.service.ts +++ b/static/skywire-manager-src/src/app/services/api.service.ts @@ -1,7 +1,7 @@ import { Injectable, NgZone } from '@angular/core'; import { HttpClient, HttpErrorResponse, HttpHeaders } from '@angular/common/http'; import { Observable, throwError } from 'rxjs'; -import { catchError, map } from 'rxjs/operators'; +import { catchError, first, map, mergeMap } from 'rxjs/operators'; import { webSocket } from 'rxjs/webSocket'; import { Router } from '@angular/router'; @@ -22,6 +22,7 @@ export class RequestOptions { requestType = RequestTypes.Json; ignoreAuth = false; vpnKeyForAuth: string; + csrfToken: string; public constructor(init?: Partial) { Object.assign(this, init); @@ -69,7 +70,12 @@ export class ApiService { * @param url Endpoint URL, after the "/api/" part. */ post(url: string, body: any = {}, options: RequestOptions = null): Observable { - return this.request('POST', url, body, options); + return this.getCsrf().pipe(first(), mergeMap(csrf => { + options = options ? options : new RequestOptions(); + options.csrfToken = csrf; + + return this.request('POST', url, body, options); + })); } /** @@ -77,7 +83,12 @@ export class ApiService { * @param url Endpoint URL, after the "/api/" part. */ put(url: string, body: any = {}, options: RequestOptions = null): Observable { - return this.request('PUT', url, body, options); + return this.getCsrf().pipe(first(), mergeMap(csrf => { + options = options ? options : new RequestOptions(); + options.csrfToken = csrf; + + return this.request('PUT', url, body, options); + })); } /** @@ -85,7 +96,19 @@ export class ApiService { * @param url Endpoint URL, after the "/api/" part. */ delete(url: string, options: RequestOptions = null): Observable { - return this.request('DELETE', url, {}, options); + return this.getCsrf().pipe(first(), mergeMap(csrf => { + options = options ? options : new RequestOptions(); + options.csrfToken = csrf; + + return this.request('DELETE', url, {}, options); + })); + } + + /** + * Gets a csrf token from the node, to be able to make protected requests. + */ + private getCsrf(): Observable { + return this.get('csrf').pipe(map(response => response.csrf_token)); } /** @@ -138,6 +161,10 @@ export class ApiService { requestOptions.headers = requestOptions.headers.append('Content-Type', 'application/json'); } + if (options.csrfToken) { + requestOptions.headers = requestOptions.headers.append('X-CSRF-Token', options.csrfToken); + } + return requestOptions; }