Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TreapMap.keys as TreapSet #20

Merged
merged 6 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/AbstractKeySet.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.certora.collect

/**
Presents the keys of a [TreapMap] as a [TreapSet].

The idea here is that a `TreapMap<K, *>` is stored with the same Treap structure as a `TreapSet<K>`, so we can very
quickly create the corresponding `TreapSet<K>` when needed, in O(n) time (as opposed to the naive O(n*log(n))
method).

We lazily initialize the set, so that we don't create it until we need it. For many operations, we can avoid
creating the set entirely, and just use the map directly. However, many operations, e.g. [addAll]/[union] and
[retainAll/intersect], are much more efficient when we have a [TreapSet], so we create it when needed.
*/
internal abstract class AbstractKeySet<@Treapable K, S : TreapSet<K>> : TreapSet<K> {
/**
The map whose keys we are presenting as a set. We prefer to use the map directly when possible, so we don't
need to create the set.
*/
abstract val map: AbstractTreapMap<K, *, *>
/**
The set of keys. This is a lazy property so that we don't create the set until we need it.
*/
abstract val keys: Lazy<S>

@Suppress("Treapability")
override fun hashCode() = keys.value.hashCode()
override fun equals(other: Any?) = keys.value.equals(other)
override fun toString() = keys.value.toString()

override val size get() = map.size
override fun isEmpty() = map.isEmpty()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i always find it very unintuitive that size for treaps is an expensive operation. No avoiding it really I guess...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we could store the size in the treap nodes, but of course that increases memory usage. So far this still seems like the right tradeoff.

override fun clear() = treapSetOf<K>()

override operator fun contains(element: K) = map.containsKey(element)
override operator fun iterator() = map.entrySequence().map { it.key }.iterator()

override fun add(element: K) = keys.value.add(element)
override fun addAll(elements: Collection<K>) = keys.value.addAll(elements)
override fun remove(element: K) = keys.value.remove(element)
override fun removeAll(elements: Collection<K>) = keys.value.removeAll(elements)
override fun removeAll(predicate: (K) -> Boolean) = keys.value.removeAll(predicate)
override fun retainAll(elements: Collection<K>) = keys.value.retainAll(elements)

override fun single() = map.single().key
override fun singleOrNull() = map.singleOrNull()?.key
override fun arbitraryOrNull() = map.arbitraryOrNull()?.key

override fun containsAny(elements: Iterable<K>) = keys.value.containsAny(elements)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you create the keys set for containsAny and containsAll?
is it because if elements happens to be a TreapSet it's more efficient?

but then if elements is not then then can be pretty costly. I'm especially thinking of the case where keys is very large and elements is tiny. You'll run in time O(|keys|).

maybe you can check here if elements is a TreapSet or not, and according to that choose what to do?

Copy link
Collaborator Author

@ericeil ericeil Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm opting for simplicity here. The problem you point out (large map, small elements) is still a problem even if everything's a treap. Optimizing that is tricky, since getting the size of a Treap (as you point out elsewhere) is surprisingly expensive.

We could have different cases here:

