From 200544bcd4bb6b1b0dfdcb96d0a6a6a3b8ad4351 Mon Sep 17 00:00:00 2001 From: reymondzzzz Date: Tue, 14 Nov 2023 18:23:09 +0300 Subject: [PATCH] add request cancellation and fix phantom completion text --- .../smallcloud/refactai/io/AsyncConnection.kt | 33 +++++++++++-------- .../smallcloud/refactai/io/RequestHelpers.kt | 10 +++--- .../refactai/lsp/LSPProcessHolder.kt | 1 - .../modes/completion/CompletionMode.kt | 28 ++++++++++------ .../refactai/panes/gptchat/ChatGPTProvider.kt | 2 +- .../refactai/struct/SMCPrediction.kt | 5 ++- .../smallcloud/refactai/struct/SMCRequest.kt | 17 +++++++--- 7 files changed, 61 insertions(+), 35 deletions(-) diff --git a/src/main/kotlin/com/smallcloud/refactai/io/AsyncConnection.kt b/src/main/kotlin/com/smallcloud/refactai/io/AsyncConnection.kt index 2467f76f..3056f076 100644 --- a/src/main/kotlin/com/smallcloud/refactai/io/AsyncConnection.kt +++ b/src/main/kotlin/com/smallcloud/refactai/io/AsyncConnection.kt @@ -87,9 +87,10 @@ class AsyncConnection : Disposable { requestProperties: Map? = null, stat: UsageStatistic = UsageStatistic(), dataReceiveEnded: (String) -> Unit = {}, - dataReceived: (String) -> Unit = {}, + dataReceived: (String, String) -> Unit = { _: String, _: String -> }, errorDataReceived: (JsonObject) -> Unit = {}, failedDataReceiveEnded: (Throwable?) -> Unit = {}, + requestId: String = "" ): CompletableFuture> { val requestProducer: AsyncRequestProducer = BasicRequestProducer( BasicRequestBuilder @@ -106,20 +107,22 @@ class AsyncConnection : Disposable { requestProducer, uri, stat = stat, dataReceiveEnded = dataReceiveEnded, dataReceived = dataReceived, - errorDataReceived = errorDataReceived, failedDataReceiveEnded = failedDataReceiveEnded + errorDataReceived = errorDataReceived, failedDataReceiveEnded = failedDataReceiveEnded, + requestId=requestId ) } fun post( - uri: URI, - body: String? = null, - headers: Map? = null, - requestProperties: Map? = null, - stat: UsageStatistic, - dataReceiveEnded: (String) -> Unit = {}, - dataReceived: (String) -> Unit = {}, - errorDataReceived: (JsonObject) -> Unit = {}, - failedDataReceiveEnded: (Throwable?) -> Unit = {}, + uri: URI, + body: String? = null, + headers: Map? = null, + requestProperties: Map? = null, + stat: UsageStatistic, + dataReceiveEnded: (String) -> Unit = {}, + dataReceived: (String, String) -> Unit = { _: String, _: String -> }, + errorDataReceived: (JsonObject) -> Unit = {}, + failedDataReceiveEnded: (Throwable?) -> Unit = {}, + requestId: String = "" ): CompletableFuture> { val requestProducer: AsyncRequestProducer = BasicRequestProducer( BasicRequestBuilder @@ -135,7 +138,8 @@ class AsyncConnection : Disposable { requestProducer, uri, stat = stat, dataReceiveEnded = dataReceiveEnded, dataReceived = dataReceived, - errorDataReceived = errorDataReceived, failedDataReceiveEnded = failedDataReceiveEnded + errorDataReceived = errorDataReceived, failedDataReceiveEnded = failedDataReceiveEnded, + requestId=requestId ) } @@ -144,9 +148,10 @@ class AsyncConnection : Disposable { uri: URI, stat: UsageStatistic, dataReceiveEnded: (String) -> Unit, - dataReceived: (String) -> Unit, + dataReceived: (String, String) -> Unit, errorDataReceived: (JsonObject) -> Unit, failedDataReceiveEnded: (Throwable?) -> Unit = {}, + requestId: String = "" ): CompletableFuture> { return CompletableFuture.supplyAsync { return@supplyAsync client.execute( @@ -179,7 +184,7 @@ class AsyncConnection : Disposable { } catch (_: JsonSyntaxException) { } val (dataPieces, maybeLeftOverBuffer) = lookForCompletedDataInStreamingBuf(bufferStr) - dataPieces.forEach { dataReceived(it) } + dataPieces.forEach { dataReceived(it, requestId) } if (maybeLeftOverBuffer == null) { return } else { diff --git a/src/main/kotlin/com/smallcloud/refactai/io/RequestHelpers.kt b/src/main/kotlin/com/smallcloud/refactai/io/RequestHelpers.kt index c14f9348..bf2e1543 100644 --- a/src/main/kotlin/com/smallcloud/refactai/io/RequestHelpers.kt +++ b/src/main/kotlin/com/smallcloud/refactai/io/RequestHelpers.kt @@ -61,14 +61,15 @@ fun streamedInferenceFetch( uri, body, headers, stat = request.stat, dataReceiveEnded = dataReceiveEnded, - dataReceived = { - val rawJson = gson.fromJson(it, JsonObject::class.java) + dataReceived = {body: String, reqId: String -> + val rawJson = gson.fromJson(body, JsonObject::class.java) if (rawJson.has("metering_balance")) { AccountManager.instance.meteringBalance = rawJson.get("metering_balance").asInt } - val json = gson.fromJson(it, SMCStreamingPeace::class.java) + val json = gson.fromJson(body, SMCStreamingPeace::class.java) InferenceGlobalContext.lastAutoModel = json.model + json.requestId = reqId UsageStats.addStatistic(true, request.stat, request.uri.toString(), "") dataReceived(json) }, @@ -76,7 +77,8 @@ fun streamedInferenceFetch( lookForCommonErrors(it, request)?.let { message -> throw SMCExceptions(message) } - } + }, + requestId=request.id ) return job diff --git a/src/main/kotlin/com/smallcloud/refactai/lsp/LSPProcessHolder.kt b/src/main/kotlin/com/smallcloud/refactai/lsp/LSPProcessHolder.kt index b7a75177..3930b8a0 100644 --- a/src/main/kotlin/com/smallcloud/refactai/lsp/LSPProcessHolder.kt +++ b/src/main/kotlin/com/smallcloud/refactai/lsp/LSPProcessHolder.kt @@ -224,7 +224,6 @@ class LSPProcessHolder: Disposable { var res = LSPCapabilities() InferenceGlobalContext.connection.get(url.resolve("/v1/caps"), dataReceiveEnded = {}, - dataReceived = {}, errorDataReceived = {}).also { var requestFuture: ComplexFuture<*>? = null try { diff --git a/src/main/kotlin/com/smallcloud/refactai/modes/completion/CompletionMode.kt b/src/main/kotlin/com/smallcloud/refactai/modes/completion/CompletionMode.kt index 079fbc35..22085252 100644 --- a/src/main/kotlin/com/smallcloud/refactai/modes/completion/CompletionMode.kt +++ b/src/main/kotlin/com/smallcloud/refactai/modes/completion/CompletionMode.kt @@ -34,7 +34,6 @@ import com.smallcloud.refactai.privacy.PrivacyService.Companion.instance as Priv import com.smallcloud.refactai.statistic.UsageStats.Companion.instance as UsageStats private val specialSymbolsRegex = "^[:\\s\\t\\n\\r(){},.\"'\\];]*\$".toRegex() - class CompletionMode( override var needToRender: Boolean = true ) : Mode, CaretListener { @@ -42,6 +41,8 @@ class CompletionMode( private val app = ApplicationManager.getApplication() private val scheduler = AppExecutorUtil.createBoundedScheduledExecutorService("SMCCompletionScheduler", 1) private var processTask: Future<*>? = null + private var lastRequestId: String = "" + private var requestTask: Future<*>? = null private var completionLayout: AsyncCompletionLayout? = null private val logger = Logger.getInstance("StreamedCompletionMode") private var completionInProgress: Boolean = false @@ -72,7 +73,6 @@ class CompletionMode( val isMultiline = currentLine.all { it == ' ' || it == '\t' } - if (!event.force) { val docEvent = event.event ?: return if (docEvent.offset + docEvent.newLength > editor.document.text.length) return @@ -202,12 +202,23 @@ class CompletionMode( request.body.parameters.maxNewTokens = 50 request.body.noCache = true } + if (requestTask != null && requestTask!!.isDone && requestTask!!.isCancelled ) { + requestTask!!.cancel(true) + } + lastRequestId = request.id streamedInferenceFetch(request, dataReceiveEnded = { InferenceGlobalContext.status = ConnectionStatus.CONNECTED InferenceGlobalContext.lastErrorMsg = null }) { prediction -> val choice = prediction.choices.first() - if ((!completionInProgress) || (choice.delta.isEmpty() && !choice.finishReason.isNullOrEmpty())) { + if (lastRequestId != prediction.requestId) { + completionLayout?.dispose() + completionLayout = null + return@streamedInferenceFetch + } + + if ((!completionInProgress) + || (choice.delta.isEmpty() && !choice.finishReason.isNullOrEmpty())) { return@streamedInferenceFetch } val completion: Completion = if (completionLayout?.lastCompletionData == null || @@ -226,26 +237,23 @@ class CompletionMode( completion.snippetTelemetryId = prediction.snippetTelemetryId } completion.updateCompletion(choice.delta) - synchronized(this) { renderCompletion( editorState.editor, editorState, completion, !prediction.cached ) } }?.also { - var requestFuture: Future<*>? = null try { - requestFuture = it.get() - requestFuture.get() + requestTask = it.get() + requestTask!!.get() logger.debug("Completion request finished") completionInProgress = false } catch (e: InterruptedException) { - handleInterruptedException(requestFuture, editorState.editor) + handleInterruptedException(requestTask, editorState.editor) } catch (e: InterruptedIOException) { - handleInterruptedException(requestFuture, editorState.editor) + handleInterruptedException(requestTask, editorState.editor) } catch (e: ExecutionException) { cancelOrClose() - requestFuture?.cancel(true) catchNetExceptions(e.cause) } catch (e: Exception) { InferenceGlobalContext.status = ConnectionStatus.ERROR diff --git a/src/main/kotlin/com/smallcloud/refactai/panes/gptchat/ChatGPTProvider.kt b/src/main/kotlin/com/smallcloud/refactai/panes/gptchat/ChatGPTProvider.kt index 426e60f7..0fd15b71 100644 --- a/src/main/kotlin/com/smallcloud/refactai/panes/gptchat/ChatGPTProvider.kt +++ b/src/main/kotlin/com/smallcloud/refactai/panes/gptchat/ChatGPTProvider.kt @@ -122,7 +122,7 @@ class ChatGPTProvider : ActionListener { InferenceGlobalContext.connection.post(req.uri, reqStr, mapOf("Authorization" to "Bearer ${req.token}"), - dataReceived = { response -> + dataReceived = { response, _ -> fun parse(response: String?): String? { val gson = Gson() val obj = gson.fromJson(response, JsonObject::class.java) diff --git a/src/main/kotlin/com/smallcloud/refactai/struct/SMCPrediction.kt b/src/main/kotlin/com/smallcloud/refactai/struct/SMCPrediction.kt index 564a93b4..18fd2434 100644 --- a/src/main/kotlin/com/smallcloud/refactai/struct/SMCPrediction.kt +++ b/src/main/kotlin/com/smallcloud/refactai/struct/SMCPrediction.kt @@ -1,5 +1,6 @@ package com.smallcloud.refactai.struct +import com.google.gson.annotations.Expose import com.google.gson.annotations.SerializedName data class SMCStreamingPeace( @@ -7,7 +8,9 @@ data class SMCStreamingPeace( val created: Double, val model: String, @SerializedName("snippet_telemetry_id") val snippetTelemetryId: Int? = null, - val cached: Boolean = false + val cached: Boolean = false, + @Expose + var requestId: String = "" ) diff --git a/src/main/kotlin/com/smallcloud/refactai/struct/SMCRequest.kt b/src/main/kotlin/com/smallcloud/refactai/struct/SMCRequest.kt index c73a8cdf..0ecccfe3 100644 --- a/src/main/kotlin/com/smallcloud/refactai/struct/SMCRequest.kt +++ b/src/main/kotlin/com/smallcloud/refactai/struct/SMCRequest.kt @@ -3,7 +3,15 @@ package com.smallcloud.refactai.struct import com.google.gson.annotations.SerializedName import com.smallcloud.refactai.statistic.UsageStatistic import java.net.URI +import java.util.concurrent.ThreadLocalRandom +import kotlin.streams.asSequence +private val charPool : List = ('a'..'z') + ('A'..'Z') + ('0'..'9') +private fun uuid() = ThreadLocalRandom.current() + .ints(8.toLong(), 0, charPool.size) + .asSequence() + .map(charPool::get) + .joinToString("") data class SMCCursor( val file: String = "", @@ -30,8 +38,9 @@ data class SMCRequestBody( ) data class SMCRequest( - var uri: URI, - var body: SMCRequestBody, - var token: String, - var stat: UsageStatistic = UsageStatistic(), + var uri: URI, + var body: SMCRequestBody, + var token: String, + var id: String = uuid(), + var stat: UsageStatistic = UsageStatistic(), )