diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 82f3c0c..8d2ee8d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,17 +1,73 @@ name: wsrpc -on: push + +on: + push: + jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version-file: 'go.mod' + + ## + # XXX: change this to the official action once multiple --out-format args are supported. + # See: https://github.com/golangci/golangci-lint-action/issues/612 + ## + - name: golangci-lint + uses: smartcontractkit/golangci-lint-action@54ab6c5f11d66a92d14c3f7cc41ea13f676644bd # feature/multiple-output-formats-backup + with: + # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version + version: v1.52.1 + + # Optional: working directory, useful for monorepos + # working-directory: + + # Optional: golangci-lint command line arguments. + allow-extra-out-format-args: true + args: --timeout=2m0s --out-format checkstyle:golangci-lint-report.xml + + # Optional: show only new issues if it's a pull request. The default value is `false`. + only-new-issues: true + + # Optional: if set to true then the action will use pre-installed Go. + # skip-go-installation: true + + # Optional: if set to true then the action don't cache or restore ~/go/pkg. + # skip-pkg-cache: true + + # Optional: if set to true then the action don't cache or restore ~/.cache/go-build. + # skip-build-cache: true + + - name: Print lint report artifact + if: always() + run: test -f golangci-lint-report.xml && cat golangci-lint-report.xml || true + + - name: Upload lint report artifact + if: always() + uses: actions/upload-artifact@3cea5372237819ed00197afe530f5a7ea3e805c8 # v3.1.0 + with: + name: golangci-lint-report + path: golangci-lint-report.xml + ci_test: name: CI Tests runs-on: ubuntu-latest steps: - name: Checkout the repo uses: actions/checkout@v2 + - name: Setup Go uses: actions/setup-go@v3 with: go-version-file: "go.mod" cache: true + - name: Cache dependencies uses: actions/cache@v2 with: @@ -21,5 +77,52 @@ jobs: key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go + - name: Run Test Suite - run: go test -p=1 -race ./... + run: set -o pipefail && go test ./... -coverpkg=./... -coverprofile=coverage.txt $1 | tee $OUTPUT_FILE + + - name: Run Race Test Suite + run: set -o pipefail && go test -race ./... -coverpkg=./... -coverprofile=race_coverage.txt $1 | tee $OUTPUT_FILE + + - name: Upload Go test results + if: always() + uses: actions/upload-artifact@v3 + with: + name: go-test-results + path: | + ./output.txt + ./coverage.txt + ./race_coverage.txt + + sonar-scan: + name: SonarQube + needs: [lint, ci_test] + runs-on: ubuntu-latest + if: always() + steps: + - name: Checkout the repo + uses: actions/checkout@v3 + with: + fetch-depth: 0 # fetches all history for all tags and branches to provide more metadata for sonar reports + + - name: Download all workflow run artifacts + uses: actions/download-artifact@9782bd6a9848b53b110e712e20e42d89988822b7 # v3.0.1 + + - name: Set SonarQube Report Paths + id: sonarqube_report_paths + shell: bash + run: | + echo "sonarqube_tests_report_paths=$(find -type f -name 'output.txt' -printf "%p,")" >> $GITHUB_OUTPUT + echo "sonarqube_coverage_report_paths=$(find -type f -name '*coverage.txt' -printf "%p,")" >> $GITHUB_OUTPUT + echo "sonarqube_golangci_report_paths=$(find -type f -name 'golangci-lint-report.xml' -printf "%p,")" >> $GITHUB_OUTPUT + + - name: SonarQube Scan + uses: sonarsource/sonarqube-scan-action@a6ba0aafc293e03de5437af7edbc97f7d3ebc91a # v1.2.0 + with: + args: > + -Dsonar.go.tests.reportPaths=${{ steps.sonarqube_report_paths.outputs.sonarqube_tests_report_paths }} + -Dsonar.go.coverage.reportPaths=${{ steps.sonarqube_report_paths.outputs.sonarqube_coverage_report_paths }} + -Dsonar.go.golangci-lint.reportPaths=${{ steps.sonarqube_report_paths.outputs.sonarqube_golangci_report_paths }} + env: + SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} + SONAR_HOST_URL: ${{ secrets.SONAR_HOST_URL }} diff --git a/.github/workflows/dependency-check.yml b/.github/workflows/dependency-check.yml new file mode 100644 index 0000000..4ba25ca --- /dev/null +++ b/.github/workflows/dependency-check.yml @@ -0,0 +1,52 @@ +name: Dependency Vulnerability Check + +on: + push: + +jobs: + changes: + name: Detect changes + runs-on: ubuntu-latest + outputs: + changes: ${{ steps.changes.outputs.src }} + steps: + - name: Checkout the repo + uses: actions/checkout@c85c95e3d7251135ab7dc9ce3241c5835cc595a9 # v3.5.3 + - uses: dorny/paths-filter@4512585405083f25c027a35db413c2b3b9006d50 # v2.11.1 + id: changes + with: + filters: | + src: + - '**/*go.sum' + - '**/*go.mod' + - '.github/workflows/dependency-check.yml' + Go: + runs-on: ubuntu-latest + needs: [changes] + steps: + - name: Check out code + uses: actions/checkout@c85c95e3d7251135ab7dc9ce3241c5835cc595a9 # v3.5.3 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version-file: 'go.mod' + id: go + + - name: Write Go Modules list + run: go list -json -m all > go.list + + - name: Check vulnerabilities + uses: sonatype-nexus-community/nancy-github-action@main + with: + nancyVersion: "v1.0.42" + + - name: Collect Metrics + if: always() + id: collect-gha-metrics + uses: smartcontractkit/push-gha-metrics-action@90fcbaac8ebf86da9c4d55dba24f6fe3029f0e0b + with: + basic-auth: ${{ secrets.GRAFANA_CLOUD_BASIC_AUTH }} + hostname: ${{ secrets.GRAFANA_CLOUD_HOST }} + this-job-name: Go + continue-on-error: true diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml deleted file mode 100644 index 1e419d3..0000000 --- a/.github/workflows/lint.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: lint -on: push -jobs: - golangci: - name: lint - runs-on: ubuntu-latest - steps: - - name: Checkout the repo - uses: actions/checkout@v2 - - name: Setup Go - uses: actions/setup-go@v3 - with: - go-version-file: "go.mod" - cache: true - - name: golangci-lint - uses: golangci/golangci-lint-action@v3 - with: - # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version - version: v1.46.2 - - # Optional: working directory, useful for monorepos - # working-directory: somedir - - # Optional: golangci-lint command line arguments. - args: --timeout=2m0s - - # Optional: show only new issues if it's a pull request. The default value is `false`. - # only-new-issues: true - - # Optional: if set to true then the action will use pre-installed Go. - # skip-go-installation: true - - # Optional: if set to true then the action don't cache or restore ~/go/pkg. - # skip-pkg-cache: true - - # Optional: if set to true then the action don't cache or restore ~/.cache/go-build. - # skip-build-cache: true diff --git a/.gitignore b/.gitignore index 192e8f3..ba55edb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,8 @@ # Editors .vscode + +# Test & linter reports +*report.xml +*report.json +*.out diff --git a/client.go b/client.go index fec9f3e..d30e7f1 100644 --- a/client.go +++ b/client.go @@ -35,18 +35,20 @@ type ClientInterface interface { // ClientConn represents a virtual connection to a websocket endpoint, to // perform and serve RPCs. type ClientConn struct { - ctx context.Context - mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + mu sync.RWMutex + wg *sync.WaitGroup // The websocket address target string // A channel which receives updates when connectivity state changes - csCh <-chan connectivity.State + stateCh <-chan connectivity.State // Manages the connectivity state. csMgr *connectivityStateManager - dopts dialOptions - conn *addrConn + dopts dialOptions + addrConn *addrConn // Contains all pending method call ids and a handler to call when a // response is received @@ -57,24 +59,32 @@ type ClientConn struct { } func Dial(target string, opts ...DialOption) (*ClientConn, error) { - return DialWithContext(context.Background(), target, opts...) + ctx := context.Background() + return DialWithContext(ctx, target, opts...) } // Dial creates a client connection to the given target. By default, it's // a non-blocking dial (the function won't wait for connections to be // established, and connecting happens in the background). To make it a blocking // dial, use WithBlock() dial option. -func DialWithContext(ctx context.Context, target string, opts ...DialOption) (*ClientConn, error) { +func DialWithContext(ctxCaller context.Context, target string, opts ...DialOption) (*ClientConn, error) { + ctx, cancel := context.WithCancel(ctxCaller) + cc := &ClientConn{ - ctx: context.Background(), + ctx: ctx, + cancel: cancel, + wg: &sync.WaitGroup{}, target: target, csMgr: &connectivityStateManager{}, dopts: defaultDialOptions(), methodCalls: map[string]MethodCallHandler{}, } - for _, opt := range opts { - opt.apply(&cc.dopts) + for i, opt := range opts { + err := opt.apply(&cc.dopts) + if err != nil { + return nil, fmt.Errorf("dial option %d failed: %w", i, err) + } } // Set the backoff strategy. We may need to consider making this @@ -87,18 +97,18 @@ func DialWithContext(ctx context.Context, target string, opts ...DialOption) (*C return nil, fmt.Errorf("error connecting: %w", err) } cc.mu.Lock() - cc.conn = addrConn + cc.addrConn = addrConn cc.mu.Unlock() if cc.dopts.block { for { - s := cc.csMgr.getState() - if s == connectivity.Ready { + curState := cc.csMgr.getState() + if curState == connectivity.Ready { break } // Wait for a state change to re run the for loop - if !cc.WaitForStateChange(ctx, s) { + if !cc.WaitForStateChange(ctx, curState) { addrConn.cancel() return nil, ctx.Err() @@ -151,22 +161,26 @@ func (cc *ClientConn) GetState() connectivity.State { // newAddrConn creates an addrConn for the addr and sets it to cc.conn. func (cc *ClientConn) newAddrConn(addr string) *addrConn { - csCh := make(chan connectivity.State) + stateCh := make(chan connectivity.State) ac := &addrConn{ state: connectivity.Idle, - stateCh: csCh, - cc: cc, + wg: &sync.WaitGroup{}, + stateCh: stateCh, addr: addr, dopts: cc.dopts, } ac.ctx, ac.cancel = context.WithCancel(cc.ctx) cc.mu.Lock() - cc.conn = ac - cc.csCh = csCh + cc.addrConn = ac + cc.stateCh = stateCh + cc.csMgr.getNotifyChan() cc.mu.Unlock() + cc.wg.Add(1) go cc.listenForConnectivityChange() + + cc.wg.Add(1) go cc.listenForRead() return ac @@ -175,32 +189,41 @@ func (cc *ClientConn) newAddrConn(addr string) *addrConn { // listenForConnectivityChange listens for the addrConn's connectivity to change // and updates the ClientConn ConnectivityStateManager. func (cc *ClientConn) listenForConnectivityChange() { + defer cc.wg.Done() for { - s := <-cc.csCh - - cc.csMgr.updateState(s) + select { + case <-cc.ctx.Done(): + return + case s := <-cc.stateCh: + cc.csMgr.updateState(s) + } } } // listenForRead listens for the connectivity state to be ready and enables the // read handler. func (cc *ClientConn) listenForRead() { + defer cc.wg.Done() + var done chan struct{} for { - notifyChan := cc.csMgr.getNotifyChan() - <-notifyChan - - s := cc.csMgr.getState() + select { + case <-cc.ctx.Done(): + return + case <-cc.csMgr.getNotifyChan(): + s := cc.csMgr.getState() - if s == connectivity.Ready { - if done == nil { - done = make(chan struct{}) - } - go cc.handleRead(done) - } else { - if done != nil { - close(done) - done = nil + if s == connectivity.Ready { + if done == nil { + done = make(chan struct{}) + } + cc.wg.Add(1) + go cc.handleRead(done) + } else { + if done != nil { + close(done) + done = nil + } } } } @@ -209,11 +232,12 @@ func (cc *ClientConn) listenForRead() { // handleRead listens to the transport read channel and passes the message to the // readFn handler. func (cc *ClientConn) handleRead(done <-chan struct{}) { + defer cc.wg.Done() var tr transport.ClientTransport var conn *addrConn cc.mu.RLock() - conn = cc.conn + conn = cc.addrConn // if connection has been closed, then conn can be nil if conn == nil { @@ -223,10 +247,14 @@ func (cc *ClientConn) handleRead(done <-chan struct{}) { } conn.mu.RLock() - tr = cc.conn.transport + tr = cc.addrConn.transport conn.mu.RUnlock() cc.mu.RUnlock() + if nil == tr { + return + } + for { select { case in := <-tr.Read(): @@ -238,13 +266,15 @@ func (cc *ClientConn) handleRead(done <-chan struct{}) { switch ex := msg.Exchange.(type) { case *message.Message_Request: + cc.wg.Add(1) go cc.handleMessageRequest(ex.Request) case *message.Message_Response: + cc.wg.Add(1) go cc.handleMessageResponse(ex.Response) default: cc.dopts.logger.Errorf("Invalid message type: %T", ex) } - case <-done: + case <-cc.ctx.Done(): return } } @@ -253,6 +283,7 @@ func (cc *ClientConn) handleRead(done <-chan struct{}) { // handleMessageRequest looks up the method matching the method name and calls // the handler. func (cc *ClientConn) handleMessageRequest(r *message.Request) { + defer cc.wg.Done() methodName := r.GetMethod() if md, ok := cc.service.methods[methodName]; ok { // Create a decoder function to unmarshal the message @@ -275,9 +306,9 @@ func (cc *ClientConn) handleMessageRequest(r *message.Request) { var tr transport.ClientTransport cc.mu.RLock() - cc.conn.mu.RLock() - tr = cc.conn.transport - cc.conn.mu.RUnlock() + cc.addrConn.mu.RLock() + tr = cc.addrConn.transport + cc.addrConn.mu.RUnlock() cc.mu.RUnlock() if err := tr.Write(ctx, replyMsg); err != nil { @@ -289,6 +320,7 @@ func (cc *ClientConn) handleMessageRequest(r *message.Request) { // handleMessageResponse finds the call which matches the method call id of the // response and sends the payload to the call channel. func (cc *ClientConn) handleMessageResponse(r *message.Response) { + defer cc.wg.Done() callID := r.GetCallId() cc.mu.Lock() @@ -323,13 +355,15 @@ func (cc *ClientConn) register(sd *ServiceDesc, ss interface{}) { // Close tears down the ClientConn and all underlying connections. func (cc *ClientConn) Close() { - + cc.cancel() cc.mu.Lock() - conn := cc.conn - cc.conn = nil + addrConn := cc.addrConn + cc.addrConn = nil cc.mu.Unlock() - conn.teardown() + addrConn.teardown() //closes lower level + + cc.wg.Wait() } // Invoke sends the RPC request on the wire and returns after response is @@ -341,10 +375,11 @@ func (cc *ClientConn) Invoke(ctx context.Context, method string, args interface{ } // Ensure the connection state is ready + cc.mu.RLock() - cc.conn.mu.RLock() - state := cc.conn.state - cc.conn.mu.RUnlock() + cc.addrConn.mu.RLock() + state := cc.addrConn.state + cc.addrConn.mu.RUnlock() cc.mu.RUnlock() if state != connectivity.Ready { @@ -376,9 +411,9 @@ func (cc *ClientConn) Invoke(ctx context.Context, method string, args interface{ var tr transport.ClientTransport cc.mu.RLock() - cc.conn.mu.RLock() - tr = cc.conn.transport - cc.conn.mu.RUnlock() + cc.addrConn.mu.RLock() + tr = cc.addrConn.transport + cc.addrConn.mu.RUnlock() cc.mu.RUnlock() if err := tr.Write(ctx, reqB); err != nil { @@ -432,8 +467,7 @@ func (cc *ClientConn) removeMethodCall(id string) { type addrConn struct { ctx context.Context cancel context.CancelFunc - - cc *ClientConn + wg *sync.WaitGroup addr string dopts dialOptions @@ -473,6 +507,7 @@ func (ac *addrConn) connect() error { ac.mu.Unlock() // Start a goroutine connecting to the server asynchronously. + ac.wg.Add(1) go ac.resetTransport() return nil @@ -484,17 +519,23 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State) { return } ac.state = s - ac.stateCh <- s + select { + case ac.stateCh <- s: + return + case <-ac.ctx.Done(): + return + } + } // resetTransport attempts to connect to the server. If the connection fails, // it will continuously attempt reconnection with an exponential backoff. func (ac *addrConn) resetTransport() { - for i := 0; ; i++ { + defer ac.wg.Done() + for { ac.mu.Lock() if ac.state == connectivity.Shutdown { ac.mu.Unlock() - return } @@ -505,16 +546,16 @@ func (ac *addrConn) resetTransport() { ac.transport = nil ac.updateConnectivityState(connectivity.Connecting) - ac.mu.Unlock() newTr, reconnect, err := ac.createTransport(addr, copts) + + ac.mu.Unlock() if err != nil { ac.dopts.logger.Errorf("failed to connect to server at %s, got: %v", addr, err) // After connection failure, the addrConn enters TRANSIENT_FAILURE. ac.mu.Lock() if ac.state == connectivity.Shutdown { ac.mu.Unlock() - return } ac.updateConnectivityState(connectivity.TransientFailure) @@ -540,8 +581,6 @@ func (ac *addrConn) resetTransport() { ac.mu.Lock() if ac.state == connectivity.Shutdown { ac.mu.Unlock() - newTr.Close() - return } ac.transport = newTr @@ -555,9 +594,12 @@ func (ac *addrConn) resetTransport() { // Block until the created transport is down. When this happens, we // attempt to reconnect by starting again from the top - <-reconnect.Done() - - ac.dopts.logger.Info("Reconnecting to server...") + select { + case <-ac.ctx.Done(): + return + case <-reconnect.Done(): + ac.dopts.logger.Info("Reconnecting to server...") + } } } @@ -570,10 +612,10 @@ func (ac *addrConn) createTransport(addr string, copts transport.ConnectOptions) once := sync.Once{} // Called when the transport closes - onClose := func() { + afterWritePump := func() { ac.mu.Lock() once.Do(func() { - if ac.state == connectivity.Ready { + if connectivity.Ready == ac.state { ac.updateConnectivityState(connectivity.Idle) } }) @@ -581,7 +623,7 @@ func (ac *addrConn) createTransport(addr string, copts transport.ConnectOptions) reconnect.Fire() } - tr, err := transport.NewClientTransport(ac.cc.ctx, ac.dopts.logger, addr, copts, onClose) + tr, err := transport.NewClientTransport(ac.ctx, ac.dopts.logger, addr, copts, afterWritePump) return tr, reconnect, err } @@ -589,21 +631,25 @@ func (ac *addrConn) createTransport(addr string, copts transport.ConnectOptions) // tearDown starts to tear down the addrConn. func (ac *addrConn) teardown() { ac.mu.Lock() - defer ac.mu.Unlock() + ac.cancel() - if ac.state == connectivity.Shutdown { + if connectivity.Shutdown == ac.state { + ac.mu.Unlock() return } - ac.updateConnectivityState(connectivity.Shutdown) - curTr := ac.transport ac.transport = nil ac.cancel() + ac.updateConnectivityState(connectivity.Shutdown) + ac.mu.Unlock() if curTr != nil { + //syncronously closes lower level curTr.Close() } + + ac.wg.Wait() } // connectivityStateManager keeps the connectivity.State of ClientConn. @@ -620,17 +666,15 @@ func (csm *connectivityStateManager) updateState(state connectivity.State) { csm.mu.Lock() defer csm.mu.Unlock() - if csm.state == connectivity.Shutdown { - return - } if csm.state == state { return } csm.state = state if csm.notifyChan != nil { // There are other goroutines waiting on this channel. - close(csm.notifyChan) + notifyChan := csm.notifyChan csm.notifyChan = nil + close(notifyChan) } } diff --git a/credentials/tls.go b/credentials/tls.go index db8ce03..7f819f2 100644 --- a/credentials/tls.go +++ b/credentials/tls.go @@ -3,6 +3,7 @@ package credentials import ( "crypto/ed25519" "crypto/rand" + "crypto/subtle" "crypto/tls" "crypto/x509" "errors" @@ -13,15 +14,19 @@ import ( type StaticSizedPublicKey [ed25519.PublicKeySize]byte +func (p StaticSizedPublicKey) String() string { + return fmt.Sprintf("%x", p[:]) +} + // NewClientTLSConfig uses the private key and public keys to construct a mutual // TLS config for the client. -func NewClientTLSConfig(priv ed25519.PrivateKey, pubs *PublicKeys) (*tls.Config, error) { +func NewClientTLSConfig(priv *PrivateKey, pubs *PublicKeys) (*tls.Config, error) { return newMutualTLSConfig(priv, pubs) } // NewServerTLSConfig uses the private key and public keys to construct a mutual // TLS config for the server. -func NewServerTLSConfig(priv ed25519.PrivateKey, pubs *PublicKeys) (*tls.Config, error) { +func NewServerTLSConfig(priv *PrivateKey, pubs *PublicKeys) (*tls.Config, error) { c, err := newMutualTLSConfig(priv, pubs) if err != nil { return nil, err @@ -39,7 +44,7 @@ func NewServerTLSConfig(priv ed25519.PrivateKey, pubs *PublicKeys) (*tls.Config, // // Certificates are currently used similarly to GPG keys and only functionally // as certificates to support the crypto/tls go module. -func newMutualTLSConfig(priv ed25519.PrivateKey, pubs *PublicKeys) (*tls.Config, error) { +func newMutualTLSConfig(priv *PrivateKey, pubs *PublicKeys) (*tls.Config, error) { cert, err := newMinimalX509Cert(priv) if err != nil { return nil, err @@ -65,33 +70,55 @@ func newMutualTLSConfig(priv ed25519.PrivateKey, pubs *PublicKeys) (*tls.Config, // Generates a minimal certificate (that wouldn't be considered valid outside of // this networking protocol) from an Ed25519 private key. -func newMinimalX509Cert(priv ed25519.PrivateKey) (tls.Certificate, error) { +func newMinimalX509Cert(priv *PrivateKey) (tls.Certificate, error) { + ed25519Priv := priv.key + template := x509.Certificate{ SerialNumber: big.NewInt(0), // serial number must be set, so we set it to 0 } - encodedCert, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv) + encodedCert, err := x509.CreateCertificate(rand.Reader, &template, &template, ed25519Priv.Public(), ed25519Priv) if err != nil { return tls.Certificate{}, err } return tls.Certificate{ Certificate: [][]byte{encodedCert}, - PrivateKey: priv, + PrivateKey: ed25519Priv, SupportedSignatureAlgorithms: []tls.SignatureScheme{tls.Ed25519}, }, nil } +type PrivateKey struct { + key ed25519.PrivateKey +} + +func ValidPrivateKeyFromEd25519(key ed25519.PrivateKey) (*PrivateKey, error) { + if len(key) != ed25519.PrivateKeySize { + return nil, fmt.Errorf("invalid key length: %d, expected: %d", len(key), ed25519.PrivateKeySize) + } + + return &PrivateKey{ + key: key, + }, nil +} + // PublicKeys wraps a slice of keys so we can update the keys dynamically. type PublicKeys struct { mu sync.RWMutex keys []ed25519.PublicKey } -func NewPublicKeys(keys ...ed25519.PublicKey) *PublicKeys { +func ValidPublicKeysFromEd25519(keys ...ed25519.PublicKey) (*PublicKeys, error) { + for _, key := range keys { + if len(key) != ed25519.PublicKeySize { + return nil, fmt.Errorf("invalid key length: %d, expected: %d", len(key), ed25519.PublicKeySize) + } + } + return &PublicKeys{ keys: keys, - } + }, nil } func (r *PublicKeys) Keys() []ed25519.PublicKey { @@ -125,10 +152,10 @@ func (r *PublicKeys) VerifyPeerCertificate() func(rawCerts [][]byte, verifiedCha // Replace replaces the existing keys with new keys. Use this to dynamically // update the allowable keys at runtime. -func (r *PublicKeys) Replace(pubs []ed25519.PublicKey) { +func (r *PublicKeys) Replace(pubs *PublicKeys) { r.mu.Lock() defer r.mu.Unlock() - r.keys = pubs + r.keys = pubs.keys } // isValidPublicKey checks the public key against a list of valid keys. @@ -136,7 +163,7 @@ func (r *PublicKeys) isValidPublicKey(pub ed25519.PublicKey) bool { r.mu.RLock() defer r.mu.RUnlock() for _, vpub := range r.keys { - if pub.Equal(vpub) { + if subtle.ConstantTimeCompare(pub, vpub) > 0 { return true } } diff --git a/credentials/tls_test.go b/credentials/tls_test.go index af0f7dd..be41af4 100644 --- a/credentials/tls_test.go +++ b/credentials/tls_test.go @@ -14,13 +14,22 @@ import ( ) func Test_NewClientTLSConfig(t *testing.T) { - _, cpriv, err := ed25519.GenerateKey(nil) + _, ed25519cpriv, err := ed25519.GenerateKey(nil) require.NoError(t, err) - spub, spriv, err := ed25519.GenerateKey(nil) + spub, ed25519spriv, err := ed25519.GenerateKey(nil) require.NoError(t, err) - tlsCfg, err := NewClientTLSConfig(cpriv, NewPublicKeys(spub)) + cpriv, err := ValidPrivateKeyFromEd25519(ed25519cpriv) + require.NoError(t, err) + + spriv, err := ValidPrivateKeyFromEd25519(ed25519spriv) + require.NoError(t, err) + + spubs, err := ValidPublicKeysFromEd25519(spub) + require.NoError(t, err) + + tlsCfg, err := NewClientTLSConfig(cpriv, spubs) require.NoError(t, err) require.Len(t, tlsCfg.Certificates, 1) @@ -36,7 +45,10 @@ func Test_NewClientTLSConfig(t *testing.T) { require.NoError(t, err) // Test an invalid client certificate - _, invspriv, err := ed25519.GenerateKey(nil) + _, ed25519invspriv, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + invspriv, err := ValidPrivateKeyFromEd25519(ed25519invspriv) require.NoError(t, err) invscert, err := newMinimalX509Cert(invspriv) @@ -47,13 +59,22 @@ func Test_NewClientTLSConfig(t *testing.T) { } func Test_NewServerTLSConfig(t *testing.T) { - _, spriv, err := ed25519.GenerateKey(nil) + _, ed25519spriv, err := ed25519.GenerateKey(nil) require.NoError(t, err) - cpub, cpriv, err := ed25519.GenerateKey(nil) + cpub, ed25519cpriv, err := ed25519.GenerateKey(nil) require.NoError(t, err) - tlsCfg, err := NewServerTLSConfig(spriv, NewPublicKeys(cpub)) + spriv, err := ValidPrivateKeyFromEd25519(ed25519spriv) + require.NoError(t, err) + + cpriv, err := ValidPrivateKeyFromEd25519(ed25519cpriv) + require.NoError(t, err) + + cpubs, err := ValidPublicKeysFromEd25519(cpub) + require.NoError(t, err) + + tlsCfg, err := NewServerTLSConfig(spriv, cpubs) require.NoError(t, err) require.Len(t, tlsCfg.Certificates, 1) @@ -70,7 +91,10 @@ func Test_NewServerTLSConfig(t *testing.T) { require.NoError(t, err) // Test an invalid client certificate - _, invcpriv, err := ed25519.GenerateKey(rand.New(rand.NewSource(42))) //nolint:gosec + _, ed25519invcpriv, err := ed25519.GenerateKey(rand.New(rand.NewSource(42))) //nolint:gosec + require.NoError(t, err) + + invcpriv, err := ValidPrivateKeyFromEd25519(ed25519invcpriv) require.NoError(t, err) invccert, err := newMinimalX509Cert(invcpriv) @@ -115,3 +139,51 @@ func Test_PubKeyFromCert_MustBeEd25519KeyError(t *testing.T) { _, err = PubKeyFromCert(cert) require.EqualError(t, err, "requires an ed25519 public key") } + +func Test_IsValidPublicKey(t *testing.T) { + t.Run("pub_key_included", func(t *testing.T) { + cpub, _, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + pk, err := ValidPublicKeysFromEd25519(cpub) + require.NoError(t, err) + + require.True(t, pk.isValidPublicKey(cpub)) + }) + + t.Run("pub_key_not_included", func(t *testing.T) { + cpub, _, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + cpub2, _, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + pk, err := ValidPublicKeysFromEd25519(cpub) + require.NoError(t, err) + + // Test + require.False(t, pk.isValidPublicKey(cpub2)) + }) +} + +func Test_NewPublicKeys(t *testing.T) { + t.Run("key_length_32", func(t *testing.T) { + cpub, _, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + _, err = ValidPublicKeysFromEd25519(cpub) + require.NoError(t, err) + }) + + t.Run("key_length_not_32", func(t *testing.T) { + shortKey := make([]byte, ed25519.PublicKeySize-1) + + _, err := ValidPublicKeysFromEd25519(shortKey) + require.Error(t, err) + + longKey := make([]byte, ed25519.PublicKeySize+1) + + _, err = ValidPublicKeysFromEd25519(longKey) + require.Error(t, err) + }) +} diff --git a/dialoptions.go b/dialoptions.go index a65371b..ed56908 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -2,7 +2,6 @@ package wsrpc import ( "crypto/ed25519" - "log" "time" "github.com/smartcontractkit/wsrpc/credentials" @@ -22,7 +21,7 @@ type dialOptions struct { // DialOption configures how we set up the connection. type DialOption interface { - apply(*dialOptions) + apply(*dialOptions) error } // funcDialOption wraps a function that modifies dialOptions into an @@ -31,8 +30,9 @@ type funcDialOption struct { f func(*dialOptions) } -func (fdo *funcDialOption) apply(do *dialOptions) { +func (fdo *funcDialOption) apply(do *dialOptions) error { fdo.f(do) + return nil } func newFuncDialOption(f func(*dialOptions)) *funcDialOption { @@ -41,21 +41,42 @@ func newFuncDialOption(f func(*dialOptions)) *funcDialOption { } } +type funcDialOptionWithErr struct { + funcWithErr func(*dialOptions) error +} + +func (fdo *funcDialOptionWithErr) apply(do *dialOptions) error { + return fdo.funcWithErr(do) +} + +func newFuncDialOptionWithErr(f func(*dialOptions) error) *funcDialOptionWithErr { + return &funcDialOptionWithErr{ + funcWithErr: f, + } +} + // WithTransportCredentials returns a DialOption which configures a connection // level security credentials (e.g., TLS/SSL). func WithTransportCreds(privKey ed25519.PrivateKey, serverPubKey ed25519.PublicKey) DialOption { - return newFuncDialOption(func(o *dialOptions) { - pubs := credentials.NewPublicKeys(serverPubKey) + return newFuncDialOptionWithErr(func(o *dialOptions) error { + privKey, err := credentials.ValidPrivateKeyFromEd25519(privKey) + if err != nil { + return err + } + + pubs, err := credentials.ValidPublicKeysFromEd25519(serverPubKey) + if err != nil { + return err + } // Generate the TLS config for the client config, err := credentials.NewClientTLSConfig(privKey, pubs) if err != nil { - log.Println(err) - - return + return err } o.copts.TransportCredentials = credentials.NewTLS(config, pubs) + return nil }) } @@ -76,6 +97,12 @@ func WithWriteTimeout(d time.Duration) DialOption { }) } +func WithReadLimit(size int64) DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.copts.ReadLimit = size + }) +} + func WithLogger(lggr logger.Logger) DialOption { return newFuncDialOption(func(o *dialOptions) { o.logger = lggr diff --git a/examples/simple/keys/keys.go b/examples/simple/keys/keys.go index d874975..80e342c 100644 --- a/examples/simple/keys/keys.go +++ b/examples/simple/keys/keys.go @@ -4,7 +4,7 @@ import ( "crypto/ed25519" "encoding/hex" "errors" - + "fmt" "github.com/smartcontractkit/wsrpc/credentials" ) @@ -56,6 +56,9 @@ func FromHex(keyHex string) []byte { // ToStaticSizedBytes convert bytes to a statically sized byte array of the // of ed25519.PublicKeySize func ToStaticSizedBytes(b []byte) (credentials.StaticSizedPublicKey, error) { + if len(b) != ed25519.PublicKeySize { + return credentials.StaticSizedPublicKey{}, fmt.Errorf("provided public key is %d bytes, expected %d bytes", len(b), ed25519.PublicKeySize) + } var sb credentials.StaticSizedPublicKey if ed25519.PublicKeySize != copy(sb[:], b) { diff --git a/examples/simple/server/main.go b/examples/simple/server/main.go index 671ae1c..f503a4a 100644 --- a/examples/simple/server/main.go +++ b/examples/simple/server/main.go @@ -44,7 +44,7 @@ func main() { log.Fatalf("[MAIN] failed to listen: %v", err) } s := wsrpc.NewServer( - wsrpc.Creds(privKey, pubKeys), + wsrpc.WithCreds(privKey, pubKeys), wsrpc.WithHealthcheck("127.0.0.1:1337"), ) diff --git a/go.mod b/go.mod index fb19f1d..1886c4a 100644 --- a/go.mod +++ b/go.mod @@ -11,3 +11,7 @@ require ( go.uber.org/zap v1.24.0 google.golang.org/protobuf v1.26.0 ) + +replace golang.org/x/text => golang.org/x/text v0.11.0 + +replace golang.org/x/net => golang.org/x/net v0.14.0 diff --git a/go.sum b/go.sum index fde2f1f..9134dff 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,7 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= @@ -39,28 +40,33 @@ go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/methods/methods.go b/internal/methods/methods.go new file mode 100644 index 0000000..ea3a449 --- /dev/null +++ b/internal/methods/methods.go @@ -0,0 +1,68 @@ +package methods + +import ( + "fmt" + + "github.com/smartcontractkit/wsrpc/credentials" + "github.com/smartcontractkit/wsrpc/internal/message" +) + +type MethodCalls struct { + MethodCalls map[credentials.StaticSizedPublicKey]*MethodCallsForPublicKey +} + +func NewMethodCalls() *MethodCalls { + return &MethodCalls{ + MethodCalls: make(map[credentials.StaticSizedPublicKey]*MethodCallsForPublicKey), + } +} + +func (m *MethodCalls) PutMethodCallForPublicKey(pubKey credentials.StaticSizedPublicKey, id string, ch chan<- *message.Response) { + var methodCallsForPubKey *MethodCallsForPublicKey + var ok bool + if methodCallsForPubKey, ok = m.MethodCalls[pubKey]; !ok { + methodCallsForPubKey = NewMethodCallsForPublicKey() + m.MethodCalls[pubKey] = methodCallsForPubKey + } + methodCallsForPubKey.PutMessageResponseChannel(id, ch) +} + +func (m *MethodCalls) GetMessageResponseChannelForPublicKey(pubKey credentials.StaticSizedPublicKey, id string) (chan<- *message.Response, error) { + if methodCallsForPubKey, ok := m.MethodCalls[pubKey]; ok { + return methodCallsForPubKey.GetMessageResponseChannel(id) + } + + return nil, fmt.Errorf("public key not found: %v", pubKey) +} + +func (m *MethodCalls) DeleteMethodCall(pubKey credentials.StaticSizedPublicKey, id string) { + if methodCallsForPubKey, ok := m.MethodCalls[pubKey]; ok { + methodCallsForPubKey.Delete(id) + } +} + +type MethodCallsForPublicKey struct { + MethodCallsForPublicKey map[string]chan<- *message.Response +} + +func NewMethodCallsForPublicKey() *MethodCallsForPublicKey { + return &MethodCallsForPublicKey{ + MethodCallsForPublicKey: make(map[string]chan<- *message.Response), + } +} + +func (m *MethodCallsForPublicKey) PutMessageResponseChannel(id string, ch chan<- *message.Response) { + m.MethodCallsForPublicKey[id] = ch +} + +func (m *MethodCallsForPublicKey) GetMessageResponseChannel(id string) (chan<- *message.Response, error) { + call, ok := m.MethodCallsForPublicKey[id] + if !ok { + return nil, fmt.Errorf("id not found: %v", id) + } + return call, nil +} + +func (m *MethodCallsForPublicKey) Delete(id string) { + delete(m.MethodCallsForPublicKey, id) +} diff --git a/internal/methods/methods_test.go b/internal/methods/methods_test.go new file mode 100644 index 0000000..80df070 --- /dev/null +++ b/internal/methods/methods_test.go @@ -0,0 +1,113 @@ +package methods_test + +import ( + "crypto/ed25519" + + "testing" + + "github.com/smartcontractkit/wsrpc/credentials" + "github.com/smartcontractkit/wsrpc/internal/message" + "github.com/smartcontractkit/wsrpc/internal/methods" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewMethodCalls(t *testing.T) { + methodCalls := methods.NewMethodCalls() + assert.NotNil(t, methodCalls.MethodCalls) +} + +func TestPutMethodCallForPublicKey(t *testing.T) { + methodCalls := methods.NewMethodCalls() + _, pubKeyBytes, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + var pubKey credentials.StaticSizedPublicKey + copy(pubKey[:], pubKeyBytes) + + id := "testID" + ch := make(chan *message.Response) + + methodCalls.PutMethodCallForPublicKey(pubKey, id, ch) + + assert.NotNil(t, methodCalls.MethodCalls[pubKey]) + assert.NotNil(t, methodCalls.MethodCalls[pubKey].MethodCallsForPublicKey[id]) +} + +func TestGetMessageResponseChannelForPublicKey(t *testing.T) { + methodCalls := methods.NewMethodCalls() + _, pubKeyBytes, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + var pubKey credentials.StaticSizedPublicKey + copy(pubKey[:], pubKeyBytes) + + id := "testID" + ch := make(chan<- *message.Response) + + methodCalls.PutMethodCallForPublicKey(pubKey, id, ch) + + retrievedCh, err := methodCalls.GetMessageResponseChannelForPublicKey(pubKey, id) + + assert.Nil(t, err) + assert.Equal(t, ch, retrievedCh) +} + +func TestDeleteMethodCall(t *testing.T) { + methodCalls := methods.NewMethodCalls() + _, pubKeyBytes, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + var pubKey credentials.StaticSizedPublicKey + copy(pubKey[:], pubKeyBytes) + id := "testID" + ch := make(chan *message.Response) + + methodCalls.PutMethodCallForPublicKey(pubKey, id, ch) + methodCalls.DeleteMethodCall(pubKey, id) + + _, err = methodCalls.GetMessageResponseChannelForPublicKey(pubKey, id) + + assert.NotNil(t, err) +} + +func TestNewMethodCallsForPublicKey(t *testing.T) { + methodCallsForPubKey := methods.NewMethodCallsForPublicKey() + assert.NotNil(t, methodCallsForPubKey.MethodCallsForPublicKey) +} + +func TestPutMessageResponseChannel(t *testing.T) { + methodCallsForPubKey := methods.NewMethodCallsForPublicKey() + id := "testID" + ch := make(chan *message.Response) + + methodCallsForPubKey.PutMessageResponseChannel(id, ch) + + assert.NotNil(t, methodCallsForPubKey.MethodCallsForPublicKey[id]) +} + +func TestGetMessageResponseChannel(t *testing.T) { + methodCallsForPubKey := methods.NewMethodCallsForPublicKey() + id := "testID" + ch := make(chan<- *message.Response) + + methodCallsForPubKey.PutMessageResponseChannel(id, ch) + + retrievedCh, err := methodCallsForPubKey.GetMessageResponseChannel(id) + + assert.Nil(t, err) + assert.Equal(t, ch, retrievedCh) +} + +func TestDelete(t *testing.T) { + methodCallsForPubKey := methods.NewMethodCallsForPublicKey() + id := "testID" + ch := make(chan *message.Response) + + methodCallsForPubKey.PutMessageResponseChannel(id, ch) + methodCallsForPubKey.Delete(id) + + _, err := methodCallsForPubKey.GetMessageResponseChannel(id) + + assert.NotNil(t, err) +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 5ef94b1..34fba22 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -4,16 +4,15 @@ import ( "context" "time" - "github.com/gorilla/websocket" - "github.com/smartcontractkit/wsrpc/credentials" "github.com/smartcontractkit/wsrpc/logger" ) const ( - // Time allowed to write a message to the connection. defaultWriteTimeout = 10 * time.Second + defaultReadLimit = int64(100_000_000) // 100 MB + // Time allowed to read the next pong message from the peer. pongWait = 20 * time.Second @@ -21,9 +20,26 @@ const ( pingPeriod = (pongWait * 9) / 10 ) +// Abstracts websocket.Conn +type WebSocketConn interface { + SetReadLimit(limit int64) + SetReadDeadline(time time.Time) error + SetPongHandler(handler func(string) error) + SetWriteDeadline(time.Time) error + ReadMessage() (messageType int, p []byte, err error) + WriteMessage(messageType int, data []byte) error + WriteControl(messageType int, data []byte, deadline time.Time) error + Close() error +} + // ConnectOptions covers all relevant options for communicating with the server. type ConnectOptions struct { + // Time allowed to write a message to the connection. WriteTimeout time.Duration + + // Size of request allowed + ReadLimit int64 + // TransportCredentials stores the Authenticator required to setup a client // connection. TransportCredentials credentials.TransportCredentials @@ -40,13 +56,16 @@ type ClientTransport interface { // Close tears down this transport. Once it returns, the transport // should not be accessed any more. - Close() error + Close() + + // Start starts this transport. + Start() } // NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(ctx context.Context, lggr logger.Logger, addr string, opts ConnectOptions, onClose func()) (ClientTransport, error) { - return newWebsocketClient(ctx, lggr, addr, opts, onClose) +func NewClientTransport(ctx context.Context, lggr logger.Logger, addr string, opts ConnectOptions, afterWritePump func()) (ClientTransport, error) { + return newWebsocketClient(ctx, lggr, addr, opts, afterWritePump) } // state of transport. @@ -63,6 +82,7 @@ const ( // ServerConfig consists of all the configurations to establish a server transport. type ServerConfig struct { + ReadLimit int64 WriteTimeout time.Duration } @@ -82,11 +102,11 @@ type ServerTransport interface { // NewServerTransport creates a ServerTransport with conn or non-nil error // if it fails. -func NewServerTransport(c *websocket.Conn, config *ServerConfig, onClose func()) (ServerTransport, error) { - return newWebsocketServer(c, config, onClose), nil +func NewServerTransport(c WebSocketConn, config *ServerConfig, afterWritePump func()) ServerTransport { + return newWebsocketServer(c, config, afterWritePump) } -func handlePong(conn *websocket.Conn) func(string) error { +func handlePong(conn WebSocketConn) func(string) error { return func(msg string) error { return conn.SetReadDeadline(time.Now().Add(pongWait)) } diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go new file mode 100644 index 0000000..63a7ff6 --- /dev/null +++ b/internal/transport/transport_test.go @@ -0,0 +1,48 @@ +package transport + +import ( + "time" +) + +type mockWebSocketConn struct { + readLimit int64 + readDeadline time.Time + pongHandler func(string) error + writeDeadline time.Time + messageType int + messageData []byte +} + +func (m *mockWebSocketConn) SetReadLimit(limit int64) { + m.readLimit = limit +} + +func (m *mockWebSocketConn) SetReadDeadline(t time.Time) error { + m.readDeadline = t + return nil +} + +func (m *mockWebSocketConn) SetPongHandler(handler func(string) error) { + m.pongHandler = handler +} + +func (m *mockWebSocketConn) SetWriteDeadline(t time.Time) error { + m.writeDeadline = t + return nil +} + +func (m *mockWebSocketConn) ReadMessage() (messageType int, p []byte, err error) { + return m.messageType, m.messageData, nil +} + +func (m *mockWebSocketConn) WriteMessage(messageType int, data []byte) error { + return nil +} + +func (m *mockWebSocketConn) WriteControl(messageType int, data []byte, deadline time.Time) error { + return nil +} + +func (m *mockWebSocketConn) Close() error { + return nil +} diff --git a/internal/transport/websocket_client.go b/internal/transport/websocket_client.go index 3469eee..75b6cee 100644 --- a/internal/transport/websocket_client.go +++ b/internal/transport/websocket_client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "sync" "time" "github.com/gorilla/websocket" @@ -18,60 +19,79 @@ type WebsocketClient struct { writeTimeout time.Duration // Underlying communication channel - conn *websocket.Conn + conn WebSocketConn // Callback function called when the transport is closed - onClose func() + afterWritePump func() + + wg sync.WaitGroup // Communication channels write chan []byte read chan []byte // A signal channel called when the reader encounters a websocket close error - done chan struct{} + closeWritePump chan struct{} // A signal channel called when the transport is closed - interrupt chan struct{} + closeConn chan struct{} log logger.Logger } // newWebsocketClient establishes the transport with the required ConnectOptions // and returns it to the caller. -func newWebsocketClient(ctx context.Context, log logger.Logger, addr string, opts ConnectOptions, onClose func()) (_ *WebsocketClient, err error) { - writeTimeout := defaultWriteTimeout - if opts.WriteTimeout != 0 { - writeTimeout = opts.WriteTimeout - } - +func newWebsocketClient(ctx context.Context, log logger.Logger, addr string, opts ConnectOptions, afterWritePump func()) (*WebsocketClient, error) { d := websocket.Dialer{ TLSClientConfig: opts.TransportCredentials.Config, HandshakeTimeout: 45 * time.Second, } url := fmt.Sprintf("wss://%s", addr) - conn, _, err := d.DialContext(ctx, url, http.Header{}) + + var conn WebSocketConn + var err error + conn, _, err = d.DialContext(ctx, url, http.Header{}) + if err != nil { return nil, fmt.Errorf("[wsrpc] error while dialing %w", err) } - c := &WebsocketClient{ - ctx: ctx, - writeTimeout: writeTimeout, - conn: conn, - onClose: onClose, - write: make(chan []byte), // Should this be buffered? - read: make(chan []byte), // Should this be buffered? - done: make(chan struct{}), - interrupt: make(chan struct{}), - log: log, - } + c := newWebsocketClientConfig(ctx, log, addr, opts, afterWritePump, conn) // Start go routines to establish the read/write channels - go c.start() + c.Start() return c, nil } +func newWebsocketClientConfig(ctx context.Context, log logger.Logger, addr string, opts ConnectOptions, afterWritePump func(), conn WebSocketConn) *WebsocketClient { + writeTimeout := defaultWriteTimeout + if opts.WriteTimeout != 0 { + writeTimeout = opts.WriteTimeout + } + + readLimit := defaultReadLimit + if opts.ReadLimit != 0 { + readLimit = opts.ReadLimit + } + + conn.SetReadLimit(readLimit) + + c := &WebsocketClient{ + ctx: ctx, + writeTimeout: writeTimeout, + conn: conn, + afterWritePump: afterWritePump, + write: make(chan []byte), + read: make(chan []byte), + closeWritePump: make(chan struct{}), + closeConn: make(chan struct{}), + log: log, + } + + return c +} + // Read returns a channel which provides the messages as they are read. func (c *WebsocketClient) Read() <-chan []byte { return c.read @@ -80,9 +100,9 @@ func (c *WebsocketClient) Read() <-chan []byte { // Write writes a message the websocket connection. func (c *WebsocketClient) Write(ctx context.Context, msg []byte) error { select { - case <-c.done: + case <-c.closeWritePump: return fmt.Errorf("[wsrpc] could not write message, websocket is closed") - case <-c.interrupt: + case <-c.closeConn: return fmt.Errorf("[wsrpc] could not write message, transport is closed") case <-ctx.Done(): return fmt.Errorf("[wsrpc] could not write message, context is done") @@ -92,20 +112,20 @@ func (c *WebsocketClient) Write(ctx context.Context, msg []byte) error { } // Close closes the websocket connection and cleans up pump goroutines. -func (c *WebsocketClient) Close() error { - close(c.interrupt) +func (c *WebsocketClient) Close() { + close(c.closeConn) - return nil + c.wg.Wait() } -// start run readPump in a goroutine and waits on writePump. -func (c WebsocketClient) start() { - defer c.onClose() - +// Start runs readPump and writePump in goroutines. +func (c *WebsocketClient) Start() { // Set up reader + c.wg.Add(1) go c.readPump() - c.writePump() + c.wg.Add(1) + go c.writePump() } // readPump pumps messages from the websocket connection. When a websocket @@ -116,7 +136,10 @@ func (c WebsocketClient) start() { // that there is at most one reader on a connection by executing all reads from // this goroutine. func (c *WebsocketClient) readPump() { - defer close(c.done) + defer func() { + close(c.closeWritePump) + c.wg.Done() + }() //nolint:errcheck c.conn.SetReadDeadline(time.Now().Add(pongWait)) @@ -124,9 +147,9 @@ func (c *WebsocketClient) readPump() { for { _, msg, err := c.conn.ReadMessage() + if err != nil { c.log.Errorw("[wsrpc] Read error", "err", err) - return } @@ -141,11 +164,15 @@ func (c *WebsocketClient) readPump() { // from this goroutine. func (c *WebsocketClient) writePump() { ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() + defer func() { + ticker.Stop() + c.afterWritePump() + c.wg.Done() + }() for { select { - case <-c.done: + case <-c.closeWritePump: // When the read detects a websocket closure, it will close the done // channel so we can exit return @@ -170,7 +197,7 @@ func (c *WebsocketClient) writePump() { return } - case <-c.interrupt: + case <-c.closeConn: // Cleanly close the connection by sending a close message and then // waiting (with timeout) for the server to close the connection. err := c.conn.WriteMessage(websocket.CloseMessage, @@ -181,7 +208,7 @@ func (c *WebsocketClient) writePump() { } c.conn.Close() select { - case <-c.done: + case <-c.closeWritePump: case <-time.After(time.Second): } diff --git a/internal/transport/websocket_client_test.go b/internal/transport/websocket_client_test.go index d11d0be..c4e0a0e 100644 --- a/internal/transport/websocket_client_test.go +++ b/internal/transport/websocket_client_test.go @@ -1 +1,55 @@ package transport + +import ( + "context" + "testing" + "time" + + "github.com/smartcontractkit/wsrpc/logger" + "github.com/stretchr/testify/assert" +) + +func TestNewWebsocketClientConfig(t *testing.T) { + ctx := context.Background() + afterWritePump := func() {} + mockLogger := logger.DefaultLogger + mockConn := &mockWebSocketConn{} + + tests := []struct { + name string + opts ConnectOptions + expectedTimeout time.Duration + expectedLimit int64 + }{ + { + name: "Default values", + opts: ConnectOptions{}, + expectedTimeout: defaultWriteTimeout, + expectedLimit: int64(defaultReadLimit), + }, + { + name: "Custom WriteTimeout", + opts: ConnectOptions{ + WriteTimeout: 5 * time.Second, + }, + expectedTimeout: 5 * time.Second, + expectedLimit: int64(defaultReadLimit), + }, + { + name: "Custom ReadLimit", + opts: ConnectOptions{ + ReadLimit: 2048, + }, + expectedTimeout: defaultWriteTimeout, + expectedLimit: 2048, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := newWebsocketClientConfig(ctx, mockLogger, "addr", tt.opts, afterWritePump, mockConn) + assert.Equal(t, tt.expectedTimeout, client.writeTimeout) + assert.Equal(t, tt.expectedLimit, mockConn.readLimit) + }) + } +} diff --git a/internal/transport/websocket_server.go b/internal/transport/websocket_server.go index 1f112bf..8f8ea08 100644 --- a/internal/transport/websocket_server.go +++ b/internal/transport/websocket_server.go @@ -17,43 +17,54 @@ type WebsocketServer struct { writeTimeout time.Duration // Underlying communication channel - conn *websocket.Conn + conn WebSocketConn // The current state of the server transport state transportState // Callback function called when the transport is closed - onClose func() + afterWritePump func() // Communication channels write chan []byte read chan []byte // A signal channel called when the reader encounters a websocket close error - done chan struct{} + closeWritePump chan struct{} // A signal channel called when the transport is closed - interrupt chan struct{} + closeConn chan struct{} } // newWebsocketServer server upgrades an HTTP connection to a websocket connection. -func newWebsocketServer(c *websocket.Conn, config *ServerConfig, onClose func()) *WebsocketServer { +func newWebsocketServer(c WebSocketConn, config *ServerConfig, afterWritePump func()) *WebsocketServer { + s := newWebsocketServerWithConfig(c, config, afterWritePump) + + go s.start() + + return s +} + +func newWebsocketServerWithConfig(c WebSocketConn, config *ServerConfig, afterWritePump func()) *WebsocketServer { writeTimeout := defaultWriteTimeout if config.WriteTimeout != 0 { writeTimeout = config.WriteTimeout } - s := &WebsocketServer{ - writeTimeout: writeTimeout, - conn: c, - onClose: onClose, - write: make(chan []byte), - read: make(chan []byte), - done: make(chan struct{}), - interrupt: make(chan struct{}), + readLimit := defaultReadLimit + if config.ReadLimit != 0 { + readLimit = config.ReadLimit } - go s.start() - + s := &WebsocketServer{ + writeTimeout: writeTimeout, + conn: c, + afterWritePump: afterWritePump, + write: make(chan []byte), + read: make(chan []byte), + closeWritePump: make(chan struct{}), + closeConn: make(chan struct{}), + } + s.conn.SetReadLimit(readLimit) return s } @@ -65,9 +76,9 @@ func (s *WebsocketServer) Read() <-chan []byte { // Write writes a message the websocket connection. func (s *WebsocketServer) Write(ctx context.Context, msg []byte) error { select { - case <-s.done: + case <-s.closeWritePump: return fmt.Errorf("[wsrpc] could not write message, websocket is closed") - case <-s.interrupt: + case <-s.closeConn: return fmt.Errorf("[wsrpc] could not write message, transport is closed") case <-ctx.Done(): return fmt.Errorf("[wsrpc] could not write message, context is done") @@ -77,7 +88,7 @@ func (s *WebsocketServer) Write(ctx context.Context, msg []byte) error { } // Close closes the websocket connection and cleans up pump goroutines. Notifies -// the caller with the onClose callback. +// the caller with the afterWritePump callback. func (s *WebsocketServer) Close() error { s.mu.Lock() // Make sure we only Close once. @@ -89,8 +100,8 @@ func (s *WebsocketServer) Close() error { s.state = closing - // Close the write channel to stop the go routine - close(s.interrupt) + // Close the connection and writePump, which causes readPump to close + close(s.closeConn) s.mu.Unlock() @@ -101,7 +112,7 @@ func (s *WebsocketServer) Close() error { func (s *WebsocketServer) start() { defer func() { s.Close() - s.onClose() + s.afterWritePump() }() // Set up reader @@ -115,7 +126,7 @@ func (s *WebsocketServer) start() { // ensures that there is at most one reader on a connection by executing all // reads from this goroutine. func (s *WebsocketServer) readPump() { - defer close(s.done) + defer close(s.closeWritePump) //nolint:errcheck s.conn.SetReadDeadline(time.Now().Add(pongWait)) @@ -127,8 +138,7 @@ func (s *WebsocketServer) readPump() { // allowing us to clean up the goroutine. if err != nil { log.Println("[wsrpc] Read error: ", err) - - break + return } s.read <- msg @@ -146,7 +156,7 @@ func (s *WebsocketServer) writePump() { for { select { - case <-s.done: + case <-s.closeWritePump: // When the read detects a websocket closure, it will close the done // channel so we can exit. return @@ -168,7 +178,7 @@ func (s *WebsocketServer) writePump() { return } - case <-s.interrupt: + case <-s.closeConn: // Cleanly close the connection by sending a close message and then // waiting (with timeout) for the server to close the connection. err := s.conn.WriteMessage(websocket.CloseMessage, @@ -179,7 +189,7 @@ func (s *WebsocketServer) writePump() { } s.conn.Close() select { - case <-s.done: + case <-s.closeWritePump: case <-time.After(time.Second): } diff --git a/internal/transport/websocket_server_test.go b/internal/transport/websocket_server_test.go index d11d0be..fd1a803 100644 --- a/internal/transport/websocket_server_test.go +++ b/internal/transport/websocket_server_test.go @@ -1 +1,51 @@ package transport + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewWebsocketServerWithConfig(t *testing.T) { + mockConn := &mockWebSocketConn{} + + tests := []struct { + name string + config *ServerConfig + wantTimeout time.Duration + wantLimit int64 + }{ + { + name: "Default values", + config: &ServerConfig{}, + wantTimeout: defaultWriteTimeout, + wantLimit: defaultReadLimit, + }, + { + name: "Custom WriteTimeout", + config: &ServerConfig{ + WriteTimeout: 2 * time.Second, + }, + wantTimeout: 2 * time.Second, + wantLimit: defaultReadLimit, + }, + { + name: "Custom ReadLimit", + config: &ServerConfig{ + ReadLimit: 2048, + }, + wantTimeout: defaultWriteTimeout, + wantLimit: 2048, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := newWebsocketServerWithConfig(mockConn, tt.config, nil) + + assert.Equal(t, tt.wantTimeout, server.writeTimeout, "Unexpected value for writeTimeout") + assert.Equal(t, tt.wantLimit, mockConn.readLimit, "Unexpected value for readLimit") + }) + } +} diff --git a/intgtest/bi_client_test.go b/intgtest/bi/bi_client_test.go similarity index 84% rename from intgtest/bi_client_test.go rename to intgtest/bi/bi_client_test.go index 7a21901..e078d35 100644 --- a/intgtest/bi_client_test.go +++ b/intgtest/bi/bi_client_test.go @@ -1,4 +1,4 @@ -package intgtest +package bi_test import ( "context" @@ -12,20 +12,21 @@ import ( "github.com/smartcontractkit/wsrpc" pb "github.com/smartcontractkit/wsrpc/intgtest/internal/rpcs" + "github.com/smartcontractkit/wsrpc/intgtest/utils" "github.com/smartcontractkit/wsrpc/peer" ) func Test_Bidirectional_ConcurrentCalls(t *testing.T) { - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) pubKeys := []ed25519.PublicKey{keypairs.Client1.PubKey} // Start the server - lis, s := setupServer(t, - wsrpc.Creds(keypairs.Server.PrivKey, pubKeys), + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), ) // Register the ping server implementation with the wsrpc server - pb.RegisterEchoServer(s, &echoServer{}) + pb.RegisterEchoServer(s, &utils.EchoServer{}) // Start serving go s.Serve(lis) @@ -33,7 +34,7 @@ func Test_Bidirectional_ConcurrentCalls(t *testing.T) { sClient := pb.NewEchoClient(s) // Start client - conn, err := setupClientConn(t, 5*time.Second, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), wsrpc.WithBlock(), ) @@ -42,7 +43,7 @@ func Test_Bidirectional_ConcurrentCalls(t *testing.T) { cClient := pb.NewEchoClient(conn) // Register the handlers on the wsrpc client - pb.RegisterEchoServer(conn, &echoServer{}) + pb.RegisterEchoServer(conn, &utils.EchoServer{}) // Make a client to server call resp, err := cClient.Echo(context.Background(), &pb.EchoRequest{ @@ -71,16 +72,16 @@ func Test_Bidirectional_ConcurrentCalls(t *testing.T) { // 2. Client makes a call back to the server in the handler // 3. Server returns the response from the client as the echo func Test_Bidirectional_MultiplexCalls(t *testing.T) { - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) pubKeys := []ed25519.PublicKey{keypairs.Client1.PubKey} // Start the server - lis, s := setupServer(t, - wsrpc.Creds(keypairs.Server.PrivKey, pubKeys), + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), ) // Register the ping server implementation with the wsrpc server - pb.RegisterEchoServer(s, &echoServer{}) + pb.RegisterEchoServer(s, &utils.EchoServer{}) // Start serving go s.Serve(lis) @@ -88,7 +89,7 @@ func Test_Bidirectional_MultiplexCalls(t *testing.T) { sClient := pb.NewEchoClient(s) // Start client - conn, err := setupClientConn(t, 5*time.Second, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), wsrpc.WithBlock(), ) diff --git a/intgtest/bi/doc.go b/intgtest/bi/doc.go new file mode 100644 index 0000000..42c7ff1 --- /dev/null +++ b/intgtest/bi/doc.go @@ -0,0 +1 @@ +package bi diff --git a/intgtest/connection_test.go b/intgtest/connection/connection_test.go similarity index 57% rename from intgtest/connection_test.go rename to intgtest/connection/connection_test.go index 107a6f1..b920e52 100644 --- a/intgtest/connection_test.go +++ b/intgtest/connection/connection_test.go @@ -1,4 +1,4 @@ -package intgtest +package connection_test import ( "context" @@ -13,14 +13,15 @@ import ( "github.com/smartcontractkit/wsrpc/connectivity" "github.com/smartcontractkit/wsrpc/credentials" pb "github.com/smartcontractkit/wsrpc/intgtest/internal/rpcs" + "github.com/smartcontractkit/wsrpc/intgtest/utils" ) func Test_ServerNotRunning(t *testing.T) { // Setup Keys - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) // Start client - conn, err := setupClientConn(t, 5*time.Second, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), ) require.NoError(t, err) @@ -36,11 +37,11 @@ func Test_ServerNotRunning(t *testing.T) { func Test_AutomatedConnectionRetry(t *testing.T) { // Setup Keys - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) pubKeys := []ed25519.PublicKey{keypairs.Client1.PubKey} // Start client - conn, err := setupClientConn(t, 1000*time.Millisecond, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), ) require.NoError(t, err) @@ -54,19 +55,19 @@ func Test_AutomatedConnectionRetry(t *testing.T) { assert.Error(t, err, "connection is not ready") // Start the server - lis, s := setupServer(t, - wsrpc.Creds(keypairs.Server.PrivKey, pubKeys), + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), ) // Register the ping server implementation with the wsrpc server - pb.RegisterEchoServer(s, &echoServer{}) + pb.RegisterEchoServer(s, &utils.EchoServer{}) // Start serving go s.Serve(lis) t.Cleanup(s.Stop) // Wait for the connection - waitForReadyConnection(t, conn) + utils.WaitForReadyConnection(t, conn) resp, err := c.Echo(context.Background(), &pb.EchoRequest{ Body: "bodyarg", @@ -78,13 +79,13 @@ func Test_AutomatedConnectionRetry(t *testing.T) { func Test_BlockingDial(t *testing.T) { // Setup Keys - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) pubKeys := []ed25519.PublicKey{keypairs.Client1.PubKey} unblocked := make(chan *wsrpc.ClientConn) go func() { - conn, err := setupClientConn(t, 5*time.Second, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), wsrpc.WithBlock(), ) @@ -95,11 +96,11 @@ func Test_BlockingDial(t *testing.T) { // Start the server in a goroutine. We wait to start up the server so we can // test the blocking mechanism. - lis, s := setupServer(t, - wsrpc.Creds(keypairs.Server.PrivKey, pubKeys), + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), ) - pb.RegisterEchoServer(s, &echoServer{}) + pb.RegisterEchoServer(s, &utils.EchoServer{}) time.Sleep(500 * time.Millisecond) go s.Serve(lis) @@ -116,10 +117,10 @@ func Test_BlockingDial(t *testing.T) { func Test_BlockingDialTimeout(t *testing.T) { // Setup Keys - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) // Start client - _, err := setupClientConn(t, 50*time.Millisecond, + _, err := utils.SetupClientConnWithOptsAndTimeout(t, 50*time.Millisecond, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), wsrpc.WithBlock(), ) @@ -128,50 +129,74 @@ func Test_BlockingDialTimeout(t *testing.T) { } func Test_InvalidCredentials(t *testing.T) { - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) pubKeys := []ed25519.PublicKey{keypairs.Client2.PubKey} // Start the server - lis, s := setupServer(t, - wsrpc.Creds(keypairs.Server.PrivKey, pubKeys), + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), ) // Register the ping server implementation with the wsrpc server - pb.RegisterEchoServer(s, &echoServer{}) + pb.RegisterEchoServer(s, &utils.EchoServer{}) // Start serving go s.Serve(lis) t.Cleanup(s.Stop) // Start client - conn, err := setupClientConn(t, 100*time.Millisecond, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), ) require.NoError(t, err) t.Cleanup(conn.Close) // Test that it fails to connect - assert.Eventually(t, func() bool { + require.Eventually(t, func() bool { return conn.GetState() == connectivity.TransientFailure }, 5*time.Second, 100*time.Millisecond) // Update the servers allowed list of public keys to include the client's - s.UpdatePublicKeys([]ed25519.PublicKey{keypairs.Client1.PubKey}) + err = s.UpdatePublicKeys(keypairs.Client1.PubKey) + require.NoError(t, err) - waitForReadyConnection(t, conn) + utils.WaitForReadyConnection(t, conn) +} + +func Test_InvalidKeyLengthCredentials(t *testing.T) { + keypairs := utils.GenerateKeys(t) + pubKeys := []ed25519.PublicKey{keypairs.Client2.PubKey} + + // Start the server + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), + ) + + // Register the ping server implementation with the wsrpc server + pb.RegisterEchoServer(s, &utils.EchoServer{}) + + // Start serving + go s.Serve(lis) + t.Cleanup(s.Stop) + + // Start client + _, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, + wsrpc.WithTransportCreds(keypairs.Client1.PrivKey[:ed25519.PublicKeySize-1], keypairs.Server.PubKey), + ) + require.Error(t, err) } func Test_GetConnectedPeerPublicKeys(t *testing.T) { - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) pubKeys := []ed25519.PublicKey{keypairs.Client1.PubKey} // Start the server - lis, s := setupServer(t, - wsrpc.Creds(keypairs.Server.PrivKey, pubKeys), + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), ) // Register the ping server implementation with the wsrpc server - pb.RegisterEchoServer(s, &echoServer{}) + pb.RegisterEchoServer(s, &utils.EchoServer{}) // Start serving go s.Serve(lis) @@ -180,13 +205,13 @@ func Test_GetConnectedPeerPublicKeys(t *testing.T) { require.Empty(t, s.GetConnectedPeerPublicKeys()) // Start client - conn, err := setupClientConn(t, 100*time.Millisecond, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), ) require.NoError(t, err) t.Cleanup(conn.Close) - waitForReadyConnection(t, conn) + utils.WaitForReadyConnection(t, conn) connectedKeys := s.GetConnectedPeerPublicKeys() require.Len(t, s.GetConnectedPeerPublicKeys(), 1) @@ -196,16 +221,16 @@ func Test_GetConnectedPeerPublicKeys(t *testing.T) { } func Test_GetNotificationChan(t *testing.T) { - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) pubKeys := []ed25519.PublicKey{keypairs.Client1.PubKey} // Start the server - lis, s := setupServer(t, - wsrpc.Creds(keypairs.Server.PrivKey, pubKeys), + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), ) // Register the ping server implementation with the wsrpc server - pb.RegisterEchoServer(s, &echoServer{}) + pb.RegisterEchoServer(s, &utils.EchoServer{}) // Start serving go s.Serve(lis) @@ -214,13 +239,13 @@ func Test_GetNotificationChan(t *testing.T) { notifyChan := s.GetConnectionNotifyChan() // Start client - conn, err := setupClientConn(t, 100*time.Millisecond, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 100*time.Millisecond, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), ) require.NoError(t, err) t.Cleanup(conn.Close) - waitForReadyConnection(t, conn) + utils.WaitForReadyConnection(t, conn) // Wait for connection notification select { @@ -230,3 +255,33 @@ func Test_GetNotificationChan(t *testing.T) { assert.Fail(t, "did not notify") } } + +func Test_ServerOpenConnections(t *testing.T) { + keypairs := utils.GenerateKeys(t) + pubKeys := []ed25519.PublicKey{keypairs.Client1.PubKey} + + // Start the server + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), + ) + + // Register the ping server implementation with the wsrpc server + pb.RegisterEchoServer(s, &utils.EchoServer{}) + + // Start serving + go s.Serve(lis) + t.Cleanup(s.Stop) + + require.Equal(t, s.OpenConnections(), 0) + + // Start client + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, + wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), + ) + require.NoError(t, err) + t.Cleanup(conn.Close) + + utils.WaitForReadyConnection(t, conn) + + require.Equal(t, s.OpenConnections(), 1) +} diff --git a/intgtest/connection/doc.go b/intgtest/connection/doc.go new file mode 100644 index 0000000..c2b0392 --- /dev/null +++ b/intgtest/connection/doc.go @@ -0,0 +1 @@ +package connection diff --git a/intgtest/uni/doc.go b/intgtest/uni/doc.go new file mode 100644 index 0000000..0f8f38d --- /dev/null +++ b/intgtest/uni/doc.go @@ -0,0 +1 @@ +package uni diff --git a/intgtest/uni_client_server_test.go b/intgtest/uni/uni_client_server_test.go similarity index 60% rename from intgtest/uni_client_server_test.go rename to intgtest/uni/uni_client_server_test.go index 34e318c..f7ec2e3 100644 --- a/intgtest/uni_client_server_test.go +++ b/intgtest/uni/uni_client_server_test.go @@ -1,4 +1,4 @@ -package intgtest +package uni_test import ( "context" @@ -11,26 +11,27 @@ import ( "github.com/smartcontractkit/wsrpc" pb "github.com/smartcontractkit/wsrpc/intgtest/internal/rpcs" + "github.com/smartcontractkit/wsrpc/intgtest/utils" ) func Test_ClientServer_SimpleCall(t *testing.T) { - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) pubKeys := []ed25519.PublicKey{keypairs.Client1.PubKey} // Start the server - lis, s := setupServer(t, - wsrpc.Creds(keypairs.Server.PrivKey, pubKeys), + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), ) // Register the ping server implementation with the wsrpc server - pb.RegisterEchoServer(s, &echoServer{}) + pb.RegisterEchoServer(s, &utils.EchoServer{}) // Start serving go s.Serve(lis) t.Cleanup(s.Stop) // Start client - conn, err := setupClientConn(t, 5*time.Second, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), ) require.NoError(t, err) @@ -39,7 +40,7 @@ func Test_ClientServer_SimpleCall(t *testing.T) { c := pb.NewEchoClient(conn) // Wait for the connection to be established - waitForReadyConnection(t, conn) + utils.WaitForReadyConnection(t, conn) resp, err := c.Echo(context.Background(), &pb.EchoRequest{ Body: "bodyarg", @@ -50,23 +51,23 @@ func Test_ClientServer_SimpleCall(t *testing.T) { } func Test_ClientServer_ConcurrentCalls(t *testing.T) { - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) pubKeys := []ed25519.PublicKey{keypairs.Client1.PubKey} // Start the server - lis, s := setupServer(t, - wsrpc.Creds(keypairs.Server.PrivKey, pubKeys), + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), ) // Register the echo server implementation with the wsrpc server - pb.RegisterEchoServer(s, &echoServer{}) + pb.RegisterEchoServer(s, &utils.EchoServer{}) // Start serving go s.Serve(lis) t.Cleanup(s.Stop) // Start client - conn, err := setupClientConn(t, 5*time.Second, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), wsrpc.WithBlock(), ) @@ -76,17 +77,22 @@ func Test_ClientServer_ConcurrentCalls(t *testing.T) { c := pb.NewEchoClient(conn) respCh := make(chan *pb.EchoResponse) - defer close(respCh) + doneCh := make(chan []*pb.EchoResponse) - reqs := []echoReq{ - {message: &pb.EchoRequest{Body: "call1", DelayMs: 500}}, - {message: &pb.EchoRequest{Body: "call2"}, timeout: 200 * time.Millisecond}, + reqs := []utils.EchoReq{ + {Message: &pb.EchoRequest{Body: "call1", DelayMs: 500}}, + {Message: &pb.EchoRequest{Body: "call2"}, Timeout: 200 * time.Millisecond}, } - processEchos(t, c, reqs, respCh) + go func() { + doneCh <- utils.WaitForResponses(t, respCh, len(reqs)) + }() - actual := waitForResponses(t, respCh, 2) + utils.ProcessEchos(t, c, reqs, respCh) + actual := <-doneCh + + assert.Equal(t, len(reqs), len(actual)) assert.Equal(t, "call2", actual[0].Body) assert.Equal(t, "call1", actual[1].Body) } diff --git a/intgtest/uni_server_client_test.go b/intgtest/uni/uni_server_client_test.go similarity index 61% rename from intgtest/uni_server_client_test.go rename to intgtest/uni/uni_server_client_test.go index 63de735..d20163a 100644 --- a/intgtest/uni_server_client_test.go +++ b/intgtest/uni/uni_server_client_test.go @@ -1,4 +1,4 @@ -package intgtest +package uni_test import ( "context" @@ -11,16 +11,17 @@ import ( "github.com/smartcontractkit/wsrpc" pb "github.com/smartcontractkit/wsrpc/intgtest/internal/rpcs" + "github.com/smartcontractkit/wsrpc/intgtest/utils" "github.com/smartcontractkit/wsrpc/peer" ) func Test_ServerClient_SimpleCall(t *testing.T) { - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) pubKeys := []ed25519.PublicKey{keypairs.Client1.PubKey} // Start the server - lis, s := setupServer(t, - wsrpc.Creds(keypairs.Server.PrivKey, pubKeys), + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), ) // Start serving @@ -30,17 +31,17 @@ func Test_ServerClient_SimpleCall(t *testing.T) { c := pb.NewEchoClient(s) // Start client - conn, err := setupClientConn(t, 5*time.Second, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 5*time.Second, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), ) require.NoError(t, err) t.Cleanup(conn.Close) // Register the handlers on the wsrpc client - pb.RegisterEchoServer(conn, &echoServer{}) + pb.RegisterEchoServer(conn, &utils.EchoServer{}) // Wait for the connection to be established - waitForReadyConnection(t, conn) + utils.WaitForReadyConnection(t, conn) ctx := peer.NewCallContext(context.Background(), keypairs.Client1.StaticallySizedPublicKey(t)) ctx, cancel := context.WithTimeout(ctx, 2*time.Second) @@ -53,16 +54,16 @@ func Test_ServerClient_SimpleCall(t *testing.T) { } func Test_ServerClient_ConcurrentCalls(t *testing.T) { - keypairs := generateKeys(t) + keypairs := utils.GenerateKeys(t) pubKeys := []ed25519.PublicKey{keypairs.Client1.PubKey} - // Start the server - lis, s := setupServer(t, - wsrpc.Creds(keypairs.Server.PrivKey, pubKeys), + // Start the serverTest_ServerClient_ConcurrentCalls + lis, s := utils.SetupServer(t, + wsrpc.WithCreds(keypairs.Server.PrivKey, pubKeys), ) // Register the ping server implementation with the wsrpc server - pb.RegisterEchoServer(s, &echoServer{}) + pb.RegisterEchoServer(s, &utils.EchoServer{}) // Create an RPC client for the server c := pb.NewEchoClient(s) @@ -71,7 +72,7 @@ func Test_ServerClient_ConcurrentCalls(t *testing.T) { t.Cleanup(s.Stop) // Start client - conn, err := setupClientConn(t, 5*time.Second, + conn, err := utils.SetupClientConnWithOptsAndTimeout(t, 500*time.Second, wsrpc.WithTransportCreds(keypairs.Client1.PrivKey, keypairs.Server.PubKey), wsrpc.WithBlock(), ) @@ -79,21 +80,26 @@ func Test_ServerClient_ConcurrentCalls(t *testing.T) { t.Cleanup(conn.Close) // Register the handlers on the wsrpc client - pb.RegisterEchoServer(conn, &echoServer{}) + pb.RegisterEchoServer(conn, &utils.EchoServer{}) respCh := make(chan *pb.EchoResponse) - defer close(respCh) + doneCh := make(chan []*pb.EchoResponse) pk := keypairs.Client1.StaticallySizedPublicKey(t) - reqs := []echoReq{ - {message: &pb.EchoRequest{Body: "call1", DelayMs: 500}, pubKey: &pk}, - {message: &pb.EchoRequest{Body: "call2"}, timeout: 200 * time.Millisecond, pubKey: &pk}, + reqs := []utils.EchoReq{ + {Message: &pb.EchoRequest{Body: "call1", DelayMs: 500}, PubKey: &pk}, + {Message: &pb.EchoRequest{Body: "call2"}, Timeout: 2000 * time.Millisecond, PubKey: &pk}, } - processEchos(t, c, reqs, respCh) + go func() { + doneCh <- utils.WaitForResponses(t, respCh, len(reqs)) + }() + + utils.ProcessEchos(t, c, reqs, respCh) - actual := waitForResponses(t, respCh, 2) + actual := <-doneCh + assert.Equal(t, len(reqs), len(actual)) assert.Equal(t, "call2", actual[0].Body) assert.Equal(t, "call1", actual[1].Body) } diff --git a/intgtest/utils_test.go b/intgtest/utils/testutils.go similarity index 64% rename from intgtest/utils_test.go rename to intgtest/utils/testutils.go index 5cfefb6..608800c 100644 --- a/intgtest/utils_test.go +++ b/intgtest/utils/testutils.go @@ -1,4 +1,4 @@ -package intgtest +package utils import ( "context" @@ -14,16 +14,17 @@ import ( "github.com/smartcontractkit/wsrpc/connectivity" "github.com/smartcontractkit/wsrpc/credentials" pb "github.com/smartcontractkit/wsrpc/intgtest/internal/rpcs" + "github.com/smartcontractkit/wsrpc/logger" "github.com/smartcontractkit/wsrpc/peer" ) const targetURI = "127.0.0.1:1338" // Implements the ping server RPC call handlers -type echoServer struct{} +type EchoServer struct{} // Echo echoes the request back to the client -func (s *echoServer) Echo(ctx context.Context, req *pb.EchoRequest) (*pb.EchoResponse, error) { +func (s *EchoServer) Echo(ctx context.Context, req *pb.EchoRequest) (*pb.EchoResponse, error) { if req.DelayMs > 0 { time.Sleep(time.Duration(req.DelayMs) * time.Millisecond) } @@ -53,8 +54,8 @@ type keys struct { Client2 keypair } -// generateKeys generates keypairs for the server and clients. -func generateKeys(t *testing.T) keys { +// GenerateKeys generates keypairs for the server and clients. +func GenerateKeys(t *testing.T) keys { t.Helper() // Setup Keys @@ -74,19 +75,20 @@ func generateKeys(t *testing.T) keys { } } -// setupClientConn is a convenience method to setup a client connection for most +// SetupClientConnWithOptsAndTimeout is a convenience method to setup a client connection for most // testing usecases. -func setupClientConn(t *testing.T, timeout time.Duration, opts ...wsrpc.DialOption) (*wsrpc.ClientConn, error) { - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, timeout) +func SetupClientConnWithOptsAndTimeout(t *testing.T, timeout time.Duration, opts ...wsrpc.DialOption) (*wsrpc.ClientConn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) t.Cleanup(cancel) + opts = append(opts, wsrpc.WithLogger(logger.Test(t))) + return wsrpc.DialWithContext(ctx, targetURI, opts...) } -// setupServer is a convenience method to set up a server for most testing +// SetupServer is a convenience method to set up a server for most testing // usecases. -func setupServer(t *testing.T, opts ...wsrpc.ServerOption) (net.Listener, *wsrpc.Server) { +func SetupServer(t *testing.T, opts ...wsrpc.ServerOption) (net.Listener, *wsrpc.Server) { // Attempt to reconnect to the port which the OS may not have had time // to clean up between tests. var ( @@ -103,81 +105,76 @@ func setupServer(t *testing.T, opts ...wsrpc.ServerOption) (net.Listener, *wsrpc return lis, wsrpc.NewServer(opts...) } -type echoReq struct { - // Sets the timeout on the request context. Defaults to no timeout - timeout time.Duration +type EchoReq struct { + // Sets the Timeout on the request context. Defaults to no Timeout + Timeout time.Duration // Insert the client connection's public key into the context. This is // required for server to client calls, but optional for client to server // calls - pubKey *credentials.StaticSizedPublicKey - // The message that will be sent in the request - message *pb.EchoRequest + PubKey *credentials.StaticSizedPublicKey + // The Message that will be sent in the request + Message *pb.EchoRequest } -func processEchos(t *testing.T, +func ProcessEchos(t *testing.T, c pb.EchoClient, - reqs []echoReq, + reqs []EchoReq, ch chan<- *pb.EchoResponse, ) { - t.Helper() - wg := sync.WaitGroup{} for _, req := range reqs { wg.Add(1) - go func(req echoReq) { - wg.Done() + go func(req EchoReq) { + defer wg.Done() ctx := context.Background() - if req.timeout > 0 { - tctx, cancel := context.WithTimeout(context.Background(), req.timeout) + if req.Timeout > 0 { + tctx, cancel := context.WithTimeout(context.Background(), req.Timeout) defer cancel() ctx = tctx } - if req.pubKey != nil { - ctx = peer.NewCallContext(ctx, *req.pubKey) + if req.PubKey != nil { + ctx = peer.NewCallContext(ctx, *req.PubKey) } - resp, err := c.Echo(ctx, req.message) + resp, err := c.Echo(ctx, req.Message) require.NoError(t, err) ch <- resp }(req) - - wg.Wait() } + wg.Wait() + close(ch) } -func waitForResponses(t *testing.T, ch <-chan *pb.EchoResponse, limit int) []*pb.EchoResponse { - // Stores the calls in the order they were received. Call 1 should arrive second - // because of the delayed response. +func WaitForResponses(t *testing.T, ch <-chan *pb.EchoResponse, limit int) []*pb.EchoResponse { + // Stores the calls in the order they were received. + resps := []*pb.EchoResponse{} i := 0 -loop: - for { - if i == limit { - break - } + timer := time.After(5 * time.Second) +loop: + for i < limit { select { case resp := <-ch: resps = append(resps, resp) - case <-time.After(3 * time.Second): + i++ + case <-timer: break loop } - - i++ } - require.Len(t, resps, 2) - return resps } -func waitForReadyConnection(t *testing.T, conn *wsrpc.ClientConn) { +func WaitForReadyConnection(t *testing.T, conn *wsrpc.ClientConn) { t.Helper() + require.Equal(t, false, conn.GetState() == connectivity.Shutdown) + require.Eventually(t, func() bool { return conn.GetState() == connectivity.Ready }, 5*time.Second, 100*time.Millisecond) diff --git a/server.go b/server.go index 62834fe..7087af2 100644 --- a/server.go +++ b/server.go @@ -16,6 +16,7 @@ import ( "github.com/smartcontractkit/wsrpc/credentials" "github.com/smartcontractkit/wsrpc/internal/message" + "github.com/smartcontractkit/wsrpc/internal/methods" "github.com/smartcontractkit/wsrpc/internal/transport" "github.com/smartcontractkit/wsrpc/internal/wsrpcsync" "github.com/smartcontractkit/wsrpc/peer" @@ -41,7 +42,7 @@ type Server struct { // Contains all pending method call ids and the channel to respond to when // a result is received - methodCalls map[string]chan<- *message.Response + methodCalls *methods.MethodCalls // Signals a quit event when the server wants to quit quit *wsrpcsync.Event @@ -65,7 +66,7 @@ func NewServer(opt ...ServerOption) *Server { WriteBufferSize: opts.writeBufferSize, }, connMgr: newConnectionsManager(), - methodCalls: map[string]chan<- *message.Response{}, + methodCalls: methods.NewMethodCalls(), quit: wsrpcsync.NewEvent(), done: wsrpcsync.NewEvent(), serveWG: sync.WaitGroup{}, @@ -89,7 +90,8 @@ func (s *Server) Serve(lis net.Listener) { }) hcsrv := &http.Server{ - Handler: hchandler, + Handler: hchandler, + ReadTimeout: s.opts.healthcheckTimeout, } //nolint:errcheck @@ -101,8 +103,9 @@ func (s *Server) Serve(lis net.Listener) { wshandler := http.NewServeMux() wshandler.HandleFunc("/", s.wshandler) wssrv := &http.Server{ - TLSConfig: s.opts.creds.Config, - Handler: wshandler, + TLSConfig: s.opts.creds.Config, + Handler: wshandler, + ReadTimeout: s.opts.wsTimeout, } //nolint:errcheck @@ -114,6 +117,12 @@ func (s *Server) Serve(lis net.Listener) { <-s.done.Done() } +func (s *Server) OpenConnections() int { + s.connMgr.mu.Lock() + defer s.connMgr.mu.Unlock() + return len(s.connMgr.conns) +} + // wshandler upgrades the HTTP connection to a websocket connection and // registers the connection's pub key for the client. func (s *Server) wshandler(w http.ResponseWriter, r *http.Request) { @@ -138,7 +147,10 @@ func (s *Server) wshandler(w http.ResponseWriter, r *http.Request) { done := make(chan struct{}) config := &transport.ServerConfig{} - onClose := func() { + config.ReadLimit = s.opts.wsReadLimit + config.WriteTimeout = s.opts.wsTimeout + + afterWritePump := func() { // There is no connection manager when we are shutting down, so // we can ignore removing the connection. s.mu.RLock() @@ -152,19 +164,21 @@ func (s *Server) wshandler(w http.ResponseWriter, r *http.Request) { close(done) } - // Initialize the transport - tr, err := transport.NewServerTransport(conn, config, onClose) - if err != nil { + s.mu.RLock() + if nil == s.connMgr { + s.mu.RUnlock() return } + s.serveWG.Add(1) + + // Initialize the transport + tr := transport.NewServerTransport(conn, config, afterWritePump) + // Register the transport against the public key - s.mu.RLock() s.connMgr.registerConnection(pubKey, tr) s.mu.RUnlock() - s.serveWG.Add(1) - // Start the reader handler go s.handleRead(pubKey, done) @@ -193,6 +207,10 @@ func (s *Server) sendMsg(ctx context.Context, pub [32]byte, msg []byte) error { // readFn handler. func (s *Server) handleRead(pubKey credentials.StaticSizedPublicKey, done <-chan struct{}) { s.mu.RLock() + if nil == s.connMgr { + s.mu.RUnlock() + return + } tr, err := s.connMgr.getTransport(pubKey) s.mu.RUnlock() if err != nil { @@ -213,7 +231,7 @@ func (s *Server) handleRead(pubKey credentials.StaticSizedPublicKey, done <-chan case *message.Message_Request: go s.handleMessageRequest(pubKey, ex.Request) case *message.Message_Response: - go s.handleMessageResponse(ex.Response) + go s.handleMessageResponse(pubKey, ex.Response) default: log.Println("Invalid message type") } @@ -227,45 +245,66 @@ func (s *Server) handleRead(pubKey credentials.StaticSizedPublicKey, done <-chan // the handler. The connection client's public key is injected into the context, // so the handler is able to identify the caller. func (s *Server) handleMessageRequest(pubKey credentials.StaticSizedPublicKey, r *message.Request) { + if err := s.validateMessageRequest(r); err != nil { + log.Printf("error validating request: %s", err) + return + } + methodName := r.GetMethod() - if md, ok := s.service.methods[methodName]; ok { - // Create a decoder function to unmarshal the message - dec := func(v interface{}) error { - return UnmarshalProtoMessage(r.GetPayload(), v) - } + md := s.service.methods[methodName] + // Create a decoder function to unmarshal the message + dec := func(v interface{}) error { + return UnmarshalProtoMessage(r.GetPayload(), v) + } - // Inject the peer's public key into the context so the handler can use it - ctx := peer.NewContext(context.Background(), &peer.Peer{PublicKey: pubKey}) - v, herr := md.Handler(s.service.serviceImpl, ctx, dec) + // Inject the peer's public key into the context so the handler can use it + ctx := peer.NewContext(context.Background(), &peer.Peer{PublicKey: pubKey}) + v, herr := md.Handler(s.service.serviceImpl, ctx, dec) - msg, err := message.NewResponse(r.GetCallId(), v, herr) - if err != nil { - return - } + msg, err := message.NewResponse(r.GetCallId(), v, herr) + if err != nil { + return + } - replyMsg, err := MarshalProtoMessage(msg) - if err != nil { - return - } + replyMsg, err := MarshalProtoMessage(msg) + if err != nil { + return + } - if err := s.sendMsg(ctx, pubKey, replyMsg); err != nil { - log.Printf("error sending message: %s", err) - } + if err := s.sendMsg(ctx, pubKey, replyMsg); err != nil { + log.Printf("error sending message: %s", err) } } // handleMessageResponse finds the call which matches the method call id of the // response and sends the payload to the call channel. -func (s *Server) handleMessageResponse(r *message.Response) { +func (s *Server) handleMessageResponse(pubKey credentials.StaticSizedPublicKey, r *message.Response) { s.mu.Lock() defer s.mu.Unlock() callID := r.GetCallId() - if call, ok := s.methodCalls[callID]; ok { - call <- r + call, err := s.methodCalls.GetMessageResponseChannelForPublicKey(pubKey, callID) + if err != nil { + log.Printf("error handling message response: %s", err) + return + } - s.removeMethodCall(callID) // Delete the call now that we have completed the request/response cycle + call <- r + s.removeMethodCall(pubKey, callID) // Delete the call now that we have completed the request/response cycle +} + +func (s *Server) validateMessageRequest(r *message.Request) error { + methodName := r.GetMethod() + if _, ok := s.service.methods[methodName]; !ok { + return fmt.Errorf("unrecognized method: %v", methodName) } + + callId := r.GetCallId() + if id, err := uuid.Parse(callId); err != nil || id.Version() != 4 { + return fmt.Errorf("invalid callId %s: %w", callId, err) + } + + return nil } // RegisterService registers a service and its implementation to the wsrpc @@ -303,10 +342,6 @@ func (s *Server) Invoke(ctx context.Context, method string, args interface{}, re return err } - s.mu.Lock() - wait := s.registerMethodCall(callID) - s.mu.Unlock() - // Extract the public key from context p, ok := peer.FromContext(ctx) if !ok { @@ -314,6 +349,10 @@ func (s *Server) Invoke(ctx context.Context, method string, args interface{}, re } pubKey := p.PublicKey + s.mu.Lock() + wait := s.registerMethodCall(pubKey, callID) + s.mu.Unlock() + if err = s.sendMsg(ctx, pubKey, req); err != nil { return err } @@ -333,7 +372,7 @@ func (s *Server) Invoke(ctx context.Context, method string, args interface{}, re case <-ctx.Done(): // Remove the call since we have timeout s.mu.Lock() - s.removeMethodCall(callID) + s.removeMethodCall(pubKey, callID) s.mu.Unlock() return fmt.Errorf("call timeout: %w", ctx.Err()) @@ -344,12 +383,18 @@ func (s *Server) Invoke(ctx context.Context, method string, args interface{}, re // UpdatePublicKeys updates the list of allowable public keys in the TLS config // and drops the connections which match the deleted keys. -func (s *Server) UpdatePublicKeys(pubKeys []ed25519.PublicKey) { +func (s *Server) UpdatePublicKeys(ed25519PubKeys ...ed25519.PublicKey) error { s.mu.Lock() defer s.mu.Unlock() + pubKeys, err := credentials.ValidPublicKeysFromEd25519(ed25519PubKeys...) + if err != nil { + return fmt.Errorf("invalid public keys: %s", err) + } + s.opts.creds.PublicKeys.Replace(pubKeys) //credentials.NewPublicKeys(pubKeys...) s.removeConnectionsToDeletedKeys(pubKeys) + return nil } // GetConnectionNotifyChan gets the connection notification channel. @@ -394,10 +439,10 @@ func (s *Server) Stop() { // When the list of allowable certs are updated, we need to refresh the existing // connections as well and shutdown any client connections no longer allowed. -func (s *Server) removeConnectionsToDeletedKeys(pubKeys []ed25519.PublicKey) { +func (s *Server) removeConnectionsToDeletedKeys(pubKeys *credentials.PublicKeys) { pubKeysMap := make(map[credentials.StaticSizedPublicKey]bool) - for _, pk := range pubKeys { + for _, pk := range pubKeys.Keys() { pubKey, err := credentials.ToStaticallySizedPublicKey(pk) if err != nil { log.Print("[Server] error reading keys while removing connections: ", err) @@ -438,9 +483,9 @@ func (s *Server) ensureSingleClientConnection(cert *x509.Certificate) ([ed25519. // registerMethodCall registers a method call to the method call map. // // This requires a lock on cc.mu. -func (s *Server) registerMethodCall(id string) <-chan *message.Response { +func (s *Server) registerMethodCall(pubKey credentials.StaticSizedPublicKey, id string) <-chan *message.Response { wait := make(chan *message.Response) - s.methodCalls[id] = wait + s.methodCalls.PutMethodCallForPublicKey(pubKey, id, wait) return wait } @@ -448,8 +493,8 @@ func (s *Server) registerMethodCall(id string) <-chan *message.Response { // removeMethodCall deregisters a method call to the method call map. // // This requires a lock on cc.mu. -func (s *Server) removeMethodCall(id string) { - delete(s.methodCalls, id) +func (s *Server) removeMethodCall(pubKey credentials.StaticSizedPublicKey, id string) { + s.methodCalls.DeleteMethodCall(pubKey, id) } // connectionsManager manages the active clients connections. diff --git a/server_test.go b/server_test.go index 1394507..c2cde43 100644 --- a/server_test.go +++ b/server_test.go @@ -4,16 +4,20 @@ import ( "crypto/ed25519" "net" "net/http" + "strings" "testing" "time" + "github.com/google/uuid" "github.com/smartcontractkit/wsrpc/examples/simple/keys" + "github.com/smartcontractkit/wsrpc/internal/message" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_Server_UpdatePublicKeys(t *testing.T) { + _, sPrivKey, err := ed25519.GenerateKey(nil) require.NoError(t, err) @@ -21,34 +25,48 @@ func Test_Server_UpdatePublicKeys(t *testing.T) { require.NoError(t, err) s := NewServer( - Creds(sPrivKey, []ed25519.PublicKey{c1PubKey}), + WithCreds(sPrivKey, []ed25519.PublicKey{c1PubKey}), ) - assert.Equal(t, []ed25519.PublicKey{c1PubKey}, s.opts.creds.PublicKeys.Keys()) + require.Equal(t, []ed25519.PublicKey{c1PubKey}, s.opts.creds.PublicKeys.Keys()) + t.Run("valid_update", func(t *testing.T) { + c2PubKey, _, err := ed25519.GenerateKey(nil) + require.NoError(t, err) - c2PubKey, _, err := ed25519.GenerateKey(nil) - require.NoError(t, err) + err = s.UpdatePublicKeys(c2PubKey) + require.NoError(t, err) + + assert.Equal(t, []ed25519.PublicKey{c2PubKey}, s.opts.creds.PublicKeys.Keys()) + }) + + t.Run("keys_not_32", func(t *testing.T) { + shortKey := make([]byte, ed25519.PublicKeySize-1) + longKey := make([]byte, ed25519.PublicKeySize+1) - s.UpdatePublicKeys([]ed25519.PublicKey{c2PubKey}) + err := s.UpdatePublicKeys(shortKey) + require.Error(t, err) - assert.Equal(t, []ed25519.PublicKey{c2PubKey}, s.opts.creds.PublicKeys.Keys()) + err = s.UpdatePublicKeys(longKey) + require.Error(t, err) + + }) } func Test_Healthcheck(t *testing.T) { // Start the server - privKey := keys.FromHex("c1afd224cec2ff6066746bf9b7cdf7f9f4694ab7ef2ca1692ff923a30df203483b0f149627adb7b6fafe1497a9dfc357f22295a5440786c3bc566dfdb0176808") + privKey := keys.FromHex(keys.ServerPrivKey) pubKeys := []ed25519.PublicKey{} lis, err := net.Listen("tcp", "127.0.0.1:1338") require.NoError(t, err) s := NewServer( - Creds(privKey, pubKeys), + WithCreds(privKey, pubKeys), WithHealthcheck("127.0.0.1:1337"), ) // Start serving go s.Serve(lis) - defer s.Stop() + t.Cleanup(s.Stop) // Test until the server boots assert.Eventually(t, func() bool { @@ -59,6 +77,115 @@ func Test_Healthcheck(t *testing.T) { } return assert.Equal(t, http.StatusOK, resp.StatusCode) - }, 1*time.Second, 100*time.Millisecond) + }, 5*time.Second, 100*time.Millisecond) + +} + +func Test_Server_HTTPTimeout_Defaults(t *testing.T) { + // Start the server + privKey := keys.FromHex(keys.ServerPrivKey) + pubKeys := []ed25519.PublicKey{} + + defaultServer := NewServer( + WithCreds(privKey, pubKeys), + WithHealthcheck("127.0.0.1:1337"), + ) + + assert.Equal(t, 5*time.Second, defaultServer.opts.healthcheckTimeout) + assert.Equal(t, 10*time.Second, defaultServer.opts.wsTimeout) + + expectedTimeout := 1 * time.Nanosecond + timeoutServer := NewServer( + WithCreds(privKey, pubKeys), + WithHealthcheck("127.0.0.1:1337"), + WithHTTPReadTimeout(expectedTimeout, expectedTimeout*2), + ) + assert.Equal(t, expectedTimeout, timeoutServer.opts.healthcheckTimeout) + assert.Equal(t, expectedTimeout*2, timeoutServer.opts.wsTimeout) +} + +func Test_Server_HTTPTimeout(t *testing.T) { + // Start the server + privKey := keys.FromHex(keys.ServerPrivKey) + pubKeys := []ed25519.PublicKey{} + + lis, err := net.Listen("tcp", "127.0.0.1:1339") + require.NoError(t, err) + + expectedTimeout := 1 * time.Nanosecond + s := NewServer( + WithCreds(privKey, pubKeys), + WithHealthcheck("127.0.0.1:1336"), + WithHTTPReadTimeout(expectedTimeout, expectedTimeout*2), + ) + + // Start serving + go s.Serve(lis) + t.Cleanup(s.Stop) + + // Test until the server boots + assert.Eventually(t, func() bool { + // Run a http call + _, err := http.Get("http://127.0.0.1:1336/healthz") + if err != nil { + return strings.Contains(err.Error(), "EOF") // Check if the error contains "timeout" + } + + return false + }, 5*time.Second, 100*time.Millisecond) +} + +func Test_Server_ValidateMessageRequest(t *testing.T) { + s := &Server{ + service: &serviceInfo{ + methods: map[string]*MethodDesc{ + "TestMethod": {}, + }, + }, + } + + t.Run("valid_request", func(t *testing.T) { + req := &message.Request{ + Method: "TestMethod", + CallId: uuid.New().String(), + Payload: []byte("test payload"), + } + + err := s.validateMessageRequest(req) + require.NoError(t, err) + }) + + t.Run("invalid_method", func(t *testing.T) { + req := &message.Request{ + Method: "InvalidMethod", + CallId: uuid.New().String(), + Payload: []byte("test payload"), + } + + err := s.validateMessageRequest(req) + require.Error(t, err) + }) + + t.Run("invalid_call_id", func(t *testing.T) { + req := &message.Request{ + Method: "TestMethod", + CallId: "invalid uuid", + Payload: []byte("test payload"), + } + + err := s.validateMessageRequest(req) + require.Error(t, err) + }) +} + +func Test_Server_OpenConnections(t *testing.T) { + privKey := keys.FromHex(keys.ServerPrivKey) + pubKeys := []ed25519.PublicKey{} + + s := NewServer( + WithCreds(privKey, pubKeys), + WithHealthcheck("127.0.0.1:1337"), + ) + assert.Equal(t, 0, s.OpenConnections()) } diff --git a/serveroptions.go b/serveroptions.go index 6618913..75a9de9 100644 --- a/serveroptions.go +++ b/serveroptions.go @@ -2,6 +2,7 @@ package wsrpc import ( "crypto/ed25519" + "time" "github.com/smartcontractkit/wsrpc/credentials" ) @@ -21,6 +22,15 @@ type serverOptions struct { // The address that the healthcheck will run on healthcheckAddr string + + // The HTTP ReadTimeout the healthcheck will use. Set to 0 for no timeout + healthcheckTimeout time.Duration + + // The HTTP ReadTimeout the ws server will use. Set to 0 for no timeout + wsTimeout time.Duration + + // The request size limit the ws server will use in bytes. Defaults to 10MB. + wsReadLimit int64 } // funcServerOption wraps a function that modifies serverOptions into an @@ -39,10 +49,32 @@ func (fdo *funcServerOption) apply(do *serverOptions) { fdo.f(do) } -// Creds returns a ServerOption that sets credentials for server connections. -func Creds(privKey ed25519.PrivateKey, pubKeys []ed25519.PublicKey) ServerOption { +// returns a ServerOption that sets the healthcheck HTTP read timeout and the server HTTP read timeout +func WithHTTPReadTimeout(hctime time.Duration, wstime time.Duration) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.healthcheckTimeout = hctime + o.wsTimeout = wstime + }) +} + +func WithWSReadLimit(numBytes int64) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.wsReadLimit = numBytes + }) +} + +// WithCreds returns a ServerOption that sets credentials for server connections. +func WithCreds(privKey ed25519.PrivateKey, pubKeys []ed25519.PublicKey) ServerOption { return newFuncServerOption(func(o *serverOptions) { - pubs := credentials.NewPublicKeys(pubKeys...) + privKey, err := credentials.ValidPrivateKeyFromEd25519(privKey) + if err != nil { + return + } + + pubs, err := credentials.ValidPublicKeysFromEd25519(pubKeys...) + if err != nil { + return + } config, err := credentials.NewServerTLSConfig(privKey, pubs) if err != nil { @@ -70,8 +102,11 @@ func ReadBufferSize(s int) ServerOption { } var defaultServerOptions = serverOptions{ - writeBufferSize: 4096, - readBufferSize: 4096, + writeBufferSize: 4096, + readBufferSize: 4096, + healthcheckTimeout: 5 * time.Second, + wsTimeout: 10 * time.Second, + wsReadLimit: int64(10_000_000), } // WithHealthcheck specifies whether to run a healthcheck endpoint. If a url diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..ff2d4bc --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,12 @@ +# projectKey is required (may be found under "Project Information" in Sonar or in project url) +sonar.projectKey=smartcontractkit_wsrpc +sonar.sources=. + +# Full exclusions from the static analysis +sonar.exclusions=**/protoc-gen-go-wsrpc/protoc-gen-go-wsrpc, **/mocks/**/*, **/testdata/**/*, **/script/**/*, **/generated/**/*, **/fixtures/**/*, **/docs/**/*, **/tools/**/*, **/*.pb.go, **/*report.xml, **/*.txt, **/*.abi, **/*.bin +# Coverage exclusions +sonar.coverage.exclusions=**/*_test.go, **/utils/**/*, **/examples/**/*, **/intgtest/**/* + +# Tests' root folder, inclusions (tests to check and count) and exclusions +sonar.tests=. +sonar.test.inclusions=**/*_test.go \ No newline at end of file diff --git a/uni_client.go b/uni_client.go index 00f700e..53e0f9f 100644 --- a/uni_client.go +++ b/uni_client.go @@ -47,18 +47,29 @@ type UniClientConn struct { } // DialUniWithContext will blocks until connection is established or context expires. -func DialUniWithContext(ctx context.Context, lggr Logger, target string, privKey ed25519.PrivateKey, serverPubKey ed25519.PublicKey) (*UniClientConn, error) { - pubs := credentials.NewPublicKeys(serverPubKey) +func DialUniWithContext(ctx context.Context, lggr Logger, target string, ed25519PrivKey ed25519.PrivateKey, serverPubKey ed25519.PublicKey) (*UniClientConn, error) { + privKey, err := credentials.ValidPrivateKeyFromEd25519(ed25519PrivKey) + if err != nil { + return nil, err + } + + pubs, err := credentials.ValidPublicKeysFromEd25519(serverPubKey) + if err != nil { + return nil, err + } + tlsConfig, err := credentials.NewClientTLSConfig(privKey, pubs) if err != nil { return nil, err } + conn, err := retryConnectWithBackoff(ctx, lggr, func(ctx2 context.Context) (Conn, error) { return connect(ctx2, target, tlsConfig) }) if err != nil { return nil, err } + return &UniClientConn{conn: conn, tlsConfig: tlsConfig, target: target, lggr: lggr, connector: connect}, nil }