Skip to content

Commit

Permalink
Use buffer pool
Browse files Browse the repository at this point in the history
  • Loading branch information
ioppermann committed Oct 9, 2024
1 parent f97943b commit 05e4118
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 66 deletions.
91 changes: 45 additions & 46 deletions http/middleware/session/HLS.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,17 @@ func (h *handler) handleHLSIngress(c echo.Context, _ string, data map[string]int
// Read out the path of the .ts files and look them up in the ts-map.
// Add it as ingress for the respective "sessionId". The "sessionId" is the .m3u8 file name.
reader := req.Body
r := &bodyReader{
r := &segmentReader{
reader: req.Body,
buffer: h.bufferPool.Get(),
}
req.Body = r

defer func() {
req.Body = reader

if r.size == 0 {
h.bufferPool.Put(r.buffer)
return
}

Expand All @@ -58,8 +60,10 @@ func (h *handler) handleHLSIngress(c echo.Context, _ string, data map[string]int
h.hlsIngressCollector.Extra(path, data)
}

h.hlsIngressCollector.Ingress(path, headerSize(req.Header))
buffer := h.bufferPool.Get()
h.hlsIngressCollector.Ingress(path, headerSize(req.Header, buffer))
h.hlsIngressCollector.Ingress(path, r.size)
h.bufferPool.Put(buffer)

segments := r.getSegments(urlpath.Dir(path))

Expand All @@ -74,6 +78,8 @@ func (h *handler) handleHLSIngress(c echo.Context, _ string, data map[string]int
}
h.lock.Unlock()
}

h.bufferPool.Put(r.buffer)
}()
} else if strings.HasSuffix(path, ".ts") {
// Get the size of the .ts file and store it in the ts-map for later use.
Expand All @@ -87,9 +93,11 @@ func (h *handler) handleHLSIngress(c echo.Context, _ string, data map[string]int
req.Body = reader

if r.size != 0 {
buffer := h.bufferPool.Get()
h.lock.Lock()
h.rxsegments[path] = r.size + headerSize(req.Header)
h.rxsegments[path] = r.size + headerSize(req.Header, buffer)
h.lock.Unlock()
h.bufferPool.Put(buffer)
}
}()
}
Expand Down Expand Up @@ -171,6 +179,7 @@ func (h *handler) handleHLSEgress(c echo.Context, _ string, data map[string]inte
// the data that we need to rewrite.
rewriter = &sessionRewriter{
ResponseWriter: res.Writer,
buffer: h.bufferPool.Get(),
}

res.Writer = rewriter
Expand All @@ -188,21 +197,29 @@ func (h *handler) handleHLSEgress(c echo.Context, _ string, data map[string]inte
if rewrite {
if res.Status < 200 || res.Status >= 300 {
res.Write(rewriter.buffer.Bytes())
h.bufferPool.Put(rewriter.buffer)
return nil
}

buffer := h.bufferPool.Get()

// Rewrite the data befor sending it to the client
rewriter.rewriteHLS(sessionID, c.Request().URL)
rewriter.rewriteHLS(sessionID, c.Request().URL, buffer)

res.Header().Set("Cache-Control", "private")
res.Write(rewriter.buffer.Bytes())
res.Write(buffer.Bytes())

h.bufferPool.Put(buffer)
h.bufferPool.Put(rewriter.buffer)
}

if isM3U8 || isTS {
if res.Status >= 200 && res.Status < 300 {
// Collect how many bytes we've written in this session
h.hlsEgressCollector.Egress(sessionID, headerSize(res.Header()))
buffer := h.bufferPool.Get()
h.hlsEgressCollector.Egress(sessionID, headerSize(res.Header(), buffer))
h.hlsEgressCollector.Egress(sessionID, res.Size)
h.bufferPool.Put(buffer)

if isTS {
// Activate the session. If the session is already active, this is a noop
Expand All @@ -214,13 +231,13 @@ func (h *handler) handleHLSEgress(c echo.Context, _ string, data map[string]inte
return nil
}

type bodyReader struct {
type segmentReader struct {
reader io.ReadCloser
buffer bytes.Buffer
buffer *bytes.Buffer
size int64
}

func (r *bodyReader) Read(b []byte) (int, error) {
func (r *segmentReader) Read(b []byte) (int, error) {
n, err := r.reader.Read(b)
if n > 0 {
r.buffer.Write(b[:n])
Expand All @@ -230,15 +247,15 @@ func (r *bodyReader) Read(b []byte) (int, error) {
return n, err
}

func (r *bodyReader) Close() error {
func (r *segmentReader) Close() error {
return r.reader.Close()
}

func (r *bodyReader) getSegments(dir string) []string {
func (r *segmentReader) getSegments(dir string) []string {
segments := []string{}

// Find all segment URLs in the .m3u8
scanner := bufio.NewScanner(&r.buffer)
scanner := bufio.NewScanner(r.buffer)
for scanner.Scan() {
line := scanner.Text()

Expand Down Expand Up @@ -280,65 +297,49 @@ func (r *bodyReader) getSegments(dir string) []string {
return segments
}

type bodysizeReader struct {
reader io.ReadCloser
size int64
}

func (r *bodysizeReader) Read(b []byte) (int, error) {
n, err := r.reader.Read(b)
r.size += int64(n)

return n, err
}

func (r *bodysizeReader) Close() error {
return r.reader.Close()
}

type sessionRewriter struct {
http.ResponseWriter
buffer bytes.Buffer
buffer *bytes.Buffer
}

func (g *sessionRewriter) Write(data []byte) (int, error) {
// Write the data into internal buffer for later rewrite
w, err := g.buffer.Write(data)

return w, err
return g.buffer.Write(data)
}

func (g *sessionRewriter) rewriteHLS(sessionID string, requestURL *url.URL) {
var buffer bytes.Buffer

func (g *sessionRewriter) rewriteHLS(sessionID string, requestURL *url.URL, buffer *bytes.Buffer) {
isMaster := false

// Find all URLS in the .m3u8 and add the session ID to the query string
scanner := bufio.NewScanner(&g.buffer)
scanner := bufio.NewScanner(g.buffer)
for scanner.Scan() {
line := scanner.Text()
byteline := scanner.Bytes()

// Write empty lines unmodified
if len(line) == 0 {
buffer.WriteString(line + "\n")
if len(byteline) == 0 {
buffer.Write(byteline)
buffer.WriteByte('\n')
continue
}

// Write comments unmodified
if strings.HasPrefix(line, "#") {
buffer.WriteString(line + "\n")
if byteline[0] == '#' {
buffer.Write(byteline)
buffer.WriteByte('\n')
continue
}

u, err := url.Parse(line)
u, err := url.Parse(string(byteline))
if err != nil {
buffer.WriteString(line + "\n")
buffer.Write(byteline)
buffer.WriteByte('\n')
continue
}

// Write anything that doesn't end in .m3u8 or .ts unmodified
if !strings.HasSuffix(u.Path, ".m3u8") && !strings.HasSuffix(u.Path, ".ts") {
buffer.WriteString(line + "\n")
buffer.Write(byteline)
buffer.WriteByte('\n')
continue
}

Expand Down Expand Up @@ -407,6 +408,4 @@ func (g *sessionRewriter) rewriteHLS(sessionID string, requestURL *url.URL) {

buffer.WriteString(urlpath.Base(requestURL.Path) + "?" + q.Encode())
}

g.buffer = buffer
}
112 changes: 112 additions & 0 deletions http/middleware/session/HLS_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package session

import (
"bytes"
"io"
"net/url"
"os"
"testing"

"github.com/datarhei/core/v16/mem"
"github.com/stretchr/testify/require"
)

func TestHLSSegmentReader(t *testing.T) {
data, err := os.ReadFile("./fixtures/segments.txt")
require.NoError(t, err)

r := bytes.NewReader(data)

br := &segmentReader{
reader: io.NopCloser(r),
buffer: &bytes.Buffer{},
}

_, err = io.ReadAll(br)
require.NoError(t, err)

segments := br.getSegments("/foobar")
require.Equal(t, []string{
"/foobar/test_0_0_0303.ts",
"/foobar/test_0_0_0304.ts",
"/foobar/test_0_0_0305.ts",
"/foobar/test_0_0_0306.ts",
"/foobar/test_0_0_0307.ts",
"/foobar/test_0_0_0308.ts",
"/foobar/test_0_0_0309.ts",
"/foobar/test_0_0_0310.ts",
}, segments)
}

func BenchmarkHLSSegmentReader(b *testing.B) {
pool := mem.NewBufferPool()

data, err := os.ReadFile("./fixtures/segments.txt")
require.NoError(b, err)

rd := bytes.NewReader(data)
r := io.NopCloser(rd)

for i := 0; i < b.N; i++ {
rd.Reset(data)
br := &segmentReader{
reader: io.NopCloser(r),
buffer: pool.Get(),
}

_, err := io.ReadAll(br)
require.NoError(b, err)

pool.Put(br.buffer)
}
}

func TestHLSRewrite(t *testing.T) {
data, err := os.ReadFile("./fixtures/segments.txt")
require.NoError(t, err)

br := &sessionRewriter{
buffer: &bytes.Buffer{},
}

_, err = br.Write(data)
require.NoError(t, err)

u, err := url.Parse("http://example.com/test.m3u8")
require.NoError(t, err)

buffer := &bytes.Buffer{}

br.rewriteHLS("oT5GV8eWBbRAh4aib5egoK", u, buffer)

data, err = os.ReadFile("./fixtures/segments_with_session.txt")
require.NoError(t, err)

require.Equal(t, data, buffer.Bytes())
}

func BenchmarkHLSRewrite(b *testing.B) {
pool := mem.NewBufferPool()

data, err := os.ReadFile("./fixtures/segments.txt")
require.NoError(b, err)

u, err := url.Parse("http://example.com/test.m3u8")
require.NoError(b, err)

for i := 0; i < b.N; i++ {
br := &sessionRewriter{
buffer: pool.Get(),
}

_, err = br.Write(data)
require.NoError(b, err)

buffer := pool.Get()

br.rewriteHLS("oT5GV8eWBbRAh4aib5egoK", u, buffer)

pool.Put(br.buffer)
pool.Put(buffer)
}
}
20 changes: 11 additions & 9 deletions http/middleware/session/HTTP.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/lithammer/shortuuid/v4"
)

func (h *handler) handleHTTP(c echo.Context, ctxuser string, data map[string]interface{}, next echo.HandlerFunc) error {
func (h *handler) handleHTTP(c echo.Context, _ string, data map[string]interface{}, next echo.HandlerFunc) error {
req := c.Request()
res := c.Response()

Expand All @@ -30,33 +30,35 @@ func (h *handler) handleHTTP(c echo.Context, ctxuser string, data map[string]int
id := shortuuid.New()

reader := req.Body
r := &fakeReader{
r := &bodysizeReader{
reader: req.Body,
}
req.Body = r

writer := res.Writer
w := &fakeWriter{
w := &bodysizeWriter{
ResponseWriter: res.Writer,
}
res.Writer = w

h.httpCollector.RegisterAndActivate(id, "", location, referrer)
h.httpCollector.Extra(id, data)

defer h.httpCollector.Close(id)

defer func() {
buffer := h.bufferPool.Get()

req.Body = reader
h.httpCollector.Ingress(id, r.size+headerSize(req.Header))
}()
h.httpCollector.Ingress(id, r.size+headerSize(req.Header, buffer))

defer func() {
res.Writer = writer

h.httpCollector.Egress(id, w.size+headerSize(res.Header()))
h.httpCollector.Egress(id, w.size+headerSize(res.Header(), buffer))
data["code"] = res.Status
h.httpCollector.Extra(id, data)

h.httpCollector.Close(id)

h.bufferPool.Put(buffer)
}()

return next(c)
Expand Down
Loading

0 comments on commit 05e4118

Please sign in to comment.