  • Empty elements
  • Single Treap element
  • Non-treap (maybe broken down further by size?)

...but I'd much rather just keep this simple for now. It's possible we will never even use this method. :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, I agree it's likely that calling these will be pretty rare.

override fun containsAny(predicate: (K) -> Boolean) = (this as Iterable<K>).any(predicate)
override fun containsAll(elements: Collection<K>) = keys.value.containsAll(elements)
override fun findEqual(element: K) = keys.value.findEqual(element)

override fun forEachElement(action: (K) -> Unit) = map.forEachEntry { action(it.key) }

override fun <R : Any> mapReduce(map: (K) -> R, reduce: (R, R) -> R) =
this.map.mapReduce({ k, _ -> map(k) }, reduce)
override fun <R : Any> parallelMapReduce(map: (K) -> R, reduce: (R, R) -> R, parallelThresholdLog2: Int) =
this.map.parallelMapReduce({ k, _ -> map(k) }, reduce, parallelThresholdLog2)
}
12 changes: 4 additions & 8 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
return when {
otherMap == null -> false
otherMap === this -> true
otherMap.isEmpty() -> false // NB AbstractTreapMap always contains at least one entry
else -> otherMap.useAsTreap(
{ otherTreap -> this.self.deepEquals(otherTreap) },
{ other.size == this.size && other.entries.all { this.containsEntry(it) }}
Expand All @@ -111,6 +112,9 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override val size: Int get() = computeSize()
override fun isEmpty(): Boolean = false

// NB AbstractTreapMap always contains at least one entry
override fun single() = singleOrNull() ?: throw IllegalArgumentException("Map contains more than one entry")

override fun containsKey(key: K) =
key.toTreapKey()?.let { self.find(it) }?.shallowContainsKey(key) ?: false

Expand Down Expand Up @@ -139,14 +143,6 @@ internal sealed class AbstractTreapMap<@Treapable K, V, @Treapable S : AbstractT
override fun iterator() = entrySequence().iterator()
}

override val keys: ImmutableSet<K>
get() = object: AbstractSet<K>(), ImmutableSet<K> {
override val size get() = [email protected]
override fun isEmpty() = [email protected]()
override operator fun contains(element: K) = containsKey(element)
override operator fun iterator() = entrySequence().map { it.key }.iterator()
}

override val values: ImmutableCollection<V>
get() = object: AbstractCollection<V>(), ImmutableCollection<V> {
override val size get() = [email protected]
Expand Down
19 changes: 2 additions & 17 deletions collect/src/main/kotlin/com/certora/collect/AbstractTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
*/
abstract fun shallowForEach(action: (element: E) -> Unit): Unit

abstract fun shallowGetSingleElement(): E?

abstract infix fun shallowUnion(that: S): S
abstract infix fun shallowIntersect(that: S): S?
abstract infix fun shallowDifference(that: S): S?
Expand Down Expand Up @@ -85,6 +83,7 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
other == null -> false
this === other -> true
other !is Set<*> -> false
other.isEmpty() -> false // NB AbstractTreapSet always contains at least one element
else -> (other as Set<E>).useAsTreap(
{ otherTreap -> this.self.deepEquals(otherTreap) },
{ this.size == other.size && this.containsAll(other) }
Expand Down Expand Up @@ -136,26 +135,12 @@ internal sealed class AbstractTreapSet<@Treapable E, S : AbstractTreapSet<E, S>>
override fun findEqual(element: E): E? =
element.toTreapKey()?.let { self.find(it) }?.shallowFindEqual(element)

@Suppress("UNCHECKED_CAST")
override fun single(): E = getSingleElement() ?: when {
isEmpty() -> throw NoSuchElementException("Set is empty")
size > 1 -> throw IllegalArgumentException("Set has more than one element")
else -> null as E // The single element must have been null!
}

override fun singleOrNull(): E? = getSingleElement()

override fun forEachElement(action: (element: E) -> Unit): Unit {
left?.forEachElement(action)
shallowForEach(action)
right?.forEachElement(action)
}

internal fun getSingleElement(): E? = when {
left === null && right === null -> shallowGetSingleElement()
else -> null
}

override fun <R : Any> mapReduce(map: (E) -> R, reduce: (R, R) -> R): R =
notForking(self) { mapReduceImpl(map, reduce) }

Expand Down Expand Up @@ -186,7 +171,7 @@ internal infix fun <@Treapable E, S : AbstractTreapSet<E, S>> S?.treapUnion(that
this == null -> that
that == null -> this
this === that -> this
that.getSingleElement() != null -> add(that)
that.singleOrNull() != null -> add(that)
else -> {
// remember, a.comparePriorityTo(b)==0 <=> a.compareKeyTo(b)==0
val c = this.comparePriorityTo(that)
Expand Down
4 changes: 3 additions & 1 deletion collect/src/main/kotlin/com/certora/collect/EmptyTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
override fun remove(key: K): TreapMap<K, V> = this
override fun remove(key: K, value: V): TreapMap<K, V> = this

override fun single(): Map.Entry<K, V> = throw NoSuchElementException("Empty map.")
override fun singleOrNull(): Map.Entry<K, V>? = null
override fun arbitraryOrNull(): Map.Entry<K, V>? = null

override fun forEachEntry(action: (Map.Entry<K, V>) -> Unit): Unit {}
Expand Down Expand Up @@ -66,7 +68,7 @@ internal class EmptyTreapMap<@Treapable K, V> private constructor() : TreapMap<K
m.asSequence().map { MapEntry(it.key, null to it.value) }

override val entries: ImmutableSet<Map.Entry<K, V>> get() = persistentSetOf<Map.Entry<K, V>>()
override val keys: ImmutableSet<K> get() = persistentSetOf<K>()
override val keys: TreapSet<K> get() = treapSetOf<K>()
override val values: ImmutableCollection<V> get() = persistentSetOf<V>()

@Suppress("Treapability", "UNCHECKED_CAST")
Expand Down
14 changes: 14 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/HashTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ internal class HashTreapMap<@Treapable K, V>(
this as? HashTreapMap<K, V>
?: (this as? PersistentMap.Builder<K, V>)?.build() as? HashTreapMap<K, V>

override fun singleOrNull() = MapEntry(key, value).takeIf { next == null && left == null && right == null }
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might wonder why I don't object to this, where I am objecting to the also nonsense. I confess I don't have a great rule of thumb, except this looks more idiomatic, and the takeIf is clearly indicating what's happening, whereas using also to exceptionally abort is pretty surprising

override fun arbitraryOrNull(): Map.Entry<K, V>? = MapEntry(key, value)

override fun getShallowMerger(merger: (K, V?, V?) -> V?): (HashTreapMap<K, V>?, HashTreapMap<K, V>?) -> HashTreapMap<K, V>? = { t1, t2 ->
Expand Down Expand Up @@ -350,6 +351,17 @@ internal class HashTreapMap<@Treapable K, V>(
forEachPair { (k, v) -> action(MapEntry(k, v)) }
right?.forEachEntry(action)
}

private fun treapSetFromKeys(): HashTreapSet<K> =
HashTreapSet(treapKey, next?.toKeyList(), left?.treapSetFromKeys(), right?.treapSetFromKeys())

inner class KeySet : AbstractKeySet<K, HashTreapSet<K>>() {
override val map get() = this@HashTreapMap
override val keys = lazy { treapSetFromKeys() }
override fun hashCode() = super.hashCode() // avoids treapability warning
}

override val keys get() = KeySet()
}

internal interface KeyValuePairList<K, V> {
Expand All @@ -359,6 +371,8 @@ internal interface KeyValuePairList<K, V> {
operator fun component1() = key
operator fun component2() = value

fun toKeyList(): ElementList.More<K> = ElementList.More(key, next?.toKeyList())

class More<K, V>(
override val key: K,
override val value: V,
Expand Down
7 changes: 5 additions & 2 deletions collect/src/main/kotlin/com/certora/collect/HashTreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ internal class HashTreapSet<@Treapable E>(
override fun Iterable<E>.toTreapSetOrNull(): HashTreapSet<E>? =
(this as? HashTreapSet<E>)
?: (this as? TreapSet.Builder<E>)?.build() as? HashTreapSet<E>
?: (this as? HashTreapMap<E, *>.KeySet)?.keys?.value

private inline fun ElementList<E>?.forEachNodeElement(action: (E) -> Unit) {
var current = this
Expand Down Expand Up @@ -228,8 +229,10 @@ internal class HashTreapSet<@Treapable E>(
}
}.iterator()

override fun shallowGetSingleElement(): E? = element.takeIf { next == null }

override fun singleOrNull(): E? = element.takeIf { next == null && left == null && right == null }
override fun single(): E = element.also {
if (next != null || left != null || right != null) { throw IllegalArgumentException("Set contains more than one element") }
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just because you can write something as a one liner, doesn't mean you should.

This "post fixing" conditional stuff sucks in ruby, let's not bring it into kotlin :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine. :)

override fun arbitraryOrNull(): E? = element

override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R {
Expand Down
12 changes: 12 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/SortedTreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ internal class SortedTreapMap<@Treapable K, V>(
this as? SortedTreapMap<K, V>
?: (this as? PersistentMap.Builder<K, V>)?.build() as? SortedTreapMap<K, V>

override fun singleOrNull(): Map.Entry<K, V>? = MapEntry(key, value).takeIf { left == null && right == null }
override fun arbitraryOrNull(): Map.Entry<K, V>? = MapEntry(key, value)

override fun getShallowUnionMerger(
Expand Down Expand Up @@ -163,4 +164,15 @@ internal class SortedTreapMap<@Treapable K, V>(
action(this.asEntry())
right?.forEachEntry(action)
}

private fun treapSetFromKeys(): SortedTreapSet<K> =
SortedTreapSet(treapKey, left?.treapSetFromKeys(), right?.treapSetFromKeys())

inner class KeySet : AbstractKeySet<K, SortedTreapSet<K>>() {
override val map get() = this@SortedTreapMap
override val keys = lazy { treapSetFromKeys() }
override fun hashCode() = super.hashCode() // avoids treapability warning
}

override val keys get() = KeySet()
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ internal class SortedTreapSet<@Treapable E>(
override fun Iterable<E>.toTreapSetOrNull(): SortedTreapSet<E>? =
(this as? SortedTreapSet<E>)
?: (this as? PersistentSet.Builder<E>)?.build() as? SortedTreapSet<E>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a question about line 29, how come you check if it is a PresistentSet.Builder and not SortedTreapSet.Builder?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no SortedTreapSet.Builder.

?: (this as? SortedTreapMap<E, *>.KeySet)?.keys?.value

override val self get() = this
override fun iterator(): Iterator<E> = this.asTreapSequence().map { it.treapKey }.iterator()
Expand All @@ -49,7 +50,10 @@ internal class SortedTreapSet<@Treapable E>(
override fun shallowRemove(element: E): SortedTreapSet<E>? = null
override fun shallowRemoveAll(predicate: (E) -> Boolean): SortedTreapSet<E>? = this.takeIf { !predicate(treapKey) }
override fun shallowComputeHashCode(): Int = treapKey.hashCode()
override fun shallowGetSingleElement(): E = treapKey
override fun singleOrNull(): E? = treapKey.takeIf { left == null && right == null }
override fun single(): E = treapKey.also {
if (left != null || right != null) { throw IllegalArgumentException("Set contains more than one element") }

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason not to do (here and also in the hash version)

    override fun single(): E = singleOrNull() ?: throw ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E might be nullable, so singleOrNull doesn't give the correct behavior.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same objection as above. Just use a block body

}
override fun arbitraryOrNull(): E? = treapKey
override fun shallowForEach(action: (element: E) -> Unit): Unit { action(treapKey) }
override fun <R : Any> shallowMapReduce(map: (E) -> R, reduce: (R, R) -> R): R = map(treapKey)
Expand Down
11 changes: 11 additions & 0 deletions collect/src/main/kotlin/com/certora/collect/TreapMap.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
override fun remove(key: K, value: @UnsafeVariance V): TreapMap<K, V>
override fun putAll(m: Map<out K, @UnsafeVariance V>): TreapMap<K, V>
override fun clear(): TreapMap<K, V>
override val keys: TreapSet<K>

/**
A [PersistentMap.Builder] that produces a [TreapMap].
Expand All @@ -23,6 +24,16 @@ public sealed interface TreapMap<K, V> : PersistentMap<K, V> {
@Suppress("Treapability")
override fun builder(): Builder<K, @UnsafeVariance V> = TreapMapBuilder(this)

/**
If this map contains exactly one entry, returns that entry. Otherwise, throws.
*/
public fun single(): Map.Entry<K, V>

/**
If this map contains exactly one entry, returns that entry. Otherwise, returns null
*/
public fun singleOrNull(): Map.Entry<K, V>?

/**
Returns an arbitrary entry from the map, or null if the map is empty.
*/
Expand Down
2 changes: 1 addition & 1 deletion collect/src/main/kotlin/com/certora/collect/TreapSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public sealed interface TreapSet<out T> : PersistentSet<T> {
public fun containsAny(predicate: (T) -> Boolean): Boolean

/**
If this set contains exactly one element, returns that element. Otherwise, throws [NoSuchElementException].
If this set contains exactly one element, returns that element. Otherwise, throws.
*/
public fun single(): T

Expand Down
Loading