Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
CrisBarreiro committed Jan 30, 2025
1 parent 6e4a289 commit 9622274
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class RealMaliciousSiteBlockerWebViewIntegration @Inject constructor(
return null
}

if (request.isForMainFrame || (isForIframe(request) && documentUri?.host == request.requestHeaders["Referer"]?.toUri()?.host)) {
val belongsToCurrentPage = documentUri?.host == request.requestHeaders["Referer"]?.toUri()?.host
if (request.isForMainFrame || (isForIframe(request) && belongsToCurrentPage)) {
if (checkMaliciousUrl(decodedUrl, confirmationCallback)) {
return WebResourceResponse(null, null, null)
} else {
Expand Down Expand Up @@ -156,6 +157,7 @@ class RealMaliciousSiteBlockerWebViewIntegration @Inject constructor(
): Boolean {
val checkId = currentCheckId.incrementAndGet()
return maliciousSiteProtection.isMalicious(url.toUri()) {
// if another load has started, we should ignore the result
val isMalicious = if (checkId == currentCheckId.get()) {
it
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class RealMaliciousSiteBlockerWebViewIntegrationTest {
}

@Test
fun `shouldIntercept callback with true received after a new request is intercepted then return false`() = runTest {
fun `if a new page load triggering is malicious is started, isMalicious callback result should be ignored for the first page`() = runTest {
val request = mock(WebResourceRequest::class.java)
whenever(request.url).thenReturn(maliciousUri)
whenever(request.isForMainFrame).thenReturn(true)
Expand Down Expand Up @@ -192,7 +192,7 @@ class RealMaliciousSiteBlockerWebViewIntegrationTest {
}

@Test
fun `shouldIntercept callback with true received before a new request is intercepted then return true`() = runTest {
fun `isMalicious callback result should be processed if no new page loads triggering isMalicious have started`() = runTest {
val request = mock(WebResourceRequest::class.java)
whenever(request.url).thenReturn(maliciousUri)
whenever(request.isForMainFrame).thenReturn(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ class MaliciousSiteProtectionFiltersUpdateWorker(
if (maliciousSiteProtectionFeature.isFeatureEnabled().not()) {
return@withContext Result.success()
}
try {
maliciousSiteRepository.loadFilters()
return@withContext Result.success()
} catch (e: Exception) {
return@withContext Result.retry()
return@withContext if (maliciousSiteRepository.loadFilters().isSuccess) {
Result.success()
} else {
Result.retry()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ class MaliciousSiteProtectionHashPrefixesUpdateWorker(
if (maliciousSiteProtectionFeature.isFeatureEnabled().not()) {
return@withContext Result.success()
}
try {
maliciousSiteRepository.loadHashPrefixes()
return@withContext Result.success()
} catch (e: Exception) {
return@withContext Result.retry()
return@withContext if (maliciousSiteRepository.loadHashPrefixes().isSuccess) {
Result.success()
} else {
Result.retry()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,114 +16,49 @@

package com.duckduckgo.malicioussiteprotection.impl.data

import com.duckduckgo.common.utils.DispatcherProvider
import com.duckduckgo.di.scopes.AppScope
import com.duckduckgo.malicioussiteprotection.impl.data.db.MaliciousSiteDao
import com.duckduckgo.malicioussiteprotection.impl.data.db.RevisionEntity
import com.duckduckgo.malicioussiteprotection.impl.data.network.FilterResponse
import com.duckduckgo.malicioussiteprotection.impl.data.network.FilterSetResponse
import com.duckduckgo.malicioussiteprotection.impl.data.network.HashPrefixResponse
import com.duckduckgo.malicioussiteprotection.impl.data.network.MaliciousSiteService
import com.duckduckgo.malicioussiteprotection.impl.models.Feed
import com.duckduckgo.malicioussiteprotection.impl.models.Feed.MALWARE
import com.duckduckgo.malicioussiteprotection.impl.models.Feed.PHISHING
import com.duckduckgo.malicioussiteprotection.impl.models.Filter
import com.duckduckgo.malicioussiteprotection.impl.models.FilterSetWithRevision
import com.duckduckgo.malicioussiteprotection.impl.models.FilterSetWithRevision.MalwareFilterSetWithRevision
import com.duckduckgo.malicioussiteprotection.impl.models.FilterSetWithRevision.PhishingFilterSetWithRevision
import com.duckduckgo.malicioussiteprotection.impl.models.HashPrefixesWithRevision
import com.duckduckgo.malicioussiteprotection.impl.models.HashPrefixesWithRevision.MalwareHashPrefixesWithRevision
import com.duckduckgo.malicioussiteprotection.impl.models.HashPrefixesWithRevision.PhishingHashPrefixesWithRevision
import com.duckduckgo.malicioussiteprotection.impl.models.Match
import com.duckduckgo.malicioussiteprotection.impl.models.Type
import com.duckduckgo.malicioussiteprotection.impl.models.Type.FILTER_SET
import com.duckduckgo.malicioussiteprotection.impl.models.Type.HASH_PREFIXES
import com.squareup.anvil.annotations.ContributesBinding
import dagger.SingleInstanceIn
import javax.inject.Inject
import timber.log.Timber
import kotlinx.coroutines.withContext

interface MaliciousSiteRepository {
suspend fun containsHashPrefix(hashPrefix: String): Boolean
suspend fun getFilters(hash: String): List<Filter>?
suspend fun matches(hashPrefix: String): List<Match>
suspend fun loadFilters()
suspend fun loadHashPrefixes()
suspend fun loadFilters(): Result<Unit>
suspend fun loadHashPrefixes(): Result<Unit>
}

@ContributesBinding(AppScope::class)
@SingleInstanceIn(AppScope::class)
class RealMaliciousSiteRepository @Inject constructor(
private val maliciousSiteDao: MaliciousSiteDao,
private val maliciousSiteService: MaliciousSiteService,
private val dispatcherProvider: DispatcherProvider,
) : MaliciousSiteRepository {

override suspend fun loadFilters() {
try {
val networkRevision = maliciousSiteService.getRevision().revision

(maliciousSiteDao.getLatestRevision()?.filter { it.type == Type.FILTER_SET.name } ?: listOf()).let { latestRevision ->
val phishingFilterSetRevision = latestRevision.firstOrNull() { it.feed == PHISHING.name }?.revision ?: 0
val phishingFilterSet: FilterSetResponse? = if (networkRevision > phishingFilterSetRevision) {
maliciousSiteService.getPhishingFilterSet(phishingFilterSetRevision)
} else {
null
}
val malwareFilterSetRevision = latestRevision.firstOrNull() { it.feed == MALWARE.name }?.revision ?: 0
val malwareFilterSet: FilterSetResponse? = if (networkRevision > malwareFilterSetRevision) {
maliciousSiteService.getMalwareFilterSet(malwareFilterSetRevision)
} else {
null
}

maliciousSiteDao.updateFilters(
phishingFilterSet?.let {
PhishingFilterSetWithRevision(
it.insert.map { insert -> Filter(insert.hash, insert.regex) }.toSet(),
it.delete.map { delete -> Filter(delete.hash, delete.regex) }.toSet(),
it.revision,
it.replace,
)
},
)
maliciousSiteDao.updateFilters(
malwareFilterSet?.let {
MalwareFilterSetWithRevision(
it.insert.map { insert -> Filter(insert.hash, insert.regex) }.toSet(),
it.delete.map { delete -> Filter(delete.hash, delete.regex) }.toSet(),
it.revision,
it.replace,
)
},
)
}
} catch (e: Exception) {
Timber.e(e, "Failed to download malicious site protection list")
}
}

override suspend fun loadHashPrefixes() {
try {
val networkRevision = maliciousSiteService.getRevision().revision

(maliciousSiteDao.getLatestRevision()?.filter { it.type == Type.HASH_PREFIXES.name } ?: listOf()).let {
val phishingHashPrefixesRevision = it.firstOrNull() { it.feed == PHISHING.name }?.revision ?: 0
val phishingHashPrefixes: HashPrefixResponse? = if (networkRevision > phishingHashPrefixesRevision) {
maliciousSiteService.getPhishingHashPrefixes(phishingHashPrefixesRevision)
} else {
null
}
val malwareHashPrefixesRevision = it.firstOrNull() { it.feed == MALWARE.name }?.revision ?: 0
val malwareHashPrefixes: HashPrefixResponse? = if (networkRevision > malwareHashPrefixesRevision) {
maliciousSiteService.getMalwareHashPrefixes(malwareHashPrefixesRevision)
} else {
null
}

maliciousSiteDao.updateHashPrefixes(
phishingHashPrefixes?.let { PhishingHashPrefixesWithRevision(it.insert, it.delete, it.revision, it.replace) },
)
maliciousSiteDao.updateHashPrefixes(
malwareHashPrefixes?.let { MalwareHashPrefixesWithRevision(it.insert, it.delete, it.revision, it.replace) },
)
}
} catch (e: Exception) {
Timber.e(e, "Failed to download malicious site protection list")
}
}

override suspend fun containsHashPrefix(hashPrefix: String): Boolean {
return maliciousSiteDao.getHashPrefix(hashPrefix) != null
}
Expand All @@ -145,4 +80,108 @@ class RealMaliciousSiteRepository @Inject constructor(
listOf()
}
}

override suspend fun loadFilters(): Result<Unit> {
return loadDataOfType(FILTER_SET) { latestRevision, networkRevision, feed -> loadFilters(latestRevision, networkRevision, feed) }
}

override suspend fun loadHashPrefixes(): Result<Unit> {
return loadDataOfType(HASH_PREFIXES) { latestRevision, networkRevision, feed -> loadHashPrefixes(latestRevision, networkRevision, feed) }
}

private suspend fun loadDataOfType(
type: Type,
loadData: suspend (revisions: List<RevisionEntity>, networkRevision: Int, feed: Feed) -> Unit,
): Result<Unit> {
return withContext(dispatcherProvider.io()) {
val networkRevision = maliciousSiteService.getRevision().revision

val localRevisions = getLocalRevisions(type)

val result = Feed.entries.fold(Result.success(Unit)) { acc, feed ->
try {
loadData(localRevisions, networkRevision, feed)
acc
} catch (e: Exception) {
Result.failure(e)
}
}
result
}
}

private suspend fun <T> loadAndUpdateData(
latestRevision: List<RevisionEntity>,
networkRevision: Int,
feed: Feed,
getFunction: suspend (Int) -> T?,
updateFunction: suspend (T?) -> Unit,
) {
val revision = latestRevision.getRevisionForFeed(feed)
val data: T? = if (networkRevision > revision) {
getFunction(revision)
} else {
null
}

updateFunction(data)
}

private suspend fun loadFilters(
latestRevision: List<RevisionEntity>,
networkRevision: Int,
feed: Feed,
) {
loadAndUpdateData(
latestRevision,
networkRevision,
feed,
when (feed) {
PHISHING -> maliciousSiteService::getPhishingFilterSet
MALWARE -> maliciousSiteService::getMalwareFilterSet
},
) { maliciousSiteDao.updateFilters(it?.toFilterSetWithRevision(feed)) }
}

private suspend fun loadHashPrefixes(
latestRevision: List<RevisionEntity>,
networkRevision: Int,
feed: Feed,
) {
loadAndUpdateData(
latestRevision,
networkRevision,
feed,
when (feed) {
PHISHING -> maliciousSiteService::getPhishingHashPrefixes
MALWARE -> maliciousSiteService::getMalwareHashPrefixes
},
) { maliciousSiteDao.updateHashPrefixes(it?.toHashPrefixesWithRevision(feed)) }
}

private fun FilterSetResponse.toFilterSetWithRevision(feed: Feed): FilterSetWithRevision {
val insert = insert.toFilterSet()
val delete = delete.toFilterSet()
return when (feed) {
PHISHING -> PhishingFilterSetWithRevision(insert, delete, revision, replace)
MALWARE -> MalwareFilterSetWithRevision(insert, delete, revision, replace)
}
}

private fun HashPrefixResponse.toHashPrefixesWithRevision(feed: Feed): HashPrefixesWithRevision {
return when (feed) {
PHISHING -> PhishingHashPrefixesWithRevision(insert, delete, revision, replace)
MALWARE -> MalwareHashPrefixesWithRevision(insert, delete, revision, replace)
}
}

private suspend fun getLocalRevisions(type: Type) = (maliciousSiteDao.getLatestRevision()?.filter { it.type == type.name } ?: listOf())

private fun Set<FilterResponse>.toFilterSet(): Set<Filter> {
return map { Filter(it.hash, it.regex) }.toSet()
}

private fun List<RevisionEntity>.getRevisionForFeed(feed: Feed): Int {
return firstOrNull { it.feed == feed.name }?.revision ?: 0
}
}
Loading

0 comments on commit 9622274

Please sign in to comment.