Skip to content

Commit

Permalink
catch oom error of coroutine
Browse files Browse the repository at this point in the history
  • Loading branch information
nkbai committed Nov 16, 2023
1 parent 8ed5315 commit 853a8aa
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import net.bytedance.security.app.getConfig
import net.bytedance.security.app.rules.DirectModeRule
import net.bytedance.security.app.rules.SliceModeRule
import net.bytedance.security.app.taintflow.TaintAnalyzer
import net.bytedance.security.app.util.oomHandler
import soot.SootMethod


Expand Down Expand Up @@ -62,7 +63,7 @@ class SliceModeProcessor(ctx: PreAnalyzeContext) : DirectModeProcessor(ctx) {
null
}
for (sinkPtr in taintRuleSourceSinkCollector.analyzerData.sinkPointerSet) {
val job = scope.launch(CoroutineName("createAnalyzersForSourceAndSink-${rule.name}")) {
val job = scope.launch(CoroutineName("createAnalyzersForSourceAndSink-${rule.name}") + oomHandler) {

val entryItem = if (callstacks == null) {
val result = ctx.callGraph.traceAndCross(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ import net.bytedance.security.app.pointer.PLLocalPointer
import net.bytedance.security.app.pointer.PLObject
import net.bytedance.security.app.pointer.PLPointer
import net.bytedance.security.app.pointer.PointerFactory
import net.bytedance.security.app.util.profiler
import net.bytedance.security.app.util.runInMilliSeconds
import net.bytedance.security.app.util.toSortedMap
import net.bytedance.security.app.util.toSortedSet
import net.bytedance.security.app.util.*
import soot.*
import soot.jimple.*
import soot.jimple.internal.*
Expand Down Expand Up @@ -88,7 +85,7 @@ class TwoStagePointerAnalyze(
profiler.startPointAnalyze(name)
thisStageStart = System.currentTimeMillis()
try {
val job = localScope.launch(Dispatchers.Default) {
val job = localScope.launch(Dispatchers.Default + CoroutineName("PointerAnalyzeStage1") + oomHandler) {
scope = this
analyzeMethod(entryMethod, null, 0)
}
Expand All @@ -100,7 +97,7 @@ class TwoStagePointerAnalyze(
Log.logInfo("$name fistStageAnalyze finished")
thisStageStart = System.currentTimeMillis()
try {
val job = localScope.launch(Dispatchers.Default) {
val job = localScope.launch(Dispatchers.Default + CoroutineName("PointerAnalyzeStage2") + oomHandler) {
scope = this
secondStageAnalyze()
}
Expand Down
24 changes: 16 additions & 8 deletions src/main/kotlin/net/bytedance/security/app/util/TaskQueue.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ import net.bytedance.security.app.Log
import java.util.*
import kotlin.system.exitProcess

/**
* global hander for oom
*/
val oomHandler = CoroutineExceptionHandler { ctx, exception ->
if (exception is OutOfMemoryError) {
val coroutineName = ctx[CoroutineName]?.name
Log.logErr("${coroutineName} CoroutineException because of oom")
exitProcess(37)
}
throw exception
}

/**
* A simple multithreaded task wrapper
*/
Expand Down Expand Up @@ -56,21 +68,16 @@ class TaskQueue<TaskData>(
suspend fun runTask(): Job {
val scope = CoroutineScope(Dispatchers.Default)
val jobs = ArrayList<Job>()

for (i in 0 until numberThreads) {
val handler = CoroutineExceptionHandler { _, exception ->
if (exception is OutOfMemoryError) {
exitProcess(37)
}
throw exception
}
val job = scope.launch(CoroutineName("$name-$i") + handler) {
val job = scope.launch(CoroutineName("$name-$i") + oomHandler) {
for (taskData in queue) {
action(taskData, i)
}
}
jobs.add(job)
}
return scope.launch(CoroutineName("$name-joinAll")) { jobs.joinAll() }
return scope.launch(CoroutineName("$name-joinAll") + oomHandler) { jobs.joinAll() }
}
}

Expand Down Expand Up @@ -101,3 +108,4 @@ suspend fun runInMilliSeconds(job: Job, milliSeconds: Long, name: String, timeou
Log.logWarn("$name runInMilliSeconds cost more than expected expect=$milliSeconds, actual=${end - start}")
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import test.SootHelper
import test.TestHelper
import java.io.IOException
import kotlin.concurrent.thread
import kotlin.system.exitProcess

class SootConcurrentErrorTest {
private val ctx = PreAnalyzeContext()
Expand Down Expand Up @@ -102,4 +103,33 @@ class SootConcurrentErrorTest {

Thread.sleep(5000)
}

@Test
fun testlaunchOOM() {
runBlocking {
val handler = CoroutineExceptionHandler { _, exception ->
println("CoroutineExceptionHandler got $exception")
}
val job = GlobalScope.launch(handler) {
val inner = launch { // all this stack of coroutines will get cancelled
launch {
launch {
val list = ArrayList<String>()
for (i in 1..1000000) {
list.add("12345".repeat(1000000))
}
}
}
}
try {
inner.join()
} catch (e: java.lang.OutOfMemoryError) {
//oom should capture by handler
exitProcess(33)
}
}
job.join()
}
}

}

0 comments on commit 853a8aa

Please sign in to comment.