forked from Azure/go-ntlmssp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
negotiator.go
144 lines (126 loc) · 3.76 KB
/
negotiator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package ntlmssp
import (
"bytes"
"encoding/base64"
"io"
"io/ioutil"
"net/http"
"strings"
)
// GetDomain : parse domain name from based on slashes in the input
func GetDomain(user string) (string, string) {
domain := ""
if strings.Contains(user, "\\") {
ucomponents := strings.SplitN(user, "\\", 2)
domain = ucomponents[0]
user = ucomponents[1]
}
return user, domain
}
//Negotiator is a http.Roundtripper decorator that automatically
//converts basic authentication to NTLM/Negotiate authentication when appropriate.
type Negotiator struct{ http.RoundTripper }
//RoundTrip sends the request to the server, handling any authentication
//re-sends as needed.
func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) {
// Use default round tripper if not provided
rt := l.RoundTripper
if rt == nil {
rt = http.DefaultTransport
}
// If it is not basic auth, just round trip the request as usual
reqauth := authheader(req.Header.Values("Authorization"))
if !reqauth.IsBasic() {
return rt.RoundTrip(req)
}
reqauthBasic := reqauth.Basic()
// Save request body
body := bytes.Buffer{}
if req.Body != nil {
_, err = body.ReadFrom(req.Body)
if err != nil {
return nil, err
}
req.Body.Close()
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
}
// first try anonymous, in case the server still finds us
// authenticated from previous traffic
req.Header.Del("Authorization")
res, err = rt.RoundTrip(req)
if err != nil {
return nil, err
}
if res.StatusCode != http.StatusUnauthorized {
return res, err
}
resauth := authheader(res.Header.Values("Www-Authenticate"))
if !resauth.IsNegotiate() && !resauth.IsNTLM() {
// Unauthorized, Negotiate not requested, let's try with basic auth
req.Header.Set("Authorization", string(reqauthBasic))
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
res, err = rt.RoundTrip(req)
if err != nil {
return nil, err
}
if res.StatusCode != http.StatusUnauthorized {
return res, err
}
resauth = authheader(res.Header.Values("Www-Authenticate"))
}
if resauth.IsNegotiate() || resauth.IsNTLM() {
// 401 with request:Basic and response:Negotiate
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
// recycle credentials
u, p, err := reqauth.GetBasicCreds()
if err != nil {
return nil, err
}
// get domain from username
domain := ""
u, domain = GetDomain(u)
// send negotiate
negotiateMessage, err := NewNegotiateMessage(domain, "")
if err != nil {
return nil, err
}
if resauth.IsNTLM() {
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(negotiateMessage))
} else {
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(negotiateMessage))
}
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
res, err = rt.RoundTrip(req)
if err != nil {
return nil, err
}
// receive challenge?
resauth = authheader(res.Header.Values("Www-Authenticate"))
challengeMessage, err := resauth.GetData()
if err != nil {
return nil, err
}
if !(resauth.IsNegotiate() || resauth.IsNTLM()) || len(challengeMessage) == 0 {
// Negotiation failed, let client deal with response
return res, nil
}
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
// send authenticate
authenticateMessage, err := ProcessChallenge(challengeMessage, u, p)
if err != nil {
return nil, err
}
if resauth.IsNTLM() {
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(authenticateMessage))
} else {
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(authenticateMessage))
}
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
return rt.RoundTrip(req)
}
return res, err
}