diff --git a/lib/pgdice.rb b/lib/pgdice.rb index 281bbef..5179c29 100644 --- a/lib/pgdice.rb +++ b/lib/pgdice.rb @@ -42,6 +42,8 @@ require 'pgdice/partition_dropper' require 'pgdice/partition_dropper_factory' +require 'pgdice/query_executor' + require 'pgdice/database_connection' require 'pgdice/database_connection_factory' diff --git a/lib/pgdice/database_connection_factory.rb b/lib/pgdice/database_connection_factory.rb index 756b3e8..5bbfeba 100644 --- a/lib/pgdice/database_connection_factory.rb +++ b/lib/pgdice/database_connection_factory.rb @@ -10,7 +10,8 @@ class DatabaseConnectionFactory def initialize(configuration, opts = {}) @configuration = configuration - @query_executor = opts[:query_executor] ||= ->(query) { pg_connection.exec(query) } + @query_executor = opts[:query_executor] ||= PgDice::QueryExecutor.new(logger: logger, + connection_supplier: -> { pg_connection }) end def call diff --git a/lib/pgdice/query_executor.rb b/lib/pgdice/query_executor.rb new file mode 100644 index 0000000..665a815 --- /dev/null +++ b/lib/pgdice/query_executor.rb @@ -0,0 +1,22 @@ +# frozen_string_literal: true + +# Entry point for DatabaseConnection +module PgDice + # Wrapper class around pg_connection to reset connection on PG errors + class QueryExecutor + attr_reader :logger + + def initialize(logger:, connection_supplier:) + @logger = logger + @connection_supplier = connection_supplier + end + + def call(query) + @connection_supplier.call.exec(query) + rescue PG::Error => error + logger.error { "Caught error: #{error}. Going to reset connection and try again" } + @connection_supplier.call.reset + @connection_supplier.call.exec(query) + end + end +end diff --git a/test/pgdice/query_executor_test.rb b/test/pgdice/query_executor_test.rb new file mode 100644 index 0000000..88e7d3a --- /dev/null +++ b/test/pgdice/query_executor_test.rb @@ -0,0 +1,53 @@ +# frozen_string_literal: true + +require 'test_helper' + +class QueryExecutorTest < Minitest::Test + def setup + @should_raise = true + @resetter_call_count = 0 + @runner_call_count = 0 + + @resetter = resetter + @runner = runner + end + + def test_retry_on_pg_error + PgDice::QueryExecutor.new(logger: logger, connection_supplier: -> { MockPgConnection.new(@runner, @resetter) }) + .call('blah') + + assert_equal 2, @runner_call_count, 'Runner should be called twice when we catch a PG error' + assert_equal 1, @resetter_call_count, 'Resetter should be called once when we catch a PG error' + end + + private + + def resetter + proc do + @should_raise = false + @resetter_call_count += 1 + end + end + + def runner + proc do + @runner_call_count += 1 + raise PG::Error, 'Something bad' if @should_raise + end + end +end + +class MockPgConnection + def initialize(runner, resetter) + @runner = runner + @resetter = resetter + end + + def exec(query) + @runner.call(query) + end + + def reset + @resetter.call + end +end