From 08862253c9019dd8ef168cbea0c1356b15ee4822 Mon Sep 17 00:00:00 2001 From: Can Stand Date: Wed, 13 Dec 2023 16:22:33 +0800 Subject: [PATCH] refactor(transport): handle transport error, no more log.Fatal --- browser.go | 4 +- browser_context.go | 17 +++-- browser_type.go | 32 ++++----- connection.go | 50 +++++++++----- element_handle.go | 2 +- fetch.go | 2 +- go.mod | 1 + go.sum | 2 + helpers.go | 17 +++++ jsonPipe.go | 40 ++++++++--- page.go | 4 +- run.go | 38 +---------- tests/helper_test.go | 44 +++++++++++++ tests/playwright_test.go | 27 ++++++++ transport.go | 139 +++++++++++++++++++++++++++------------ 15 files changed, 289 insertions(+), 130 deletions(-) create mode 100644 tests/playwright_test.go diff --git a/browser.go b/browser.go index 55e86ee7..ed205792 100644 --- a/browser.go +++ b/browser.go @@ -79,7 +79,7 @@ func (b *browserImpl) NewContext(options ...BrowserNewContextOptions) (BrowserCo } channel, err := b.channel.Send("newContext", options, overrides) if err != nil { - return nil, fmt.Errorf("could not send message: %w", err) + return nil, err } context := fromChannel(channel).(*browserContextImpl) context.browser = b @@ -108,7 +108,7 @@ func (b *browserImpl) NewPage(options ...BrowserNewPageOptions) (Page, error) { func (b *browserImpl) NewBrowserCDPSession() (CDPSession, error) { channel, err := b.channel.Send("newBrowserCDPSession") if err != nil { - return nil, fmt.Errorf("could not send message: %w", err) + return nil, err } cdpSession := fromChannel(channel).(*cdpSessionImpl) diff --git a/browser_context.go b/browser_context.go index b44db2c0..9a13cd73 100644 --- a/browser_context.go +++ b/browser_context.go @@ -7,6 +7,7 @@ import ( "log" "os" "strings" + "time" ) type browserContextImpl struct { @@ -79,7 +80,7 @@ func (b *browserContextImpl) NewCDPSession(page interface{}) (CDPSession, error) channel, err := b.channel.Send("newCDPSession", params) if err != nil { - return nil, fmt.Errorf("could not send message: %w", err) + return nil, err } cdpSession := fromChannel(channel).(*cdpSessionImpl) @@ -93,7 +94,7 @@ func (b *browserContextImpl) NewPage() (Page, error) { } channel, err := b.channel.Send("newPage") if err != nil { - return nil, fmt.Errorf("could not send message: %w", err) + return nil, err } return fromChannel(channel).(*pageImpl), nil } @@ -103,7 +104,7 @@ func (b *browserContextImpl) Cookies(urls ...string) ([]Cookie, error) { "urls": urls, }) if err != nil { - return nil, fmt.Errorf("could not send message: %w", err) + return nil, err } cookies := make([]Cookie, len(result.([]interface{}))) for i, item := range result.([]interface{}) { @@ -369,7 +370,15 @@ func (b *browserContextImpl) Close(options ...BrowserContextCloseOptions) error _, err = b.channel.Send("close", map[string]interface{}{ "reason": b.closeReason, }) - <-b.closed + if err != nil { + return err + } + timeout := b.timeoutSettings.Timeout() + select { + case <-time.After(time.Duration(timeout) * time.Millisecond): + return ErrTimeout + case <-b.closed: + } return err } diff --git a/browser_type.go b/browser_type.go index 3455010f..4dcdebae 100644 --- a/browser_type.go +++ b/browser_type.go @@ -25,7 +25,7 @@ func (b *browserTypeImpl) Launch(options ...BrowserTypeLaunchOptions) (Browser, } channel, err := b.channel.Send("launch", options, overrides) if err != nil { - return nil, fmt.Errorf("could not send message: %w", err) + return nil, err } browser := fromChannel(channel).(*browserImpl) b.didLaunchBrowser(browser) @@ -81,7 +81,7 @@ func (b *browserTypeImpl) LaunchPersistentContext(userDataDir string, options .. } channel, err := b.channel.Send("launchPersistentContext", options, overrides) if err != nil { - return nil, fmt.Errorf("could not send message: %w", err) + return nil, err } context := fromChannel(channel).(*browserContextImpl) b.didCreateContext(context, option, tracesDir) @@ -97,9 +97,15 @@ func (b *browserTypeImpl) Connect(wsEndpoint string, options ...BrowserTypeConne return nil, err } jsonPipe := fromChannel(pipe.(map[string]interface{})["pipe"]).(*jsonPipe) - connection := newConnection(jsonPipe.Close, localUtils) - connection.isRemote = true - var browser *browserImpl + connection := newConnection(jsonPipe, localUtils) + + playwright, err := connection.Start() + if err != nil { + return nil, err + } + playwright.setSelectors(b.playwright.Selectors) + browser := fromChannel(playwright.initializer["preLaunchedBrowser"]).(*browserImpl) + browser.shouldCloseConnectionOnClose = true pipeClosed := func() { for _, context := range browser.Contexts() { pages := context.Pages() @@ -112,21 +118,7 @@ func (b *browserTypeImpl) Connect(wsEndpoint string, options ...BrowserTypeConne connection.cleanup() } jsonPipe.On("closed", pipeClosed) - connection.onmessage = func(message map[string]interface{}) error { - if err := jsonPipe.Send(message); err != nil { - pipeClosed() - return err - } - return nil - } - jsonPipe.On("message", connection.Dispatch) - playwright, err := connection.Start() - if err != nil { - return nil, err - } - playwright.setSelectors(b.playwright.Selectors) - browser = fromChannel(playwright.initializer["preLaunchedBrowser"]).(*browserImpl) - browser.shouldCloseConnectionOnClose = true + b.didLaunchBrowser(browser) return browser, nil } diff --git a/connection.go b/connection.go index 97b8b861..2eae3970 100644 --- a/connection.go +++ b/connection.go @@ -25,6 +25,7 @@ type result struct { } type connection struct { + transport transport apiZone sync.Map objects map[string]*channelOwner lastID int @@ -33,21 +34,38 @@ type connection struct { callbacks sync.Map afterClose func() onClose func() error - onmessage func(map[string]interface{}) error isRemote bool localUtils *localUtilsImpl tracingCount atomic.Int32 abort chan struct{} - closedError atomic.Value + closedError *safeValue[error] } func (c *connection) Start() (*Playwright, error) { + go func() { + for { + msg, err := c.transport.Poll() + if err != nil { + _ = c.transport.Close() + c.cleanup(err) + return + } + c.Dispatch(msg) + } + }() + + c.onClose = func() error { + if err := c.transport.Close(); err != nil { + return err + } + return nil + } + return c.rootObject.initialize() } func (c *connection) Stop() error { - err := c.onClose() - if err != nil { + if err := c.onClose(); err != nil { return err } c.cleanup() @@ -56,9 +74,9 @@ func (c *connection) Stop() error { func (c *connection) cleanup(cause ...error) { if len(cause) > 0 { - c.closedError.Store(fmt.Errorf("%w: %w", ErrTargetClosed, cause[0])) + c.closedError.Set(fmt.Errorf("%w: %w", ErrTargetClosed, cause[0])) } else { - c.closedError.Store(ErrTargetClosed) + c.closedError.Set(ErrTargetClosed) } if c.afterClose != nil { c.afterClose() @@ -71,7 +89,7 @@ func (c *connection) cleanup(cause ...error) { } func (c *connection) Dispatch(msg *message) { - if c.closedError.Load() != nil { + if c.closedError.Get() != nil { return } method := msg.Method @@ -198,8 +216,8 @@ func (c *connection) replaceGuidsWithChannels(payload interface{}) interface{} { } func (c *connection) sendMessageToServer(object *channelOwner, method string, params interface{}, noReply bool) (*protocolCallback, error) { - if e := c.closedError.Load(); e != nil { - return nil, e.(error) + if err := c.closedError.Get(); err != nil { + return nil, err } if object.wasCollected { return nil, errors.New("The object has been collected to prevent unbounded heap growth.") @@ -233,7 +251,7 @@ func (c *connection) sendMessageToServer(object *channelOwner, method string, pa c.LocalUtils().AddStackToTracingNoReply(id, stack) } - if err := c.onmessage(message); err != nil { + if err := c.transport.Send(message); err != nil { return nil, fmt.Errorf("could not send message: %w", err) } @@ -307,15 +325,17 @@ func serializeCallLocation(caller stack.Call) map[string]interface{} { } } -func newConnection(onClose func() error, localUtils ...*localUtilsImpl) *connection { +func newConnection(transport transport, localUtils ...*localUtilsImpl) *connection { connection := &connection{ - abort: make(chan struct{}, 1), - objects: make(map[string]*channelOwner), - onClose: onClose, - isRemote: false, + abort: make(chan struct{}, 1), + objects: make(map[string]*channelOwner), + transport: transport, + isRemote: false, + closedError: &safeValue[error]{}, } if len(localUtils) > 0 { connection.localUtils = localUtils[0] + connection.isRemote = true } connection.rootObject = newRootChannelOwner(connection) return connection diff --git a/element_handle.go b/element_handle.go index cf214e0c..9c9969ca 100644 --- a/element_handle.go +++ b/element_handle.go @@ -261,7 +261,7 @@ func (e *elementHandleImpl) Screenshot(options ...ElementHandleScreenshotOptions } data, err := e.channel.Send("screenshot", options, overrides) if err != nil { - return nil, fmt.Errorf("could not send message :%w", err) + return nil, err } image, err := base64.StdEncoding.DecodeString(data.(string)) if err != nil { diff --git a/fetch.go b/fetch.go index 105e0192..59da4159 100644 --- a/fetch.go +++ b/fetch.go @@ -37,7 +37,7 @@ func (r *apiRequestImpl) NewContext(options ...APIRequestNewContextOptions) (API channel, err := r.channel.Send("newRequest", options, overrides) if err != nil { - return nil, fmt.Errorf("could not send message: %w", err) + return nil, err } return fromChannel(channel).(*apiRequestContextImpl), nil } diff --git a/go.mod b/go.mod index e7f3aaea..c5a8fdb2 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/go-stack/stack v1.8.1 github.com/gorilla/websocket v1.5.1 github.com/h2non/filetype v1.1.3 + github.com/mitchellh/go-ps v1.0.0 github.com/stretchr/testify v1.8.4 github.com/tidwall/gjson v1.17.0 go.uber.org/multierr v1.11.0 diff --git a/go.sum b/go.sum index 975805fb..e374082d 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc= +github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/helpers.go b/helpers.go index d2e46158..2c7fbcae 100644 --- a/helpers.go +++ b/helpers.go @@ -578,3 +578,20 @@ func prepareRecordHarOptions(option recordHarInputOptions) recordHarOptions { } return out } + +type safeValue[T any] struct { + sync.Mutex + v T +} + +func (s *safeValue[T]) Set(v T) { + s.Lock() + defer s.Unlock() + s.v = v +} + +func (s *safeValue[T]) Get() T { + s.Lock() + defer s.Unlock() + return s.v +} diff --git a/jsonPipe.go b/jsonPipe.go index 486d3acd..378bbb81 100644 --- a/jsonPipe.go +++ b/jsonPipe.go @@ -2,11 +2,13 @@ package playwright import ( "encoding/json" - "log" + "errors" + "fmt" ) type jsonPipe struct { channelOwner + msgChan chan *message } func (j *jsonPipe) Send(message map[string]interface{}) error { @@ -19,23 +21,43 @@ func (j *jsonPipe) Close() error { _, err := j.channel.Send("close") return err } + +func (j *jsonPipe) Poll() (*message, error) { + msg := <-j.msgChan + if msg == nil { + return nil, errors.New("jsonPipe closed") + } + return msg, nil +} + func newJsonPipe(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *jsonPipe { - j := &jsonPipe{} + j := &jsonPipe{ + msgChan: make(chan *message, 2), + } j.createChannelOwner(j, parent, objectType, guid, initializer) j.channel.On("message", func(ev map[string]interface{}) { + var msg message m, err := json.Marshal(ev["message"]) - if err != nil { - log.Fatal(err) + if err == nil { + err = json.Unmarshal(m, &msg) } - var msg message - err = json.Unmarshal(m, &msg) if err != nil { - log.Fatal(err) + msg = message{ + Error: &struct { + Error Error "json:\"error\"" + }{ + Error: Error{ + Name: "Error", + Message: fmt.Sprintf("jsonPipe: could not decode message: %s", err.Error()), + }, + }, + } } - j.Emit("message", &msg) + j.msgChan <- &msg }) - j.channel.On("closed", func() { + j.channel.Once("closed", func() { j.Emit("closed") + close(j.msgChan) }) return j } diff --git a/page.go b/page.go index 92292229..244a260e 100644 --- a/page.go +++ b/page.go @@ -342,7 +342,7 @@ func (p *pageImpl) Screenshot(options ...PageScreenshotOptions) ([]byte, error) } data, err := p.channel.Send("screenshot", options, overrides) if err != nil { - return nil, fmt.Errorf("could not send message :%w", err) + return nil, err } image, err := base64.StdEncoding.DecodeString(data.(string)) if err != nil { @@ -363,7 +363,7 @@ func (p *pageImpl) PDF(options ...PagePdfOptions) ([]byte, error) { } data, err := p.channel.Send("pdf", options) if err != nil { - return nil, fmt.Errorf("could not send message :%w", err) + return nil, err } pdf, err := base64.StdEncoding.DecodeString(data.(string)) if err != nil { diff --git a/run.go b/run.go index d2efe35b..9d98a33a 100644 --- a/run.go +++ b/run.go @@ -191,43 +191,11 @@ func (d *PlaywrightDriver) DownloadDriver() error { } func (d *PlaywrightDriver) run() (*connection, error) { - cmd := exec.Command(d.DriverBinaryLocation, "run-driver") - cmd.SysProcAttr = defaultSysProcAttr - cmd.Stderr = os.Stderr - stdin, err := cmd.StdinPipe() - if err != nil { - return nil, fmt.Errorf("could not get stdin pipe: %w", err) - } - stdout, err := cmd.StdoutPipe() + transport, err := newPipeTransport(d.DriverBinaryLocation) if err != nil { - return nil, fmt.Errorf("could not get stdout pipe: %w", err) - } - if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("could not start driver: %w", err) + return nil, err } - transport := newPipeTransport(stdin, stdout) - go func() { - if err := transport.Start(); err != nil { - log.Fatal(err) - } - }() - connection := newConnection(func() error { - if err := stdin.Close(); err != nil { - return fmt.Errorf("could not close stdin: %v", err) - } - if err := stdout.Close(); err != nil { - return fmt.Errorf("could not close stdout: %v", err) - } - if err := cmd.Process.Kill(); err != nil { - return fmt.Errorf("could not kill process: %v", err) - } - if _, err := cmd.Process.Wait(); err != nil { - return fmt.Errorf("could not wait for process: %v", err) - } - return nil - }) - connection.onmessage = transport.Send - transport.onmessage = connection.Dispatch + connection := newConnection(transport) return connection, nil } diff --git a/tests/helper_test.go b/tests/helper_test.go index 4e93b00c..3c31adc0 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -18,6 +18,7 @@ import ( "sync" "testing" + "github.com/mitchellh/go-ps" "github.com/playwright-community/playwright-go" "github.com/stretchr/testify/require" ) @@ -371,3 +372,46 @@ func getFileLastModifiedTimeMs(path string) (int64, error) { } return info.ModTime().UnixMilli(), nil } + +// find and kill playwright process, only work for Windows/macOs +func killPlaywrightProcess() error { + all, err := ps.Processes() + if err != nil { + return err + } + for _, process := range all { + if process.Executable() == "node" || process.Executable() == "node.exe" { + parent, err := ps.FindProcess(process.PPid()) + if err != nil { + return err + } + if parent.Executable() == "bash" || parent.Executable() == "sh" || parent.Executable() == "cmd.exe" { + grandpa, err := ps.FindProcess(parent.PPid()) + if err != nil { + return err + } + if strings.HasPrefix(grandpa.Executable(), "__debug_bin") || grandpa.Executable() == filepath.Base(os.Args[0]) { + if err := killProcessByPid(parent.Pid()); err != nil { + return err + } + if err := killProcessByPid(process.Pid()); err != nil { + return err + } + return nil + } + } + } + } + return fmt.Errorf("playwright process not found") +} + +func killProcessByPid(pid int) error { + process, err := os.FindProcess(pid) + if err != nil { + return err + } + if err := process.Kill(); err != nil { + return err + } + return nil +} diff --git a/tests/playwright_test.go b/tests/playwright_test.go new file mode 100644 index 00000000..904c8b4f --- /dev/null +++ b/tests/playwright_test.go @@ -0,0 +1,27 @@ +package playwright_test + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestShouldNotHangWhenPlaywrightUnexpectedExit(t *testing.T) { + BeforeEach(t) + defer AfterEach(t, false) + defer BeforeAll() // need restart playwright driver + if !isChromium { + t.Skip("browser agnostic testing") + return + } + if runtime.GOOS == "linux" { + t.Skip("ignore linux, hard to find the playwright process") + return + } + + err := killPlaywrightProcess() + require.NoError(t, err) + _, err = browser.NewContext() + require.Error(t, err) +} diff --git a/transport.go b/transport.go index 8d4f1113..50875189 100644 --- a/transport.go +++ b/transport.go @@ -8,42 +8,48 @@ import ( "io" "log" "os" - "sync" + "os/exec" "github.com/go-jose/go-jose/v3/json" ) +type transport interface { + Send(msg map[string]interface{}) error + Poll() (*message, error) + Close() error +} + type pipeTransport struct { - stdin io.WriteCloser - stdout io.ReadCloser - onmessage func(msg *message) - rLock sync.Mutex + writer io.WriteCloser + bufReader *bufio.Reader + closed chan struct{} + onClose func() error } -func (t *pipeTransport) Start() error { - reader := bufio.NewReader(t.stdout) - for { - lengthContent := make([]byte, 4) - _, err := io.ReadFull(reader, lengthContent) - if err == io.EOF || errors.Is(err, os.ErrClosed) { - return nil - } else if err != nil { - return fmt.Errorf("could not read padding: %w", err) - } - length := binary.LittleEndian.Uint32(lengthContent) +func (t *pipeTransport) Poll() (*message, error) { + if t.isClosed() { + return nil, fmt.Errorf("transport closed") + } + lengthContent := make([]byte, 4) + _, err := io.ReadFull(t.bufReader, lengthContent) + if err == io.EOF || errors.Is(err, os.ErrClosed) { + return nil, fmt.Errorf("pipe closed: %w", err) + } else if err != nil { + return nil, fmt.Errorf("could not read padding: %w", err) + } + length := binary.LittleEndian.Uint32(lengthContent) - msg := &message{} - if err := json.NewDecoder(io.LimitReader(reader, int64(length))).Decode(&msg); err != nil { - return fmt.Errorf("could not decode json: %w", err) - } - if os.Getenv("DEBUGP") != "" { - fmt.Fprint(os.Stdout, "\x1b[33mRECV>\x1b[0m\n") - if err := json.NewEncoder(os.Stdout).Encode(msg); err != nil { - log.Printf("could not encode json: %v", err) - } + msg := &message{} + if err := json.NewDecoder(io.LimitReader(t.bufReader, int64(length))).Decode(&msg); err != nil { + return nil, fmt.Errorf("could not decode json: %w", err) + } + if os.Getenv("DEBUGP") != "" { + fmt.Fprint(os.Stdout, "\x1b[33mRECV>\x1b[0m\n") + if err := json.NewEncoder(os.Stdout).Encode(msg); err != nil { + log.Printf("could not encode json: %v", err) } - t.onmessage(msg) } + return msg, nil } type message struct { @@ -57,33 +63,84 @@ type message struct { } `json:"error,omitempty"` } -func (t *pipeTransport) Send(message map[string]interface{}) error { - msg, err := json.Marshal(message) +func (t *pipeTransport) Send(msg map[string]interface{}) error { + if t.isClosed() { + return fmt.Errorf("transport closed") + } + msgBytes, err := json.Marshal(msg) if err != nil { - return fmt.Errorf("could not marshal json: %w", err) + return fmt.Errorf("pipeTransport: could not marshal json: %w", err) } if os.Getenv("DEBUGP") != "" { fmt.Fprint(os.Stdout, "\x1b[32mSEND>\x1b[0m\n") - if err := json.NewEncoder(os.Stdout).Encode(message); err != nil { + if err := json.NewEncoder(os.Stdout).Encode(msg); err != nil { log.Printf("could not encode json: %v", err) } } lengthPadding := make([]byte, 4) - t.rLock.Lock() - defer t.rLock.Unlock() - binary.LittleEndian.PutUint32(lengthPadding, uint32(len(msg))) - if _, err = t.stdin.Write(lengthPadding); err != nil { - return err - } - if _, err = t.stdin.Write(msg); err != nil { + binary.LittleEndian.PutUint32(lengthPadding, uint32(len(msgBytes))) + if _, err = t.writer.Write(append(lengthPadding, msgBytes...)); err != nil { return err } return nil } -func newPipeTransport(stdin io.WriteCloser, stdout io.ReadCloser) *pipeTransport { - return &pipeTransport{ - stdout: stdout, - stdin: stdin, +func (t *pipeTransport) Close() error { + select { + case <-t.closed: + return nil + default: + return t.onClose() } } + +func (t *pipeTransport) isClosed() bool { + select { + case <-t.closed: + return true + default: + return false + } +} + +func newPipeTransport(driverCli string) (transport, error) { + t := &pipeTransport{ + closed: make(chan struct{}, 1), + } + + cmd := exec.Command(driverCli, "run-driver") + cmd.SysProcAttr = defaultSysProcAttr + cmd.Stderr = os.Stderr + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("could not create stdin pipe: %w", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("could not create stdout pipe: %w", err) + } + t.writer = stdin + t.bufReader = bufio.NewReader(stdout) + + t.onClose = func() error { + select { + case <-t.closed: + default: + close(t.closed) + } + if err := t.writer.Close(); err != nil { + return err + } + // playwright-cli will exit when its stdin is closed + if err := cmd.Wait(); err != nil { + return err + } + return nil + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("could not start driver: %w", err) + } + + return t, nil +}