From 1974c00a669d02f27d9f0a498ef8bae686613e62 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sat, 5 Oct 2024 13:32:58 -0700 Subject: [PATCH] Added experimental support for SQLite (sqlite-vec) --- CHANGELOG.md | 1 + Gemfile | 2 + README.md | 50 ++++++++- Rakefile | 14 +++ gemfiles/activerecord70.gemfile | 2 + gemfiles/activerecord71.gemfile | 2 + gemfiles/activerecord80.gemfile | 2 + lib/generators/neighbor/sqlite_generator.rb | 13 +++ .../neighbor/templates/sqlite.rb.tt | 2 + lib/neighbor.rb | 3 + lib/neighbor/attribute.rb | 31 ++++++ lib/neighbor/model.rb | 101 ++++++++++++------ lib/neighbor/sqlite.rb | 20 ++++ lib/neighbor/type/sqlite_vector.rb | 29 +++++ neighbor.gemspec | 2 +- test/sqlite_generator_test.rb | 14 +++ test/sqlite_test.rb | 62 +++++++++++ test/support/sqlite.rb | 21 ++++ test/test_helper.rb | 1 + 19 files changed, 336 insertions(+), 36 deletions(-) create mode 100644 lib/generators/neighbor/sqlite_generator.rb create mode 100644 lib/generators/neighbor/templates/sqlite.rb.tt create mode 100644 lib/neighbor/attribute.rb create mode 100644 lib/neighbor/sqlite.rb create mode 100644 lib/neighbor/type/sqlite_vector.rb create mode 100644 test/sqlite_generator_test.rb create mode 100644 test/sqlite_test.rb create mode 100644 test/support/sqlite.rb diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c88882..17aa2e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.5.0 (unreleased) +- Added experimental support for SQLite (sqlite-vec) - Changed `normalize` option to use Active Record normalization - Dropped support for Active Record < 7 diff --git a/Gemfile b/Gemfile index 9340244..2d94ae9 100644 --- a/Gemfile +++ b/Gemfile @@ -6,4 +6,6 @@ gem "rake" gem "minitest", ">= 5" gem "activerecord", "~> 7.2.0" gem "pg" +gem "sqlite3" +gem "sqlite-vec" gem "railties", require: false diff --git a/README.md b/README.md index ac6da71..68845c6 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,11 @@ # Neighbor -Nearest neighbor search for Rails and Postgres +Nearest neighbor search for Rails + +Supports: + +- Postgres (cube and pgvector) +- SQLite (sqlite-vec, experimental, unreleased) [![Build Status](https://github.com/ankane/neighbor/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/neighbor/actions) @@ -12,7 +17,7 @@ Add this line to your application’s Gemfile: gem "neighbor" ``` -## Choose An Extension +### For Postgres Neighbor supports two extensions: [cube](https://www.postgresql.org/docs/current/cube.html) and [pgvector](https://github.com/pgvector/pgvector). cube ships with Postgres, while pgvector supports more dimensions and approximate nearest neighbor search. @@ -30,6 +35,20 @@ rails generate neighbor:vector rails db:migrate ``` +### For SQLite + +Add this line to your application’s Gemfile: + +```ruby +gem "sqlite-vec" +``` + +And run: + +```sh +rails generate neighbor:sqlite +``` + ## Getting Started Create a migration @@ -37,9 +56,14 @@ Create a migration ```ruby class AddEmbeddingToItems < ActiveRecord::Migration[7.2] def change + # cube add_column :items, :embedding, :cube - # or + + # pgvector add_column :items, :embedding, :vector, limit: 3 # dimensions + + # sqlite-vec + add_column :items, :embedding, :blob end end ``` @@ -81,6 +105,7 @@ See the additional docs for: - [cube](#cube) - [pgvector](#pgvector) +- [sqlite-vec](#sqlite-vec) Or check out some [examples](#examples) @@ -241,6 +266,25 @@ embedding = Neighbor::SparseVector.new({0 => 0.9, 1 => 1.3, 2 => 1.1}, 3) Item.nearest_neighbors(:embedding, embedding, distance: "euclidean").first(5) ``` +## sqlite-vec + +### Distance + +Supported values are: + +- `euclidean` +- `cosine` + +### Dimensions + +For sqlite-vec, it’s a good idea to specify the number of dimensions to ensure all records have the same number. + +```ruby +class Item < ApplicationRecord + has_neighbors :embedding, dimensions: 3 +end +``` + ## Examples - [Embeddings](#openai-embeddings) with OpenAI diff --git a/Rakefile b/Rakefile index 1862bb6..b8fc664 100644 --- a/Rakefile +++ b/Rakefile @@ -1,6 +1,20 @@ require "bundler/gem_tasks" require "rake/testtask" +namespace :test do + Rake::TestTask.new(:postgresql) do |t| + t.description = "Run tests for Postgres" + t.libs << "test" + t.test_files = FileList["test/**/*_test.rb"].exclude("test/sqlite*_test.rb") + end + + Rake::TestTask.new(:sqlite) do |t| + t.description = "Run tests for SQLite" + t.libs << "test" + t.test_files = FileList["test/**/sqlite*_test.rb"] + end +end + Rake::TestTask.new(:test) do |t| t.libs << "test" t.test_files = FileList["test/**/*_test.rb"] diff --git a/gemfiles/activerecord70.gemfile b/gemfiles/activerecord70.gemfile index 6aa76d6..a7cd3f2 100644 --- a/gemfiles/activerecord70.gemfile +++ b/gemfiles/activerecord70.gemfile @@ -6,4 +6,6 @@ gem "rake" gem "minitest", ">= 5" gem "activerecord", "~> 7.0.0" gem "pg" +gem "sqlite3", "< 2" +gem "sqlite-vec" gem "railties", require: false diff --git a/gemfiles/activerecord71.gemfile b/gemfiles/activerecord71.gemfile index 33a0056..cd41a20 100644 --- a/gemfiles/activerecord71.gemfile +++ b/gemfiles/activerecord71.gemfile @@ -6,4 +6,6 @@ gem "rake" gem "minitest", ">= 5" gem "activerecord", "~> 7.1.0" gem "pg" +gem "sqlite3", "< 2" +gem "sqlite-vec" gem "railties", require: false diff --git a/gemfiles/activerecord80.gemfile b/gemfiles/activerecord80.gemfile index ce47545..c7bd494 100644 --- a/gemfiles/activerecord80.gemfile +++ b/gemfiles/activerecord80.gemfile @@ -6,4 +6,6 @@ gem "rake" gem "minitest", ">= 5" gem "activerecord", "~> 8.0.0.beta1" gem "pg" +gem "sqlite3" +gem "sqlite-vec" gem "railties", "~> 8.0.0.beta1", require: false diff --git a/lib/generators/neighbor/sqlite_generator.rb b/lib/generators/neighbor/sqlite_generator.rb new file mode 100644 index 0000000..9655d5d --- /dev/null +++ b/lib/generators/neighbor/sqlite_generator.rb @@ -0,0 +1,13 @@ +require "rails/generators" + +module Neighbor + module Generators + class SqliteGenerator < Rails::Generators::Base + source_root File.join(__dir__, "templates") + + def copy_templates + template "sqlite.rb", "config/initializers/neighbor.rb" + end + end + end +end diff --git a/lib/generators/neighbor/templates/sqlite.rb.tt b/lib/generators/neighbor/templates/sqlite.rb.tt new file mode 100644 index 0000000..b3b2e83 --- /dev/null +++ b/lib/generators/neighbor/templates/sqlite.rb.tt @@ -0,0 +1,2 @@ +# Load the sqlite-vec extension +Neighbor::SQLite.initialize! diff --git a/lib/neighbor.rb b/lib/neighbor.rb index 1368daa..a9cf1f7 100644 --- a/lib/neighbor.rb +++ b/lib/neighbor.rb @@ -4,6 +4,7 @@ # modules require_relative "neighbor/reranking" require_relative "neighbor/sparse_vector" +require_relative "neighbor/sqlite" require_relative "neighbor/utils" require_relative "neighbor/version" @@ -31,11 +32,13 @@ def initialize_type_map(m = type_map) end ActiveSupport.on_load(:active_record) do + require_relative "neighbor/attribute" require_relative "neighbor/model" require_relative "neighbor/normalized_attribute" require_relative "neighbor/type/cube" require_relative "neighbor/type/halfvec" require_relative "neighbor/type/sparsevec" + require_relative "neighbor/type/sqlite_vector" require_relative "neighbor/type/vector" extend Neighbor::Model diff --git a/lib/neighbor/attribute.rb b/lib/neighbor/attribute.rb new file mode 100644 index 0000000..e1ac832 --- /dev/null +++ b/lib/neighbor/attribute.rb @@ -0,0 +1,31 @@ +module Neighbor + class Attribute < ActiveRecord::Type::Value + delegate :type, :serialize, :deserialize, :cast, to: :new_cast_type + + def initialize(cast_type:, model:) + @cast_type = cast_type + @model = model + end + + private + + def cast_value(...) + new_cast_type.send(:cast_value, ...) + end + + def new_cast_type + @new_cast_type ||= begin + if @cast_type.is_a?(ActiveModel::Type::Value) + case @model.connection_db_config.adapter + when /sqlite/i + Type::SqliteVector.new + else + @cast_type + end + else + @cast_type + end + end + end + end +end diff --git a/lib/neighbor/model.rb b/lib/neighbor/model.rb index 9d73853..0757add 100644 --- a/lib/neighbor/model.rb +++ b/lib/neighbor/model.rb @@ -27,6 +27,18 @@ def self.neighbor_attributes @neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize} end + if ActiveRecord::VERSION::STRING.to_f >= 7.2 + decorate_attributes(attribute_names) do |_name, cast_type| + Neighbor::Attribute.new(cast_type: cast_type, model: self) + end + else + attribute_names.each do |attribute_name| + attribute attribute_name do |cast_type| + Neighbor::Attribute.new(cast_type: cast_type, model: self) + end + end + end + if normalize if ActiveRecord::VERSION::STRING.to_f >= 7.1 attribute_names.each do |attribute_name| @@ -76,39 +88,57 @@ def self.neighbor_attributes column_info = columns_hash[attribute_name.to_s] column_type = column_info&.type + adapter = + case connection.adapter_name + when /sqlite/i + :sqlite + else + :postgresql + end + operator = - case column_type - when :bit + case adapter + when :sqlite case distance - when "hamming" - "<~>" - when "jaccard" - "<%>" - when "hamming2" - "#" - end - when :vector, :halfvec, :sparsevec - case distance - when "inner_product" - "<#>" - when "cosine" - "<=>" when "euclidean" - "<->" - when "taxicab" - "<+>" - end - when :cube - case distance - when "taxicab" - "<#>" - when "chebyshev" - "<=>" - when "euclidean", "cosine" - "<->" + "vec_distance_L2" + when "cosine" + "vec_distance_cosine" end else - raise ArgumentError, "Unsupported type: #{column_type}" + case column_type + when :bit + case distance + when "hamming" + "<~>" + when "jaccard" + "<%>" + when "hamming2" + "#" + end + when :vector, :halfvec, :sparsevec + case distance + when "inner_product" + "<#>" + when "cosine" + "<=>" + when "euclidean" + "<->" + when "taxicab" + "<+>" + end + when :cube + case distance + when "taxicab" + "<#>" + when "chebyshev" + "<=>" + when "euclidean", "cosine" + "<->" + end + else + raise ArgumentError, "Unsupported type: #{column_type}" + end end raise ArgumentError, "Invalid distance: #{distance}" unless operator @@ -140,10 +170,17 @@ def self.neighbor_attributes end end - order = "#{quoted_attribute} #{operator} #{query}" - if operator == "#" - order = "bit_count(#{order})" - end + order = + case adapter + when :sqlite + "#{operator}(#{quoted_attribute}, #{query})" + else + if operator == "#" + "bit_count(#{quoted_attribute} # #{query})" + else + "#{quoted_attribute} #{operator} #{query}" + end + end # https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance # with normalized vectors: diff --git a/lib/neighbor/sqlite.rb b/lib/neighbor/sqlite.rb new file mode 100644 index 0000000..928df11 --- /dev/null +++ b/lib/neighbor/sqlite.rb @@ -0,0 +1,20 @@ +module Neighbor + module SQLite + def self.initialize! + require "sqlite_vec" + require "active_record/connection_adapters/sqlite3_adapter" + + ActiveRecord::ConnectionAdapters::SQLite3Adapter.prepend(InstanceMethods) + end + + module InstanceMethods + def configure_connection + super + db = ActiveRecord::VERSION::STRING.to_f >= 7.1 ? @raw_connection : @connection + db.enable_load_extension(1) + SqliteVec.load(db) + db.enable_load_extension(0) + end + end + end +end diff --git a/lib/neighbor/type/sqlite_vector.rb b/lib/neighbor/type/sqlite_vector.rb new file mode 100644 index 0000000..1a024e0 --- /dev/null +++ b/lib/neighbor/type/sqlite_vector.rb @@ -0,0 +1,29 @@ +module Neighbor + module Type + class SqliteVector < ActiveRecord::Type::Binary + def serialize(value) + if Utils.array?(value) + value = value.to_a.pack("f*") + end + super(value) + end + + def deserialize(value) + value = super + cast_value(value) unless value.nil? + end + + private + + def cast_value(value) + if value.is_a?(String) + value.unpack("f*") + elsif Utils.array?(value) + value.to_a + else + raise "can't cast #{value.class.name} to vector" + end + end + end + end +end diff --git a/neighbor.gemspec b/neighbor.gemspec index f95caa8..4ef7292 100644 --- a/neighbor.gemspec +++ b/neighbor.gemspec @@ -3,7 +3,7 @@ require_relative "lib/neighbor/version" Gem::Specification.new do |spec| spec.name = "neighbor" spec.version = Neighbor::VERSION - spec.summary = "Nearest neighbor search for Rails and Postgres" + spec.summary = "Nearest neighbor search for Rails" spec.homepage = "https://github.com/ankane/neighbor" spec.license = "MIT" diff --git a/test/sqlite_generator_test.rb b/test/sqlite_generator_test.rb new file mode 100644 index 0000000..90071bd --- /dev/null +++ b/test/sqlite_generator_test.rb @@ -0,0 +1,14 @@ +require_relative "test_helper" + +require "generators/neighbor/sqlite_generator" + +class SqliteGeneratorTest < Rails::Generators::TestCase + tests Neighbor::Generators::SqliteGenerator + destination File.expand_path("../tmp", __dir__) + setup :prepare_destination + + def test_works + run_generator + assert_file "config/initializers/neighbor.rb", /Neighbor::SQLite.initialize!/ + end +end diff --git a/test/sqlite_test.rb b/test/sqlite_test.rb new file mode 100644 index 0000000..ca2a930 --- /dev/null +++ b/test/sqlite_test.rb @@ -0,0 +1,62 @@ +require_relative "test_helper" + +class SqliteTest < Minitest::Test + def setup + SqliteItem.delete_all + end + + def test_cosine + create_items(SqliteItem, :embedding) + result = SqliteItem.find(1).nearest_neighbors(:embedding, distance: "cosine").first(3) + assert_equal [2, 3], result.map(&:id) + assert_elements_in_delta [0, 0.05719095841050148], result.map(&:neighbor_distance) + end + + def test_euclidean + create_items(SqliteItem, :embedding) + result = SqliteItem.find(1).nearest_neighbors(:embedding, distance: "euclidean").first(3) + assert_equal [3, 2], result.map(&:id) + assert_elements_in_delta [1, Math.sqrt(3)], result.map(&:neighbor_distance) + end + + def test_create + item = SqliteItem.create!(embedding: [1, 2, 3]) + assert_equal [1, 2, 3], item.embedding + end + + def test_vec_to_json + SqliteItem.create!(embedding: [1, 2, 3]) + assert_equal "[1.000000,2.000000,3.000000]", SqliteItem.pluck("vec_to_json(embedding)").last + end + + def test_schema + file = Tempfile.new + connection = ActiveRecord::VERSION::STRING.to_f >= 7.2 ? SqliteItem.connection_pool : SqliteItem.connection + ActiveRecord::SchemaDumper.dump(connection, file) + file.rewind + contents = file.read + refute_match "Could not dump table", contents + assert_match %{t.binary "embedding"}, contents + end + + def test_invalid_dimensions + error = assert_raises(ActiveRecord::RecordInvalid) do + SqliteItem.create!(embedding: [1, 1]) + end + assert_match "Validation failed: Embedding must have 3 dimensions", error.message + end + + def test_infinite + error = assert_raises(ActiveRecord::RecordInvalid) do + SqliteItem.create!(embedding: [Float::INFINITY, 0, 0]) + end + assert_equal "Validation failed: Embedding must have finite values", error.message + end + + def test_nan + error = assert_raises(ActiveRecord::RecordInvalid) do + SqliteItem.create!(embedding: [Float::NAN, 0, 0]) + end + assert_equal "Validation failed: Embedding must have finite values", error.message + end +end diff --git a/test/support/sqlite.rb b/test/support/sqlite.rb new file mode 100644 index 0000000..645deee --- /dev/null +++ b/test/support/sqlite.rb @@ -0,0 +1,21 @@ +class SqliteRecord < ActiveRecord::Base + self.abstract_class = true + + establish_connection adapter: "sqlite3", database: ":memory:" +end + +Neighbor::SQLite.initialize! + +SqliteRecord.connection.instance_eval do + create_table :items, force: true do |t| + t.blob :embedding + end +end + +class SqliteItem < SqliteRecord + has_neighbors :embedding, dimensions: 3 + self.table_name = "items" +end + +# ensure has_neighbors does not cause model schema to load +raise "has_neighbors loading model schema early" if SqliteItem.send(:schema_loaded?) diff --git a/test/test_helper.rb b/test/test_helper.rb index 930c062..3d5334c 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -10,6 +10,7 @@ ActiveRecord::Base.partial_inserts = false require_relative "support/postgresql" +require_relative "support/sqlite" class Minitest::Test def assert_elements_in_delta(expected, actual)