Skip to content

Commit

Permalink
Modify MaliciousSiteProtection API
Browse files Browse the repository at this point in the history
  • Loading branch information
CrisBarreiro committed Dec 11, 2024
1 parent 17d0ff2 commit 7fd8100
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class BrowserWebViewClient @Inject constructor(

private var shouldOpenDuckPlayerInNewTab: Boolean = true

private val onSiteBlockedAsync: () -> Unit = {
private val confirmationCallback: (isMalicious: Boolean) -> Unit = {
// TODO (cbarreiro): Handle site blocked asynchronously
}

Expand Down Expand Up @@ -170,7 +170,7 @@ class BrowserWebViewClient @Inject constructor(
url,
webView.url?.toUri(),
isForMainFrame,
onSiteBlockedAsync,
confirmationCallback,
)
}
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ class WebViewRequestInterceptor(
): WebResourceResponse? {
val url: Uri? = request.url

val onSiteBlockedAsync: () -> Unit = {
val confirmationCallback: (isMalicious: Boolean) -> Unit = {
// TODO (cbarreiro): Handle site blocked asynchronously
}

maliciousSiteProtectionWebViewIntegration.shouldIntercept(request, documentUri, onSiteBlockedAsync)?.let {
maliciousSiteProtectionWebViewIntegration.shouldIntercept(request, documentUri, confirmationCallback)?.let {
// TODO (cbarreiro): Handle site blocked synchronously
return it
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import com.duckduckgo.app.pixels.remoteconfig.AndroidBrowserConfigFeature
import com.duckduckgo.common.utils.DispatcherProvider
import com.duckduckgo.di.scopes.AppScope
import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection
import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection.IsMaliciousResult.MALICIOUS
import com.duckduckgo.privacy.config.api.PrivacyConfigCallbackPlugin
import com.squareup.anvil.annotations.ContributesBinding
import com.squareup.anvil.annotations.ContributesMultibinding
Expand All @@ -41,14 +42,14 @@ interface MaliciousSiteBlockerWebViewIntegration {
suspend fun shouldIntercept(
request: WebResourceRequest,
documentUri: Uri?,
onSiteBlockedAsync: () -> Unit,
confirmationCallback: (isMalicious: Boolean) -> Unit,
): WebResourceResponse?

suspend fun shouldOverrideUrlLoading(
url: Uri,
webViewUrl: Uri?,
isForMainFrame: Boolean,
onSiteBlockedAsync: () -> Unit,
confirmationCallback: (isMalicious: Boolean) -> Unit,
): Boolean

fun onPageLoadStarted()
Expand Down Expand Up @@ -87,7 +88,7 @@ class RealMaliciousSiteBlockerWebViewIntegration @Inject constructor(
override suspend fun shouldIntercept(
request: WebResourceRequest,
documentUri: Uri?,
onSiteBlockedAsync: () -> Unit,
confirmationCallback: (isMalicious: Boolean) -> Unit,
): WebResourceResponse? {
if (!isFeatureEnabled) {
return null
Expand All @@ -109,12 +110,12 @@ class RealMaliciousSiteBlockerWebViewIntegration @Inject constructor(
}

if (request.isForMainFrame && decodedUrl.toUri() == documentUri) {
if (maliciousSiteProtection.isMalicious(decodedUrl.toUri(), onSiteBlockedAsync)) {
if (maliciousSiteProtection.isMalicious(decodedUrl.toUri(), confirmationCallback) == MALICIOUS) {
return WebResourceResponse(null, null, null)
}
processedUrls.add(decodedUrl)
} else if (isForIframe(request) && documentUri?.host == request.requestHeaders["Referer"]?.toUri()?.host) {
if (maliciousSiteProtection.isMalicious(decodedUrl.toUri(), onSiteBlockedAsync)) {
if (maliciousSiteProtection.isMalicious(decodedUrl.toUri(), confirmationCallback) == MALICIOUS) {
return WebResourceResponse(null, null, null)
}
processedUrls.add(decodedUrl)
Expand All @@ -126,7 +127,7 @@ class RealMaliciousSiteBlockerWebViewIntegration @Inject constructor(
url: Uri,
webViewUrl: Uri?,
isForMainFrame: Boolean,
onSiteBlockedAsync: () -> Unit,
confirmationCallback: (isMalicious: Boolean) -> Unit,
): Boolean {
if (!isFeatureEnabled) {
return false
Expand All @@ -140,7 +141,7 @@ class RealMaliciousSiteBlockerWebViewIntegration @Inject constructor(
}

if (isForMainFrame && decodedUrl.toUri() == webViewUrl) {
if (maliciousSiteProtection.isMalicious(decodedUrl.toUri(), onSiteBlockedAsync)) {
if (maliciousSiteProtection.isMalicious(decodedUrl.toUri(), confirmationCallback) == MALICIOUS) {
return true
}
processedUrls.add(decodedUrl)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.duckduckgo.common.test.CoroutineTestRule
import com.duckduckgo.feature.toggles.api.FakeFeatureToggleFactory
import com.duckduckgo.feature.toggles.api.Toggle.State
import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection
import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection.IsMaliciousResult.MALICIOUS
import junit.framework.TestCase.assertTrue
import kotlinx.coroutines.test.runTest
import org.junit.Assert.assertFalse
Expand Down Expand Up @@ -75,7 +76,7 @@ class RealMaliciousSiteBlockerWebViewIntegrationTest {
val request = mock(WebResourceRequest::class.java)
whenever(request.url).thenReturn(maliciousUri)
whenever(request.isForMainFrame).thenReturn(true)
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(true)
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(MALICIOUS)

val result = testee.shouldIntercept(request, maliciousUri) {}
assertNotNull(result)
Expand All @@ -87,7 +88,7 @@ class RealMaliciousSiteBlockerWebViewIntegrationTest {
whenever(request.url).thenReturn(maliciousUri)
whenever(request.isForMainFrame).thenReturn(true)
whenever(request.requestHeaders).thenReturn(mapOf("Sec-Fetch-Dest" to "iframe"))
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(true)
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(MALICIOUS)

val result = testee.shouldIntercept(request, maliciousUri) {}
assertNotNull(result)
Expand All @@ -98,39 +99,39 @@ class RealMaliciousSiteBlockerWebViewIntegrationTest {
val request = mock(WebResourceRequest::class.java)
whenever(request.url).thenReturn(maliciousUri)
whenever(request.isForMainFrame).thenReturn(false)
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(true)
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(MALICIOUS)

val result = testee.shouldIntercept(request, maliciousUri) {}
assertNull(result)
}

@Test
fun `shouldOverride returns false when feature is enabled, is malicious, and is not mainframe`() = runTest {
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(true)
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(MALICIOUS)

val result = testee.shouldOverrideUrlLoading(maliciousUri, maliciousUri, false) {}
assertFalse(result)
}

@Test
fun `shouldOverride returns true when feature is enabled, is malicious, and is mainframe`() = runTest {
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(true)
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(MALICIOUS)

val result = testee.shouldOverrideUrlLoading(maliciousUri, maliciousUri, true) {}
assertTrue(result)
}

@Test
fun `shouldOverride returns false when feature is enabled, is malicious, and not mainframe nor iframe`() = runTest {
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(true)
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(MALICIOUS)

val result = testee.shouldOverrideUrlLoading(maliciousUri, maliciousUri, false) {}
assertFalse(result)
}

@Test
fun `shouldOverride returns true when feature is enabled, is malicious, and is mainframe but webView has different host`() = runTest {
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(true)
whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(MALICIOUS)

val result = testee.shouldOverrideUrlLoading(maliciousUri, exampleUri, true) {}
assertFalse(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,11 @@ import android.net.Uri

interface MaliciousSiteProtection {

suspend fun isMalicious(url: Uri, onSiteBlockedAsync: () -> Unit): Boolean
suspend fun isMalicious(url: Uri, confirmationCallback: (isMalicious: Boolean) -> Unit): IsMaliciousResult

enum class IsMaliciousResult {
MALICIOUS,
SAFE,
WAIT_FOR_CONFIRMATION,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.duckduckgo.app.di.IsMainProcess
import com.duckduckgo.common.utils.DispatcherProvider
import com.duckduckgo.di.scopes.AppScope
import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection
import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection.IsMaliciousResult
import com.duckduckgo.privacy.config.api.PrivacyConfigCallbackPlugin
import com.squareup.anvil.annotations.ContributesBinding
import com.squareup.anvil.annotations.ContributesMultibinding
Expand Down Expand Up @@ -67,9 +68,9 @@ class RealMaliciousSiteProtection @Inject constructor(
}
}

override suspend fun isMalicious(url: Uri, onSiteBlockedAsync: () -> Unit): Boolean {
override suspend fun isMalicious(url: Uri, confirmationCallback: (isMalicious: Boolean) -> Unit): IsMaliciousResult {
Timber.tag("MaliciousSiteProtection").d("isMalicious $url")
// TODO (cbarreiro): Implement the logic to check if the URL is malicious
return false
return IsMaliciousResult.SAFE
}
}

0 comments on commit 7fd8100

Please sign in to comment.