-
Notifications
You must be signed in to change notification settings - Fork 3
/
session.go
211 lines (187 loc) · 5.94 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
// Copyright 2021 Flamego. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package session
import (
"context"
"net/http"
"reflect"
"time"
"github.com/pkg/errors"
"github.com/flamego/flamego"
)
// Session is a session for the current request.
type Session interface {
// ID returns the session ID.
ID() string
// RegenerateID regenerates the session ID.
RegenerateID(w http.ResponseWriter, r *http.Request) error
// Get returns the value of given key in the session. It returns nil if no such
// key exists.
Get(key interface{}) interface{}
// Set sets the value of given key in the session.
Set(key, val interface{})
// SetFlash sets the flash to be the given value in the session.
SetFlash(val interface{})
// Delete deletes a key from the session.
Delete(key interface{})
// Flush wipes out all existing data in the session.
Flush()
// Encode encodes session data to binary.
Encode() ([]byte, error)
// HasChanged returns whether the session has changed.
HasChanged() bool
}
// CookieOptions contains options for setting HTTP cookies.
type CookieOptions struct {
// Name is the name of the cookie. Default is "flamego_session".
Name string
// Path is the Path attribute of the cookie. Default is "/".
Path string
// Domain is the Domain attribute of the cookie. Default is not set.
Domain string
// MaxAge is the MaxAge attribute of the cookie. Default is not set.
MaxAge int
// Secure specifies whether to set Secure for the cookie.
Secure bool
// HTTPOnly specifies whether to set HTTPOnly for the cookie.
HTTPOnly bool
// SameSite is the SameSite attribute of the cookie. Default is
// http.SameSiteLaxMode.
SameSite http.SameSite
}
// Options contains options for the session.Sessioner middleware.
type Options struct {
// Initer is the initialization function of the session store. Default is
// session.MemoryIniter.
Initer Initer
// Config is the configuration object to be passed to the Initer for the session
// store.
Config interface{}
// Cookie is a set of options for setting HTTP cookies.
Cookie CookieOptions
// IDLength specifies the length of session IDs. Default is 16.
IDLength int
// GCInterval is the time interval for GC operations. Default is 5 minutes.
GCInterval time.Duration
// ErrorFunc is the function used to print errors when something went wrong on
// the background. Default is to drop errors silently.
ErrorFunc func(err error)
// ReadIDFunc is the function to read session ID from the request. Default is
// reading from cookie.
ReadIDFunc func(r *http.Request) string
// WriteIDFunc is the function to write session ID to the response. Default is
// writing to cookie. The `created` argument indicates whether a new session was
// created in the session store.
WriteIDFunc func(w http.ResponseWriter, r *http.Request, sid string, created bool)
}
const minimumSIDLength = 3
var ErrMinimumSIDLength = errors.Errorf("the SID does not have the minimum required length %d", minimumSIDLength)
// Sessioner returns a middleware handler that injects session.Session and
// session.Store into the request context, which are used for manipulating
// session data.
func Sessioner(opts ...Options) flamego.Handler {
var opt Options
if len(opts) > 0 {
opt = opts[0]
}
parseOptions := func(opts Options) Options {
if opts.Initer == nil {
opts.Initer = MemoryIniter()
}
if reflect.DeepEqual(opts.Cookie, CookieOptions{}) {
opts.Cookie = CookieOptions{
HTTPOnly: true,
}
}
if opts.Cookie.Name == "" {
opts.Cookie.Name = "flamego_session"
}
if opts.Cookie.SameSite < http.SameSiteDefaultMode || opts.Cookie.SameSite > http.SameSiteNoneMode {
opts.Cookie.SameSite = http.SameSiteLaxMode
}
if opts.Cookie.Path == "" {
opts.Cookie.Path = "/"
}
// NOTE: The file store requires at least 3 characters for the filename.
if opts.IDLength < minimumSIDLength {
opts.IDLength = 16
}
if opts.GCInterval.Seconds() < 1 {
opts.GCInterval = 5 * time.Minute
}
if opts.ErrorFunc == nil {
opts.ErrorFunc = func(error) {}
}
if opts.ReadIDFunc == nil {
opts.ReadIDFunc = func(r *http.Request) string {
cookie, err := r.Cookie(opts.Cookie.Name)
if err != nil {
return ""
}
return cookie.Value
}
}
if opts.WriteIDFunc == nil {
opts.WriteIDFunc = func(w http.ResponseWriter, r *http.Request, sid string, created bool) {
if !created {
return
}
cookie := &http.Cookie{
Name: opts.Cookie.Name,
Value: sid,
Path: opts.Cookie.Path,
Domain: opts.Cookie.Domain,
MaxAge: opts.Cookie.MaxAge,
Secure: opts.Cookie.Secure,
HttpOnly: opts.Cookie.HTTPOnly,
SameSite: opts.Cookie.SameSite,
}
http.SetCookie(w, cookie)
r.AddCookie(cookie)
}
}
return opts
}
opt = parseOptions(opt)
ctx := context.Background()
store, err := opt.Initer(
ctx,
opt.Config,
IDWriter(func(w http.ResponseWriter, r *http.Request, sid string) {
opt.WriteIDFunc(w, r, sid, true)
}),
)
if err != nil {
panic("session: " + err.Error())
}
mgr := newManager(store)
mgr.startGC(ctx, opt.GCInterval, opt.ErrorFunc)
return flamego.ContextInvoker(func(c flamego.Context) {
sid := opt.ReadIDFunc(c.Request().Request)
sess, created, err := mgr.load(c.Request().Request, sid, opt.IDLength)
if err != nil {
if errors.Is(err, context.Canceled) {
c.ResponseWriter().WriteHeader(http.StatusUnprocessableEntity)
return
}
panic("session: load: " + err.Error())
}
opt.WriteIDFunc(c.ResponseWriter(), c.Request().Request, sess.ID(), created)
flash := sess.Get(flashKey)
if flash != nil {
sess.Delete(flashKey)
}
c.Map(store, sess)
c.MapTo(flash, (*Flash)(nil))
c.Next()
if sess.HasChanged() {
err = store.Save(c.Request().Context(), sess)
} else {
err = store.Touch(c.Request().Context(), sess.ID())
}
if err != nil && !errors.Is(err, context.Canceled) {
panic("session: save: " + err.Error())
}
})
}