Skip to content

Commit

Permalink
Use package name
Browse files Browse the repository at this point in the history
  • Loading branch information
yutannihilation committed Sep 16, 2023
1 parent 7474c99 commit 7a8a3f8
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
2 changes: 1 addition & 1 deletion R-package/src/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ fn print_list(x: ListSxp) {

let name = if k.is_empty() { "(no name)" } else { k };

savvy::r_print(format!("{name}: {content}\n"));
savvy::r_print(format!("{name}: {content}\n").as_str());
}
}

Expand Down
38 changes: 31 additions & 7 deletions savvy-bindgen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ fn is_marked(attrs: &[syn::Attribute]) -> bool {
attrs.iter().any(|attr| attr == &parse_quote!(#[savvy]))
}

fn parse_file(path: &Path) -> ParsedResult {
fn read_file(path: &Path) -> String {
if !path.exists() {
eprintln!("{} does not exist", path.to_string_lossy());
std::process::exit(1);
Expand All @@ -115,7 +115,11 @@ fn parse_file(path: &Path) -> ParsedResult {
std::process::exit(2);
};

let ast = match syn::parse_str::<syn::File>(&content) {
content
}

fn parse_file(path: &Path) -> ParsedResult {
let ast = match syn::parse_str::<syn::File>(&read_file(path)) {
Ok(ast) => ast,
Err(_) => {
eprintln!("Failed to parse the specified file");
Expand Down Expand Up @@ -148,6 +152,23 @@ fn parse_file(path: &Path) -> ParsedResult {
result
}

// Parse DESCRIPTION file and get the package name
fn parse_description(path: &Path) -> Option<String> {
let content = read_file(path);
for line in content.lines() {
if !line.starts_with("Package") {
continue;
}
let mut s = line.split(":");
s.next();
if let Some(rhs) = s.next() {
return Some(rhs.trim().to_string());
}
}

None
}

const PATH_DESCRIPTION: &str = "DESCRIPTION";
const PATH_LIB_RS: &str = "src/rust/src/lib.rs";
const PATH_C_HEADER: &str = "src/rust/api.h";
Expand All @@ -165,10 +186,13 @@ fn update(path: &Path) {
std::process::exit(1);
}

if !path.join(PATH_DESCRIPTION).exists() {
let pkg_name = parse_description(&path.join(PATH_DESCRIPTION));

if pkg_name.is_none() {
eprintln!("{} is not an R package root", path.to_string_lossy());
std::process::exit(4);
}
let pkg_name = pkg_name.unwrap();

let path_lib_rs = path.join(PATH_LIB_RS);
println!("Parsing {}", path_lib_rs.to_string_lossy());
Expand All @@ -180,11 +204,11 @@ fn update(path: &Path) {

let path_c_impl = path.join(PATH_C_IMPL);
println!("Writing {}", path_c_impl.to_string_lossy());
std::fs::write(path_c_impl, make_c_impl_file(&parsed_result)).unwrap();
std::fs::write(path_c_impl, make_c_impl_file(&parsed_result, &pkg_name)).unwrap();

let path_r_impl = path.join(PATH_R_IMPL);
println!("Writing {}", path_r_impl.to_string_lossy());
std::fs::write(path_r_impl, make_r_impl_file(&parsed_result)).unwrap();
std::fs::write(path_r_impl, make_r_impl_file(&parsed_result, &pkg_name)).unwrap();
}

fn main() {
Expand All @@ -197,11 +221,11 @@ fn main() {
}
Commands::CImpl { file } => {
let parsed_result = parse_file(file.as_path());
println!("{}", make_c_impl_file(&parsed_result));
println!("{}", make_c_impl_file(&parsed_result, "%%PACKAGE_NAME%%"));
}
Commands::RImpl { file } => {
let parsed_result = parse_file(file.as_path());
println!("{}", make_r_impl_file(&parsed_result));
println!("{}", make_r_impl_file(&parsed_result, "%%PACKAGE_NAME%%"));
}
Commands::Update { r_pkg_dir } => {
update(r_pkg_dir.as_path());
Expand Down
23 changes: 14 additions & 9 deletions savvy-bindgen/src/savvy_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,20 @@ SEXP {fn_name}_wrapper({args_sig}) {{
let args = self
.get_c_args()
.iter()
.map(|(pat, _)| pat.as_str())
.collect::<Vec<&str>>()
.join(", ");
.map(|(pat, _)| pat.clone())
.collect::<Vec<String>>();

let mut args_call = args.clone();
args_call.insert(0, fn_name_c.to_string());

let args = args.join(", ");
let args_call = args_call.join(", ");

let body = if self.has_result {
format!(".Call({fn_name_c}, {args})")
format!(".Call({args_call})")
} else {
// If the result is NULL, wrap it with invisible
format!("invisible(.Call({fn_name_c}, {args}))")
format!("invisible(.Call({args_call}))")
};

format!(
Expand Down Expand Up @@ -529,7 +534,7 @@ fn make_c_function_call_entry(fns: &[SavvyFn]) -> String {
.join("\n")
}

pub fn make_c_impl_file(parsed_result: &ParsedResult) -> String {
pub fn make_c_impl_file(parsed_result: &ParsedResult, pkg_name: &str) -> String {
let common_part = r#"
#include <stdint.h>
#include <Rinternals.h>
Expand Down Expand Up @@ -606,7 +611,7 @@ static const R_CallMethodDef CallEntries[] = {{
{{NULL, NULL, 0}}
}};
void R_init_savvy(DllInfo *dll) {{
void R_init_{pkg_name}(DllInfo *dll) {{
R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
R_useDynamicSymbols(dll, FALSE);
}}
Expand Down Expand Up @@ -711,7 +716,7 @@ fn make_r_impl_for_impl(savvy_impl: &SavvyImpl) -> String {
)
}

pub fn make_r_impl_file(parsed_result: &ParsedResult) -> String {
pub fn make_r_impl_file(parsed_result: &ParsedResult, pkg_name: &str) -> String {
let r_fns = parsed_result
.bare_fns
.iter()
Expand All @@ -727,7 +732,7 @@ pub fn make_r_impl_file(parsed_result: &ParsedResult) -> String {
.join("\n");

format!(
r#"#' @useDynLib savvy, .registration = TRUE
r#"#' @useDynLib {pkg_name}, .registration = TRUE
#' @keywords internal
"_PACKAGE"
Expand Down

0 comments on commit 7a8a3f8

Please sign in to comment.