Skip to content

Commit

Permalink
Added support for binary vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 5, 2024
1 parent c988a51 commit feea554
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 12 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,22 @@ class Item < ApplicationRecord
end
```

### Binary Vectors

Use the `type` option for binary vectors

```ruby
class Item < ApplicationRecord
has_neighbors :embedding, type: :bit
end
```

Get the nearest neighbors by Hamming distance

```ruby
Item.nearest_neighbors(:embedding, "\x05", distance: "hamming").first(5)
```

## MariaDB

### Distance
Expand Down
8 changes: 5 additions & 3 deletions lib/neighbor/attribute.rb
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ def new_cast_type
if @cast_type.is_a?(ActiveModel::Type::Value)
case Utils.adapter(@model)
when :sqlite
case @type.to_s
when "int8"
case @type
when :int8
Type::SqliteInt8Vector.new
when "float32", ""
when :bit
@cast_type
when :float32, nil
Type::SqliteFloat32Vector.new
else
raise ArgumentError, "Unsupported type"
Expand Down
15 changes: 10 additions & 5 deletions lib/neighbor/model.rb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def self.neighbor_attributes

attribute_names.each do |attribute_name|
raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type}
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type&.to_sym}
end

if ActiveRecord::VERSION::STRING.to_f >= 7.2
Expand Down Expand Up @@ -63,11 +63,12 @@ def self.neighbor_attributes
column_info = self.class.columns_hash[k.to_s]
dimensions = v[:dimensions]
dimensions ||= column_info&.limit unless column_info&.type == :binary
type = v[:type] || column_info&.type

if !Neighbor::Utils.validate_dimensions(value, column_info&.type, dimensions).nil?
if !Neighbor::Utils.validate_dimensions(value, type, dimensions).nil?
errors.add(k, "must have #{dimensions} dimensions")
end
if !Neighbor::Utils.validate_finite(value, column_info&.type)
if !Neighbor::Utils.validate_finite(value, type)
errors.add(k, "must have finite values")
end
end
Expand Down Expand Up @@ -106,7 +107,8 @@ def self.neighbor_attributes

column_attribute = klass.type_for_attribute(attribute_name)
vector = column_attribute.cast(vector)
Neighbor::Utils.validate(vector, dimensions: dimensions, column_info: column_info)
dimensions ||= column_info&.limit unless column_info&.type == :binary
Neighbor::Utils.validate(vector, dimensions: dimensions, type: type || column_info&.type)
vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize

query = connection.quote(column_attribute.serialize(vector))
Expand All @@ -129,8 +131,11 @@ def self.neighbor_attributes
order =
case adapter
when :sqlite
if type.to_s == "int8"
case type
when :int8
"#{operator}(vec_int8(#{quoted_attribute}), vec_int8(#{query}))"
when :bit
"#{operator}(vec_bit(#{quoted_attribute}), vec_bit(#{query}))"
else
"#{operator}(#{quoted_attribute}, #{query})"
end
Expand Down
9 changes: 5 additions & 4 deletions lib/neighbor/utils.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ def self.validate_finite(value, type)
end
end

def self.validate(value, dimensions:, column_info:)
dimensions ||= column_info&.limit unless column_info&.type == :binary
if (message = validate_dimensions(value, column_info&.type, dimensions))
def self.validate(value, dimensions:, type:)
if (message = validate_dimensions(value, type, dimensions))
raise Error, message
end

if !validate_finite(value, column_info&.type)
if !validate_finite(value, type)
raise Error, "Values must be finite"
end
end
Expand Down Expand Up @@ -65,6 +64,8 @@ def self.operator(adapter, column_type, distance)
"vec_distance_cosine"
when "taxicab"
"vec_distance_L1"
when "hamming"
"vec_distance_hamming"
end
when :mariadb
case column_type
Expand Down
28 changes: 28 additions & 0 deletions test/sqlite_bit_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
require_relative "test_helper"
require_relative "support/sqlite"

class SqliteBitTest < Minitest::Test
def setup
SqliteItem.delete_all
end

def test_hamming
create_bit_items
result = SqliteItem.find(1).nearest_neighbors(:binary_embedding, distance: "hamming").first(3)
assert_equal [2, 3], result.map(&:id)
assert_elements_in_delta [2, 3], result.map(&:neighbor_distance)
end

def test_hamming_scope
create_bit_items
result = SqliteItem.nearest_neighbors(:binary_embedding, "\x05", distance: "hamming").first(5)
assert_equal [2, 3, 1], result.map(&:id)
assert_elements_in_delta [0, 1, 2], result.map(&:neighbor_distance)
end

def create_bit_items
SqliteItem.create!(id: 1, binary_embedding: "\x00")
SqliteItem.create!(id: 2, binary_embedding: "\x05")
SqliteItem.create!(id: 3, binary_embedding: "\x07")
end
end
2 changes: 2 additions & 0 deletions test/support/sqlite.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class SqliteRecord < ActiveRecord::Base
create_table :items, force: true do |t|
t.binary :embedding
t.binary :int8_embedding
t.binary :binary_embedding
end

if ActiveRecord::VERSION::MAJOR >= 8
Expand Down Expand Up @@ -44,6 +45,7 @@ class SqliteRecord < ActiveRecord::Base
class SqliteItem < SqliteRecord
has_neighbors :embedding, dimensions: 3
has_neighbors :int8_embedding, dimensions: 3, type: :int8
has_neighbors :binary_embedding, type: :bit
self.table_name = "items"
end

Expand Down

0 comments on commit feea554

Please sign in to comment.