Skip to content

Commit

Permalink
correct set.union method (#574)
Browse files Browse the repository at this point in the history
* correct set.union method

The `union` function on `set` is inconsistent with the behaviour of
`update` as it does not support multiple iterable positional arguments
as is the case in the Bazel specification of the language. This PR
will align `starlark-go` with the Bazel spec.

Note that the `set.union` method no longer uses the `Union` function
defined in `value.go` in order to avoid making a new `set` instance
for each interable processed.

* correct set.union method

Fixes in PR from @adonovan.

* correct set.union method

function name change
  • Loading branch information
andponlin-canva authored Jan 28, 2025
1 parent 2fb1215 commit d908c3e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 33 deletions.
8 changes: 4 additions & 4 deletions doc/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -3762,11 +3762,11 @@ x.symmetric_difference([3, 4, 5]) # set([1, 2, 4, 5])
<a id='set·union'></a>
### set·union
`S.union(iterable)` returns a new set into which have been inserted
all the elements of set S and all the elements of the argument, which
must be iterable.
`S.union(iterable...)` returns a new set into which have been inserted
all the elements of set S and each element of the iterable sequences.
`union` fails if any element of the iterable is not hashable.
`union` fails if any argument is not an iterable sequence, or if any
sequence element is not hashable.
```python
x = set([1, 2])
Expand Down
55 changes: 27 additions & 28 deletions starlark/library.go
Original file line number Diff line number Diff line change
Expand Up @@ -2337,41 +2337,18 @@ func set_symmetric_difference(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple)

// https://github.com/google/starlark-go/blob/master/doc/spec.md#set·union.
func set_union(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) {
var iterable Iterable
if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &iterable); err != nil {
return nil, err
}
iter := iterable.Iterate()
defer iter.Done()
union, err := b.Receiver().(*Set).Union(iter)
if err != nil {
receiverSet := b.Receiver().(*Set).clone()
if err := setUpdate(receiverSet, args, kwargs); err != nil {
return nil, nameErr(b, err)
}
return union, nil
return receiverSet, nil
}

// https://github.com/google/starlark-go/blob/master/doc/spec.md#set·update.
func set_update(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) {
if len(kwargs) > 0 {
return nil, nameErr(b, "update does not accept keyword arguments")
}

receiverSet := b.Receiver().(*Set)

for i, arg := range args {
iterable, ok := arg.(Iterable)
if !ok {
return nil, fmt.Errorf("update: argument #%d is not iterable: %s", i+1, arg.Type())
}
if err := func() error {
iter := iterable.Iterate()
defer iter.Done()
return receiverSet.InsertAll(iter)
}(); err != nil {
return nil, nameErr(b, err)
}
if err := setUpdate(b.Receiver().(*Set), args, kwargs); err != nil {
return nil, nameErr(b, err)
}

return None, nil
}

Expand Down Expand Up @@ -2474,6 +2451,28 @@ func updateDict(dict *Dict, updates Tuple, kwargs []Tuple) error {
return nil
}

func setUpdate(s *Set, args Tuple, kwargs []Tuple) error {
if len(kwargs) > 0 {
return errors.New("does not accept keyword arguments")
}

for i, arg := range args {
iterable, ok := arg.(Iterable)
if !ok {
return fmt.Errorf("argument #%d is not iterable: %s", i+1, arg.Type())
}
if err := func() error {
iter := iterable.Iterate()
defer iter.Done()
return s.InsertAll(iter)
}(); err != nil {
return err
}
}

return nil
}

// nameErr returns an error message of the form "name: msg"
// where name is b.Name() and msg is a string or error.
func nameErr(b *Builtin, msg interface{}) error {
Expand Down
11 changes: 10 additions & 1 deletion starlark/testdata/set.star
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,23 @@ assert.eq(list(set("a".elems()).union("b".elems())), ["a", "b"])
assert.eq(list(set("ab".elems()).union("bc".elems())), ["a", "b", "c"])
assert.eq(set().union([]), set())
assert.eq(type(x.union(y)), "set")
assert.eq(list(x.union()), [1, 2, 3])
assert.eq(list(x.union(y)), [1, 2, 3, 4, 5])
assert.eq(list(x.union(y, [6, 7])), [1, 2, 3, 4, 5, 6, 7])
assert.eq(list(x.union([5, 1])), [1, 2, 3, 5])
assert.eq(list(x.union((6, 5, 4))), [1, 2, 3, 6, 5, 4])
assert.fails(lambda : x.union([1, 2, {}]), "unhashable type: dict")
assert.fails(lambda : x.union(1, 2, 3), "argument #1 is not iterable: int")

# set.update (allows any iterable for the right operand)
# The update function will mutate the set so the tests below are
# scoped using a function.

def test_update_return_value():
assert.eq(set(x).update(y), None)

test_update_return_value()

def test_update_elems_singular():
s = set("a".elems())
s.update("b".elems())
Expand Down Expand Up @@ -130,7 +139,7 @@ test_update_non_iterable()

def test_update_kwargs():
s = set(x)
assert.fails(lambda: x.update(gee = [3, 4]), "update: update does not accept keyword arguments")
assert.fails(lambda: x.update(gee = [3, 4]), "update: does not accept keyword arguments")

test_update_kwargs()

Expand Down

0 comments on commit d908c3e

Please sign in to comment.