Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for Threshold, Limit, and Order Arguments #12

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,51 @@ movie.nearest_neighbors(:factors, distance: "cosine").first(5).map(&:name)

See the complete code for [cube](examples/disco_item_recs_cube.rb) and [vector](examples/disco_item_recs_vector.rb)

## Method Options

There are 3 options available when calling with the `nearest_neighbor` method.

### order

```ruby
movie = Movie.find_by(name: "Star Wars (1977)")
# Order all results by the neighbor_distance column in descending order
movie.nearest_neighbors(:factors, distance: "inner_product", order: { neighbor_distance: :desc })
```

### limit

```ruby
movie = Movie.find_by(name: "Star Wars (1977)")
# Limit the results to 3 records
movie.nearest_neighbors(:factors, distance: "inner_product", limit: 3)
```

### threshold

```ruby
movie = Movie.find_by(name: "Star Wars (1977)")
# Only return records where the neighbor_distance is greater than or equal to 0.9
movie.nearest_neighbors(:factors, distance: "inner_product", threshold: { gte: 0.9 })
```

### Multiple Options

All options can be used at the same time or separately.

```ruby
movie = Movie.find_by(name: "Star Wars (1977)")

# Only return 5 records where the neighbor_distance is greater than or equal to 0.9 in descending order
movie.nearest_neighbors(
:factors,
distance: "inner_product",
limit: 5,
threshold: { gte: 0.9 },
order: { neighbor_distance: :desc }
)
```

## Upgrading

### 0.2.0
Expand Down
144 changes: 101 additions & 43 deletions lib/neighbor/model.rb
Original file line number Diff line number Diff line change
@@ -1,4 +1,67 @@
module Neighbor
module Helpers
class << self
# Determines the operator for the distance function.
def determine_operator(distance, is_vector)
if is_vector
case distance
when "inner_product"
"<#>"
when "cosine"
"<=>"
when "euclidean"
"<->"
end
else
case distance
when "taxicab"
"<#>"
when "chebyshev"
"<=>"
when "euclidean", "cosine"
"<->"
end
end
end
# https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
# with normalized vectors:
# cosine similarity = 1 - (euclidean distance)**2 / 2
# cosine distance = 1 - cosine similarity
# this transformation doesn't change the order, so only needed for select
def neighbor_distance_statement(distance, order, is_vector)
if !is_vector && distance == "cosine"
"POWER(#{order}, 2) / 2.0"
elsif is_vector && distance == "inner_product"
"(#{order}) * -1"
else
order
end
end
# Opts key must be lt, lte, gt, or gte and have a numeric value
# Returns an array of strings that can be passed to ActiveRecord where method
# Example: {lt: 5, gte: 2} => ["neighbor_distance < 5", "neighbor_distance >= 2"]
def args_for_threshold(quoted_neighbor_field, opts)
raise ArgumentError, "Invalid threshold" unless opts.is_a?(Hash)

opts.map do |key, value|
raise ArgumentError, "Invalid threshold: allowed keys are lt, lte, gt, gte" unless [:lt, :lte, :gt, :gte].include?(key)
raise ArgumentError, "Invalid threshold: value must be numeric type" unless value.is_a?(Numeric)

case key
when :lt
["#{quoted_neighbor_field} < ?", value]
when :lte
["#{quoted_neighbor_field} <= ?", value]
when :gt
["#{quoted_neighbor_field} > ?", value]
when :gte
["#{quoted_neighbor_field} >= ?", value]
end
end
end
end
end

module Model
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil)
if attribute_names.empty?
Expand Down Expand Up @@ -32,85 +95,80 @@ def self.neighbor_attributes

return if @neighbor_attributes.size != attribute_names.size

