Skip to content
This repository has been archived by the owner on Jan 8, 2024. It is now read-only.

Commit

Permalink
Rest of the equality and comparison operators.
Browse files Browse the repository at this point in the history
  • Loading branch information
Julian committed Mar 10, 2023
1 parent f5559a3 commit 8fb651d
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "imrc"
version = "0.1.0"
version = "0.2.0"
edition = "2021"

[lib]
Expand Down
129 changes: 121 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,25 @@ impl HashMapPy {
format!("HashMap({{{}}})", contents.collect::<Vec<_>>().join(", "))
}

fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult<PyObject> {
match op {
CompareOp::Eq => (self.inner.len() == other.inner.len()).into_py(py),
CompareOp::Ne => (self.inner.len() != other.inner.len()).into_py(py),
_ => py.NotImplemented(),
CompareOp::Eq => Ok((self.inner.len() == other.inner.len()
&& self
.inner
.iter()
.map(|(k1, v1)| (v1, other.inner.get(&k1)))
.map(|(v1, v2)| PyAny::eq(v1.extract(py)?, v2))
.all(|r| r.unwrap_or(false)))
.into_py(py)),
CompareOp::Ne => Ok((self.inner.len() != other.inner.len()
|| self
.inner
.iter()
.map(|(k1, v1)| (v1, other.inner.get(&k1)))
.map(|(v1, v2)| PyAny::ne(v1.extract(py)?, v2))
.all(|r| r.unwrap_or(true)))
.into_py(py)),
_ => Ok(py.NotImplemented()),
}
}

Expand Down Expand Up @@ -255,6 +269,10 @@ impl<'source> FromPyObject<'source> for HashSetPy {
}
}

fn is_subset(one: &HashSet<Key>, two: &HashSet<Key>) -> bool {
one.iter().all(|v| two.contains(v))
}

#[pymethods]
impl HashSetPy {
#[new]
Expand All @@ -268,6 +286,22 @@ impl HashSetPy {
}
}

fn __and__(&self, other: &Self) -> Self {
self.intersection(&other)
}

fn __or__(&self, other: &Self) -> Self {
self.union(&other)
}

fn __sub__(&self, other: &Self) -> Self {
self.difference(&other)
}

fn __xor__(&self, other: &Self) -> Self {
self.symmetric_difference(&other)
}

fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<KeyIterator>> {
let iter = slf
.inner
Expand All @@ -292,11 +326,19 @@ impl HashSetPy {
format!("HashSet({{{}}})", contents.collect::<Vec<_>>().join(", "))
}

fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult<PyObject> {
match op {
CompareOp::Eq => (self.inner.len() == other.inner.len()).into_py(py),
CompareOp::Ne => (self.inner.len() != other.inner.len()).into_py(py),
_ => py.NotImplemented(),
CompareOp::Eq => Ok((self.inner.len() == other.inner.len()
&& is_subset(&self.inner, &other.inner))
.into_py(py)),
CompareOp::Ne => Ok((self.inner.len() != other.inner.len()
|| self.inner.iter().any(|k| !other.inner.contains(k)))
.into_py(py)),
CompareOp::Lt => Ok((self.inner.len() < other.inner.len()
&& is_subset(&self.inner, &other.inner))
.into_py(py)),
CompareOp::Le => Ok(is_subset(&self.inner, &other.inner).into_py(py)),
_ => Ok(py.NotImplemented()),
}
}

Expand Down Expand Up @@ -326,6 +368,69 @@ impl HashSetPy {
}
}

fn difference(&self, other: &Self) -> Self {
let mut inner = self.inner.clone();
for value in other.inner.iter() {
inner.remove(value);
}
HashSetPy { inner }
}

fn intersection(&self, other: &Self) -> Self {
let mut inner: HashSet<Key> = HashSet::new();
let larger: &HashSet<Key>;
let iter;
if self.inner.len() > other.inner.len() {
larger = &self.inner;
iter = other.inner.iter();
} else {
larger = &other.inner;
iter = self.inner.iter();
}
for value in iter {
if larger.contains(value) {
inner.insert(value.to_owned());
}
}
HashSetPy { inner }
}

fn symmetric_difference(&self, other: &Self) -> Self {
let mut inner: HashSet<Key>;
let iter;
if self.inner.len() > other.inner.len() {
inner = self.inner.clone();
iter = other.inner.iter();
} else {
inner = other.inner.clone();
iter = self.inner.iter();
}
for value in iter {
if inner.contains(value) {
inner.remove(value);
} else {
inner.insert(value.to_owned());
}
}
HashSetPy { inner }
}

