Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CSRF protection to the Hypervisor API #1604

Merged
merged 2 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/visor/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ var (
logTag string
hiddenflags []string
all bool
useCsrf bool
pkg bool
usr bool
localIPs []net.IP // nolint:unused
Expand Down Expand Up @@ -133,6 +134,7 @@ func init() {
RootCmd.Flags().BoolVar(&isForceColor, "forcecolor", false, "force color logging when out is not STDOUT")
hiddenflags = append(hiddenflags, "forcecolor")
RootCmd.Flags().BoolVar(&all, "all", false, "show all flags")
RootCmd.Flags().BoolVar(&useCsrf, "csrf", true, "Request a CSRF token for sensitive hypervisor API requests")
for _, j := range hiddenflags {
RootCmd.Flags().MarkHidden(j) //nolint
}
Expand Down
108 changes: 108 additions & 0 deletions pkg/visor/csrf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Package visor pkg/visor/hypervisor.go
package visor

import (
"time"

"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"strings"

"github.com/skycoin/skycoin/src/cipher"
)

const (
// CSRFHeaderName is the name of the CSRF header
CSRFHeaderName = "X-CSRF-Token"

// CSRFMaxAge is the lifetime of a CSRF token in seconds
CSRFMaxAge = time.Second * 30

csrfSecretLength = 64

csrfNonceLength = 64
)

var (
// ErrCSRFInvalid is returned when the the CSRF token is in invalid format
ErrCSRFInvalid = errors.New("invalid CSRF token")
// ErrCSRFExpired is returned when the csrf token has expired
ErrCSRFExpired = errors.New("csrf token expired")
)

var csrfSecretKey []byte

func init() {
csrfSecretKey = cipher.RandByte(csrfSecretLength)
}

// CSRFToken csrf token
type CSRFToken struct {
Nonce []byte
ExpiresAt time.Time
}

// newCSRFToken generates a new CSRF Token
func newCSRFToken() (string, error) {
token := &CSRFToken{
Nonce: cipher.RandByte(csrfNonceLength),
ExpiresAt: time.Now().Add(CSRFMaxAge),
}

tokenJSON, err := json.Marshal(token)
if err != nil {
return "", err
}

h := hmac.New(sha256.New, csrfSecretKey)
_, err = h.Write([]byte(tokenJSON))
if err != nil {
return "", err
}

sig := base64.RawURLEncoding.EncodeToString(h.Sum(nil))

signingString := base64.RawURLEncoding.EncodeToString(tokenJSON)

return strings.Join([]string{signingString, sig}, "."), nil
}

// verifyCSRFToken checks validity of the given token
func verifyCSRFToken(headerToken string) error {
tokenParts := strings.Split(headerToken, ".")
if len(tokenParts) != 2 {
return ErrCSRFInvalid
}

signingString, err := base64.RawURLEncoding.DecodeString(tokenParts[0])
if err != nil {
return err
}

h := hmac.New(sha256.New, csrfSecretKey)
_, err = h.Write([]byte(signingString))
if err != nil {
return err
}

sig := base64.RawURLEncoding.EncodeToString(h.Sum(nil))

if sig != tokenParts[1] {
return ErrCSRFInvalid
}

var csrfToken CSRFToken
err = json.Unmarshal(signingString, &csrfToken)
if err != nil {
return err
}

if time.Now().After(csrfToken.ExpiresAt) {
return ErrCSRFExpired
}

return nil
}
40 changes: 40 additions & 0 deletions pkg/visor/hypervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ func (hv *Hypervisor) makeMux() chi.Router {

r.Get("/ping", hv.getPong())

r.Get("/csrf", hv.getCsrf())

if hv.c.EnableAuth {
r.Group(func(r chi.Router) {
r.Post("/create-account", hv.users.CreateAccount())
Expand Down Expand Up @@ -299,6 +301,29 @@ func (hv *Hypervisor) getPong() http.HandlerFunc {
}
}

// Csrf provides a temporal security token.
type Csrf struct {
Token string `json:"csrf_token"`
}

func (hv *Hypervisor) getCsrf() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if useCsrf {
token, err := newCSRFToken()
if err != nil {
httputil.WriteJSON(w, r, http.StatusInternalServerError, err)
return
}

httputil.WriteJSON(w, r, http.StatusOK, Csrf{
Token: token,
})
} else {
httputil.WriteJSON(w, r, http.StatusOK, Csrf{Token: ""})
}
}
}

// About provides info about the hypervisor.
type About struct {
PubKey cipher.PubKey `json:"public_key"` // The hypervisor's public key.
Expand Down Expand Up @@ -1346,6 +1371,21 @@ func (hv *Hypervisor) visorCtx(w http.ResponseWriter, r *http.Request) (*httpCtx
return nil, false
}

if useCsrf && (r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE") {
csrfToken := r.Header.Get(CSRFHeaderName)
if csrfToken == "" {
errMsg := fmt.Errorf("no csrf token for %s request", r.Method)
httputil.WriteJSON(w, r, http.StatusForbidden, errMsg)
return nil, false
}

err = verifyCSRFToken(csrfToken)
if err != nil {
httputil.WriteJSON(w, r, http.StatusForbidden, err)
return nil, false
}
}

if pk != hv.c.PK {
v, ok := hv.visorConn(pk)

Expand Down
35 changes: 31 additions & 4 deletions static/skywire-manager-src/src/app/services/api.service.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Injectable, NgZone } from '@angular/core';
import { HttpClient, HttpErrorResponse, HttpHeaders } from '@angular/common/http';
import { Observable, throwError } from 'rxjs';
import { catchError, map } from 'rxjs/operators';
import { catchError, first, map, mergeMap } from 'rxjs/operators';
import { webSocket } from 'rxjs/webSocket';
import { Router } from '@angular/router';

Expand All @@ -22,6 +22,7 @@ export class RequestOptions {
requestType = RequestTypes.Json;
ignoreAuth = false;
vpnKeyForAuth: string;
csrfToken: string;

public constructor(init?: Partial<RequestOptions>) {
Object.assign(this, init);
Expand Down Expand Up @@ -69,23 +70,45 @@ export class ApiService {
* @param url Endpoint URL, after the "/api/" part.
*/
post(url: string, body: any = {}, options: RequestOptions = null): Observable<any> {
return this.request('POST', url, body, options);
return this.getCsrf().pipe(first(), mergeMap(csrf => {
options = options ? options : new RequestOptions();
options.csrfToken = csrf;

return this.request('POST', url, body, options);
}));
}

/**
* Makes a request to a PUT endpoint.
* @param url Endpoint URL, after the "/api/" part.
*/
put(url: string, body: any = {}, options: RequestOptions = null): Observable<any> {
return this.request('PUT', url, body, options);
return this.getCsrf().pipe(first(), mergeMap(csrf => {
options = options ? options : new RequestOptions();
options.csrfToken = csrf;

return this.request('PUT', url, body, options);
}));
}

/**
* Makes a request to a DELETE endpoint.
* @param url Endpoint URL, after the "/api/" part.
*/
delete(url: string, options: RequestOptions = null): Observable<any> {
return this.request('DELETE', url, {}, options);
return this.getCsrf().pipe(first(), mergeMap(csrf => {
options = options ? options : new RequestOptions();
options.csrfToken = csrf;

return this.request('DELETE', url, {}, options);
}));
}

/**
* Gets a csrf token from the node, to be able to make protected requests.
*/
private getCsrf(): Observable<string> {
return this.get('csrf').pipe(map(response => response.csrf_token));
}

/**
Expand Down Expand Up @@ -138,6 +161,10 @@ export class ApiService {
requestOptions.headers = requestOptions.headers.append('Content-Type', 'application/json');
}

if (options.csrfToken) {
requestOptions.headers = requestOptions.headers.append('X-CSRF-Token', options.csrfToken);
}

return requestOptions;
}

Expand Down
Loading