scope :nearest_neighbors, ->(attribute_name, vector = nil, distance:) {
scope :nearest_neighbors, ->(attribute_name, vector = nil, distance:, **kwargs) {
# Check optional arguments for threshold
if vector.nil? && !attribute_name.nil? && attribute_name.respond_to?(:to_a)
vector = attribute_name
attribute_name = :neighbor_vector
end

attribute_name = attribute_name.to_sym

options = neighbor_attributes[attribute_name]

raise ArgumentError, "Invalid attribute" unless options

normalize = options[:normalize]
dimensions = options[:dimensions]

# Check optional arguments in options
order_option = kwargs[:order] || nil
limit_option = kwargs[:limit] || nil
threshold_option = kwargs[:threshold] || nil

return none if vector.nil?

distance = distance.to_s

# Define the quoted attribute names
quoted_attribute = "#{connection.quote_table_name(table_name)}.#{connection.quote_column_name(attribute_name)}"
quoted_neighbor = "#{connection.quote_table_name(table_name)}.#{connection.quote_column_name('neighbor_distance')}"

column_info = klass.type_for_attribute(attribute_name).column_info

operator =
if column_info[:type] == :vector
case distance
when "inner_product"
"<#>"
when "cosine"
"<=>"
when "euclidean"
"<->"
end
else
case distance
when "taxicab"
"<#>"
when "chebyshev"
"<=>"
when "euclidean", "cosine"
"<->"
end
end
# Check if column type is vector or cube
is_cube = column_info[:type] == :cube
is_vector = column_info[:type] == :vector

operator = Neighbor::Helpers.determine_operator(distance, is_vector)

raise ArgumentError, "Invalid distance: #{distance}" unless operator

# ensure normalize set (can be true or false)
if distance == "cosine" && column_info[:type] == :cube && normalize.nil?
if distance == "cosine" && is_cube && normalize.nil?
raise Neighbor::Error, "Set normalize for cosine distance with cube"
end

vector = Neighbor::Vector.cast(vector, dimensions: dimensions, normalize: normalize, column_info: column_info)

# important! neighbor_vector should already be typecast
# but use to_f as extra safeguard against SQL injection
query =
if column_info[:type] == :vector
connection.quote("[#{vector.map(&:to_f).join(", ")}]")
else
"cube(array[#{vector.map(&:to_f).join(", ")}])"
end
query = is_vector ? connection.quote("[#{vector.map(&:to_f).join(", ")}]") : "cube(array[#{vector.map(&:to_f).join(", ")}])"

order = "#{quoted_attribute} #{operator} #{query}"

# https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
# with normalized vectors:
# cosine similarity = 1 - (euclidean distance)**2 / 2
# cosine distance = 1 - cosine similarity
# this transformation doesn't change the order, so only needed for select
neighbor_distance =
if column_info[:type] != :vector && distance == "cosine"
"POWER(#{order}, 2) / 2.0"
elsif column_info[:type] == :vector && distance == "inner_product"
"(#{order}) * -1"
else
order
end
neighbor_distance = Neighbor::Helpers.neighbor_distance_statement(distance, order, is_vector)

# Add ActiveRecord methods to options_chain if they are present in options
options_chain = []
options_chain << [:limit, limit_option] if limit_option
options_chain << [:reorder, order_option] if order_option

# for select, use column_names instead of * to account for ignored columns
select(*column_names, "#{neighbor_distance} AS neighbor_distance")
select_query = select(*column_names, "#{neighbor_distance} AS neighbor_distance")
.where.not(attribute_name => nil)
.order(Arel.sql(order))

# Add threshold query to select query if threshold option is present
if threshold_option
select_query = from(select_query, table_name.to_sym).where(
*Neighbor::Helpers.args_for_threshold(quoted_neighbor, threshold_option)
)
end

# Run through all options and apply them to the select query
options_chain.inject(select_query) do |obj, method_and_args|
obj.send(*method_and_args)
end
}

def nearest_neighbors(attribute_name = :neighbor_vector, **options)
Expand Down
85 changes: 85 additions & 0 deletions test/neighbor_attrs_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
require_relative "test_helper"

class NeighborAttrsTest < Minitest::Test
def setup
Item.delete_all
end

def test_attribute_order_desc
4.times { |i| DimensionsItem.create!(embedding: [-i, 3, i]) }

result_scores = DimensionsItem.nearest_neighbors(
:embedding, [3, 3, 3],
distance: "euclidean",
order: { neighbor_distance: :desc }
).map(&:neighbor_distance)

assert_equal result_scores, result_scores.sort.reverse
end

def test_attribute_order_asc
4.times { |i| DimensionsItem.create!(embedding: [-i, 3, i]) }

result_scores = DimensionsItem.nearest_neighbors(
:embedding, [3, 3, 3],
distance: "euclidean",
order: { neighbor_distance: :desc }
).map(&:neighbor_distance)

assert_equal result_scores, result_scores.sort.reverse
end

def test_attribute_limit
3.times { |i| DimensionsItem.create!(embedding: [-i, 3, i]) }

results = DimensionsItem.nearest_neighbors(
:embedding, [3, 3, 3],
distance: "euclidean",
limit: 1
)

assert_equal 1, results.length
end

def test_attribute_threshold_lt
# Close neighbor
DimensionsItem.create!(embedding: [3, 3, 3])
# Far away neighbor
DimensionsItem.create!(embedding: [3, 3, 10])

results = DimensionsItem.nearest_neighbors(
:embedding, [3, 3, 3],
distance: "euclidean",
threshold: { lt: 1 }
)

assert_equal 1, results.length
end

class MultipleAttibuteTests < Minitest::Test
def setup
5.times { |i| DimensionsItem.create!(embedding: [-i, 5, i]) }

# Run query using all options
@results = DimensionsItem.nearest_neighbors(
:embedding, [-3, 5, 3],
distance: "euclidean",
order: { neighbor_distance: :desc },
threshold: { lte: 3 },
limit: 2
)
end

def test_multiple_attributes_limit
assert_equal 2, @results.length
end

def test_multiple_attributes_order
assert_equal @results.map(&:neighbor_distance), @results.map(&:neighbor_distance).sort.reverse
end

def test_multiple_attributes_threshold
assert @results.all? { |result| result.neighbor_distance <= 3 }
end
end
end
2 changes: 1 addition & 1 deletion test/neighbor_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_attribute_not_loaded
Item.select(:id).find(1).nearest_neighbors(:embedding, distance: "euclidean")
end
end

def test_large_dimensions
max_dimensions = vector? ? 16000 : 100
error = assert_raises(ActiveRecord::StatementInvalid) do
Expand Down