diff --git a/pkg/db/client.go b/pkg/db/client.go index 48b40d4..e1e9228 100644 --- a/pkg/db/client.go +++ b/pkg/db/client.go @@ -206,8 +206,6 @@ func (t *TridentDB) StreamingInsertResults() chan *Result { } commit: - log.Printf("streaming %d records to db", count) - _, err = stmt.Exec() if err != nil { log.Fatal(err) diff --git a/pkg/dispatch/clients/webhook/webhook.go b/pkg/dispatch/clients/webhook/webhook.go index 7011820..efbaa80 100644 --- a/pkg/dispatch/clients/webhook/webhook.go +++ b/pkg/dispatch/clients/webhook/webhook.go @@ -17,6 +17,7 @@ package webhook import ( "bytes" "encoding/json" + "errors" "fmt" "net/http" @@ -85,6 +86,15 @@ func (w *Client) Submit(r event.AuthRequest) (*event.AuthResponse, error) { } defer resp.Body.Close() // nolint:errcheck + if resp.StatusCode != 200 { + var res event.ErrorResponse + err = json.NewDecoder(resp.Body).Decode(&res) + if err != nil { + return nil, err + } + return nil, errors.New(res.ErrorMsg) + } + var res event.AuthResponse err = json.NewDecoder(resp.Body).Decode(&res) return &res, err diff --git a/pkg/dispatch/dispatch.go b/pkg/dispatch/dispatch.go index aeb0f76..eaab510 100644 --- a/pkg/dispatch/dispatch.go +++ b/pkg/dispatch/dispatch.go @@ -94,28 +94,26 @@ func NewDispatcher(ctx context.Context, opts Options, wc WorkerClient) (*Dispatc // to the worker and results are then published to the Pub/Sub topic. func (d *Dispatcher) Listen(ctx context.Context) error { return d.sub.Receive(ctx, func(ctx context.Context, msg *pubsub.Message) { + // always ACK messages to avoid infinite loop handling a bad message + defer msg.Ack() + var req event.AuthRequest err := json.Unmarshal(msg.Data, &req) if err != nil { log.Printf("error unmarshaling: %s", err) - msg.Ack() return } ts := time.Now() if ts.After(req.NotAfter) { - log.Printf("received an event after end time, dropping") - msg.Ack() return } resp, err := d.wc.Submit(req) if err != nil { log.Printf("error from worker: %s", err) - msg.Nack() return } - msg.Ack() b, _ := json.Marshal(resp) d.resultc.Publish(ctx, &pubsub.Message{ diff --git a/pkg/event/event.go b/pkg/event/event.go index 8895660..af4e051 100644 --- a/pkg/event/event.go +++ b/pkg/event/event.go @@ -74,3 +74,10 @@ type AuthResponse struct { // Additional metadata from the auth provider (e.g. information about MFA) Metadata map[string]interface{} `json:"metadata"` } + +// ErrorResponse represents a failure in task processing. This response should +// be accompanied by a non-200 HTTP response code (e.g. HTTP 500). +type ErrorResponse struct { + // ErrorMsg is the result of error.Error() + ErrorMsg string `json:"error"` +} diff --git a/pkg/server/handlers.go b/pkg/server/handlers.go index 0e58fdd..47d619d 100644 --- a/pkg/server/handlers.go +++ b/pkg/server/handlers.go @@ -17,7 +17,6 @@ package server import ( "encoding/json" "errors" - "fmt" "net/http" log "github.com/sirupsen/logrus" @@ -46,16 +45,13 @@ func (s *Server) CampaignHandler(w http.ResponseWriter, r *http.Request) { err := parse.DecodeJSONBody(w, r, &c) if err != nil { - log.Errorf("error parsing json: %s", err) - var mr *parse.MalformedRequest - if errors.As(err, &mr) { http.Error(w, mr.Msg, mr.Status) } else { - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + log.Errorf("unknown error decoding json: %s", err) + http.Error(w, http.StatusText(500), 500) } - return } @@ -82,30 +78,24 @@ func (s *Server) CampaignHandler(w http.ResponseWriter, r *http.Request) { // ResultsHandler takes a user defined database query (returned fields + filter) // and applies it, returning the results in JSON func (s *Server) ResultsHandler(w http.ResponseWriter, r *http.Request) { - log.Info("retrieving results for query") var q db.Query err := parse.DecodeJSONBody(w, r, &q) if err != nil { - log.Errorf("error parsing json: %s", err) - var mr *parse.MalformedRequest - if errors.As(err, &mr) { http.Error(w, mr.Msg, mr.Status) } else { - log.Errorf("there was something else we don't know: %s", err) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + log.Errorf("unknown error decoding json: %s", err) + http.Error(w, http.StatusText(500), 500) } - return } results, err := s.DB.SelectResults(q) if err != nil { - message := fmt.Sprintf("there was an error collecting results from the database: %s", err) - log.Error(message) - http.Error(w, message, http.StatusInternalServerError) + log.Printf("error querying database: %s", err) + http.Error(w, http.StatusText(500), 500) } err = json.NewEncoder(w).Encode(&results) @@ -120,15 +110,12 @@ func (s *Server) ResultsHandler(w http.ResponseWriter, r *http.Request) { // CampaignListHandler accepts no parameters and returns the list of active campaigns // via JSON func (s *Server) CampaignListHandler(w http.ResponseWriter, r *http.Request) { - log.Info("retrieving list of active campaigns") - log.Info("is this even deploying...") var campaigns []db.Campaign campaigns, err := s.DB.ListCampaign() if err != nil { - message := fmt.Sprintf("there was an error collecting results from the database: %s", err) - log.Error(message) - http.Error(w, message, http.StatusInternalServerError) + log.Printf("error querying database: %s", err) + http.Error(w, http.StatusText(500), 500) } err = json.NewEncoder(w).Encode(&campaigns) @@ -143,31 +130,25 @@ func (s *Server) CampaignListHandler(w http.ResponseWriter, r *http.Request) { // CampaignDescribeHandler takes a user-defined DB query with the campaignID, then // returns the parameters of that campaign via JSON func (s *Server) CampaignDescribeHandler(w http.ResponseWriter, r *http.Request) { - log.Info("retrieving description of queried campaign") var q db.Query var campaign db.Campaign err := parse.DecodeJSONBody(w, r, &q) if err != nil { - log.Errorf("error parsing json: %s", err) - var mr *parse.MalformedRequest - if errors.As(err, &mr) { http.Error(w, mr.Msg, mr.Status) } else { - log.Errorf("there was something else we don't know: %s", err) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + log.Errorf("unknown error decoding json: %s", err) + http.Error(w, http.StatusText(500), 500) } - return } campaign, err = s.DB.DescribeCampaign(q) if err != nil { - message := fmt.Sprintf("there was an error collecting results from the database: %s", err) - log.Error(message) - http.Error(w, message, http.StatusInternalServerError) + log.Printf("error querying database: %s", err) + http.Error(w, http.StatusText(500), 500) } err = json.NewEncoder(w).Encode(&campaign) diff --git a/pkg/worker/webhook/handlers.go b/pkg/worker/webhook/handlers.go index e4eb13b..a0c7a27 100644 --- a/pkg/worker/webhook/handlers.go +++ b/pkg/worker/webhook/handlers.go @@ -16,7 +16,7 @@ package webhook import ( "encoding/json" - "errors" + "fmt" "net/http" "time" @@ -24,7 +24,6 @@ import ( "github.com/praetorian-inc/trident/pkg/event" "github.com/praetorian-inc/trident/pkg/nozzle" - "github.com/praetorian-inc/trident/pkg/parse" "github.com/praetorian-inc/trident/pkg/util" ) @@ -47,40 +46,33 @@ func NewWebhookServer() (*Server, error) { // HealthzHandler returns an HTTP 200 ok always. func (s *Server) HealthzHandler(w http.ResponseWriter, r *http.Request) {} +func httperr(w http.ResponseWriter, err error) { + res := event.ErrorResponse{ErrorMsg: err.Error()} + w.WriteHeader(500) + json.NewEncoder(w).Encode(&res) // nolint:errcheck,gosec +} + // EventHandler accepts an AuthRequest, executes the task using the nozzle // interface and returns the AuthResponse via JSON. func (s *Server) EventHandler(w http.ResponseWriter, r *http.Request) { - log.Info("retrieving results for query") var req event.AuthRequest - err := parse.DecodeJSONBody(w, r, &req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - log.Infof("error parsing json: %s", err) - - var mr *parse.MalformedRequest - - if errors.As(err, &mr) { - http.Error(w, mr.Msg, mr.Status) - } else { - log.Errorf("there was something else we don't know: %s", err) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - } - + httperr(w, fmt.Errorf("error decoding body: %w", err)) return } noz, err := nozzle.Open(req.Provider, req.ProviderMetadata) if err != nil { - log.Errorf("error opening nozzle: %s", err) - http.Error(w, err.Error(), http.StatusInternalServerError) + httperr(w, fmt.Errorf("error opening nozzle: %w", err)) return } ts := time.Now() res, err := noz.Login(req.Username, req.Password) if err != nil { - log.Errorf("error logging in to %s: %s", req.Provider, err) - http.Error(w, err.Error(), http.StatusInternalServerError) + httperr(w, fmt.Errorf("error authenticating to %s provider: %w", req.Provider, err)) return } @@ -91,8 +83,5 @@ func (s *Server) EventHandler(w http.ResponseWriter, r *http.Request) { res.Timestamp = ts res.IP = s.ip - err = json.NewEncoder(w).Encode(&res) - if err != nil { - log.Printf("error writing to http response: %s", err) - } + json.NewEncoder(w).Encode(&res) // nolint:errcheck,gosec }