Skip to content

Commit

Permalink
dbeaver/dbeaver#23361 Prepared statements support
Browse files Browse the repository at this point in the history
  • Loading branch information
serge-rider committed Oct 29, 2024
1 parent 7d09bc8 commit 8a21ef2
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,9 @@
import java.net.URL;
import java.sql.*;
import java.util.Calendar;
import java.util.LinkedHashMap;
import java.util.Map;

public class LSqlPreparedStatement extends LSqlStatement implements PreparedStatement {

protected Map<Object, Object> parameters = new LinkedHashMap<>();

public LSqlPreparedStatement(
@NotNull LSqlConnection connection, String sql) throws SQLException {
super(connection);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.util.LinkedHashMap;
import java.util.Map;

public class LSqlStatement extends AbstractJdbcStatement<LSqlConnection> {

protected String queryText;
protected Map<Object, Object> parameters = new LinkedHashMap<>();

protected LSqlExecutionResult executionResult;
protected LSqlResultSet resultSet;

Expand All @@ -37,7 +41,7 @@ public LSqlStatement(@NotNull LSqlConnection connection) throws SQLException {

@Override
public ResultSet executeQuery(String sql) throws SQLException {
executionResult = connection.getClient().execute(sql);
executionResult = connection.getClient().execute(sql, parameters);
return getResultSet();
}

Expand All @@ -48,49 +52,49 @@ public ResultSet executeQuery() throws SQLException {

@Override
protected boolean execute(@NotNull String sql, @Nullable int[] columnIndexes, @Nullable String[] columnNames, int autoGeneratedKeys) throws SQLException {
executionResult = connection.getClient().execute(sql);
executionResult = connection.getClient().execute(sql, parameters);
return true;
}

@Override
public boolean execute() throws SQLException {
executionResult = connection.getClient().execute(queryText);
executionResult = connection.getClient().execute(queryText, parameters);
return true;
}

@Override
protected int executeUpdate(@NotNull String sql, @Nullable int[] columnIndexes, @Nullable String[] columnNames, int autoGeneratedKeys) throws SQLException {
executionResult = connection.getClient().execute(sql);
executionResult = connection.getClient().execute(sql, parameters);
return (int) executionResult.getUpdateCount();
}

@Override
public long executeLargeUpdate(String sql, String[] columnNames) throws SQLException {
executionResult = connection.getClient().execute(sql);
executionResult = connection.getClient().execute(sql, parameters);
return executionResult.getUpdateCount();
}

@Override
public long executeLargeUpdate(String sql) throws SQLException {
executionResult = connection.getClient().execute(sql);
executionResult = connection.getClient().execute(sql, parameters);
return executionResult.getUpdateCount();
}

@Override
public long executeLargeUpdate(String sql, int autoGeneratedKeys) throws SQLException {
executionResult = connection.getClient().execute(sql);
executionResult = connection.getClient().execute(sql, parameters);
return executionResult.getUpdateCount();
}

@Override
public long executeLargeUpdate(String sql, int[] columnIndexes) throws SQLException {
executionResult = connection.getClient().execute(sql);
executionResult = connection.getClient().execute(sql, parameters);
return executionResult.getUpdateCount();
}

@Override
public long executeLargeUpdate() throws SQLException {
executionResult = connection.getClient().execute(queryText);
executionResult = connection.getClient().execute(queryText, parameters);
return executionResult.getUpdateCount();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.sql.SQLException;
import java.util.Map;
import java.util.TreeMap;

/**
* The entry point to LibSQL client API.
Expand All @@ -41,8 +43,8 @@ public LSqlClient(URL url, String authToken) {
*
* @return The result set.
*/
public LSqlExecutionResult execute(String stmt) throws SQLException {
return batch(new String[]{stmt})[0];
public LSqlExecutionResult execute(String stmt, Map<Object, Object> parameters) throws SQLException {
return batch(new String[]{stmt}, new Map[]{ parameters })[0];
}

/**
Expand All @@ -51,14 +53,14 @@ public LSqlExecutionResult execute(String stmt) throws SQLException {
* @param stmts The SQL statements.
* @return The result sets.
*/
public LSqlExecutionResult[] batch(String[] stmts) throws SQLException {
public LSqlExecutionResult[] batch(String[] stmts, Map<Object, Object>[] parameters) throws SQLException {
try {
HttpURLConnection conn = openConnection();
conn.setRequestMethod("POST");
conn.setDoOutput(true);

try (OutputStream os = conn.getOutputStream()) {
query(stmts, os);
executeQuery(stmts, parameters, os);
}
try (InputStreamReader in = new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8)) {
//String responseStr = IOUtils.readToString(in);
Expand Down Expand Up @@ -101,19 +103,73 @@ private void setAuthParameters(HttpURLConnection conn) {
}
}

private void query(String[] stmts, OutputStream os) throws IOException {
private void executeQuery(String[] stmts, Map<Object, Object>[] parameters, OutputStream os) throws IOException {
JsonWriter jsonWriter = new JsonWriter(new OutputStreamWriter(os, StandardCharsets.UTF_8));
jsonWriter.beginObject();
jsonWriter.name("statements");
jsonWriter.beginArray();
for (String stmt : stmts) {
jsonWriter.value(stmt);
for (int i = 0; i < stmts.length; i++) {
String stmt = stmts[i];
if (i < parameters.length && !CommonUtils.isEmpty(parameters[i])) {
// Query with parameters
jsonWriter.beginObject();
jsonWriter.name("q");
jsonWriter.value(stmt);
jsonWriter.name("params");
if (isIndexedParams(parameters[i])) {
Map<Integer, Object> paramTree = new TreeMap<>();
for (Map.Entry<?,?> entry : parameters[i].entrySet()) {
paramTree.put((Integer) entry.getKey(), entry.getValue());
}
jsonWriter.beginArray();
for (Object value : paramTree.values()) {
serializeParameterValue(value, jsonWriter);
}
jsonWriter.endArray();
} else {
jsonWriter.beginObject();
for (Map.Entry<?, ?> param : parameters[i].entrySet()) {
jsonWriter.name(String.valueOf(param.getKey()));
serializeParameterValue(param.getValue(), jsonWriter);
}
jsonWriter.endObject();
}

jsonWriter.endObject();
} else {
// Simple query
jsonWriter.value(stmt);
}
}
jsonWriter.endArray();
jsonWriter.endObject();
jsonWriter.flush();
}

private boolean isIndexedParams(Map<Object, Object> parameter) {
if (!parameter.isEmpty()) {
Object key1 = parameter.keySet().iterator().next();
if (key1 instanceof Integer) {
return true;
}
}
return false;
}

private static void serializeParameterValue(Object value, JsonWriter jsonWriter) throws IOException {
if (value == null) {
jsonWriter.nullValue();
} else if (value instanceof Number nValue) {
jsonWriter.value(nValue);
} else if (value instanceof Boolean bValue) {
jsonWriter.value(bValue);
} else if (value instanceof String strValue) {
jsonWriter.value(strValue);
} else {
jsonWriter.value(value.toString());
}
}

private static class Response {
public String error;
public LSqlExecutionResult results;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
package com.dbeaver.jdbc.driver.libsql.client;

import org.jkiss.utils.IOUtils;

import java.io.IOException;
import java.io.Reader;

public class LSqlReaderInput {
Expand All @@ -26,4 +29,17 @@ public LSqlReaderInput(Reader stream, long length) {
this.stream = stream;
this.length = length;
}

@Override
public String toString() {
try {
String str = IOUtils.readToString(stream);
if (length <= 0) {
return str;
}
return str.substring(0, (int) length);
} catch (IOException e) {
return e.getMessage();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
*/
package com.dbeaver.jdbc.driver.libsql.client;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;

public class LSqlStreamInput {
private InputStream stream;
Expand All @@ -26,4 +29,19 @@ public LSqlStreamInput(InputStream stream, long length) {
this.stream = stream;
this.length = length;
}

@Override
public String toString() {
try {
ByteArrayOutputStream result = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
for (int length; (length = stream.read(buffer)) != -1; ) {
result.write(buffer, 0, length);
}
return result.toString(StandardCharsets.UTF_8);
} catch (IOException e) {
return e.getMessage();
}
}

}

0 comments on commit 8a21ef2

Please sign in to comment.