Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cache criteria to CachedTool #791

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,44 @@ abstract class CachedTool<Input, Output>(
private val timeCachePolicy: Duration = 1.days
) : Tool<Input, Output> {

override suspend fun invoke(input: Input): Output {
return cache(CachedToolKey(input, seed)) { onCacheMissed(input) }
}
/**
* Logic to be executed when the cache is missed.
*
* @return the output.
*/
abstract suspend fun onCacheMissed(input: Input): Output

/**
* Criteria to check if the cache should be used for the given [input]. By default, it returns
* true, meaning always use the cache if available.
*
* @return true if the cache should be used.
*/
open suspend fun shouldUseCache(input: Input): Boolean = true

/**
* Criteria to check if the result should be cached based on the given [input] and [output]. By
* default, it returns true, meaning always cache the result.
*
* @return true if the result should be cached.
*/
open suspend fun shouldCacheOutput(input: Input, output: Output): Boolean = true

/**
* Caches the result of [onCacheMissed] if [shouldCacheOutput] returns true. Otherwise, returns
* the result of [onCacheMissed].
*
* @return the output.
*/
override suspend fun invoke(input: Input): Output =
if (shouldUseCache(input)) cache(CachedToolKey(input, seed)) { onCacheMissed(input) }
else onCacheMissed(input)

/**
* Exposes the cache as a [Map] of [Input] to [Output] filtered by instance [seed] and
* [timeCachePolicy]. Removes expired cache entries.
*
* @return the map of input to output.
*/
suspend fun getCache(): Map<Input, Output> {
val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds
Expand All @@ -42,8 +73,6 @@ abstract class CachedTool<Input, Output>(
return withoutExpired.map { it.key.value to it.value.value }.toMap()
}

abstract suspend fun onCacheMissed(input: Input): Output

private suspend fun cache(input: CachedToolKey<Input>, block: suspend () -> Output): Output {
val cachedToolInfo = cache.get()[input]
if (cachedToolInfo != null) {
Expand All @@ -55,7 +84,9 @@ abstract class CachedTool<Input, Output>(
}
}
val response = block()
cache.get()[input] = CachedToolValue(response, timeInMillis())
if (shouldCacheOutput(input.value, response)) {
cache.get()[input] = CachedToolValue(response, timeInMillis())
}
return response
}
}
Loading