diff --git a/README.md b/README.md index bb604e4..be51ad2 100644 --- a/README.md +++ b/README.md @@ -254,6 +254,51 @@ movie.nearest_neighbors(:factors, distance: "cosine").first(5).map(&:name) See the complete code for [cube](examples/disco_item_recs_cube.rb) and [vector](examples/disco_item_recs_vector.rb) +## Method Options + +There are 3 options available when calling with the `nearest_neighbor` method. + +### order + +```ruby +movie = Movie.find_by(name: "Star Wars (1977)") +# Order all results by the neighbor_distance column in descending order +movie.nearest_neighbors(:factors, distance: "inner_product", order: { neighbor_distance: :desc }) +``` + +### limit + +```ruby +movie = Movie.find_by(name: "Star Wars (1977)") +# Limit the results to 3 records +movie.nearest_neighbors(:factors, distance: "inner_product", limit: 3) +``` + +### threshold + +```ruby +movie = Movie.find_by(name: "Star Wars (1977)") +# Only return records where the neighbor_distance is greater than or equal to 0.9 +movie.nearest_neighbors(:factors, distance: "inner_product", threshold: { gte: 0.9 }) +``` + +### Multiple Options + +All options can be used at the same time or separately. + +```ruby +movie = Movie.find_by(name: "Star Wars (1977)") + +# Only return 5 records where the neighbor_distance is greater than or equal to 0.9 in descending order +movie.nearest_neighbors( + :factors, + distance: "inner_product", + limit: 5, + threshold: { gte: 0.9 }, + order: { neighbor_distance: :desc } +) +``` + ## Upgrading ### 0.2.0 diff --git a/lib/neighbor/model.rb b/lib/neighbor/model.rb index 142aeba..1675628 100644 --- a/lib/neighbor/model.rb +++ b/lib/neighbor/model.rb @@ -1,4 +1,67 @@ module Neighbor + module Helpers + class << self + # Determines the operator for the distance function. + def determine_operator(distance, is_vector) + if is_vector + case distance + when "inner_product" + "<#>" + when "cosine" + "<=>" + when "euclidean" + "<->" + end + else + case distance + when "taxicab" + "<#>" + when "chebyshev" + "<=>" + when "euclidean", "cosine" + "<->" + end + end + end + # https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance + # with normalized vectors: + # cosine similarity = 1 - (euclidean distance)**2 / 2 + # cosine distance = 1 - cosine similarity + # this transformation doesn't change the order, so only needed for select + def neighbor_distance_statement(distance, order, is_vector) + if !is_vector && distance == "cosine" + "POWER(#{order}, 2) / 2.0" + elsif is_vector && distance == "inner_product" + "(#{order}) * -1" + else + order + end + end + # Opts key must be lt, lte, gt, or gte and have a numeric value + # Returns an array of strings that can be passed to ActiveRecord where method + # Example: {lt: 5, gte: 2} => ["neighbor_distance < 5", "neighbor_distance >= 2"] + def args_for_threshold(quoted_neighbor_field, opts) + raise ArgumentError, "Invalid threshold" unless opts.is_a?(Hash) + + opts.map do |key, value| + raise ArgumentError, "Invalid threshold: allowed keys are lt, lte, gt, gte" unless [:lt, :lte, :gt, :gte].include?(key) + raise ArgumentError, "Invalid threshold: value must be numeric type" unless value.is_a?(Numeric) + + case key + when :lt + ["#{quoted_neighbor_field} < ?", value] + when :lte + ["#{quoted_neighbor_field} <= ?", value] + when :gt + ["#{quoted_neighbor_field} > ?", value] + when :gte + ["#{quoted_neighbor_field} >= ?", value] + end + end + end + end + end + module Model def has_neighbors(*attribute_names, dimensions: nil, normalize: nil) if attribute_names.empty? @@ -32,51 +95,46 @@ def self.neighbor_attributes return if @neighbor_attributes.size != attribute_names.size - scope :nearest_neighbors, ->(attribute_name, vector = nil, distance:) { + scope :nearest_neighbors, ->(attribute_name, vector = nil, distance:, **kwargs) { + # Check optional arguments for threshold if vector.nil? && !attribute_name.nil? && attribute_name.respond_to?(:to_a) vector = attribute_name attribute_name = :neighbor_vector end + attribute_name = attribute_name.to_sym - options = neighbor_attributes[attribute_name] + raise ArgumentError, "Invalid attribute" unless options + normalize = options[:normalize] dimensions = options[:dimensions] + # Check optional arguments in options + order_option = kwargs[:order] || nil + limit_option = kwargs[:limit] || nil + threshold_option = kwargs[:threshold] || nil + return none if vector.nil? distance = distance.to_s + # Define the quoted attribute names quoted_attribute = "#{connection.quote_table_name(table_name)}.#{connection.quote_column_name(attribute_name)}" + quoted_neighbor = "#{connection.quote_table_name(table_name)}.#{connection.quote_column_name('neighbor_distance')}" column_info = klass.type_for_attribute(attribute_name).column_info - operator = - if column_info[:type] == :vector - case distance - when "inner_product" - "<#>" - when "cosine" - "<=>" - when "euclidean" - "<->" - end - else - case distance - when "taxicab" - "<#>" - when "chebyshev" - "<=>" - when "euclidean", "cosine" - "<->" - end - end + # Check if column type is vector or cube + is_cube = column_info[:type] == :cube + is_vector = column_info[:type] == :vector + + operator = Neighbor::Helpers.determine_operator(distance, is_vector) raise ArgumentError, "Invalid distance: #{distance}" unless operator # ensure normalize set (can be true or false) - if distance == "cosine" && column_info[:type] == :cube && normalize.nil? + if distance == "cosine" && is_cube && normalize.nil? raise Neighbor::Error, "Set normalize for cosine distance with cube" end @@ -84,33 +142,33 @@ def self.neighbor_attributes # important! neighbor_vector should already be typecast # but use to_f as extra safeguard against SQL injection - query = - if column_info[:type] == :vector - connection.quote("[#{vector.map(&:to_f).join(", ")}]") - else - "cube(array[#{vector.map(&:to_f).join(", ")}])" - end + query = is_vector ? connection.quote("[#{vector.map(&:to_f).join(", ")}]") : "cube(array[#{vector.map(&:to_f).join(", ")}])" order = "#{quoted_attribute} #{operator} #{query}" - # https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance - # with normalized vectors: - # cosine similarity = 1 - (euclidean distance)**2 / 2 - # cosine distance = 1 - cosine similarity - # this transformation doesn't change the order, so only needed for select - neighbor_distance = - if column_info[:type] != :vector && distance == "cosine" - "POWER(#{order}, 2) / 2.0" - elsif column_info[:type] == :vector && distance == "inner_product" - "(#{order}) * -1" - else - order - end + neighbor_distance = Neighbor::Helpers.neighbor_distance_statement(distance, order, is_vector) + + # Add ActiveRecord methods to options_chain if they are present in options + options_chain = [] + options_chain << [:limit, limit_option] if limit_option + options_chain << [:reorder, order_option] if order_option # for select, use column_names instead of * to account for ignored columns - select(*column_names, "#{neighbor_distance} AS neighbor_distance") + select_query = select(*column_names, "#{neighbor_distance} AS neighbor_distance") .where.not(attribute_name => nil) .order(Arel.sql(order)) + + # Add threshold query to select query if threshold option is present + if threshold_option + select_query = from(select_query, table_name.to_sym).where( + *Neighbor::Helpers.args_for_threshold(quoted_neighbor, threshold_option) + ) + end + + # Run through all options and apply them to the select query + options_chain.inject(select_query) do |obj, method_and_args| + obj.send(*method_and_args) + end } def nearest_neighbors(attribute_name = :neighbor_vector, **options) diff --git a/test/neighbor_attrs_test.rb b/test/neighbor_attrs_test.rb new file mode 100644 index 0000000..9f6d2cf --- /dev/null +++ b/test/neighbor_attrs_test.rb @@ -0,0 +1,85 @@ +require_relative "test_helper" + +class NeighborAttrsTest < Minitest::Test + def setup + Item.delete_all + end + + def test_attribute_order_desc + 4.times { |i| DimensionsItem.create!(embedding: [-i, 3, i]) } + + result_scores = DimensionsItem.nearest_neighbors( + :embedding, [3, 3, 3], + distance: "euclidean", + order: { neighbor_distance: :desc } + ).map(&:neighbor_distance) + + assert_equal result_scores, result_scores.sort.reverse + end + + def test_attribute_order_asc + 4.times { |i| DimensionsItem.create!(embedding: [-i, 3, i]) } + + result_scores = DimensionsItem.nearest_neighbors( + :embedding, [3, 3, 3], + distance: "euclidean", + order: { neighbor_distance: :desc } + ).map(&:neighbor_distance) + + assert_equal result_scores, result_scores.sort.reverse + end + + def test_attribute_limit + 3.times { |i| DimensionsItem.create!(embedding: [-i, 3, i]) } + + results = DimensionsItem.nearest_neighbors( + :embedding, [3, 3, 3], + distance: "euclidean", + limit: 1 + ) + + assert_equal 1, results.length + end + + def test_attribute_threshold_lt + # Close neighbor + DimensionsItem.create!(embedding: [3, 3, 3]) + # Far away neighbor + DimensionsItem.create!(embedding: [3, 3, 10]) + + results = DimensionsItem.nearest_neighbors( + :embedding, [3, 3, 3], + distance: "euclidean", + threshold: { lt: 1 } + ) + + assert_equal 1, results.length + end + + class MultipleAttibuteTests < Minitest::Test + def setup + 5.times { |i| DimensionsItem.create!(embedding: [-i, 5, i]) } + + # Run query using all options + @results = DimensionsItem.nearest_neighbors( + :embedding, [-3, 5, 3], + distance: "euclidean", + order: { neighbor_distance: :desc }, + threshold: { lte: 3 }, + limit: 2 + ) + end + + def test_multiple_attributes_limit + assert_equal 2, @results.length + end + + def test_multiple_attributes_order + assert_equal @results.map(&:neighbor_distance), @results.map(&:neighbor_distance).sort.reverse + end + + def test_multiple_attributes_threshold + assert @results.all? { |result| result.neighbor_distance <= 3 } + end + end +end \ No newline at end of file diff --git a/test/neighbor_test.rb b/test/neighbor_test.rb index 42b0d0f..e401a0f 100644 --- a/test/neighbor_test.rb +++ b/test/neighbor_test.rb @@ -132,7 +132,7 @@ def test_attribute_not_loaded Item.select(:id).find(1).nearest_neighbors(:embedding, distance: "euclidean") end end - + def test_large_dimensions max_dimensions = vector? ? 16000 : 100 error = assert_raises(ActiveRecord::StatementInvalid) do