Skip to content

Commit

Permalink
sync: add implementation from upstream Go for OnceFunc, OnceValue, an…
Browse files Browse the repository at this point in the history
…d OnceValues

Signed-off-by: deadprogram <[email protected]>
  • Loading branch information
deadprogram committed Jul 20, 2023
1 parent 4da1f6b commit 01d2ef3
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 0 deletions.
97 changes: 97 additions & 0 deletions src/sync/oncefunc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package sync

// OnceFunc returns a function that invokes f only once. The returned function
// may be called concurrently.
//
// If f panics, the returned function will panic with the same value on every call.
func OnceFunc(f func()) func() {
var (
once Once
valid bool
p any
)
// Construct the inner closure just once to reduce costs on the fast path.
g := func() {
defer func() {
p = recover()
if !valid {
// Re-panic immediately so on the first call the user gets a
// complete stack trace into f.
panic(p)
}
}()
f()
valid = true // Set only if f does not panic
}
return func() {
once.Do(g)
if !valid {
panic(p)
}
}
}

// OnceValue returns a function that invokes f only once and returns the value
// returned by f. The returned function may be called concurrently.
//
// If f panics, the returned function will panic with the same value on every call.
func OnceValue[T any](f func() T) func() T {
var (
once Once
valid bool
p any
result T
)
g := func() {
defer func() {
p = recover()
if !valid {
panic(p)
}
}()
result = f()
valid = true
}
return func() T {
once.Do(g)
if !valid {
panic(p)
}
return result
}
}

// OnceValues returns a function that invokes f only once and returns the values
// returned by f. The returned function may be called concurrently.
//
// If f panics, the returned function will panic with the same value on every call.
func OnceValues[T1, T2 any](f func() (T1, T2)) func() (T1, T2) {
var (
once Once
valid bool
p any
r1 T1
r2 T2
)
g := func() {
defer func() {
p = recover()
if !valid {
panic(p)
}
}()
r1, r2 = f()
valid = true
}
return func() (T1, T2) {
once.Do(g)
if !valid {
panic(p)
}
return r1, r2
}
}
159 changes: 159 additions & 0 deletions src/sync/oncefunc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package sync_test

import (
"sync"
"testing"
)

// We assume that the Once.Do tests have already covered parallelism.

func TestOnceFunc(t *testing.T) {
calls := 0
f := sync.OnceFunc(func() { calls++ })
allocs := testing.AllocsPerRun(10, f)
if calls != 1 {
t.Errorf("want calls==1, got %d", calls)
}
if allocs != 0 {
t.Errorf("want 0 allocations per call, got %v", allocs)
}
}

func TestOnceValue(t *testing.T) {
calls := 0
f := sync.OnceValue(func() int {
calls++
return calls
})
allocs := testing.AllocsPerRun(10, func() { f() })
value := f()
if calls != 1 {
t.Errorf("want calls==1, got %d", calls)
}
if value != 1 {
t.Errorf("want value==1, got %d", value)
}
if allocs != 0 {
t.Errorf("want 0 allocations per call, got %v", allocs)
}
}

func TestOnceValues(t *testing.T) {
calls := 0
f := sync.OnceValues(func() (int, int) {
calls++
return calls, calls + 1
})
allocs := testing.AllocsPerRun(10, func() { f() })
v1, v2 := f()
if calls != 1 {
t.Errorf("want calls==1, got %d", calls)
}
if v1 != 1 || v2 != 2 {
t.Errorf("want v1==1 and v2==2, got %d and %d", v1, v2)
}
if allocs != 0 {
t.Errorf("want 0 allocations per call, got %v", allocs)
}
}

// TODO: need to implement more complete panic handling for these tests.
// func testOncePanicX(t *testing.T, calls *int, f func()) {
// testOncePanicWith(t, calls, f, func(label string, p any) {
// if p != "x" {
// t.Fatalf("%s: want panic %v, got %v", label, "x", p)
// }
// })
// }

// func testOncePanicWith(t *testing.T, calls *int, f func(), check func(label string, p any)) {
// // Check that the each call to f panics with the same value, but the
// // underlying function is only called once.
// for _, label := range []string{"first time", "second time"} {
// var p any
// panicked := true
// func() {
// defer func() {
// p = recover()
// }()
// f()
// panicked = false
// }()
// if !panicked {
// t.Fatalf("%s: f did not panic", label)
// }
// check(label, p)
// }
// if *calls != 1 {
// t.Errorf("want calls==1, got %d", *calls)
// }
// }

// func TestOnceFuncPanic(t *testing.T) {
// calls := 0
// f := sync.OnceFunc(func() {
// calls++
// panic("x")
// })
// testOncePanicX(t, &calls, f)
// }

// func TestOnceValuePanic(t *testing.T) {
// calls := 0
// f := sync.OnceValue(func() int {
// calls++
// panic("x")
// })
// testOncePanicX(t, &calls, func() { f() })
// }

// func TestOnceValuesPanic(t *testing.T) {
// calls := 0
// f := sync.OnceValues(func() (int, int) {
// calls++
// panic("x")
// })
// testOncePanicX(t, &calls, func() { f() })
// }
//
// func TestOnceFuncPanicNil(t *testing.T) {
// calls := 0
// f := sync.OnceFunc(func() {
// calls++
// panic(nil)
// })
// testOncePanicWith(t, &calls, f, func(label string, p any) {
// switch p.(type) {
// case nil, *runtime.PanicNilError:
// return
// }
// t.Fatalf("%s: want nil panic, got %v", label, p)
// })
// }
//
// func TestOnceFuncGoexit(t *testing.T) {
// // If f calls Goexit, the results are unspecified. But check that f doesn't
// // get called twice.
// calls := 0
// f := sync.OnceFunc(func() {
// calls++
// runtime.Goexit()
// })
// var wg sync.WaitGroup
// for i := 0; i < 2; i++ {
// wg.Add(1)
// go func() {
// defer wg.Done()
// defer func() { recover() }()
// f()
// }()
// wg.Wait()
// }
// if calls != 1 {
// t.Errorf("want calls==1, got %d", calls)
// }
// }

0 comments on commit 01d2ef3

Please sign in to comment.