Skip to content

Commit

Permalink
Switch ExtensionSpecifier from #sqlite_extension_path to #to_path
Browse files Browse the repository at this point in the history
which is a way better idea, thanks @tenderlove
  • Loading branch information
flavorjones committed Nov 26, 2024
1 parent 175c05e commit a8b1655
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
18 changes: 9 additions & 9 deletions lib/sqlite3/database.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ module SQLite3
# db.enable_load_extension(true)
# db.load_extension("/path/to/extension")
#
# As of v2.4.0, it's also possible to pass an object that responds to
# +#sqlite_extension_path+. This documentation will refer to the supported interface as
# +_ExtensionSpecifier+, which can be expressed in RBS syntax as:
# As of v2.4.0, it's also possible to pass an object that responds to +#to_path+. This
# documentation will refer to the supported interface as +_ExtensionSpecifier+, which can be
# expressed in RBS syntax as:
#
# interface _ExtensionSpecifier
# def sqlite_extension_path: () → String
# def to_path: () → String
# end
#
# So, for example, if you are using the {sqlean gem}[https://github.com/flavorjones/sqlean-ruby]
Expand Down Expand Up @@ -718,7 +718,7 @@ def busy_handler_timeout=(milliseconds)
#
# [Parameters]
# - +extension_specifier+: (String | +_ExtensionSpecifier+) If a String, it is the filesystem path
# to the sqlite extension file. If an object that responds to #sqlite_extension_path, the
# to the sqlite extension file. If an object that responds to #to_path, the
# return value of that method is used as the filesystem path to the sqlite extension file.
#
# [Example] Using a filesystem path:
Expand All @@ -730,8 +730,8 @@ def busy_handler_timeout=(milliseconds)
# db.load_extension(SQLean::VSV)
#
def load_extension(extension_specifier)
if extension_specifier.respond_to?(:sqlite_extension_path)
extension_specifier = extension_specifier.sqlite_extension_path
if extension_specifier.respond_to?(:to_path)
extension_specifier = extension_specifier.to_path
elsif !extension_specifier.is_a?(String)
raise TypeError, "extension_specifier #{extension_specifier.inspect} is not a String or a valid extension specifier object"
end
Expand All @@ -747,11 +747,11 @@ def marshal_extensions(extensions) # :nodoc:

extensions.each do |extension|
# marshall the extension into an object if it's the name of a constant that responds to
# `#sqlite_extension_path`
# `#to_path`
if extension.is_a?(String)
begin
extension_spec = Object.const_get(extension)
if extension_spec.respond_to?(:sqlite_extension_path)
if extension_spec.respond_to?(:to_path)
extension = extension_spec
end
rescue NameError
Expand Down
19 changes: 16 additions & 3 deletions test/test_database.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

module SQLite3
class FakeExtensionSpecifier
def self.sqlite_extension_path
def self.to_path
"/path/to/extension"
end
end
Expand Down Expand Up @@ -672,20 +672,28 @@ def test_load_extension_error_with_nonexistent_path
db.enable_load_extension(true)

assert_raises(SQLite3::Exception) { db.load_extension("/path/to/extension") }
assert_raises(SQLite3::Exception) { db.load_extension(Pathname.new("foo")) }
end

def test_load_extension_error_with_invalid_argument
skip("extensions are not enabled") unless db.respond_to?(:load_extension)
db.enable_load_extension(true)

assert_raises(TypeError) { db.load_extension(1) }
assert_raises(TypeError) { db.load_extension(Pathname.new("foo")) }
assert_raises(TypeError) { db.load_extension({a: 1}) }
assert_raises(TypeError) { db.load_extension([]) }
assert_raises(TypeError) { db.load_extension(Object.new) }
end

def test_load_extension_with_an_extension_descriptor
mock_database_load_extension_internal(db)

db.load_extension(FakeExtensionSpecifier)
db.load_extension(Pathname.new("/path/to/ext2"))
assert_equal(["/path/to/ext2"], db.load_extension_internal_path)

db.load_extension_internal_path.clear # reset

db.load_extension(FakeExtensionSpecifier)
assert_equal(["/path/to/extension"], db.load_extension_internal_path)
end

Expand All @@ -712,6 +720,11 @@ def enable_load_extension(...)
def test_marshal_extensions_object_is_an_extension_specifier
mock_database_load_extension_internal(db)

db.marshal_extensions([Pathname.new("/path/to/extension")])
assert_equal(["/path/to/extension"], db.load_extension_internal_path)

db.load_extension_internal_path.clear # reset

db.marshal_extensions([FakeExtensionSpecifier])
assert_equal(["/path/to/extension"], db.load_extension_internal_path)

Expand Down

0 comments on commit a8b1655

Please sign in to comment.