Skip to content

Commit

Permalink
Analyics: stop double reporting posthog utds
Browse files Browse the repository at this point in the history
  • Loading branch information
BillCarsonFr committed Apr 4, 2024
1 parent 5ccc486 commit 0a284bb
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package im.vector.app.features

import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import im.vector.app.InstrumentedTest
import im.vector.app.features.analytics.ReportedDecryptionFailurePersistence
import kotlinx.coroutines.test.runTest
import org.amshove.kluent.shouldBeEqualTo
import org.junit.Test
import org.junit.runner.RunWith

@RunWith(AndroidJUnit4::class)
class ReportedDecryptionFailurePersistenceTest : InstrumentedTest {

private val context = InstrumentationRegistry.getInstrumentation().targetContext

@Test
fun shouldPersistReportedUtds() = runTest {
val persistence = ReportedDecryptionFailurePersistence(context)
persistence.load()

val eventIds = listOf("$0000", "$0001", "$0002", "$0003")
eventIds.forEach {
persistence.markAsReported(it)
}

eventIds.forEach {
persistence.hasBeenReported(it) shouldBeEqualTo true
}

persistence.hasBeenReported("$0004") shouldBeEqualTo false

persistence.persist()

// Load a new one
val persistence2 = ReportedDecryptionFailurePersistence(context)
persistence2.load()

eventIds.forEach {
persistence2.hasBeenReported(it) shouldBeEqualTo true
}
}

@Test
fun testSaturation() = runTest {
val persistence = ReportedDecryptionFailurePersistence(context)

for (i in 1..6000) {
persistence.markAsReported("000$i")
}

// This should have saturated the bloom filter, making the rate of false positives too high.
// A new bloom filter should have been created to avoid that and the recent reported events should still be in the new filter.
for (i in 5800..6000) {
persistence.hasBeenReported("000$i") shouldBeEqualTo true
}

// Old ones should not be there though
for (i in 1..1000) {
persistence.hasBeenReported("000$i") shouldBeEqualTo false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ private const val MAX_WAIT_MILLIS = 60_000
class DecryptionFailureTracker @Inject constructor(
private val analyticsTracker: AnalyticsTracker,
private val sessionDataSource: ActiveSessionDataSource,
private val decryptionFailurePersistence: ReportedDecryptionFailurePersistence,
private val clock: Clock
) : Session.Listener, LiveEventListener {

Expand All @@ -76,9 +77,6 @@ class DecryptionFailureTracker @Inject constructor(
// Only accessed on a `post` call, ensuring sequential access
private val trackedEventsMap = mutableMapOf<String, DecryptionFailure>()

// List of eventId that have been reported, to avoid double reporting
private val alreadyReported = mutableListOf<String>()

// Mutex to ensure sequential access to internal state
private val mutex = Mutex()

Expand All @@ -98,10 +96,16 @@ class DecryptionFailureTracker @Inject constructor(
this.scope = scope
}
observeActiveSession()
post {
decryptionFailurePersistence.load()
}
}

fun stop() {
Timber.v("Stop DecryptionFailureTracker")
post {
decryptionFailurePersistence.persist()
}
activeSessionSourceDisposable.cancel(CancellationException("Closing DecryptionFailureTracker"))

activeSession?.removeListener(this)
Expand All @@ -123,6 +127,7 @@ class DecryptionFailureTracker @Inject constructor(
delay(CHECK_INTERVAL)
post {
checkFailures()
decryptionFailurePersistence.persist()
currentTicker = null
if (trackedEventsMap.isNotEmpty()) {
// Reschedule
Expand All @@ -136,15 +141,15 @@ class DecryptionFailureTracker @Inject constructor(
.distinctUntilChanged()
.onEach {
Timber.v("Active session changed ${it.getOrNull()?.myUserId}")
it.orNull()?.let { session ->
it.getOrNull()?.let { session ->
post {
onSessionActive(session)
}
}
}.launchIn(scope)
}

private fun onSessionActive(session: Session) {
private suspend fun onSessionActive(session: Session) {
Timber.v("onSessionActive ${session.myUserId} previous: ${activeSession?.myUserId}")
val sessionId = session.sessionId
if (sessionId == activeSession?.sessionId) {
Expand Down Expand Up @@ -201,7 +206,8 @@ class DecryptionFailureTracker @Inject constructor(
// already tracked
return
}
if (alreadyReported.contains(eventId)) {
if (decryptionFailurePersistence.hasBeenReported(eventId)) {
Timber.v("Event $eventId already reported")
// already reported
return
}
Expand Down Expand Up @@ -236,7 +242,7 @@ class DecryptionFailureTracker @Inject constructor(
}
}

private fun handleEventDecrypted(eventId: String) {
private suspend fun handleEventDecrypted(eventId: String) {
Timber.v("Handle event decrypted $eventId time: ${clock.epochMillis()}")
// Only consider if it was tracked as a failure
val trackedFailure = trackedEventsMap[eventId] ?: return
Expand Down Expand Up @@ -269,7 +275,7 @@ class DecryptionFailureTracker @Inject constructor(
}

// This will mutate the trackedEventsMap, so don't call it while iterating on it.
private fun reportFailure(decryptionFailure: DecryptionFailure) {
private suspend fun reportFailure(decryptionFailure: DecryptionFailure) {
Timber.v("Report failure for event ${decryptionFailure.failedEventId}")
val error = decryptionFailure.toAnalyticsEvent()

Expand All @@ -278,10 +284,10 @@ class DecryptionFailureTracker @Inject constructor(
// now remove from tracked
trackedEventsMap.remove(decryptionFailure.failedEventId)
// mark as already reported
alreadyReported.add(decryptionFailure.failedEventId)
decryptionFailurePersistence.markAsReported(decryptionFailure.failedEventId)
}

private fun checkFailures() {
private suspend fun checkFailures() {
val now = clock.epochMillis()
Timber.v("Check failures now $now")
// report the definitely failed
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) 2024 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package im.vector.app.features.analytics

import android.content.Context
import android.util.LruCache
import com.google.common.hash.BloomFilter
import com.google.common.hash.Funnels
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import timber.log.Timber
import java.io.File
import java.io.FileOutputStream
import javax.inject.Inject

private const val REPORTED_UTD_FILE_NAME = "im.vector.analytics.reported_utd"
private const val EXPECTED_INSERTIONS = 5000

/**
* This class is used to keep track of the reported decryption failures to avoid double reporting.
* It uses a bloom filter to limit the memory/disk usage.
*/
class ReportedDecryptionFailurePersistence @Inject constructor(
private val context: Context,
) {

// Keep a cache of recent reported failures in memory.
// They will be persisted to the a new bloom filter if the previous one is getting saturated.
// Should be around 30KB max in memory.
// Also allows to have 0% false positive rate for recent failures.
private val inMemoryReportedFailures: LruCache<String, Unit> = LruCache(300)

// Thread-safe and lock-free.
// The expected insertions is 5000, and expected false positive probability of 3% when close to max capability.
// The persisted size is expected to be around 5KB (100 times less than if it was raw strings).
private var bloomFilter: BloomFilter<String> = BloomFilter.create<String>(Funnels.stringFunnel(Charsets.UTF_8), EXPECTED_INSERTIONS)

/**
* Mark an event as reported.
* @param eventId the event id to mark as reported.
*/
suspend fun markAsReported(eventId: String) {
// Add to in memory cache.
inMemoryReportedFailures.put(eventId, Unit)
bloomFilter.put(eventId)

// check if the filter is getting saturated? and then replace
if (bloomFilter.approximateElementCount() > EXPECTED_INSERTIONS - 500) {
// The filter is getting saturated, and the false positive rate is increasing.
// It's time to replace the filter with a new one. And move the in-memory cache to the new filter.
bloomFilter = BloomFilter.create<String>(Funnels.stringFunnel(Charsets.UTF_8), EXPECTED_INSERTIONS)
inMemoryReportedFailures.snapshot().keys.forEach {
bloomFilter.put(it)
}
persist()
}
Timber.v("## Bloom filter stats: expectedFpp: ${bloomFilter.expectedFpp()}, size: ${bloomFilter.approximateElementCount()}")
}

/**
* Check if an event has been reported.
* @param eventId the event id to check.
* @return true if the event has been reported.
*/
fun hasBeenReported(eventId: String): Boolean {
// First check in memory cache.
if (inMemoryReportedFailures.get(eventId) != null) {
return true
}
return bloomFilter.mightContain(eventId)
}

/**
* Load the reported failures from disk.
*/
suspend fun load() {
withContext(Dispatchers.IO) {
try {
val file = File(context.applicationContext.cacheDir, REPORTED_UTD_FILE_NAME)
if (file.exists()) {
file.inputStream().use {
bloomFilter = BloomFilter.readFrom(it, Funnels.stringFunnel(Charsets.UTF_8))
}
}
} catch (e: Throwable) {
Timber.e(e, "## Failed to load reported failures")
}
}
}

/**
* Persist the reported failures to disk.
*/
suspend fun persist() {
withContext(Dispatchers.IO) {
try {
val file = File(context.applicationContext.cacheDir, REPORTED_UTD_FILE_NAME)
if (!file.exists()) file.createNewFile()
FileOutputStream(file).buffered().use {
bloomFilter.writeTo(it)
}
Timber.v("## Successfully saved reported failures, size: ${file.length()}")
} catch (e: Throwable) {
Timber.e(e, "## Failed to save reported failures")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import im.vector.app.test.fakes.FakeAnalyticsTracker
import im.vector.app.test.fakes.FakeClock
import im.vector.app.test.fakes.FakeSession
import im.vector.app.test.shared.createTimberTestRule
import io.mockk.coEvery
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
Expand Down Expand Up @@ -60,9 +61,24 @@ class DecryptionFailureTrackerTest {

private val fakeClock = FakeClock()

val reportedEvents = mutableSetOf<String>()

private val fakePersistence = mockk<ReportedDecryptionFailurePersistence> {

coEvery { load() } just runs
coEvery { persist() } just runs
coEvery { markAsReported(any()) } coAnswers {
reportedEvents.add(firstArg())
}
every { hasBeenReported(any()) } answers {
reportedEvents.contains(firstArg())
}
}

private val decryptionFailureTracker = DecryptionFailureTracker(
fakeAnalyticsTracker,
fakeActiveSessionDataSource.instance,
fakePersistence,
fakeClock
)

Expand Down Expand Up @@ -101,6 +117,7 @@ class DecryptionFailureTrackerTest {

@Before
fun setupTest() {
reportedEvents.clear()
fakeMxOrgTestSession.fakeCryptoService.fakeCrossSigningService.givenIsCrossSigningVerifiedReturns(false)
}

Expand Down

0 comments on commit 0a284bb

Please sign in to comment.