-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathsocket.go
131 lines (120 loc) · 3.8 KB
/
socket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/*
* Copyright 2017 LinkedIn Corporation. All rights reserved. Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package main
import (
"context"
"encoding/json"
"net/http"
"github.com/satori/go.uuid"
"github.com/gorilla/websocket"
)
// upgrader is a global upgrader, because why not
// TODO: Come up with a better reason other than "why not" or fix it
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // This is normally bad practice, but for this tool it's intentional
},
}
// WebSocketRequest is the
type WebSocketRequest struct {
RequestID uuid.UUID `json:"requestId"`
Action string `json:"action"`
}
// WebSocketInitRequest is the initial request
type WebSocketInitRequest struct {
Navigator struct {
UserAgent string `json:"userAgent"`
} `json:"navigator"`
}
// WebSocketHostRequest is the request to offer rebinds for a given host
type WebSocketHostRequest struct {
Host *Address `json:"host"`
}
// WebSocketHostResponse is the request to offer rebinds for a given host
type WebSocketHostResponse struct {
RequestID uuid.UUID `json:"requestId"`
Offers []RebindOffer `json:"offers"`
}
// WebSocketMessageHandler handles parsed messages from the socket, returns a list of waitgroups to decrement when the socket closes and any errors that occured
func (m *RebindManager) WebSocketMessageHandler(ctx context.Context, conn *websocket.Conn, wReq WebSocketRequest, rawMsg []byte) error {
log.Infof(`Socket "%s" got msg "%s" for "%s" action`, socketID(ctx), requestID(ctx), wReq.Action)
switch wReq.Action {
case "host":
// Parse the message
var msg WebSocketHostRequest
if err := json.Unmarshal(rawMsg, &msg); err != nil {
return err
}
log.Debug(msg)
// Make rebind offers based on the information provided by the host
offers := m.MakeOffer(ctx, msg)
// Marshal into a response and write it back
resp := &WebSocketHostResponse{
RequestID: wReq.RequestID,
Offers: offers,
}
rResp, err := json.Marshal(resp)
if err != nil {
return err
}
err = conn.WriteMessage(websocket.TextMessage, rResp)
if err != nil {
return err
}
log.Infof(`Wrote (%d) offers (%s) to socket "%s" in response to msg "%s"`, len(offers), offers, socketID(ctx), requestID(ctx))
}
return nil
}
// WebSocketHandler handles upgrading clients to a websocket connection and passing messages to the HandleMessage method
func (m *RebindManager) WebSocketHandler(w http.ResponseWriter, req *http.Request) {
// Upgrade to a websocket
conn, err := upgrader.Upgrade(w, req, nil)
if err != nil {
log.Error(err)
http.Error(w, err.Error(), 400)
return
}
// Each socket has an ID
id := uuid.NewV4()
log.Infof(`New socket connection "%s"`, id)
// Create a cancel-able child context
ctx, triggerClose := context.WithCancel(req.Context())
ctx = context.WithValue(ctx, socketIDKey, id)
// When the socket closes
defer func() {
log.Infof(`Socket "%s" has closed, cleaning up`, id)
triggerClose() // Cancel the context
}()
// Loop reading messages in a queue
for {
// Read the init message
_, rawMsg, err := conn.ReadMessage()
if err != nil {
// Ignore 1001 (going away) "errors" as they are not errors really
if !websocket.IsCloseError(err, 1001) {
log.Error(err)
http.Error(w, err.Error(), 400)
}
return
}
// Read the request
var wReq WebSocketRequest
if err := json.Unmarshal(rawMsg, &wReq); err != nil {
log.Error(err)
http.Error(w, err.Error(), 400)
return
}
// Add the provided requestID to the context
ctx := context.WithValue(ctx, requestIDKey, wReq.RequestID)
// Handle it
if err := m.WebSocketMessageHandler(ctx, conn, wReq, rawMsg); err != nil {
log.Error(err)
http.Error(w, err.Error(), 400)
return
}
}
}