fn union(&self, other: &Self) -> Self {
let mut inner: HashSet<Key>;
let iter;
if self.inner.len() > other.inner.len() {
inner = self.inner.clone();
iter = other.inner.iter();
} else {
inner = other.inner.clone();
iter = self.inner.iter();
}
for value in iter {
inner.insert(value.to_owned());
}
HashSetPy { inner }
}

#[pyo3(signature = (*iterables))]
fn update(&self, iterables: &PyTuple) -> PyResult<HashSetPy> {
let mut inner = self.inner.clone();
Expand Down Expand Up @@ -414,6 +519,14 @@ impl VectorPy {
.map(|(e1, e2)| PyAny::eq(e1.extract(py)?, e2))
.all(|r| r.unwrap_or(false)))
.into_py(py)),
CompareOp::Ne => Ok((self.inner.len() != other.inner.len()
|| self
.inner
.iter()
.zip(other.inner.iter())
.map(|(e1, e2)| PyAny::ne(e1.extract(py)?, e2))
.any(|r| r.unwrap_or(true)))
.into_py(py)),
_ => Ok(py.NotImplemented()),
}
}
Expand Down
15 changes: 14 additions & 1 deletion tests/test_hash_trie_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def test_update_with_multiple_arguments():
def test_update_one_argument():
x = HashMap(a=1)

assert x.update({"b": "2"}) == HashMap(a=1, b=2)
assert x.update({"b": 2}) == HashMap(a=1, b=2)


def test_update_no_arguments():
Expand Down Expand Up @@ -309,3 +309,16 @@ def test_convert_hashtriemap():
def test_fast_convert_hashtriemap():
m = HashMap({i: i * 2 for i in range(3)})
assert HashMap.convert(m) is m


def test_more_eq():
# Non-pyrsistent-test-suite test
o = object()

assert HashMap([(o, o), (1, o)]) == HashMap([(o, o), (1, o)])
assert HashMap([(o, "foo")]) == HashMap([(o, "foo")])
assert HashMap() == HashMap([])

assert HashMap({1: 2}) != HashMap({1: 3})
assert HashMap({o: 1}) != HashMap({o: o})
assert HashMap([]) != HashMap([(o, 1)])
32 changes: 30 additions & 2 deletions tests/test_hash_trie_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def test_contains():
assert 4 not in s


@pytest.mark.xfail(reason="Can't figure out inheriting collections.abc yet")
def test_supports_set_operations():
s1 = HashSet([1, 2, 3])
s2 = HashSet([3, 4, 5])
Expand All @@ -114,7 +113,6 @@ def test_supports_set_operations():
assert s1.symmetric_difference(s2) == s1 ^ s2


@pytest.mark.xfail(reason="Can't figure out inheriting collections.abc yet")
def test_supports_set_comparisons():
s1 = HashSet([1, 2, 3])
s3 = HashSet([1, 2])
Expand Down Expand Up @@ -151,3 +149,33 @@ def test_update_no_elements():

def test_iterable():
assert HashSet(iter("a")) == HashSet(iter("a"))


def test_more_eq():
# Non-pyrsistent-test-suite test
o = object()

assert HashSet([o]) == HashSet([o])
assert HashSet([o, o]) == HashSet([o, o])
assert HashSet([o]) == HashSet([o, o])
assert HashSet() == HashSet([])
assert not (HashSet([1, 2]) == HashSet([1, 3]))
assert not (HashSet([o, 1]) == HashSet([o, o]))
assert not (HashSet([]) == HashSet([o]))

assert HashSet([1, 2]) != HashSet([1, 3])
assert HashSet([]) != HashSet([o])
assert not (HashSet([o]) != HashSet([o]))
assert not (HashSet([o, o]) != HashSet([o, o]))
assert not (HashSet([o]) != HashSet([o, o]))
assert not (HashSet() != HashSet([]))


def test_more_set_comparisons():
s = HashSet([1, 2, 3])

assert s == s
assert not (s < s)
assert s <= s
assert not (s > s)
assert s >= s
19 changes: 19 additions & 0 deletions tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,22 @@ def test_hashing():
def test_sequence():
m = Vector("asdf")
assert m == Vector(["a", "s", "d", "f"])


def test_more_eq():
# Non-pyrsistent-test-suite test
o = object()

assert Vector([o, o]) == Vector([o, o])
assert Vector([o]) == Vector([o])
assert Vector() == Vector([])
assert not (Vector([1, 2]) == Vector([1, 3]))
assert not (Vector([o]) == Vector([o, o]))
assert not (Vector([]) == Vector([o]))

assert Vector([1, 2]) != Vector([1, 3])
assert Vector([o]) != Vector([o, o])
assert Vector([]) != Vector([o])
assert not (Vector([o, o]) != Vector([o, o]))
assert not (Vector([o]) != Vector([o]))
assert not (Vector() != Vector([]))

0 comments on commit 8fb651d

Please sign in to comment.