diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml new file mode 100644 index 0000000..2def876 --- /dev/null +++ b/.github/workflows/pull_request.yml @@ -0,0 +1,33 @@ +name: Pull Request Workflow +on: + pull_request: + workflow_dispatch: + push: + branches: + - master +jobs: + checks: + name: Workspace Checks + runs-on: ubuntu-latest + steps: + + - name: Checkout + uses: actions/checkout@v2 + + - uses: actions/setup-go@v2 + with: + go-version: '1.20' + + - name: Lint + run: if [ "$(gofmt -s -l . | wc -l)" -gt 0 ]; then exit 1; fi + + - name: Tidy + run: | + go mod tidy + if [[ -n $(git status -s) ]]; then exit 1; fi + + - name: Vet + run: go vet ./... + + - name: Test + run: go test -race ./... diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..4f6cdd8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Spire Technology LLC + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..b8690b8 --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# go-keymutex + +Go library for keyed mutexes. + +## Installation + +```bash +go get github.com/spiretechnology/go-keymutex +``` diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..63c6694 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/spiretechnology/go-keymutex + +go 1.20 + +require github.com/stretchr/testify v1.8.4 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..fa4b6e6 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/keymutex.go b/keymutex.go new file mode 100644 index 0000000..34cd2b1 --- /dev/null +++ b/keymutex.go @@ -0,0 +1,72 @@ +package keymutex + +import "sync" + +// KeyMutex is a mutex that can be locked and unlocked on arbitrary keys. +type KeyMutex[T comparable] struct { + mut sync.Mutex + m map[T]*sync.Mutex + refCounts map[T]int +} + +// Lock locks the mutex for the given key. +func (km *KeyMutex[T]) Lock(key T) { + km.lockWithWaiting(key, nil) +} + +// Unlock unlocks the mutex for the given key. +func (km *KeyMutex[T]) Unlock(key T) { + // Acquire the map lock + km.mut.Lock() + defer km.mut.Unlock() + + // Ensure Unlock is not called more times than Lock + if km.refCounts[key] <= 0 { + return + } + + // Get the mutex for the key + mut := km.m[key] + + // Decrement the counter for the key + km.refCounts[key]-- + + // If the counter is zero, delete the mutex + if km.refCounts[key] == 0 { + delete(km.m, key) + delete(km.refCounts, key) + } + + // Unlock the mutex + mut.Unlock() +} + +func (km *KeyMutex[T]) lockWithWaiting(key T, chanCallback chan<- struct{}) { + // Acquire the map lock + km.mut.Lock() + + // Ensure the map exists + if km.m == nil { + km.m = map[T]*sync.Mutex{} + } + if km.refCounts == nil { + km.refCounts = map[T]int{} + } + + // Get the mutex for the key. Create it if it doesn't exist + mut, ok := km.m[key] + if !ok { + mut = &sync.Mutex{} + km.m[key] = mut + } + + // Increment the counter for the key + km.refCounts[key]++ + + // Lock the mutex + if chanCallback != nil { + chanCallback <- struct{}{} + } + km.mut.Unlock() + mut.Lock() +} diff --git a/keymutex_test.go b/keymutex_test.go new file mode 100644 index 0000000..51bf2e4 --- /dev/null +++ b/keymutex_test.go @@ -0,0 +1,103 @@ +package keymutex + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestKeyMutex(t *testing.T) { + var km KeyMutex[int] + var wg sync.WaitGroup + + var sequence1, sequence2 []string + key1 := 1 + key2 := 2 + + km.Lock(1) + + // In the background, queue a sequence of events + wg.Add(2) + go func() { + defer wg.Done() + km.Lock(key1) + require.Equal(t, 1, km.refCounts[key1], "refCounts[key1] should be 1") + defer km.Unlock(key1) + go func() { + defer wg.Done() + km.Lock(key1) + defer km.Unlock(key1) + sequence1 = append(sequence1, "C") + }() + sequence1 = append(sequence1, "B") + km.Unlock(key1) + }() + + // This should not deadlock, even though key1 is already locked + km.Lock(key2) + require.Equal(t, 1, km.refCounts[key2], "refCounts[key2] should be 1") + sequence2 = append(sequence2, "A") + km.Unlock(key2) + key2RefCount, key2RefCountOk := km.refCounts[key2] + require.Equal(t, 0, key2RefCount, "refCounts[key2] should be 0") + require.Equal(t, false, key2RefCountOk, "refCounts[key2] should not exist") + + // Add to the sequence and unlock the key, allowing the goroutines to continue + sequence1 = append(sequence1, "A") + km.Unlock(key1) + + // Wait for the goroutines to finish + wg.Wait() + + require.Equal(t, []string{"A", "B", "C"}, sequence1) + require.Equal(t, []string{"A"}, sequence2) + require.Equal(t, 0, km.refCounts[key1], "refCounts[key1] should be 0") + require.Equal(t, 0, km.refCounts[key2], "refCounts[key2] should be 0") +} + +func TestKeyMutexLocking(t *testing.T) { + var km KeyMutex[int] + var wgAcquiringLock sync.WaitGroup + var wgAllLocksReleased sync.WaitGroup + iterCount := 5 + var grantedCount int + + km.Lock(1) + + chanUnsuspend := make(chan struct{}) + + // Queue up a bunch of goroutines waiting to acquire the same lock + for i := 0; i < iterCount; i++ { + wgAcquiringLock.Add(1) + wgAllLocksReleased.Add(1) + go func() { + defer wgAllLocksReleased.Done() + defer km.Unlock(1) + chanWaiting := make(chan struct{}) + go func() { + <-chanWaiting + wgAcquiringLock.Done() + }() + km.lockWithWaiting(1, chanWaiting) + <-chanUnsuspend + grantedCount++ + }() + } + + // Because we acquired the first lock, the grantedCount should still be zero here + require.Equal(t, 0, grantedCount, "grantedCount should be 0") + + // Wait for all goroutines to be waiting to acquire the lock + wgAcquiringLock.Wait() + require.Equal(t, iterCount+1, km.refCounts[1], "refCounts[1] should be %d", iterCount+1) + + // Allow all locks to be acquired sequentially + km.Unlock(1) + close(chanUnsuspend) + + // Acquire one more lock, which should wait until all the other locks are released + wgAllLocksReleased.Wait() + require.Equal(t, 0, km.refCounts[1], "refCounts[1] should be 0") + require.Equal(t, iterCount, grantedCount, "grantedCount should be %d", iterCount) +}