diff --git a/be/src/common/config.cpp b/be/src/common/config.cpp index f342791c4c5943..cccb1e2033e425 100644 --- a/be/src/common/config.cpp +++ b/be/src/common/config.cpp @@ -61,6 +61,8 @@ DEFINE_Int32(brpc_port, "8060"); DEFINE_Int32(arrow_flight_sql_port, "-1"); +DEFINE_mString(public_access_ip, ""); + // the number of bthreads for brpc, the default value is set to -1, // which means the number of bthreads is #cpu-cores DEFINE_Int32(brpc_num_threads, "256"); diff --git a/be/src/common/config.h b/be/src/common/config.h index 43b994ae8a3a24..4a1c07149afd2d 100644 --- a/be/src/common/config.h +++ b/be/src/common/config.h @@ -100,6 +100,11 @@ DECLARE_Int32(brpc_port); // Default -1, do not start arrow flight sql server. DECLARE_Int32(arrow_flight_sql_port); +// If priority_networks is incorrect but cannot be modified, set public_access_ip as BE’s real IP. +// For ADBC client fetch result, default is empty, the ADBC client uses the backend ip to fetch the result. +// If ADBC client cannot access the backend ip, can set public_access_ip to modify the fetch result ip. +DECLARE_mString(public_access_ip); + // the number of bthreads for brpc, the default value is set to -1, // which means the number of bthreads is #cpu-cores DECLARE_Int32(brpc_num_threads); diff --git a/be/src/service/internal_service.cpp b/be/src/service/internal_service.cpp index 86b75376a1017a..e541c738be595c 100644 --- a/be/src/service/internal_service.cpp +++ b/be/src/service/internal_service.cpp @@ -747,6 +747,9 @@ void PInternalServiceImpl::fetch_arrow_flight_schema(google::protobuf::RpcContro auto st = serialize_arrow_schema(&schema, &schema_str); if (st.ok()) { result->set_schema(std::move(schema_str)); + if (config::public_access_ip != "") { + result->set_be_arrow_flight_ip(config::public_access_ip); + } } st.to_protobuf(result->mutable_status()); }); diff --git a/be/src/util/arrow/row_batch.cpp b/be/src/util/arrow/row_batch.cpp index 6662f2e0ba7aee..b3cf5c5452024f 100644 --- a/be/src/util/arrow/row_batch.cpp +++ b/be/src/util/arrow/row_batch.cpp @@ -168,12 +168,13 @@ Status convert_to_arrow_schema(const RowDescriptor& row_desc, Status convert_expr_ctxs_arrow_schema(const vectorized::VExprContextSPtrs& output_vexpr_ctxs, std::shared_ptr* result) { std::vector> fields; - for (auto expr_ctx : output_vexpr_ctxs) { + for (int i = 0; i < output_vexpr_ctxs.size(); i++) { std::shared_ptr arrow_type; - auto root_expr = expr_ctx->root(); + auto root_expr = output_vexpr_ctxs.at(i)->root(); RETURN_IF_ERROR(convert_to_arrow_type(root_expr->type(), &arrow_type)); - auto field_name = root_expr->is_slot_ref() ? root_expr->expr_name() - : root_expr->data_type()->get_name(); + auto field_name = root_expr->is_slot_ref() && !root_expr->expr_name().empty() + ? root_expr->expr_name() + : fmt::format("{}_{}", root_expr->data_type()->get_name(), i); fields.push_back( std::make_shared(field_name, arrow_type, root_expr->is_nullable())); } diff --git a/fe/fe-core/pom.xml b/fe/fe-core/pom.xml index c5e31fcaccacf9..0375db287e6e21 100644 --- a/fe/fe-core/pom.xml +++ b/fe/fe-core/pom.xml @@ -699,6 +699,10 @@ under the License. io.grpc grpc-core + + org.apache.arrow + flight-sql-jdbc-driver + io.grpc grpc-context diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java index 7dc45029e7e509..a719081496b05b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/DistributedPlanner.java @@ -268,7 +268,6 @@ private PlanFragment createMergeFragment(PlanFragment inputFragment) mergePlan.init(ctx.getRootAnalyzer()); Preconditions.checkState(mergePlan.hasValidStats()); PlanFragment fragment = new PlanFragment(ctx.getNextFragmentId(), mergePlan, DataPartition.UNPARTITIONED); - fragment.setResultSinkType(ctx.getRootAnalyzer().getContext().getResultSinkType()); inputFragment.setDestination(mergePlan); return fragment; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/OriginalPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/planner/OriginalPlanner.java index 1140d326fe0c27..559c9b6d73d295 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/OriginalPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/OriginalPlanner.java @@ -257,6 +257,7 @@ public void createPlanFragments(StatementBase statement, Analyzer analyzer, TQue LOG.debug("substitute result Exprs {}", resExprs); rootFragment.setOutputExprs(resExprs); } + rootFragment.setResultSinkType(ConnectContext.get().getResultSinkType()); LOG.debug("finalize plan fragments"); for (PlanFragment fragment : fragments) { fragment.finalize(queryStmt); diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java index dcadf2795b1f21..d86ee57414643b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java @@ -112,6 +112,7 @@ public enum ConnectType { protected volatile long loginTime; // for arrow flight protected volatile String peerIdentity; + private final Map preparedQuerys = new HashMap<>(); private String runningQuery; private TNetworkAddress resultFlightServerAddr; private TNetworkAddress resultInternalServiceAddr; @@ -611,6 +612,18 @@ public void resetLoginTime() { this.loginTime = System.currentTimeMillis(); } + public void addPreparedQuery(String preparedStatementId, String preparedQuery) { + preparedQuerys.put(preparedStatementId, preparedQuery); + } + + public String getPreparedQuery(String preparedStatementId) { + return preparedQuerys.get(preparedStatementId); + } + + public void removePreparedQuery(String preparedStatementId) { + preparedQuerys.remove(preparedStatementId); + } + public void setRunningQuery(String runningQuery) { this.runningQuery = runningQuery; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java index 2ac2520ead95d4..cf34283fa95536 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java @@ -1687,7 +1687,7 @@ private TNetworkAddress toArrowFlightHost(TNetworkAddress host) throws Exception if (backend.getArrowFlightSqlPort() < 0) { return null; } - return new TNetworkAddress(backend.getHost(), backend.getArrowFlightSqlPort()); + return backend.getArrowFlightAddress(); } // estimate if this fragment contains UnionNode diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java index 237708183cb179..a383c6526d9f0c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java @@ -605,9 +605,6 @@ public void finalizeQuery() { } private void handleQueryWithRetry(TUniqueId queryId) throws Exception { - if (context.getConnectType() == ConnectType.ARROW_FLIGHT_SQL) { - context.setReturnResultFromLocal(false); - } // queue query here syncJournalIfNeeded(); QueueOfferToken offerRet = null; @@ -642,6 +639,9 @@ private void handleQueryWithRetry(TUniqueId queryId) throws Exception { DebugUtil.printId(queryId), i, DebugUtil.printId(newQueryId)); context.setQueryId(newQueryId); } + if (context.getConnectType() == ConnectType.ARROW_FLIGHT_SQL) { + context.setReturnResultFromLocal(false); + } handleQueryStmt(); break; } catch (RpcException e) { @@ -2305,18 +2305,23 @@ private void handleLockTablesStmt() { } public void handleExplainStmt(String result, boolean isNereids) throws IOException { - // TODO support arrow flight sql ShowResultSetMetaData metaData = ShowResultSetMetaData.builder() .addColumn(new Column("Explain String" + (isNereids ? "(Nereids Planner)" : "(Old Planner)"), ScalarType.createVarchar(20))) .build(); - sendMetaData(metaData); + if (context.getConnectType() == ConnectType.MYSQL) { + sendMetaData(metaData); - // Send result set. - for (String item : result.split("\n")) { - serializer.reset(); - serializer.writeLenEncodedString(item); - context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer()); + // Send result set. + for (String item : result.split("\n")) { + serializer.reset(); + serializer.writeLenEncodedString(item); + context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer()); + } + } else if (context.getConnectType() == ConnectType.ARROW_FLIGHT_SQL) { + context.getFlightSqlChannel() + .addResult(DebugUtil.printId(context.queryId()), context.getRunningQuery(), metaData, result); + context.setReturnResultFromLocal(true); } context.getState().setEof(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java index d2f8b46b893683..b7eda2c3ff5b3a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java @@ -47,6 +47,7 @@ import org.apache.arrow.flight.sql.SqlInfoBuilder; import org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; @@ -61,20 +62,31 @@ import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; +import org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity; import org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery; +import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable { private static final Logger LOG = LogManager.getLogger(DorisFlightSqlProducer.class); @@ -82,49 +94,97 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable private final BufferAllocator rootAllocator = new RootAllocator(); private final SqlInfoBuilder sqlInfoBuilder; private final FlightSessionsManager flightSessionsManager; + private final ExecutorService executorService = Executors.newFixedThreadPool(100); public DorisFlightSqlProducer(final Location location, FlightSessionsManager flightSessionsManager) { this.location = location; this.flightSessionsManager = flightSessionsManager; sqlInfoBuilder = new SqlInfoBuilder(); - sqlInfoBuilder.withFlightSqlServerName("DorisFE") - .withFlightSqlServerVersion("1.0") - .withFlightSqlServerArrowVersion("13.0") - .withFlightSqlServerReadOnly(false) - .withSqlIdentifierQuoteChar("`") - .withSqlDdlCatalog(true) - .withSqlDdlSchema(false) - .withSqlDdlTable(false) + sqlInfoBuilder.withFlightSqlServerName("DorisFE").withFlightSqlServerVersion("1.0") + .withFlightSqlServerArrowVersion("13.0").withFlightSqlServerReadOnly(false) + .withSqlIdentifierQuoteChar("`").withSqlDdlCatalog(true).withSqlDdlSchema(false).withSqlDdlTable(false) .withSqlIdentifierCase(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE) .withSqlQuotedIdentifierCase(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE); } + private static ByteBuffer serializeMetadata(final Schema schema) { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(outputStream)), schema); + + return ByteBuffer.wrap(outputStream.toByteArray()); + } catch (final IOException e) { + throw new RuntimeException("Failed to serialize arrow flight sql schema", e); + } + } + + private void getStreamStatementResult(String handle, ServerStreamListener listener) { + String[] handleParts = handle.split(":"); + String executedPeerIdentity = handleParts[0]; + String queryId = handleParts[1]; + ConnectContext connectContext = flightSessionsManager.getConnectContext(executedPeerIdentity); + try { + // The tokens used for authentication between getStreamStatement and getFlightInfoStatement are different. + final FlightSqlResultCacheEntry flightSqlResultCacheEntry = Objects.requireNonNull( + connectContext.getFlightSqlChannel().getResult(queryId)); + final VectorSchemaRoot vectorSchemaRoot = flightSqlResultCacheEntry.getVectorSchemaRoot(); + listener.start(vectorSchemaRoot); + listener.putNext(); + } catch (Exception e) { + listener.error(e); + String errMsg = "get stream statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e) + + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: " + + connectContext.getState().getErrorMessage(); + LOG.warn(errMsg, e); + throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException(); + } finally { + listener.completed(); + // The result has been sent or sent failed, delete it. + connectContext.getFlightSqlChannel().invalidate(queryId); + } + } + @Override public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context, final ServerStreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("getStreamPreparedStatement unimplemented").toRuntimeException(); + getStreamStatementResult(command.getPreparedStatementHandle().toStringUtf8(), listener); + } + + @Override + public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context, + final ServerStreamListener listener) { + getStreamStatementResult(ticketStatementQuery.getStatementHandle().toStringUtf8(), listener); } @Override public void closePreparedStatement(final ActionClosePreparedStatementRequest request, final CallContext context, final StreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("closePreparedStatement unimplemented").toRuntimeException(); + executorService.submit(() -> { + try { + String[] handleParts = request.getPreparedStatementHandle().toStringUtf8().split(":"); + String executedPeerIdentity = handleParts[0]; + String preparedStatementId = handleParts[1]; + flightSessionsManager.getConnectContext(executedPeerIdentity).removePreparedQuery(preparedStatementId); + } catch (final Exception e) { + listener.onError(e); + return; + } + listener.onCompleted(); + }); } - @Override - public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context, + private FlightInfo executeQueryStatement(String peerIdentity, ConnectContext connectContext, String query, final FlightDescriptor descriptor) { - ConnectContext connectContext = null; + Preconditions.checkState(null != connectContext); + Preconditions.checkState(!query.isEmpty()); try { - connectContext = flightSessionsManager.getConnectContext(context.peerIdentity()); // After the previous query was executed, there was no getStreamStatement to take away the result. connectContext.getFlightSqlChannel().reset(); - final String query = request.getQuery(); final FlightSqlConnectProcessor flightSQLConnectProcessor = new FlightSqlConnectProcessor(connectContext); flightSQLConnectProcessor.handleQuery(query); if (connectContext.getState().getStateType() == MysqlStateType.ERR) { - throw new RuntimeException("after handleQuery"); + throw new RuntimeException("after executeQueryStatement handleQuery"); } if (connectContext.isReturnResultFromLocal()) { @@ -132,30 +192,30 @@ public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, fi if (connectContext.getFlightSqlChannel().resultNum() == 0) { // a random query id and add empty results String queryId = UUID.randomUUID().toString(); - connectContext.getFlightSqlChannel().addEmptyResult(queryId, query); + connectContext.getFlightSqlChannel().addOKResult(queryId, query); - final ByteString handle = ByteString.copyFromUtf8(context.peerIdentity() + ":" + queryId); + final ByteString handle = ByteString.copyFromUtf8(peerIdentity + ":" + queryId); TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder().setStatementHandle(handle) .build(); return getFlightInfoForSchema(ticketStatement, descriptor, connectContext.getFlightSqlChannel().getResult(queryId).getVectorSchemaRoot().getSchema()); + } else { + // A Flight Sql request can only contain one statement that returns result, + // otherwise expected thrown exception during execution. + Preconditions.checkState(connectContext.getFlightSqlChannel().resultNum() == 1); + + // The tokens used for authentication between getStreamStatement and getFlightInfoStatement + // are different. So put the peerIdentity into the ticket and then getStreamStatement is used to + // find the correct ConnectContext. + // queryId is used to find query results. + final ByteString handle = ByteString.copyFromUtf8( + peerIdentity + ":" + DebugUtil.printId(connectContext.queryId())); + TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder().setStatementHandle(handle) + .build(); + return getFlightInfoForSchema(ticketStatement, descriptor, + connectContext.getFlightSqlChannel().getResult(DebugUtil.printId(connectContext.queryId())) + .getVectorSchemaRoot().getSchema()); } - - // A Flight Sql request can only contain one statement that returns result, - // otherwise expected thrown exception during execution. - Preconditions.checkState(connectContext.getFlightSqlChannel().resultNum() == 1); - - // The tokens used for authentication between getStreamStatement and getFlightInfoStatement - // are different. So put the peerIdentity into the ticket and then getStreamStatement is used to find - // the correct ConnectContext. - // queryId is used to find query results. - final ByteString handle = ByteString.copyFromUtf8( - context.peerIdentity() + ":" + DebugUtil.printId(connectContext.queryId())); - TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder().setStatementHandle(handle) - .build(); - return getFlightInfoForSchema(ticketStatement, descriptor, - connectContext.getFlightSqlChannel().getResult(DebugUtil.printId(connectContext.queryId())) - .getVectorSchemaRoot().getSchema()); } else { // Now only query stmt will pull results from BE. final ByteString handle = ByteString.copyFromUtf8( @@ -176,24 +236,31 @@ public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, fi return new FlightInfo(schema, descriptor, endpoints, -1, -1); } } catch (Exception e) { - if (null != connectContext) { - connectContext.setCommand(MysqlCommand.COM_SLEEP); - String errMsg = "get flight info statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage( - e) + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: " - + connectContext.getState().getErrorMessage(); - LOG.warn(errMsg, e); - throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException(); - } - LOG.warn("get flight info statement failed, " + e.getMessage(), e); - throw CallStatus.INTERNAL.withDescription(Util.getRootCauseMessage(e)).withCause(e).toRuntimeException(); + connectContext.setCommand(MysqlCommand.COM_SLEEP); + String errMsg = "get flight info statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e) + + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: " + + connectContext.getState().getErrorMessage(); + LOG.warn(errMsg, e); + throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException(); } } + @Override + public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context, + final FlightDescriptor descriptor) { + ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity()); + return executeQueryStatement(context.peerIdentity(), connectContext, request.getQuery(), descriptor); + } + @Override public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context, final FlightDescriptor descriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("getFlightInfoPreparedStatement unimplemented") - .toRuntimeException(); + String[] handleParts = command.getPreparedStatementHandle().toStringUtf8().split(":"); + String executedPeerIdentity = handleParts[0]; + String preparedStatementId = handleParts[1]; + ConnectContext connectContext = flightSessionsManager.getConnectContext(executedPeerIdentity); + return executeQueryStatement(executedPeerIdentity, connectContext, + connectContext.getPreparedQuery(preparedStatementId), descriptor); } @Override @@ -202,42 +269,6 @@ public SchemaResult getSchemaStatement(final CommandStatementQuery command, fina throw CallStatus.UNIMPLEMENTED.withDescription("getSchemaStatement unimplemented").toRuntimeException(); } - @Override - public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context, - final ServerStreamListener listener) { - ConnectContext connectContext = null; - final String handle = ticketStatementQuery.getStatementHandle().toStringUtf8(); - String[] handleParts = handle.split(":"); - String executedPeerIdentity = handleParts[0]; - String queryId = handleParts[1]; - try { - // The tokens used for authentication between getStreamStatement and getFlightInfoStatement are different. - connectContext = flightSessionsManager.getConnectContext(executedPeerIdentity); - final FlightSqlResultCacheEntry flightSqlResultCacheEntry = Objects.requireNonNull( - connectContext.getFlightSqlChannel().getResult(queryId)); - final VectorSchemaRoot vectorSchemaRoot = flightSqlResultCacheEntry.getVectorSchemaRoot(); - listener.start(vectorSchemaRoot); - listener.putNext(); - } catch (Exception e) { - listener.error(e); - if (null != connectContext) { - String errMsg = "get stream statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e) - + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: " - + connectContext.getState().getErrorMessage(); - LOG.warn(errMsg, e); - throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException(); - } - LOG.warn("get stream statement failed, " + e.getMessage(), e); - throw CallStatus.INTERNAL.withDescription(Util.getRootCauseMessage(e)).withCause(e).toRuntimeException(); - } finally { - listener.completed(); - if (null != connectContext) { - // The result has been sent, delete it. - connectContext.getFlightSqlChannel().invalidate(queryId); - } - } - } - @Override public void close() throws Exception { AutoCloseables.close(rootAllocator); @@ -248,10 +279,54 @@ public void listFlights(CallContext context, Criteria criteria, StreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("createPreparedStatement unimplemented").toRuntimeException(); + // TODO can only execute complete SQL, not support SQL parameters. + executorService.submit(() -> { + ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity()); + try { + final String query = request.getQuery(); + String preparedStatementId = UUID.randomUUID().toString(); + final ByteString handle = ByteString.copyFromUtf8(context.peerIdentity() + ":" + preparedStatementId); + connectContext.addPreparedQuery(preparedStatementId, query); + + VectorSchemaRoot emptyVectorSchemaRoot = new VectorSchemaRoot(new ArrayList<>(), new ArrayList<>()); + final Schema parameterSchema = emptyVectorSchemaRoot.getSchema(); + // TODO FE does not have the ability to convert root fragment output expr into arrow schema. + // However, the metaData schema returned by createPreparedStatement is usually not used by the client, + // but it cannot be empty, otherwise it will be mistaken by the client as an updata statement. + // see: https://github.com/apache/arrow/issues/38911 + Schema metaData = connectContext.getFlightSqlChannel() + .createOneOneSchemaRoot("ResultMeta", "UNIMPLEMENTED").getSchema(); + listener.onNext(new Result( + Any.pack(buildCreatePreparedStatementResult(handle, parameterSchema, metaData)) + .toByteArray())); + } catch (Exception e) { + connectContext.setCommand(MysqlCommand.COM_SLEEP); + String errMsg = "create prepared statement failed, " + e.getMessage() + ", " + + Util.getRootCauseMessage(e) + ", error code: " + connectContext.getState().getErrorCode() + + ", error msg: " + connectContext.getState().getErrorMessage(); + LOG.warn(errMsg, e); + listener.onError(CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException()); + return; + } catch (final Throwable t) { + listener.onError(CallStatus.INTERNAL.withDescription("Unknown error: " + t).toRuntimeException()); + return; + } + listener.onCompleted(); + }); } @Override @@ -268,8 +343,22 @@ public Runnable acceptPutStatement(CommandStatementUpdate command, CallContext c @Override public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate command, CallContext context, FlightStream flightStream, StreamListener ackStream) { - throw CallStatus.UNIMPLEMENTED.withDescription("acceptPutPreparedStatementUpdate unimplemented") - .toRuntimeException(); + return () -> { + while (flightStream.next()) { + final VectorSchemaRoot root = flightStream.getRoot(); + final int rowCount = root.getRowCount(); + // TODO support update + Preconditions.checkState(rowCount == 0); + + final int recordCount = -1; + final DoPutUpdateResult build = DoPutUpdateResult.newBuilder().setRecordCount(recordCount).build(); + try (final ArrowBuf buffer = rootAllocator.buffer(build.getSerializedSize())) { + buffer.writeBytes(build.toByteArray()); + ackStream.onNext(PutResult.metadata(buffer)); + } + } + ackStream.onCompleted(); + }; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlConnectProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlConnectProcessor.java index ae353fdb033ce3..1655d69c80feb8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlConnectProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlConnectProcessor.java @@ -121,7 +121,10 @@ public Schema fetchArrowFlightSchema(int timeoutMs) { Status status = new Status(); status.setPstatus(pResult.getStatus()); throw new RuntimeException(String.format("fetch arrow flight schema failed, finstId: %s, errmsg: %s", - DebugUtil.printId(tid), status)); + DebugUtil.printId(tid), status.getErrorMsg())); + } + if (pResult.hasBeArrowFlightIp()) { + ctx.getResultFlightServerAddr().hostname = pResult.getBeArrowFlightIp().toStringUtf8(); } if (pResult.hasSchema() && pResult.getSchema().size() > 0) { RootAllocator rootAllocator = new RootAllocator(Integer.MAX_VALUE); diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlChannel.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlChannel.java index 174e733c2db1a0..5eeb89ba031ff4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlChannel.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlChannel.java @@ -47,12 +47,8 @@ public class FlightSqlChannel { public FlightSqlChannel() { // The Stmt result is not picked up by the Client within 10 minutes and will be deleted. - resultCache = - CacheBuilder.newBuilder() - .maximumSize(100) - .expireAfterWrite(10, TimeUnit.MINUTES) - .removalListener(new ResultRemovalListener()) - .build(); + resultCache = CacheBuilder.newBuilder().maximumSize(100).expireAfterWrite(10, TimeUnit.MINUTES) + .removalListener(new ResultRemovalListener()).build(); allocator = new RootAllocator(Long.MAX_VALUE); } @@ -98,19 +94,53 @@ public void addResult(String queryId, String runningQuery, ResultSet resultSet) resultCache.put(queryId, flightSqlResultCacheEntry); } - public void addEmptyResult(String queryId, String query) { + public void addResult(String queryId, String runningQuery, ResultSetMetaData metaData, String result) { List schemaFields = new ArrayList<>(); List dataFields = new ArrayList<>(); - schemaFields.add(new Field("StatusResult", FieldType.nullable(new Utf8()), null)); - VarCharVector varCharVector = new VarCharVector("StatusResult", allocator); + + // TODO: only support varchar type + for (Column col : metaData.getColumns()) { + schemaFields.add(new Field(col.getName(), FieldType.nullable(new Utf8()), null)); + VarCharVector varCharVector = new VarCharVector(col.getName(), allocator); + varCharVector.allocateNew(); + varCharVector.setValueCount(result.split("\n").length); + dataFields.add(varCharVector); + } + + int rowNum = 0; + for (String item : result.split("\n")) { + if (item == null || item.equals(FeConstants.null_string)) { + dataFields.get(0).setNull(rowNum); + } else { + ((VarCharVector) dataFields.get(0)).setSafe(rowNum, item.getBytes()); + } + rowNum += 1; + } + VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(schemaFields, dataFields); + final FlightSqlResultCacheEntry flightSqlResultCacheEntry = new FlightSqlResultCacheEntry(vectorSchemaRoot, + runningQuery); + resultCache.put(queryId, flightSqlResultCacheEntry); + } + + /** + * Create a SchemaRoot with one row and one column. + */ + public VectorSchemaRoot createOneOneSchemaRoot(String colName, String colValue) { + List schemaFields = new ArrayList<>(); + List dataFields = new ArrayList<>(); + schemaFields.add(new Field(colName, FieldType.nullable(new Utf8()), null)); + VarCharVector varCharVector = new VarCharVector(colName, allocator); varCharVector.allocateNew(); varCharVector.setValueCount(1); - varCharVector.setSafe(0, "OK".getBytes()); + varCharVector.setSafe(0, colValue.getBytes()); dataFields.add(varCharVector); - VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(schemaFields, dataFields); - final FlightSqlResultCacheEntry flightSqlResultCacheEntry = new FlightSqlResultCacheEntry(vectorSchemaRoot, - query); + return new VectorSchemaRoot(schemaFields, dataFields); + } + + public void addOKResult(String queryId, String query) { + final FlightSqlResultCacheEntry flightSqlResultCacheEntry = new FlightSqlResultCacheEntry( + createOneOneSchemaRoot("StatusResult", "OK"), query); resultCache.put(queryId, flightSqlResultCacheEntry); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsManager.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsManager.java index f850384d4ed96c..275bc8085dd99a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsManager.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsManager.java @@ -66,8 +66,7 @@ static ConnectContext buildConnectContext(String peerIdentity, UserIdentity user connectContext.setConnectScheduler(ExecuteEnv.getInstance().getScheduler()); if (!ExecuteEnv.getInstance().getScheduler().registerConnection(connectContext)) { - connectContext.getState().setError(ErrorCode.ERR_TOO_MANY_USER_CONNECTIONS, - "Reach limit of connections"); + connectContext.getState().setError(ErrorCode.ERR_TOO_MANY_USER_CONNECTIONS, "Reach limit of connections"); throw CallStatus.UNAUTHENTICATED.withDescription("Reach limit of connections").toRuntimeException(); } return connectContext; diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsWithTokenManager.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsWithTokenManager.java index e1866b094b2641..fc0e79290377bb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsWithTokenManager.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsWithTokenManager.java @@ -17,6 +17,7 @@ package org.apache.doris.service.arrowflight.sessions; +import org.apache.doris.common.util.Util; import org.apache.doris.qe.ConnectContext; import org.apache.doris.service.ExecuteEnv; import org.apache.doris.service.arrowflight.tokens.FlightTokenDetails; @@ -37,18 +38,23 @@ public FlightSessionsWithTokenManager(FlightTokenManager flightTokenManager) { @Override public ConnectContext getConnectContext(String peerIdentity) { - ConnectContext connectContext = ExecuteEnv.getInstance().getScheduler().getContext(peerIdentity); - if (null == connectContext) { - connectContext = createConnectContext(peerIdentity); + try { + ConnectContext connectContext = ExecuteEnv.getInstance().getScheduler().getContext(peerIdentity); if (null == connectContext) { - flightTokenManager.invalidateToken(peerIdentity); - String err = "UserSession expire after access, need reauthorize."; - LOG.error(err); - throw CallStatus.UNAUTHENTICATED.withDescription(err).toRuntimeException(); + connectContext = createConnectContext(peerIdentity); + if (null == connectContext) { + flightTokenManager.invalidateToken(peerIdentity); + String err = "UserSession expire after access, need reauthorize."; + LOG.error(err); + throw CallStatus.UNAUTHENTICATED.withDescription(err).toRuntimeException(); + } + return connectContext; } return connectContext; + } catch (Exception e) { + LOG.warn("getConnectContext failed, " + e.getMessage(), e); + throw CallStatus.INTERNAL.withDescription(Util.getRootCauseMessage(e)).withCause(e).toRuntimeException(); } - return connectContext; } @Override diff --git a/fe/fe-core/src/test/java/org/apache/doris/service/FlightSqlJDBC.java b/fe/fe-core/src/test/java/org/apache/doris/service/FlightSqlJDBC.java new file mode 100644 index 00000000000000..083d9537ea8730 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/service/FlightSqlJDBC.java @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from + +package org.apache.doris.service; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +public class FlightSqlJDBC { + // JDBC driver name and database URL + static final String DB_URL = "jdbc:arrow-flight-sql://0.0.0.0:10478?useServerPrepStmts=false" + + "&cachePrepStmts=true&useSSL=false&useEncryption=false"; + + // Database credentials + static final String USER = "root"; + static final String PASS = ""; + + public static void main(String[] args) throws ClassNotFoundException { + Connection conn = null; + Statement stmt = null; + Class.forName("org.apache.arrow.driver.jdbc.ArrowFlightJdbcDriver"); + try { + + conn = DriverManager.getConnection(DB_URL, USER, PASS); + stmt = conn.createStatement(); + + stmt.executeQuery("set dry_run_query=true"); + stmt.executeQuery("use information_schema;"); + String sql = "show tables;"; + + ResultSet resultSet = stmt.executeQuery(sql); + while (resultSet.next()) { + String col1 = resultSet.getString(1); + System.out.println(col1); + } + + stmt.execute(sql); + try (final ResultSet resultSet2 = stmt.getResultSet()) { + final int columnCount = resultSet2.getMetaData().getColumnCount(); + System.out.println(columnCount); + while (resultSet2.next()) { + String col1 = resultSet2.getString(1); + System.out.println(col1); + } + } + + resultSet.close(); + stmt.close(); + conn.close(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } +} diff --git a/fe/pom.xml b/fe/pom.xml index 18990f4fc241ef..0a8c2598843e22 100644 --- a/fe/pom.xml +++ b/fe/pom.xml @@ -305,7 +305,11 @@ under the License. 1.1.0 0.45.2-public 1.11.2 - 13.0.0 + + 15.0.0-SNAPSHOT 0.13.1 2.7.4-11 @@ -375,6 +379,10 @@ under the License. cloudera https://repository.cloudera.com/repository/libs-release-local/ + + arrow-apache-nightlies + https://nightlies.apache.org/arrow/java + @@ -1511,6 +1519,11 @@ under the License. flight-sql ${arrow.version} + + org.apache.arrow + flight-sql-jdbc-driver + ${arrow.version} + org.apache.arrow arrow-memory-core diff --git a/gensrc/proto/internal_service.proto b/gensrc/proto/internal_service.proto index 529977f43d85d0..544a91a3cb57f5 100644 --- a/gensrc/proto/internal_service.proto +++ b/gensrc/proto/internal_service.proto @@ -258,6 +258,7 @@ message PFetchArrowFlightSchemaResult { optional PStatus status = 1; // valid when status is ok optional bytes schema = 2; + optional bytes be_arrow_flight_ip = 3; }; message KeyTuple {