Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic language choice #13

Merged
merged 10 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ node_modules/
result
.vscode/
**/.DS_Store
hs_err_pid*

This file was deleted.

This file was deleted.

10 changes: 8 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ lazy val core = crossProject(JVMPlatform)
.jvmSettings(
commonJVMSettings,
libraryDependencies ++= Seq(
"net.java.dev.jna" % "jna" % "5.14.0"
"net.java.dev.jna" % "jna" % "5.15.0"
),
)

Expand All @@ -82,7 +82,13 @@ lazy val bindingsPython = crossProject(JVMPlatform)
lazy val tests = crossProject(JVMPlatform)
.crossType(CrossType.Pure)
.settings(
commonSettings
commonSettings,
run / fork := true,
// options for debugging JNA issues
// Test / javaOptions ++= Seq(
// "-Djna.debug_load=true",
// "-Djna.debug_load.jna=true",
// ),
)
.dependsOn(bindingsPython)
.jvmSettings(commonJVMSettings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package org.polyvariant.treesitter4s.lowlevel

import org.polyvariant.treesitter4s.internal.TreeSitterLibrary
import com.sun.jna.*
import org.polyvariant.treesitter4s.internal.TreeSitterLibrary

object TreeSitterPlatform {

Expand All @@ -41,11 +41,38 @@ object TreeSitterPlatform {
val NullTree: Tree = null
}

type Language = org.polyvariant.treesitter4s.Language
trait LanguageWrapper {
def lang: org.polyvariant.treesitter4s.Language
}

type Language = LanguageWrapper

val Language: LanguageMethods =
new {
def apply(language: org.polyvariant.treesitter4s.Language): Language = language

def apply(
languageName: String
): Language = {
val library = NativeLibrary.getInstance(s"tree-sitter-$languageName")

val function = library.getFunction(s"tree_sitter_$languageName");

val langg = function
.invoke(classOf[org.polyvariant.treesitter4s.Language], Array())
.asInstanceOf[org.polyvariant.treesitter4s.Language]

new LanguageWrapper {
def lang: org.polyvariant.treesitter4s.Language = {
// but we need to keep a reference to the library for... reasons
// probably related to, but not quite the same, as:
// https://github.com/java-native-access/jna/pull/1378
// basically, segfaults.
library.hashCode()
langg
}
}
}

}

type Node = TreeSitterLibrary.Node
Expand All @@ -56,7 +83,7 @@ object TreeSitterPlatform {
def tsParserSetLanguage(
parser: Parser,
language: Language,
): Boolean = LIBRARY.ts_parser_set_language(parser, language)
): Boolean = LIBRARY.ts_parser_set_language(parser, language.lang)

def tsParserParseString(
parser: Parser,
Expand All @@ -67,9 +94,9 @@ object TreeSitterPlatform {

def tsLanguageSymbolCount(
language: Language
): Long = LIBRARY.ts_language_symbol_count(language)
): Long = LIBRARY.ts_language_symbol_count(language.lang)

def tsLanguageVersion(language: Language): Long = LIBRARY.ts_language_version(language)
def tsLanguageVersion(language: Language): Long = LIBRARY.ts_language_version(language.lang)

def tsNodeChild(node: Node, index: Long): Node = LIBRARY.ts_node_child(node, index)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ trait TreeSitterAPI {

object TreeSitterAPI {

def make(language: (ts: TreeSitter) => ts.Language): TreeSitterAPI = {
def make(language: String): TreeSitterAPI = {
val ts = TreeSitter.instance
val lang = ts.Language(language)

internal.Facade.make(ts, language(ts))
internal.Facade.make(ts, lang)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

package org.polyvariant.treesitter4s.internal

import org.polyvariant.treesitter4s.lowlevel.TreeSitter
import org.polyvariant.treesitter4s.TreeSitterAPI
import org.polyvariant.treesitter4s.Tree
import org.polyvariant.treesitter4s.Node
import org.polyvariant.treesitter4s.Tree
import org.polyvariant.treesitter4s.TreeSitterAPI
import org.polyvariant.treesitter4s.lowlevel.TreeSitter

private[treesitter4s] object Facade {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trait TreeSitter {
type Language

trait LanguageMethods {
def apply(language: org.polyvariant.treesitter4s.Language): Language
def apply(libraryName: String): Language
}

val Language: LanguageMethods
Expand Down
62 changes: 62 additions & 0 deletions fun-times.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include <dlfcn.h>
#include <stdio.h>
#include <stdlib.h>

// POC of 'reflection' on tree-sitter grammar dylibs in C
// to see if we can apply the same paradigm on all scala platforms
// and avoid language implementors having to write any binding traits whatsoever.
// usage: cc fun-times.c -o get-lang && ./get-lang python
// (requires the dylib to be in the current dir of course)
int main(int argc, char *argv[])
{
// Ensure a language name is provided
if (argc != 2)
{
fprintf(stderr, "Usage: %s <language_name>\n", argv[0]);
return 1;
}

const char *language_name = argv[1]; // Read the language name from argv

// Construct the shared library name dynamically
char lib_name[256];
snprintf(lib_name, sizeof(lib_name), "libtree-sitter-%s.dylib", language_name);

// Load the shared library
void *handle = dlopen(lib_name, RTLD_LAZY);
if (!handle)
{
fprintf(stderr, "Error loading library: %s\n", dlerror());
return 1;
}

const char *func_name_prefix = "tree_sitter"; // Prefix for the function name
char func_name_full[256];

// Construct the full function name dynamically
snprintf(func_name_full, sizeof(func_name_full), "%s_%s", func_name_prefix, language_name);

void *(*func)(); // Adjust the function pointer type to return a pointer

// Use dlsym to get the function address
*(void **)(&func) = dlsym(handle, func_name_full);

// Check for errors in retrieving the function
char *error = dlerror();
if (error != NULL)
{
fprintf(stderr, "Error locating function: %s\n", error);
dlclose(handle);
return 1;
}

// Call the function and get the returned pointer
void *result = func();

// Print the address of the returned pointer
printf("Address of the returned pointer: %p\n", result);

// Clean up
dlclose(handle);
return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
package org.polyvariant.treesitter4s.tests

import org.polyvariant.treesitter4s.TreeSitterAPI
import org.polyvariant.treesitter4s.bindings.python.PythonLanguage

object Demo {

def main(args: Array[String]): Unit = {
val ts = TreeSitterAPI.make(PythonLanguage)
val ts = TreeSitterAPI.make("python")

System.out.println(ts.parse("""def main = print("hello world")""").rootNode.map(_.tpe))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ import cats.implicits._
import org.polyvariant.treesitter4s.Tree
import weaver._
import org.polyvariant.treesitter4s.TreeSitterAPI
import org.polyvariant.treesitter4s.bindings.python.PythonLanguage

object BindingTests extends FunSuite {
val tsPython = TreeSitterAPI.make(PythonLanguage)

def parseExample(s: String): Tree = tsPython.parse(s)
val ts = TreeSitterAPI.make("python")
def parseExample(s: String): Tree = ts.parse(s)

test("root node child count") {
val tree = parseExample("def main = print('Hello')\n")
Expand All @@ -41,13 +39,13 @@ object BindingTests extends FunSuite {
// assert.eql(rootNode.map(_.tpe), Some("compilation_unit"))
// }

// test("root node child by index (in range)") {
// val tree = parseExample("class Hello {}")
test("root node child by index (in range)") {
val tree = parseExample("class Hello {}")

// val rootNode = tree.rootNode.getOrElse(sys.error("missing root node"))
val rootNode = tree.rootNode.getOrElse(sys.error("missing root node"))

// assert.eql(rootNode.children.lift(0).isDefined, true)
// }
assert.eql(rootNode.children.lift(0).isDefined, true)
}

// test("root node child by index (out of range)") {
// val tree = parseExample("class Hello {}")
Expand Down