Skip to content

Commit

Permalink
ffi: migrate to jffi which will provide a path for struct support
Browse files Browse the repository at this point in the history
  • Loading branch information
azenla committed Oct 7, 2023
1 parent 0d8c672 commit 437ab75
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 180 deletions.
2 changes: 2 additions & 0 deletions ffi/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ dependencies {
api(project(":evaluator"))

implementation(project(":common"))
implementation("com.github.jnr:jffi:1.3.12")
implementation("com.github.jnr:jffi:1.3.12:native")
}
7 changes: 7 additions & 0 deletions ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiAddress.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package gay.pizza.pork.ffi

data class FfiAddress(val location: Long) {
companion object {
val Null = FfiAddress(0L)
}
}
46 changes: 0 additions & 46 deletions ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiLibraryCache.kt

This file was deleted.

158 changes: 124 additions & 34 deletions ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiNativeProvider.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package gay.pizza.pork.ffi

import gay.pizza.pork.ast.ArgumentSpec
import com.kenai.jffi.*
import com.kenai.jffi.Function
import gay.pizza.pork.ast.gen.ArgumentSpec
import gay.pizza.pork.evaluator.CallableFunction
import gay.pizza.pork.evaluator.NativeProvider
import gay.pizza.pork.evaluator.None
import java.lang.foreign.*
import java.nio.file.Path
import kotlin.io.path.Path
import kotlin.io.path.absolutePathString
Expand All @@ -15,7 +16,6 @@ class FfiNativeProvider : NativeProvider {

override fun provideNativeFunction(definitions: List<String>, arguments: List<ArgumentSpec>): CallableFunction {
val functionDefinition = FfiFunctionDefinition.parse(definitions[0], definitions[1])
val linker = Linker.nativeLinker()
val functionAddress = lookupSymbol(functionDefinition)

val parameters = functionDefinition.parameters.map { id ->
Expand All @@ -25,47 +25,72 @@ class FfiNativeProvider : NativeProvider {
val returnTypeId = functionDefinition.returnType
val returnType = ffiTypeRegistry.lookup(returnTypeId) ?:
throw RuntimeException("Unknown ffi return type: $returnTypeId")
val parameterArray = parameters.map { typeAsLayout(it) }.toTypedArray()
val descriptor = if (returnType == FfiPrimitiveType.Void)
FunctionDescriptor.ofVoid(*parameterArray)
else FunctionDescriptor.of(typeAsLayout(returnType), *parameterArray)
val handle = linker.downcallHandle(functionAddress, descriptor)
val returnTypeFfi = typeConversion(returnType)
val parameterArray = parameters.map { typeConversion(it) }.toTypedArray()
val function = Function(functionAddress, returnTypeFfi, *parameterArray)
val context = function.callContext
val invoker = Invoker.getInstance()
return CallableFunction { functionArguments, _ ->
Arena.ofConfined().use { arena ->
handle.invokeWithArguments(functionArguments.map { valueAsFfi(it, arena) }) ?: None
val buffer = HeapInvocationBuffer(context)
val freeStringList = mutableListOf<FfiStringWrapper>()
for ((index, spec) in arguments.withIndex()) {
val ffiType = ffiTypeRegistry.lookup(functionDefinition.parameters[index]) ?:
throw RuntimeException("Unknown ffi type: ${functionDefinition.parameters[index]}")
if (spec.multiple) {
val variableArguments = functionArguments
.subList(index, functionArguments.size)
variableArguments.forEach {
var value = it
if (value is String) {
value = FfiStringWrapper(value)
freeStringList.add(value)
}
put(buffer, value)
}
break
} else {
val converted = convert(ffiType, functionArguments[index])
if (converted is FfiStringWrapper) {
freeStringList.add(converted)
}
put(buffer, converted)
}
}
}
}

private fun lookupSymbol(functionDefinition: FfiFunctionDefinition): MemorySegment {
if (functionDefinition.library == "c") {
return SymbolLookup.loaderLookup().find(functionDefinition.function).orElseThrow {
RuntimeException("Unknown function: ${functionDefinition.function}")
try {
return@CallableFunction invoke(invoker, function, buffer, returnType)
} finally {
freeStringList.forEach { it.free() }
}
}
}

private fun lookupSymbol(functionDefinition: FfiFunctionDefinition): Long {
val actualLibraryPath = findLibraryPath(functionDefinition.library)
val functionAddress = FfiLibraryCache.dlsym(actualLibraryPath.absolutePathString(), functionDefinition.function)
if (functionAddress.address() == 0L) {
throw RuntimeException("Unknown function: ${functionDefinition.function} in library $actualLibraryPath")
val library = Library.getCachedInstance(actualLibraryPath.absolutePathString(), Library.NOW)
?: throw RuntimeException("Failed to load library $actualLibraryPath")
val functionAddress = library.getSymbolAddress(functionDefinition.function)
if (functionAddress == 0L) {
throw RuntimeException(
"Failed to find symbol ${functionDefinition.function} in " +
"library ${actualLibraryPath.absolutePathString()}")
}
return functionAddress
}

private fun typeAsLayout(type: FfiType): MemoryLayout = when (type) {
FfiPrimitiveType.UnsignedByte, FfiPrimitiveType.Byte -> ValueLayout.JAVA_BYTE
FfiPrimitiveType.UnsignedInt, FfiPrimitiveType.Int -> ValueLayout.JAVA_INT
FfiPrimitiveType.UnsignedShort, FfiPrimitiveType.Short -> ValueLayout.JAVA_SHORT
FfiPrimitiveType.UnsignedLong, FfiPrimitiveType.Long -> ValueLayout.JAVA_LONG
FfiPrimitiveType.String -> ValueLayout.ADDRESS
FfiPrimitiveType.Pointer -> ValueLayout.ADDRESS
FfiPrimitiveType.Void -> MemoryLayout.sequenceLayout(0, ValueLayout.JAVA_INT)
else -> throw RuntimeException("Unknown ffi type to convert to memory layout: $type")
}

private fun valueAsFfi(value: Any, allocator: SegmentAllocator): Any = when (value) {
is String -> allocator.allocateUtf8String(value)
None -> MemorySegment.NULL
else -> value
private fun typeConversion(type: FfiType): Type = when (type) {
FfiPrimitiveType.UnsignedByte -> Type.UINT8
FfiPrimitiveType.Byte -> Type.SINT8
FfiPrimitiveType.UnsignedInt -> Type.UINT32
FfiPrimitiveType.Int -> Type.SINT32
FfiPrimitiveType.UnsignedShort -> Type.UINT16
FfiPrimitiveType.Short -> Type.SINT16
FfiPrimitiveType.UnsignedLong -> Type.UINT64
FfiPrimitiveType.Long -> Type.SINT64
FfiPrimitiveType.String -> Type.POINTER
FfiPrimitiveType.Pointer -> Type.POINTER
FfiPrimitiveType.Void -> Type.VOID
else -> throw RuntimeException("Unknown ffi type: $type")
}

private fun findLibraryPath(name: String): Path {
Expand All @@ -76,4 +101,69 @@ class FfiNativeProvider : NativeProvider {
return FfiPlatforms.current.platform.findLibrary(name)
?: throw RuntimeException("Unable to find library: $name")
}

private fun convert(type: FfiType, value: Any?): Any {
if (type !is FfiPrimitiveType) {
return value ?: FfiAddress.Null
}

if (type.numberConvert != null) {
return numberConvert(type.id, value, type.numberConvert)
}

if (type.notNullConversion != null) {
return notNullConvert(type.id, value, type.notNullConversion)
}

if (type.nullableConversion != null) {
return nullableConvert(value, type.nullableConversion) ?: FfiAddress.Null
}
return value ?: FfiAddress.Null
}

private fun <T> notNullConvert(type: String, value: Any?, into: Any.() -> T): T {
if (value == null) {
throw RuntimeException("Null values cannot be used for converting to type $type")
}
return into(value)
}

private fun <T> nullableConvert(value: Any?, into: Any.() -> T): T? {
if (value == null || value == None) {
return null
}
return into(value)
}

private fun <T> numberConvert(type: String, value: Any?, into: Number.() -> T): T {
if (value == null || value == None) {
throw RuntimeException("Null values cannot be used for converting to numeric type $type")
}

if (value !is Number) {
throw RuntimeException("Cannot convert value '$value' into type $type")
}
return into(value)
}

private fun put(buffer: InvocationBuffer, value: Any): Unit = when (value) {
is Byte -> buffer.putByte(value.toInt())
is Short -> buffer.putShort(value.toInt())
is Int -> buffer.putInt(value)
is Long -> buffer.putLong(value)
is FfiAddress -> buffer.putAddress(value.location)
is FfiStringWrapper -> buffer.putAddress(value.address)
else -> throw RuntimeException("Unknown buffer insertion: $value (${value.javaClass.name})")
}

private fun invoke(invoker: Invoker, function: Function, buffer: HeapInvocationBuffer, type: FfiType): Any = when (type) {
FfiPrimitiveType.Pointer -> invoker.invokeAddress(function, buffer)
FfiPrimitiveType.UnsignedInt, FfiPrimitiveType.Int -> invoker.invokeInt(function, buffer)
FfiPrimitiveType.Long -> invoker.invokeLong(function, buffer)
FfiPrimitiveType.Void -> invoker.invokeStruct(function, buffer)
FfiPrimitiveType.Double -> invoker.invokeDouble(function, buffer)
FfiPrimitiveType.Float -> invoker.invokeFloat(function, buffer)
FfiPrimitiveType.String -> invoker.invokeAddress(function, buffer)
else -> throw RuntimeException("Unsupported ffi return type: $type")
} ?: None
}
14 changes: 7 additions & 7 deletions ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiPrimitiveType.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gay.pizza.pork.ffi

import gay.pizza.pork.evaluator.None
import java.lang.foreign.MemorySegment

enum class FfiPrimitiveType(
val id: kotlin.String,
Expand All @@ -20,13 +19,14 @@ enum class FfiPrimitiveType(
Long("long", 8, numberConvert = { toLong() }),
UnsignedLong("unsigned long", 8, numberConvert = { toLong() }),
Double("double", 8, numberConvert = { toDouble() }),
String("char*", 8, nullableConversion = { toString() }),
String("char*", 8, nullableConversion = { FfiStringWrapper(toString()) }),
Pointer("void*", 8, nullableConversion = {
if (this is kotlin.Long) {
MemorySegment.ofAddress(this)
} else if (this == None) {
MemorySegment.NULL
} else this as MemorySegment
when (this) {
is FfiAddress -> this
is None -> FfiAddress.Null
is Number -> FfiAddress(this.toLong())
else -> FfiAddress.Null
}
}),
Void("void", 0)
}
17 changes: 17 additions & 0 deletions ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiStringWrapper.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package gay.pizza.pork.ffi

import com.kenai.jffi.MemoryIO

class FfiStringWrapper(input: String) {
val address: Long

init {
val bytes = input.toByteArray()
address = MemoryIO.getInstance().allocateMemory((bytes.size + 1).toLong(), true)
MemoryIO.getInstance().putZeroTerminatedByteArray(address, bytes, 0, bytes.size)
}

fun free() {
MemoryIO.getInstance().freeMemory(address)
}
}
93 changes: 0 additions & 93 deletions ffi/src/main/kotlin/gay/pizza/pork/ffi/JnaNativeProvider.kt

This file was deleted.

0 comments on commit 437ab75

Please sign in to comment.