Skip to content

Commit

Permalink
feat: add WebSocket routing
Browse files Browse the repository at this point in the history
  • Loading branch information
canstand committed Oct 25, 2024
1 parent 3e8cb5f commit 59ce07d
Show file tree
Hide file tree
Showing 17 changed files with 1,014 additions and 177 deletions.
144 changes: 92 additions & 52 deletions browser_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type browserContextImpl struct {
options *BrowserNewContextOptions
pages []Page
routes []*routeHandlerEntry
webSocketRoutes []*webSocketRouteHandler
ownedPage Page
browser *browserImpl
serviceWorkers []Worker
Expand All @@ -44,7 +45,7 @@ func (b *browserContextImpl) SetDefaultNavigationTimeout(timeout float64) {

func (b *browserContextImpl) setDefaultNavigationTimeoutImpl(timeout *float64) {
b.timeoutSettings.SetDefaultNavigationTimeout(timeout)
b.channel.SendNoReply("setDefaultNavigationTimeoutNoReply", true, map[string]interface{}{
b.channel.SendNoReplyInternal("setDefaultNavigationTimeoutNoReply", map[string]interface{}{
"timeout": timeout,
})
}
Expand All @@ -55,7 +56,7 @@ func (b *browserContextImpl) SetDefaultTimeout(timeout float64) {

func (b *browserContextImpl) setDefaultTimeoutImpl(timeout *float64) {
b.timeoutSettings.SetDefaultTimeout(timeout)
b.channel.SendNoReply("setDefaultTimeoutNoReply", true, map[string]interface{}{
b.channel.SendNoReplyInternal("setDefaultTimeoutNoReply", map[string]interface{}{
"timeout": timeout,
})
}
Expand Down Expand Up @@ -541,7 +542,7 @@ func (b *browserContextImpl) onBinding(binding *bindingCallImpl) {
if !ok || function == nil {
return
}
binding.Call(function)
go binding.Call(function)
}

func (b *browserContextImpl) onClose() {
Expand Down Expand Up @@ -572,58 +573,56 @@ func (b *browserContextImpl) onPage(page Page) {
}

func (b *browserContextImpl) onRoute(route *routeImpl) {
go func() {
b.Lock()
route.context = b
page := route.Request().(*requestImpl).safePage()
routes := make([]*routeHandlerEntry, len(b.routes))
copy(routes, b.routes)
b.Unlock()
b.Lock()
route.context = b
page := route.Request().(*requestImpl).safePage()
routes := make([]*routeHandlerEntry, len(b.routes))
copy(routes, b.routes)
b.Unlock()

checkInterceptionIfNeeded := func() {
b.Lock()
defer b.Unlock()
if len(b.routes) == 0 {
_, err := b.connection.WrapAPICall(func() (interface{}, error) {
err := b.updateInterceptionPatterns()
return nil, err
}, true)
if err != nil {
logger.Printf("could not update interception patterns: %v\n", err)
}
checkInterceptionIfNeeded := func() {
b.Lock()
defer b.Unlock()
if len(b.routes) == 0 {
_, err := b.connection.WrapAPICall(func() (interface{}, error) {
err := b.updateInterceptionPatterns()
return nil, err
}, true)
if err != nil {
logger.Printf("could not update interception patterns: %v\n", err)
}
}
}

url := route.Request().URL()
for _, handlerEntry := range routes {
// If the page or the context was closed we stall all requests right away.
if (page != nil && page.closeWasCalled) || b.closeWasCalled {
return
}
if !handlerEntry.Matches(url) {
continue
}
if !slices.ContainsFunc(b.routes, func(entry *routeHandlerEntry) bool {
return entry == handlerEntry
}) {
continue
}
if handlerEntry.WillExceed() {
b.routes = slices.DeleteFunc(b.routes, func(rhe *routeHandlerEntry) bool {
return rhe == handlerEntry
})
}
handled := handlerEntry.Handle(route)
checkInterceptionIfNeeded()
yes := <-handled
if yes {
return
}
url := route.Request().URL()
for _, handlerEntry := range routes {
// If the page or the context was closed we stall all requests right away.
if (page != nil && page.closeWasCalled) || b.closeWasCalled {
return
}
// If the page is closed or unrouteAll() was called without waiting and interception disabled,
// the method will throw an error - silence it.
_ = route.internalContinue(true)
}()
if !handlerEntry.Matches(url) {
continue
}
if !slices.ContainsFunc(b.routes, func(entry *routeHandlerEntry) bool {
return entry == handlerEntry
}) {
continue
}
if handlerEntry.WillExceed() {
b.routes = slices.DeleteFunc(b.routes, func(rhe *routeHandlerEntry) bool {
return rhe == handlerEntry
})
}
handled := handlerEntry.Handle(route)
checkInterceptionIfNeeded()
yes := <-handled
if yes {
return
}
}
// If the page is closed or unrouteAll() was called without waiting and interception disabled,
// the method will throw an error - silence it.
_ = route.internalContinue(true)
}

func (b *browserContextImpl) updateInterceptionPatterns() error {
Expand Down Expand Up @@ -726,6 +725,40 @@ func (b *browserContextImpl) OnWebError(fn func(WebError)) {
b.On("weberror", fn)
}

func (b *browserContextImpl) RouteWebSocket(url interface{}, handler func(WebSocketRoute)) error {
b.Lock()
defer b.Unlock()
b.webSocketRoutes = slices.Insert(b.webSocketRoutes, 0, newWebSocketRouteHandler(newURLMatcher(url, b.options.BaseURL), handler))

return b.updateWebSocketInterceptionPatterns()
}

func (b *browserContextImpl) onWebSocketRoute(wr WebSocketRoute) {
b.Lock()
index := slices.IndexFunc(b.webSocketRoutes, func(r *webSocketRouteHandler) bool {
return r.Matches(wr.URL())
})
if index == -1 {
b.Unlock()
_, err := wr.ConnectToServer()
if err != nil {
logger.Println(err)
}
return
}
handler := b.webSocketRoutes[index]
b.Unlock()
handler.Handle(wr)
}

func (b *browserContextImpl) updateWebSocketInterceptionPatterns() error {
patterns := prepareWebSocketRouteHandlerInterceptionPatterns(b.webSocketRoutes)
_, err := b.channel.Send("setWebSocketInterceptionPatterns", map[string]interface{}{
"patterns": patterns,
})
return err
}

func (b *browserContextImpl) effectiveCloseReason() *string {
b.Lock()
defer b.Unlock()
Expand Down Expand Up @@ -758,15 +791,22 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini
bt.request = fromChannel(initializer["requestContext"]).(*apiRequestContextImpl)
bt.clock = newClock(bt)
bt.channel.On("bindingCall", func(params map[string]interface{}) {
go bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl))
bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl))
})

bt.channel.On("close", bt.onClose)
bt.channel.On("page", func(payload map[string]interface{}) {
bt.onPage(fromChannel(payload["page"]).(*pageImpl))
})
bt.channel.On("route", func(params map[string]interface{}) {
bt.onRoute(fromChannel(params["route"]).(*routeImpl))
bt.channel.CreateTask(func() {
bt.onRoute(fromChannel(params["route"]).(*routeImpl))
})
})
bt.channel.On("webSocketRoute", func(params map[string]interface{}) {
bt.channel.CreateTask(func() {
bt.onWebSocketRoute(fromChannel(params["webSocketRoute"]).(*webSocketRouteImpl))
})
})
bt.channel.On("backgroundPage", bt.onBackgroundPage)
bt.channel.On("serviceWorker", func(params map[string]interface{}) {
Expand Down
41 changes: 39 additions & 2 deletions channel.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package playwright

import "encoding/json"
import (
"encoding/json"
"fmt"
)

type channel struct {
eventEmitter
Expand All @@ -16,6 +19,23 @@ func (c *channel) MarshalJSON() ([]byte, error) {
})
}

// for catch errors of route handlers etc.
func (c *channel) CreateTask(fn func()) {
go func() {
defer func() {
if e := recover(); e != nil {
err, ok := e.(error)
if ok {
c.connection.err.Set(err)
} else {
c.connection.err.Set(fmt.Errorf("%v", e))
}
}
}()
fn()
}()
}

func (c *channel) Send(method string, options ...interface{}) (interface{}, error) {
return c.connection.WrapAPICall(func() (interface{}, error) {
return c.innerSend(method, options...).GetResultValue()
Expand All @@ -30,16 +50,33 @@ func (c *channel) SendReturnAsDict(method string, options ...interface{}) (map[s
}

func (c *channel) innerSend(method string, options ...interface{}) *protocolCallback {
if err := c.connection.err.Get(); err != nil {
c.connection.err.Set(nil)
pc := newProtocolCallback(false, c.connection.abort)
pc.SetError(err)
return pc
}
params := transformOptions(options...)
return c.connection.sendMessageToServer(c.owner, method, params, false)
}

func (c *channel) SendNoReply(method string, isInternal bool, options ...interface{}) {
// SendNoReply ignores return value and errors
// almost equivalent to `send(...).catch(() => {})`
func (c *channel) SendNoReply(method string, options ...interface{}) {
c.innerSendNoReply(method, c.owner.isInternalType, options...)
}

func (c *channel) SendNoReplyInternal(method string, options ...interface{}) {
c.innerSendNoReply(method, true, options...)
}

func (c *channel) innerSendNoReply(method string, isInternal bool, options ...interface{}) {
params := transformOptions(options...)
_, err := c.connection.WrapAPICall(func() (interface{}, error) {
return c.connection.sendMessageToServer(c.owner, method, params, true).GetResult()
}, isInternal)
if err != nil {
// ignore error actively, log only for debug
logger.Printf("SendNoReply failed: %v\n", err)
}
}
Expand Down
2 changes: 1 addition & 1 deletion channel_owner.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (c *channelOwner) setEventSubscriptionMapping(mapping map[string]string) {
func (c *channelOwner) updateSubscription(event string, enabled bool) {
protocolEvent, ok := c.eventToSubscriptionMapping[event]
if ok {
c.channel.SendNoReply("updateSubscription", true, map[string]interface{}{
c.channel.SendNoReplyInternal("updateSubscription", map[string]interface{}{
"event": protocolEvent,
"enabled": enabled,
})
Expand Down
4 changes: 3 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type connection struct {
tracingCount atomic.Int32
abort chan struct{}
abortOnce sync.Once
err *safeValue[error] // for event listener error
closedError *safeValue[error]
}

Expand Down Expand Up @@ -301,6 +302,7 @@ func newConnection(transport transport, localUtils ...*localUtilsImpl) *connecti
objects: safe.NewSyncMap[string, *channelOwner](),
transport: transport,
isRemote: false,
err: &safeValue[error]{},
closedError: &safeValue[error]{},
}
if len(localUtils) > 0 {
Expand Down Expand Up @@ -393,7 +395,7 @@ func newProtocolCallback(noReply bool, abort <-chan struct{}) *protocolCallback
}
}
return &protocolCallback{
done: make(chan struct{}),
done: make(chan struct{}, 1),
abort: abort,
}
}
Loading

0 comments on commit 59ce07d

Please sign in to comment.