forked from EndlessCheng/codeforces-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
treap_kthsum.go
119 lines (104 loc) · 2.1 KB
/
treap_kthsum.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package copypasta
import "time"
/* 维护前 k 小元素和的 treap
支持添加删除元素
https://leetcode.cn/problems/divide-an-array-into-subarrays-with-minimum-cost-ii/
https://atcoder.jp/contests/abc306/tasks/abc306_e
https://atcoder.jp/contests/abc287/tasks/abc287_g
*/
type nodeSum struct {
lr [2]*nodeSum
priority uint
key int
keyCnt int
subSize int
keySum int
subSum int
}
func (o *nodeSum) cmp(a int) int {
b := o.key
if a == b {
return -1
}
if a < b {
return 0
}
return 1
}
func (o *nodeSum) getSize() int {
if o != nil {
return o.subSize
}
return 0
}
func (o *nodeSum) getSum() int {
if o != nil {
return o.subSum
}
return 0
}
func (o *nodeSum) maintain() {
o.subSize = o.keyCnt + o.lr[0].getSize() + o.lr[1].getSize()
o.subSum = o.keySum + o.lr[0].getSum() + o.lr[1].getSum()
}
func (o *nodeSum) rotate(d int) *nodeSum {
x := o.lr[d^1]
o.lr[d^1] = x.lr[d]
x.lr[d] = o
o.maintain()
x.maintain()
return x
}
type treapSum struct {
rd uint
root *nodeSum
}
func (t *treapSum) fastRand() uint {
t.rd ^= t.rd << 13
t.rd ^= t.rd >> 17
t.rd ^= t.rd << 5
return t.rd
}
func (t *treapSum) _put(o *nodeSum, key, num int) *nodeSum {
if o == nil {
o = &nodeSum{priority: t.fastRand(), key: key, keyCnt: num, keySum: key * num}
} else if d := o.cmp(key); d >= 0 {
o.lr[d] = t._put(o.lr[d], key, num)
if o.lr[d].priority > o.priority {
o = o.rotate(d ^ 1)
}
} else {
o.keyCnt += num
o.keySum += key * num
}
o.maintain()
return o
}
// num=1 表示添加一个 key
// num=-1 表示移除一个 key
func (t *treapSum) put(key, num int) { t.root = t._put(t.root, key, num) }
// 返回前 k 小数的和(k 从 1 开始)
func (t *treapSum) kth(k int) (sum int) {
if k > t.root.getSize() {
panic(-1)
}
for o := t.root; o != nil; {
if ls := o.lr[0].getSize(); k < ls {
o = o.lr[0]
} else {
sum += o.lr[0].getSum()
k -= ls
if k <= o.keyCnt {
sum += o.key * k
return
}
sum += o.keySum
k -= o.keyCnt
o = o.lr[1]
}
}
return
}
func newTreapSum() *treapSum {
return &treapSum{rd: uint(time.Now().UnixNano())/2 + 1}
}