diff --git a/lib/neighbor.rb b/lib/neighbor.rb index 1dc4286..8759a3d 100644 --- a/lib/neighbor.rb +++ b/lib/neighbor.rb @@ -1,38 +1,25 @@ # dependencies require "active_support" -# modules +# adapter hooks +require_relative "neighbor/mysql" require_relative "neighbor/postgresql" +require_relative "neighbor/sqlite" + +# modules require_relative "neighbor/reranking" require_relative "neighbor/sparse_vector" -require_relative "neighbor/sqlite" require_relative "neighbor/utils" require_relative "neighbor/version" module Neighbor class Error < StandardError; end - - module MysqlRegisterTypes - def initialize_type_map(m) - super - register_vector_type(m) - end - - def register_vector_type(m) - m.register_type %r(^vector)i do |sql_type| - limit = extract_limit(sql_type) - Type::MysqlVector.new(limit: limit) - end - end - end end ActiveSupport.on_load(:active_record) do require_relative "neighbor/attribute" require_relative "neighbor/model" require_relative "neighbor/normalized_attribute" - require_relative "neighbor/type/mysql_vector" - require_relative "neighbor/type/sqlite_vector" extend Neighbor::Model @@ -42,21 +29,7 @@ def register_vector_type(m) # tries to load pg gem, which may not be available end - require "active_record/connection_adapters/abstract_mysql_adapter" - - # ensure schema can be dumped - ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"} - - # ensure schema can be loaded - unless ActiveRecord::ConnectionAdapters::TableDefinition.method_defined?(:vector) - ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :vector) - end - - # prevent unknown OID warning - ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.singleton_class.prepend(Neighbor::MysqlRegisterTypes) - if ActiveRecord::VERSION::STRING.to_f < 7.1 - ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.register_vector_type(ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::TYPE_MAP) - end + Neighbor::MySQL.initialize! end require_relative "neighbor/railtie" if defined?(Rails::Railtie) diff --git a/lib/neighbor/mysql.rb b/lib/neighbor/mysql.rb new file mode 100644 index 0000000..c84fc30 --- /dev/null +++ b/lib/neighbor/mysql.rb @@ -0,0 +1,37 @@ +module Neighbor + module MySQL + def self.initialize! + require_relative "type/mysql_vector" + + require "active_record/connection_adapters/abstract_mysql_adapter" + + # ensure schema can be dumped + ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"} + + # ensure schema can be loaded + unless ActiveRecord::ConnectionAdapters::TableDefinition.method_defined?(:vector) + ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :vector) + end + + # prevent unknown OID warning + ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.singleton_class.prepend(RegisterTypes) + if ActiveRecord::VERSION::STRING.to_f < 7.1 + ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.register_vector_type(ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::TYPE_MAP) + end + end + + module RegisterTypes + def initialize_type_map(m) + super + register_vector_type(m) + end + + def register_vector_type(m) + m.register_type %r(^vector)i do |sql_type| + limit = extract_limit(sql_type) + Type::MysqlVector.new(limit: limit) + end + end + end + end +end diff --git a/lib/neighbor/sqlite.rb b/lib/neighbor/sqlite.rb index 928df11..44f7731 100644 --- a/lib/neighbor/sqlite.rb +++ b/lib/neighbor/sqlite.rb @@ -1,6 +1,7 @@ module Neighbor module SQLite def self.initialize! + require_relative "type/sqlite_vector" require "sqlite_vec" require "active_record/connection_adapters/sqlite3_adapter"