Skip to content

Commit

Permalink
fix Intersection2Vector panic (#19383)
Browse files Browse the repository at this point in the history
  • Loading branch information
gouhongshen authored Oct 16, 2024
1 parent 6f6fdf5 commit 82055b3
Show file tree
Hide file tree
Showing 2 changed files with 344 additions and 23 deletions.
129 changes: 106 additions & 23 deletions pkg/container/vector/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -4301,7 +4301,12 @@ func BuildVarlenaFromArray[T types.RealNumbers](vec *Vector, v *types.Varlena, a

// Intersection2VectorOrdered does a ∩ b ==> ret, keeps all item unique and sorted
// it assumes that a and b all sorted already
func Intersection2VectorOrdered[T types.OrderedT | types.Decimal128](a, b []T, ret *Vector, mp *mpool.MPool, cmp func(x, y T) int) {
func Intersection2VectorOrdered[T types.OrderedT | types.Decimal128](
a, b []T,
ret *Vector,
mp *mpool.MPool,
cmp func(x, y T) int) (err error) {

var long, short []T
if len(a) < len(b) {
long = b
Expand All @@ -4312,44 +4317,76 @@ func Intersection2VectorOrdered[T types.OrderedT | types.Decimal128](a, b []T, r
}
var lenLong, lenShort = len(long), len(short)

ret.PreExtend(lenLong+lenShort, mp)
if err = ret.PreExtend(lenLong+lenShort, mp); err != nil {
return err
}

for i := range short {
idx := sort.Search(lenLong, func(j int) bool {
return cmp(long[j], short[i]) >= 0
})

if idx >= lenLong {
break
}

j := idx
if cmp(short[i], long[idx]) == 0 {
AppendFixed(ret, short[i], false, mp)
if err = AppendFixed(ret, short[i], false, mp); err != nil {
return err
}

j++

// skip the same item
for j < lenLong && cmp(long[j], long[j-1]) == 0 {
j++
}

if j >= lenLong {
break
}
}

idx = j

long = long[idx:]
lenLong = len(long)
}
return nil
}

// Union2VectorOrdered does a ∪ b ==> ret, keeps all item unique and sorted
// it assumes that a and b all sorted already
func Union2VectorOrdered[T types.OrderedT | types.Decimal128](a, b []T, ret *Vector, mp *mpool.MPool, cmp func(x, y T) int) {
func Union2VectorOrdered[T types.OrderedT | types.Decimal128](
a, b []T,
ret *Vector,
mp *mpool.MPool,
cmp func(x, y T) int) (err error) {

var i, j int
var prevVal T
var lenA, lenB = len(a), len(b)

ret.PreExtend(lenA+lenB, mp)
if err = ret.PreExtend(lenA+lenB, mp); err != nil {
return err
}

for i < lenA && j < lenB {
if cmp(a[i], b[j]) <= 0 {
if (i == 0 && j == 0) || cmp(prevVal, a[i]) != 0 {
prevVal = a[i]
AppendFixed(ret, a[i], false, mp)
if err = AppendFixed(ret, a[i], false, mp); err != nil {
return err
}
}
i++
} else {
if (i == 0 && j == 0) || cmp(prevVal, b[j]) != 0 {
prevVal = b[j]
AppendFixed(ret, b[j], false, mp)
if err = AppendFixed(ret, b[j], false, mp); err != nil {
return err
}
}
j++
}
Expand All @@ -4358,21 +4395,30 @@ func Union2VectorOrdered[T types.OrderedT | types.Decimal128](a, b []T, ret *Vec
for ; i < lenA; i++ {
if (i == 0 && j == 0) || cmp(prevVal, a[i]) != 0 {
prevVal = a[i]
AppendFixed(ret, a[i], false, mp)
if err = AppendFixed(ret, a[i], false, mp); err != nil {
return err
}
}
}

for ; j < lenB; j++ {
if (i == 0 && j == 0) || cmp(prevVal, b[j]) != 0 {
prevVal = b[j]
AppendFixed(ret, b[j], false, mp)
if err = AppendFixed(ret, b[j], false, mp); err != nil {
return err
}
}
}
return nil
}

// Intersection2VectorVarlen does a ∩ b ==> ret, keeps all item unique and sorted
// it assumes that va and vb all sorted already
func Intersection2VectorVarlen(va, vb *Vector, ret *Vector, mp *mpool.MPool) {
func Intersection2VectorVarlen(
va, vb *Vector,
ret *Vector,
mp *mpool.MPool) (err error) {

var shortCol, longCol []types.Varlena
var shortArea, longArea []byte

Expand All @@ -4393,28 +4439,53 @@ func Intersection2VectorVarlen(va, vb *Vector, ret *Vector, mp *mpool.MPool) {

var lenLong, lenShort = len(longCol), len(shortCol)

ret.PreExtend(lenLong+lenShort, mp)
if err = ret.PreExtend(lenLong+lenShort, mp); err != nil {
return err
}

for i := range shortCol {
shortBytes := shortCol[i].GetByteSlice(shortArea)
idx := sort.Search(lenLong, func(j int) bool {
return bytes.Compare(longCol[j].GetByteSlice(longArea), shortBytes) >= 0
})

if idx >= lenLong {
break
}

j := idx
if bytes.Equal(shortBytes, longCol[idx].GetByteSlice(longArea)) {
AppendBytes(ret, shortBytes, false, mp)
if err = AppendBytes(ret, shortBytes, false, mp); err != nil {
return err
}

// skip the same item
j++
for j < lenLong && bytes.Equal(
longCol[j].GetByteSlice(longArea), longCol[j-1].GetByteSlice(longArea)) {
j++
}

if j >= lenLong {
break
}
}

idx = j

longCol = longCol[idx:]
lenLong = len(longCol)
}
return nil
}

// Union2VectorValen does a ∪ b ==> ret, keeps all item unique and sorted
// it assumes that va and vb all sorted already
func Union2VectorValen(va, vb *Vector, ret *Vector, mp *mpool.MPool) {
func Union2VectorValen(
va, vb *Vector,
ret *Vector,
mp *mpool.MPool) (err error) {

var i, j int
var prevVal []byte

Expand All @@ -4423,40 +4494,52 @@ func Union2VectorValen(va, vb *Vector, ret *Vector, mp *mpool.MPool) {

var lenA, lenB = len(cola), len(colb)

ret.PreExtend(lenA+lenB, mp)
if err = ret.PreExtend(lenA+lenB, mp); err != nil {
return err
}

for i < lenA && j < lenB {
bb := colb[j].GetByteSlice(areab)
ba := cola[i].GetByteSlice(areaa)
bb := colb[j].GetByteSlice(areab)

if bytes.Compare(ba, bb) <= 0 {
if (i == 0 && j == 0) || bytes.Equal(prevVal, ba) {
if (i == 0 && j == 0) || !bytes.Equal(prevVal, ba) {
prevVal = ba
AppendBytes(ret, ba, false, mp)
if err = AppendBytes(ret, ba, false, mp); err != nil {
return err
}
}
i++
} else {
if (i == 0 && j == 0) || bytes.Equal(prevVal, bb) {
if (i == 0 && j == 0) || !bytes.Equal(prevVal, bb) {
prevVal = bb
AppendBytes(ret, bb, false, mp)
if err = AppendBytes(ret, bb, false, mp); err != nil {
return err
}
}
j++
}
}

for ; i < lenA; i++ {
ba := cola[i].GetByteSlice(areaa)
if (i == 0 && j == 0) || bytes.Equal(prevVal, ba) {
if (i == 0 && j == 0) || !bytes.Equal(prevVal, ba) {
prevVal = ba
AppendBytes(ret, ba, false, mp)
if err = AppendBytes(ret, ba, false, mp); err != nil {
return err
}
}
}

for ; j < lenB; j++ {
bb := colb[j].GetByteSlice(areab)
if (i == 0 && j == 0) || bytes.Equal(prevVal, bb) {
if (i == 0 && j == 0) || !bytes.Equal(prevVal, bb) {
prevVal = bb
AppendBytes(ret, bb, false, mp)
if err = AppendBytes(ret, bb, false, mp); err != nil {
return err
}
}
}

return nil
}
Loading

0 comments on commit 82055b3

Please sign in to comment.