From 8a21ef29239d3b9b6f981879c9bc9876c53af851 Mon Sep 17 00:00:00 2001 From: serge-rider Date: Tue, 29 Oct 2024 18:39:22 +0100 Subject: [PATCH] dbeaver/dbeaver#23361 Prepared statements support --- .../driver/libsql/LSqlPreparedStatement.java | 4 -- .../jdbc/driver/libsql/LSqlStatement.java | 22 +++--- .../jdbc/driver/libsql/client/LSqlClient.java | 70 +++++++++++++++++-- .../driver/libsql/client/LSqlReaderInput.java | 16 +++++ .../driver/libsql/client/LSqlStreamInput.java | 18 +++++ 5 files changed, 110 insertions(+), 20 deletions(-) diff --git a/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/LSqlPreparedStatement.java b/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/LSqlPreparedStatement.java index 3c21f7a..0b5ca48 100644 --- a/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/LSqlPreparedStatement.java +++ b/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/LSqlPreparedStatement.java @@ -26,13 +26,9 @@ import java.net.URL; import java.sql.*; import java.util.Calendar; -import java.util.LinkedHashMap; -import java.util.Map; public class LSqlPreparedStatement extends LSqlStatement implements PreparedStatement { - protected Map parameters = new LinkedHashMap<>(); - public LSqlPreparedStatement( @NotNull LSqlConnection connection, String sql) throws SQLException { super(connection); diff --git a/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/LSqlStatement.java b/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/LSqlStatement.java index acefb5b..2421c6d 100644 --- a/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/LSqlStatement.java +++ b/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/LSqlStatement.java @@ -24,10 +24,14 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; +import java.util.LinkedHashMap; +import java.util.Map; public class LSqlStatement extends AbstractJdbcStatement { protected String queryText; + protected Map parameters = new LinkedHashMap<>(); + protected LSqlExecutionResult executionResult; protected LSqlResultSet resultSet; @@ -37,7 +41,7 @@ public LSqlStatement(@NotNull LSqlConnection connection) throws SQLException { @Override public ResultSet executeQuery(String sql) throws SQLException { - executionResult = connection.getClient().execute(sql); + executionResult = connection.getClient().execute(sql, parameters); return getResultSet(); } @@ -48,49 +52,49 @@ public ResultSet executeQuery() throws SQLException { @Override protected boolean execute(@NotNull String sql, @Nullable int[] columnIndexes, @Nullable String[] columnNames, int autoGeneratedKeys) throws SQLException { - executionResult = connection.getClient().execute(sql); + executionResult = connection.getClient().execute(sql, parameters); return true; } @Override public boolean execute() throws SQLException { - executionResult = connection.getClient().execute(queryText); + executionResult = connection.getClient().execute(queryText, parameters); return true; } @Override protected int executeUpdate(@NotNull String sql, @Nullable int[] columnIndexes, @Nullable String[] columnNames, int autoGeneratedKeys) throws SQLException { - executionResult = connection.getClient().execute(sql); + executionResult = connection.getClient().execute(sql, parameters); return (int) executionResult.getUpdateCount(); } @Override public long executeLargeUpdate(String sql, String[] columnNames) throws SQLException { - executionResult = connection.getClient().execute(sql); + executionResult = connection.getClient().execute(sql, parameters); return executionResult.getUpdateCount(); } @Override public long executeLargeUpdate(String sql) throws SQLException { - executionResult = connection.getClient().execute(sql); + executionResult = connection.getClient().execute(sql, parameters); return executionResult.getUpdateCount(); } @Override public long executeLargeUpdate(String sql, int autoGeneratedKeys) throws SQLException { - executionResult = connection.getClient().execute(sql); + executionResult = connection.getClient().execute(sql, parameters); return executionResult.getUpdateCount(); } @Override public long executeLargeUpdate(String sql, int[] columnIndexes) throws SQLException { - executionResult = connection.getClient().execute(sql); + executionResult = connection.getClient().execute(sql, parameters); return executionResult.getUpdateCount(); } @Override public long executeLargeUpdate() throws SQLException { - executionResult = connection.getClient().execute(queryText); + executionResult = connection.getClient().execute(queryText, parameters); return executionResult.getUpdateCount(); } diff --git a/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlClient.java b/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlClient.java index a56e49e..472e4dd 100644 --- a/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlClient.java +++ b/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlClient.java @@ -16,6 +16,8 @@ import java.net.URL; import java.nio.charset.StandardCharsets; import java.sql.SQLException; +import java.util.Map; +import java.util.TreeMap; /** * The entry point to LibSQL client API. @@ -41,8 +43,8 @@ public LSqlClient(URL url, String authToken) { * * @return The result set. */ - public LSqlExecutionResult execute(String stmt) throws SQLException { - return batch(new String[]{stmt})[0]; + public LSqlExecutionResult execute(String stmt, Map parameters) throws SQLException { + return batch(new String[]{stmt}, new Map[]{ parameters })[0]; } /** @@ -51,14 +53,14 @@ public LSqlExecutionResult execute(String stmt) throws SQLException { * @param stmts The SQL statements. * @return The result sets. */ - public LSqlExecutionResult[] batch(String[] stmts) throws SQLException { + public LSqlExecutionResult[] batch(String[] stmts, Map[] parameters) throws SQLException { try { HttpURLConnection conn = openConnection(); conn.setRequestMethod("POST"); conn.setDoOutput(true); try (OutputStream os = conn.getOutputStream()) { - query(stmts, os); + executeQuery(stmts, parameters, os); } try (InputStreamReader in = new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8)) { //String responseStr = IOUtils.readToString(in); @@ -101,19 +103,73 @@ private void setAuthParameters(HttpURLConnection conn) { } } - private void query(String[] stmts, OutputStream os) throws IOException { + private void executeQuery(String[] stmts, Map[] parameters, OutputStream os) throws IOException { JsonWriter jsonWriter = new JsonWriter(new OutputStreamWriter(os, StandardCharsets.UTF_8)); jsonWriter.beginObject(); jsonWriter.name("statements"); jsonWriter.beginArray(); - for (String stmt : stmts) { - jsonWriter.value(stmt); + for (int i = 0; i < stmts.length; i++) { + String stmt = stmts[i]; + if (i < parameters.length && !CommonUtils.isEmpty(parameters[i])) { + // Query with parameters + jsonWriter.beginObject(); + jsonWriter.name("q"); + jsonWriter.value(stmt); + jsonWriter.name("params"); + if (isIndexedParams(parameters[i])) { + Map paramTree = new TreeMap<>(); + for (Map.Entry entry : parameters[i].entrySet()) { + paramTree.put((Integer) entry.getKey(), entry.getValue()); + } + jsonWriter.beginArray(); + for (Object value : paramTree.values()) { + serializeParameterValue(value, jsonWriter); + } + jsonWriter.endArray(); + } else { + jsonWriter.beginObject(); + for (Map.Entry param : parameters[i].entrySet()) { + jsonWriter.name(String.valueOf(param.getKey())); + serializeParameterValue(param.getValue(), jsonWriter); + } + jsonWriter.endObject(); + } + + jsonWriter.endObject(); + } else { + // Simple query + jsonWriter.value(stmt); + } } jsonWriter.endArray(); jsonWriter.endObject(); jsonWriter.flush(); } + private boolean isIndexedParams(Map parameter) { + if (!parameter.isEmpty()) { + Object key1 = parameter.keySet().iterator().next(); + if (key1 instanceof Integer) { + return true; + } + } + return false; + } + + private static void serializeParameterValue(Object value, JsonWriter jsonWriter) throws IOException { + if (value == null) { + jsonWriter.nullValue(); + } else if (value instanceof Number nValue) { + jsonWriter.value(nValue); + } else if (value instanceof Boolean bValue) { + jsonWriter.value(bValue); + } else if (value instanceof String strValue) { + jsonWriter.value(strValue); + } else { + jsonWriter.value(value.toString()); + } + } + private static class Response { public String error; public LSqlExecutionResult results; diff --git a/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlReaderInput.java b/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlReaderInput.java index 3fb6d9a..ee59a05 100644 --- a/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlReaderInput.java +++ b/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlReaderInput.java @@ -16,6 +16,9 @@ */ package com.dbeaver.jdbc.driver.libsql.client; +import org.jkiss.utils.IOUtils; + +import java.io.IOException; import java.io.Reader; public class LSqlReaderInput { @@ -26,4 +29,17 @@ public LSqlReaderInput(Reader stream, long length) { this.stream = stream; this.length = length; } + + @Override + public String toString() { + try { + String str = IOUtils.readToString(stream); + if (length <= 0) { + return str; + } + return str.substring(0, (int) length); + } catch (IOException e) { + return e.getMessage(); + } + } } diff --git a/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlStreamInput.java b/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlStreamInput.java index d385e6b..62674f0 100644 --- a/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlStreamInput.java +++ b/com.dbeaver.jdbc.driver.libsql/src/main/java/com/dbeaver/jdbc/driver/libsql/client/LSqlStreamInput.java @@ -16,7 +16,10 @@ */ package com.dbeaver.jdbc.driver.libsql.client; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.io.InputStream; +import java.nio.charset.StandardCharsets; public class LSqlStreamInput { private InputStream stream; @@ -26,4 +29,19 @@ public LSqlStreamInput(InputStream stream, long length) { this.stream = stream; this.length = length; } + + @Override + public String toString() { + try { + ByteArrayOutputStream result = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + for (int length; (length = stream.read(buffer)) != -1; ) { + result.write(buffer, 0, length); + } + return result.toString(StandardCharsets.UTF_8); + } catch (IOException e) { + return e.getMessage(); + } + } + }