Skip to content

Commit

Permalink
add request cancellation and fix phantom completion text
Browse files Browse the repository at this point in the history
  • Loading branch information
reymondzzzz committed Nov 14, 2023
1 parent 320f39a commit 0e04bf0
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 36 deletions.
33 changes: 19 additions & 14 deletions src/main/kotlin/com/smallcloud/refactai/io/AsyncConnection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ class AsyncConnection : Disposable {
requestProperties: Map<String, String>? = null,
stat: UsageStatistic = UsageStatistic(),
dataReceiveEnded: (String) -> Unit = {},
dataReceived: (String) -> Unit = {},
dataReceived: (String, String) -> Unit = { _: String, _: String -> },
errorDataReceived: (JsonObject) -> Unit = {},
failedDataReceiveEnded: (Throwable?) -> Unit = {},
requestId: String = ""
): CompletableFuture<Future<*>> {
val requestProducer: AsyncRequestProducer = BasicRequestProducer(
BasicRequestBuilder
Expand All @@ -106,20 +107,22 @@ class AsyncConnection : Disposable {
requestProducer, uri,
stat = stat,
dataReceiveEnded = dataReceiveEnded, dataReceived = dataReceived,
errorDataReceived = errorDataReceived, failedDataReceiveEnded = failedDataReceiveEnded
errorDataReceived = errorDataReceived, failedDataReceiveEnded = failedDataReceiveEnded,
requestId=requestId
)
}

fun post(
uri: URI,
body: String? = null,
headers: Map<String, String>? = null,
requestProperties: Map<String, String>? = null,
stat: UsageStatistic,
dataReceiveEnded: (String) -> Unit = {},
dataReceived: (String) -> Unit = {},
errorDataReceived: (JsonObject) -> Unit = {},
failedDataReceiveEnded: (Throwable?) -> Unit = {},
uri: URI,
body: String? = null,
headers: Map<String, String>? = null,
requestProperties: Map<String, String>? = null,
stat: UsageStatistic,
dataReceiveEnded: (String) -> Unit = {},
dataReceived: (String, String) -> Unit = { _: String, _: String -> },
errorDataReceived: (JsonObject) -> Unit = {},
failedDataReceiveEnded: (Throwable?) -> Unit = {},
requestId: String = ""
): CompletableFuture<Future<*>> {
val requestProducer: AsyncRequestProducer = BasicRequestProducer(
BasicRequestBuilder
Expand All @@ -135,7 +138,8 @@ class AsyncConnection : Disposable {
requestProducer, uri,
stat = stat,
dataReceiveEnded = dataReceiveEnded, dataReceived = dataReceived,
errorDataReceived = errorDataReceived, failedDataReceiveEnded = failedDataReceiveEnded
errorDataReceived = errorDataReceived, failedDataReceiveEnded = failedDataReceiveEnded,
requestId=requestId
)
}

Expand All @@ -144,9 +148,10 @@ class AsyncConnection : Disposable {
uri: URI,
stat: UsageStatistic,
dataReceiveEnded: (String) -> Unit,
dataReceived: (String) -> Unit,
dataReceived: (String, String) -> Unit,
errorDataReceived: (JsonObject) -> Unit,
failedDataReceiveEnded: (Throwable?) -> Unit = {},
requestId: String = ""
): CompletableFuture<Future<*>> {
return CompletableFuture.supplyAsync {
return@supplyAsync client.execute(
Expand Down Expand Up @@ -179,7 +184,7 @@ class AsyncConnection : Disposable {
} catch (_: JsonSyntaxException) {
}
val (dataPieces, maybeLeftOverBuffer) = lookForCompletedDataInStreamingBuf(bufferStr)
dataPieces.forEach { dataReceived(it) }
dataPieces.forEach { dataReceived(it, requestId) }
if (maybeLeftOverBuffer == null) {
return
} else {
Expand Down
10 changes: 6 additions & 4 deletions src/main/kotlin/com/smallcloud/refactai/io/RequestHelpers.kt
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,24 @@ fun streamedInferenceFetch(
uri, body, headers,
stat = request.stat,
dataReceiveEnded = dataReceiveEnded,
dataReceived = {
val rawJson = gson.fromJson(it, JsonObject::class.java)
dataReceived = {body: String, reqId: String ->
val rawJson = gson.fromJson(body, JsonObject::class.java)
if (rawJson.has("metering_balance")) {
AccountManager.instance.meteringBalance = rawJson.get("metering_balance").asInt
}

val json = gson.fromJson(it, SMCStreamingPeace::class.java)
val json = gson.fromJson(body, SMCStreamingPeace::class.java)
InferenceGlobalContext.lastAutoModel = json.model
json.requestId = reqId
UsageStats.addStatistic(true, request.stat, request.uri.toString(), "")
dataReceived(json)
},
errorDataReceived = {
lookForCommonErrors(it, request)?.let { message ->
throw SMCExceptions(message)
}
}
},
requestId=request.id
)

return job
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ class LSPProcessHolder: Disposable {
var res = LSPCapabilities()
InferenceGlobalContext.connection.get(url.resolve("/v1/caps"),
dataReceiveEnded = {},
dataReceived = {},
errorDataReceived = {}).also {
var requestFuture: ComplexFuture<*>? = null
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ import com.smallcloud.refactai.privacy.PrivacyService.Companion.instance as Priv
import com.smallcloud.refactai.statistic.UsageStats.Companion.instance as UsageStats

private val specialSymbolsRegex = "^[:\\s\\t\\n\\r(){},.\"'\\];]*\$".toRegex()

class CompletionMode(
override var needToRender: Boolean = true
) : Mode, CaretListener {
private val scope: String = "completion"
private val app = ApplicationManager.getApplication()
private val scheduler = AppExecutorUtil.createBoundedScheduledExecutorService("SMCCompletionScheduler", 1)
private var processTask: Future<*>? = null
private var lastRequestId: String = ""
private var requestTask: Future<*>? = null
private var completionLayout: AsyncCompletionLayout? = null
private val logger = Logger.getInstance("StreamedCompletionMode")
private var completionInProgress: Boolean = false
Expand Down Expand Up @@ -72,7 +73,6 @@ class CompletionMode(

val isMultiline = currentLine.all { it == ' ' || it == '\t' }


if (!event.force) {
val docEvent = event.event ?: return
if (docEvent.offset + docEvent.newLength > editor.document.text.length) return
Expand Down Expand Up @@ -202,12 +202,23 @@ class CompletionMode(
request.body.parameters.maxNewTokens = 50
request.body.noCache = true
}
if (requestTask != null && requestTask!!.isDone && requestTask!!.isCancelled ) {
requestTask!!.cancel(true)
}
lastRequestId = request.id
streamedInferenceFetch(request, dataReceiveEnded = {
InferenceGlobalContext.status = ConnectionStatus.CONNECTED
InferenceGlobalContext.lastErrorMsg = null
}) { prediction ->
val choice = prediction.choices.first()
if ((!completionInProgress) || (choice.delta.isEmpty() && !choice.finishReason.isNullOrEmpty())) {
if (lastRequestId != prediction.requestId) {
completionLayout?.dispose()
completionLayout = null
return@streamedInferenceFetch
}

if ((!completionInProgress)
|| (choice.delta.isEmpty() && !choice.finishReason.isNullOrEmpty())) {
return@streamedInferenceFetch
}
val completion: Completion = if (completionLayout?.lastCompletionData == null ||
Expand All @@ -226,26 +237,23 @@ class CompletionMode(
completion.snippetTelemetryId = prediction.snippetTelemetryId
}
completion.updateCompletion(choice.delta)

synchronized(this) {
renderCompletion(
editorState.editor, editorState, completion, !prediction.cached
)
}
}?.also {
var requestFuture: Future<*>? = null
try {
requestFuture = it.get()
requestFuture.get()
requestTask = it.get()
requestTask!!.get()
logger.debug("Completion request finished")
completionInProgress = false
} catch (e: InterruptedException) {
handleInterruptedException(requestFuture, editorState.editor)
handleInterruptedException(requestTask, editorState.editor)
} catch (e: InterruptedIOException) {
handleInterruptedException(requestFuture, editorState.editor)
handleInterruptedException(requestTask, editorState.editor)
} catch (e: ExecutionException) {
cancelOrClose()
requestFuture?.cancel(true)
catchNetExceptions(e.cause)
} catch (e: Exception) {
InferenceGlobalContext.status = ConnectionStatus.ERROR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class ChatGPTProvider : ActionListener {
InferenceGlobalContext.connection.post(req.uri,
reqStr,
mapOf("Authorization" to "Bearer ${req.token}"),
dataReceived = { response ->
dataReceived = { response, _ ->
fun parse(response: String?): String? {
val gson = Gson()
val obj = gson.fromJson(response, JsonObject::class.java)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package com.smallcloud.refactai.struct

import com.google.gson.annotations.Expose
import com.google.gson.annotations.SerializedName

data class SMCStreamingPeace(
val choices: List<StreamingChoice>,
val created: Double,
val model: String,
@SerializedName("snippet_telemetry_id") val snippetTelemetryId: Int? = null,
val cached: Boolean = false
val cached: Boolean = false,
@Expose
var requestId: String = ""
)


Expand Down
12 changes: 7 additions & 5 deletions src/main/kotlin/com/smallcloud/refactai/struct/SMCRequest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package com.smallcloud.refactai.struct

import com.google.gson.annotations.SerializedName
import com.smallcloud.refactai.statistic.UsageStatistic
import org.apache.commons.lang.RandomStringUtils
import java.net.URI

private fun uuid() = RandomStringUtils.randomAlphanumeric(8)

data class SMCCursor(
val file: String = "",
Expand All @@ -30,8 +31,9 @@ data class SMCRequestBody(
)

data class SMCRequest(
var uri: URI,
var body: SMCRequestBody,
var token: String,
var stat: UsageStatistic = UsageStatistic(),
var uri: URI,
var body: SMCRequestBody,
var token: String,
var id: String = uuid(),
var stat: UsageStatistic = UsageStatistic(),
)

0 comments on commit 0e04bf0

Please sign in to comment.