diff --git a/strikt-core/src/main/kotlin/strikt/api/Assertion.kt b/strikt-core/src/main/kotlin/strikt/api/Assertion.kt index c629c161..41963bb5 100644 --- a/strikt-core/src/main/kotlin/strikt/api/Assertion.kt +++ b/strikt-core/src/main/kotlin/strikt/api/Assertion.kt @@ -1,8 +1,12 @@ package strikt.api +import filepeek.FileInfo import filepeek.LambdaBody +import filepeek.SourceFileNotFoundException import strikt.internal.FilePeek +import java.io.File import java.util.Locale +import java.util.concurrent.ConcurrentHashMap import kotlin.jvm.internal.CallableReference import kotlin.reflect.KFunction import kotlin.reflect.KProperty @@ -233,6 +237,24 @@ interface Assertion { description: String, function: T.() -> R, block: Builder.() -> Unit + ): Builder = with({ description }, function, block) + + /** + * Runs a group of assertions on the subject returned by [function]. + * + * The [description] is only invoked if the test fails. + * + * @param description a lambda that produces a description of the mapped result. + * @param function a lambda whose receiver is the current assertion subject. + * @param block a closure that can perform multiple assertions that will all + * be evaluated regardless of whether preceding ones pass or fail. + * @param R the mapped subject type. + * @return this builder, to facilitate chaining. + */ + fun with( + description: () -> String, + function: T.() -> R, + block: Builder.() -> Unit ): Builder /** @@ -262,6 +284,23 @@ interface Assertion { fun get( description: String, function: T.() -> R + ): DescribeableBuilder = get({ description }, function) + + /** + * Maps the assertion subject to the result of [function]. + * This is useful for chaining to property values or method call results on + * the subject. + * + * The [description] is only invoked if the test fails. + * + * @param description a lambda that produces a description of the mapped result. + * @param function a lambda whose receiver is the current assertion subject. + * @return an assertion builder whose subject is the value returned by + * [function]. + */ + fun get( + description: () -> String, + function: T.() -> R ): DescribeableBuilder /** @@ -322,23 +361,117 @@ interface Assertion { } } -private fun (Receiver.() -> Result).describe(): String = +private val DESCRIBE_CACHE = ConcurrentHashMap, () -> String>() + +private fun (Receiver.() -> Result).describe(): () -> String = when (this) { - is KProperty<*> -> - "value of property $name" - is KFunction<*> -> - "return value of $name" - is CallableReference -> "value of $propertyName" + is KProperty<*> -> { + { "value of property $name" } + } + + is KFunction<*> -> { + { "return value of $name" } + } + + is CallableReference -> { + { "value of $propertyName" } + } + else -> { - try { - val line = FilePeek.filePeek.getCallerFileInfo().line - LambdaBody("get", line).body.trim() - } catch (e: Exception) { - "%s" + var lambda = DESCRIBE_CACHE[javaClass] + + if (lambda == null) { + lambda = captureGet(RuntimeException()) + DESCRIBE_CACHE.putIfAbsent(javaClass, lambda) } + + lambda + } + } + +private fun captureGet(ex: Throwable): () -> String { + return { + try { + val line = FilePeek.filePeek.specialGetCallerInfo(ex).line + LambdaBody("get", line).body.trim() + } catch (e: Exception) { + "%s" + } + } +} + +private fun Sequence.takeWhileInclusive(pred: (T) -> Boolean): Sequence { + var shouldContinue = true + return takeWhile { + val result = shouldContinue + shouldContinue = pred(it) + result + } +} + +private val FS = File.separator + +private val ignoredPackages = listOf( + "strikt.internal", + "strikt.api", + "filepeek" +) + +private val sourceRoots: List = listOf("src${FS}test${FS}kotlin", "src${FS}test${FS}java") + +private fun filepeek.FilePeek.specialGetCallerInfo(ex: Throwable): FileInfo { + val callerStackTraceElement = ex.stackTrace.first { el -> + ignoredPackages + .none { el.className.startsWith(it) } + } + val className = callerStackTraceElement.className.substringBefore('$') + val clazz = javaClass.classLoader.loadClass(className)!! + val classFilePath = File(clazz.protectionDomain.codeSource.location.path) + .absolutePath + + val buildDir = when { + classFilePath.contains("${FS}out$FS") -> "out${FS}test${FS}classes" // running inside IDEA + classFilePath.contains("build${FS}classes${FS}java") -> "build${FS}classes${FS}java${FS}test" // gradle 4.x java source + classFilePath.contains("build${FS}classes${FS}kotlin") -> "build${FS}classes${FS}kotlin${FS}test" // gradle 4.x kotlin sources + classFilePath.contains("target${FS}classes") -> "target${FS}classes" // maven + else -> "build${FS}classes${FS}test" // older gradle + } + + val sourceFileCandidates = sourceRoots + .map { sourceRoot -> + val sourceFileWithoutExtension = + classFilePath.replace(buildDir, sourceRoot) + .plus(FS + className.replace(".", FS)) + + File(sourceFileWithoutExtension).parentFile + .resolve(callerStackTraceElement.fileName!!) } + val sourceFile = sourceFileCandidates.singleOrNull(File::exists) ?: throw SourceFileNotFoundException( + classFilePath, + className, + sourceFileCandidates + ) + + val callerLine = sourceFile.bufferedReader().useLines { lines -> + var braceDelta = 0 + lines.drop(callerStackTraceElement.lineNumber - 1) + .takeWhileInclusive { line -> + val openBraces = line.count { it == '{' } + val closeBraces = line.count { it == '}' } + braceDelta += openBraces - closeBraces + braceDelta != 0 + }.map { it.trim() }.joinToString(separator = "") } + return FileInfo( + callerStackTraceElement.lineNumber, + sourceFileName = sourceFile.absolutePath, + line = callerLine.trim(), + methodName = callerStackTraceElement.methodName + + ) +} + private val CallableReference.propertyName: String get() = "^get(.+)$".toRegex().find(name).let { match -> return when (match) { diff --git a/strikt-core/src/main/kotlin/strikt/internal/AssertionBuilder.kt b/strikt-core/src/main/kotlin/strikt/internal/AssertionBuilder.kt index 37f3da02..0d3f4db3 100644 --- a/strikt-core/src/main/kotlin/strikt/internal/AssertionBuilder.kt +++ b/strikt-core/src/main/kotlin/strikt/internal/AssertionBuilder.kt @@ -17,14 +17,14 @@ internal class AssertionBuilder( override fun describedAs(description: String): Builder { if (context is DescribedNode<*>) { - context.description = description + context.description = { description } } return this } override fun describedAs(descriptor: T.() -> String): Builder { if (context is DescribedNode<*>) { - context.description = context.subject.descriptor() + context.description = { context.subject.descriptor() } } return this } @@ -98,14 +98,14 @@ internal class AssertionBuilder( } override fun get( - description: String, + description: () -> String, function: (T) -> R ): DescribeableBuilder = if (context.allowChain) { runCatching { function(context.subject) } - .getOrElse { ex -> throw MappingFailed(description, ex) } + .getOrElse { ex -> throw MappingFailed(description(), ex) } .let { AssertionBuilder( AssertionSubject(context, it, description), @@ -121,7 +121,7 @@ internal class AssertionBuilder( } override fun with( - description: String, + description: () -> String, function: T.() -> R, block: Builder.() -> Unit ): Builder { @@ -135,7 +135,7 @@ internal class AssertionBuilder( strategy.evaluate(nestedContext) } } - .onFailure { ex -> throw MappingFailed(description, ex) } + .onFailure { ex -> throw MappingFailed(description(), ex) } return this } diff --git a/strikt-core/src/main/kotlin/strikt/internal/AssertionNode.kt b/strikt-core/src/main/kotlin/strikt/internal/AssertionNode.kt index fa267221..9d7f90e6 100644 --- a/strikt-core/src/main/kotlin/strikt/internal/AssertionNode.kt +++ b/strikt-core/src/main/kotlin/strikt/internal/AssertionNode.kt @@ -23,7 +23,7 @@ internal interface AssertionNode { } internal interface DescribedNode : AssertionNode { - var description: String + var description: () -> String } internal interface AssertionGroup : AssertionNode { @@ -39,7 +39,7 @@ internal interface AssertionResult : DescribedNode { internal class AssertionSubject( override val parent: AssertionGroup<*>?, override val subject: S, - override var description: String = "%s" + override var description: () -> String = { "%s" } ) : AssertionGroup, DescribedNode { constructor(value: S) : this(null, value) @@ -138,7 +138,7 @@ internal class AssertionChainedGroup( internal abstract class AtomicAssertionNode( final override val parent: AssertionGroup, - override var description: String, + override var description: () -> String, override val expected: Any? = null ) : AssertionResult, AtomicAssertion { @@ -156,7 +156,7 @@ internal abstract class AtomicAssertionNode( internal abstract class CompoundAssertionNode( final override val parent: AssertionGroup, - override var description: String, + override var description: () -> String, override val expected: Any? = null ) : AssertionGroup, AssertionResult, CompoundAssertion { diff --git a/strikt-core/src/main/kotlin/strikt/internal/AssertionStrategy.kt b/strikt-core/src/main/kotlin/strikt/internal/AssertionStrategy.kt index 310a3d49..4fc0c3c3 100644 --- a/strikt-core/src/main/kotlin/strikt/internal/AssertionStrategy.kt +++ b/strikt-core/src/main/kotlin/strikt/internal/AssertionStrategy.kt @@ -20,7 +20,7 @@ internal sealed class AssertionStrategy { ): AtomicAssertionNode = object : AtomicAssertionNode( context, - provideDescription(description), + { provideDescription(description) }, expected ) { @@ -62,7 +62,7 @@ internal sealed class AssertionStrategy { ): CompoundAssertionNode = object : CompoundAssertionNode( context, - provideDescription(description), + { provideDescription(description) }, expected ) { diff --git a/strikt-core/src/main/kotlin/strikt/internal/reporting/DefaultResultWriter.kt b/strikt-core/src/main/kotlin/strikt/internal/reporting/DefaultResultWriter.kt index f71794fd..ff753101 100644 --- a/strikt-core/src/main/kotlin/strikt/internal/reporting/DefaultResultWriter.kt +++ b/strikt-core/src/main/kotlin/strikt/internal/reporting/DefaultResultWriter.kt @@ -82,6 +82,7 @@ internal open class DefaultResultWriter : ResultWriter { if (isRoot) { writer.append("Expect that ") } + val description = this.description() // if the value spans > 1 line, this is how much to indent following lines val valueIndent = (description.indexOf("%")).coerceAtLeast(0) + 14 + (indent * 2) @@ -104,6 +105,7 @@ internal open class DefaultResultWriter : ResultWriter { val failed = status as? Failed when { failed?.comparison != null -> { + val description = this.description() val formattedComparison = failed.comparison.formatValues() val failedDescription = failed.description ?: "found %s" val descriptionIndent = description.indexOf("%") @@ -143,10 +145,10 @@ internal open class DefaultResultWriter : ResultWriter { } failed?.description != null -> writer - .append(description.format(formatValue(expected))) + .append(description().format(formatValue(expected))) .append(" : ") .append(failed.description) - else -> writer.append(description.format(formatValue(expected))) + else -> writer.append(description().format(formatValue(expected))) } } diff --git a/strikt-core/src/test/kotlin/strikt/Mapping.kt b/strikt-core/src/test/kotlin/strikt/Mapping.kt index 2c232a91..2204120d 100644 --- a/strikt-core/src/test/kotlin/strikt/Mapping.kt +++ b/strikt-core/src/test/kotlin/strikt/Mapping.kt @@ -15,6 +15,7 @@ import strikt.assertions.first import strikt.assertions.flatMap import strikt.assertions.get import strikt.assertions.isEqualTo +import strikt.assertions.isIn import strikt.assertions.isNotNull import strikt.assertions.isNull import strikt.assertions.last @@ -22,9 +23,78 @@ import strikt.assertions.map import strikt.assertions.message import strikt.assertions.single import java.time.LocalDate +import kotlin.random.Random @DisplayName("mapping assertions") internal class Mapping { + @Test + fun `get perf`() { + class Person( + val id: Int = Random.nextInt(), + val id1: Int = Random.nextInt(), + val id2: Int = Random.nextInt(), + val id3: Int = Random.nextInt(), + val id4: Int = Random.nextInt(), + val id5: Int = Random.nextInt(), + val id6: Int = Random.nextInt(), + val id7: Int = Random.nextInt(), + val id8: Int = Random.nextInt(), + val id9: Int = Random.nextInt(), + val id10: Int = Random.nextInt(), + val id11: Int = Random.nextInt(), + val id12: Int = Random.nextInt(), + val id13: Int = Random.nextInt(), + val id14: Int = Random.nextInt(), + val id15: Int = Random.nextInt(), + val id16: Int = Random.nextInt(), + val id17: Int = Random.nextInt(), + val id18: Int = Random.nextInt(), + val id19: Int = Random.nextInt(), + val id20: Int = Random.nextInt(), + val id21: Int = Random.nextInt(), + val id22: Int = Random.nextInt(), + val id23: Int = Random.nextInt(), + val id24: Int = Random.nextInt(), + val id25: Int = Random.nextInt(), + val id26: Int = Random.nextInt(), + val id27: Int = Random.nextInt(), + ) + val range = Int.MIN_VALUE..Int.MAX_VALUE + + repeat(10_000) { + expectThat(Person(Random.nextInt())) { + get { id } isIn range + get { id1 } isIn range + get { id2 } isIn range + get { id3 } isIn range + get { id4 } isIn range + get { id5 } isIn range + get { id6 } isIn range + get { id7 } isIn range + get { id8 } isIn range + get { id9 } isIn range + get { id10 } isIn range + get { id11 } isIn range + get { id12 } isIn range + get { id13 } isIn range + get { id14 } isIn range + get { id15 } isIn range + get { id16 } isIn range + get { id17 } isIn range + get { id18 } isIn range + get { id19 } isIn range + get { id20 } isIn range + get { id21 } isIn range + get { id22 } isIn range + get { id23 } isIn range + get { id24 } isIn range + get { id25 } isIn range + get { id26 } isIn range + get { id27 } isIn range + } + } + } + @Test fun `map() on iterable subjects maps to an iterable`() { val subject = listOf("catflap", "rubberplant", "marzipan")