Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[samples] create multiple variants of NsNet2 #32

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codegen/compiler/src/Quidditch/ConvertToRISCV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void ConvertToRISCV::runOnOperation() {
llvm::FileRemover stdinFileRemove(stdinFile);
{
llvm::raw_fd_ostream ss(stdinFd, /*shouldClose=*/true);
func.print(ss, OpPrintingFlags().printGenericOpForm().useLocalScope());
func.print(ss, OpPrintingFlags().useLocalScope());
}

SmallString<64> stdoutFile;
Expand Down
38 changes: 24 additions & 14 deletions runtime/cmake/quidditch_module.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,23 @@ find_program(XDSL_OPT_PATH xdsl-opt
# The resulting library is the source file's name with the extension removed and
# '_module' appended.
function(quidditch_module)
cmake_parse_arguments(_RULE "LLVM;ASSERT_XDSL" "SRC" "FLAGS;DEPENDS" ${ARGN})
cmake_parse_arguments(_RULE "LLVM;ASSERT_XDSL" "SRC;N_THREADS;DST" "FLAGS;DEPENDS" ${ARGN})

set(_MLIR_SRC "${_RULE_SRC}")
if (NOT _RULE_DST)
cmake_path(GET _MLIR_SRC STEM _RULE_DST)
set(_RULE_DST "${_RULE_DST}")
endif ()

cmake_path(GET _MLIR_SRC STEM filename)
if (NOT _RULE_N_THREADS)
set(_RULE_N_THREADS 8)
endif ()

get_filename_component(_MLIR_SRC "${_MLIR_SRC}" REALPATH)
set(_O_QUIDDITCH_FILE_NAME "${CMAKE_CURRENT_BINARY_DIR}/${filename}/${filename}.o")
set(_O_LLVM_FILE_NAME "${CMAKE_CURRENT_BINARY_DIR}/${filename}/${filename}_llvm.o")
set(_H_FILE_NAME "${CMAKE_CURRENT_BINARY_DIR}/${filename}/${filename}_module.h")
set(_MODULE_NAME "${filename}_module")
set(_O_QUIDDITCH_FILE_NAME "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_DST}/${_RULE_DST}.o")
set(_O_LLVM_FILE_NAME "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_DST}/${_RULE_DST}_llvm.o")
set(_H_FILE_NAME "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_DST}/${_RULE_DST}_module.h")
set(_MODULE_NAME "${_RULE_DST}_module")

set(_COMPILER_ARGS ${_RULE_FLAGS})
list(APPEND _COMPILER_ARGS "--iree-vm-bytecode-module-strip-source-map=true")
Expand All @@ -72,10 +78,7 @@ function(quidditch_module)
list(APPEND _COMPILER_ARGS "--iree-input-demote-f64-to-f32=0")

set(_OUTPUT_FILES "${_H_FILE_NAME}")
string(REPLACE ".o" ".h" _STATIC_HDR_PATH "${_O_LLVM_FILE_NAME}")
list(APPEND _OUTPUT_FILES "${_STATIC_HDR_PATH}" "${_O_LLVM_FILE_NAME}")

set(_OBJECT_FILES "${_O_LLVM_FILE_NAME}")
set(_OBJECT_FILES)

set(_EXTRA_DEPENDS ${_RULE_DEPENDS})
if (NOT _RULE_LLVM)
Expand All @@ -92,11 +95,18 @@ function(quidditch_module)

list(APPEND _OUTPUT_FILES "${_O_QUIDDITCH_FILE_NAME}")
list(APPEND _OBJECT_FILES "${_O_QUIDDITCH_FILE_NAME}")
list(APPEND _OBJECT_FILES "${_O_LLVM_FILE_NAME}")

string(REPLACE ".o" ".h" _STATIC_HDR_PATH "${_O_QUIDDITCH_FILE_NAME}")
list(APPEND _OUTPUT_FILES "${_STATIC_HDR_PATH}")
else ()
set(_O_LLVM_FILE_NAME ${_O_QUIDDITCH_FILE_NAME})
list(APPEND _OBJECT_FILES "${_O_LLVM_FILE_NAME}")
endif ()

string(REPLACE ".o" ".h" _STATIC_HDR_PATH "${_O_LLVM_FILE_NAME}")
list(APPEND _OUTPUT_FILES "${_STATIC_HDR_PATH}" "${_O_LLVM_FILE_NAME}")

list(APPEND _COMPILER_ARGS "--iree-hal-target-backends=llvm-cpu")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-debug-symbols=true")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-triple=riscv32-unknown-elf")
Expand All @@ -106,7 +116,7 @@ function(quidditch_module)
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-float-abi=hard")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-link-embedded=false")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-link-static")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-number-of-threads=8")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-number-of-threads=${_RULE_N_THREADS}")
list(APPEND _COMPILER_ARGS "--iree-llvmcpu-static-library-output-path=${_O_LLVM_FILE_NAME}")

list(APPEND _COMPILER_ARGS "--output-format=vm-c")
Expand All @@ -126,15 +136,15 @@ function(quidditch_module)
#define EMITC_IMPLEMENTATION
#include "@[email protected]"
]] @ONLY)
add_library(${_MODULE_NAME}
add_library(${_RULE_DST}
STATIC ${_C_FILE_NAME} ${_OBJECT_FILES}
${_H_FILE_NAME}
)
target_link_libraries(${_MODULE_NAME}
target_link_libraries(${_RULE_DST}
PUBLIC
iree::vm
)
target_include_directories(${_MODULE_NAME} INTERFACE ${CMAKE_CURRENT_BINARY_DIR}/${filename})
target_include_directories(${_RULE_DST} INTERFACE ${CMAKE_CURRENT_BINARY_DIR}/${_RULE_DST})
endfunction()

# Use iree-turbine to convert a PyTorch model to MLIR.
Expand Down
39 changes: 33 additions & 6 deletions runtime/samples/nsnet2/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,40 @@

iree_turbine(SRC NsNet2.py DST ${CMAKE_CURRENT_BINARY_DIR}/nsnet2.mlirbc DTYPE "f64")
quidditch_module(SRC ${CMAKE_CURRENT_BINARY_DIR}/nsnet2.mlirbc)
quidditch_module(SRC ${CMAKE_CURRENT_BINARY_DIR}/nsnet2.mlirbc DST nsnet2)
quidditch_module(SRC ${CMAKE_CURRENT_BINARY_DIR}/nsnet2.mlirbc LLVM DST nsnet2_llvm)
quidditch_module(SRC ${CMAKE_CURRENT_BINARY_DIR}/nsnet2.mlirbc DST nsnet2st)
quidditch_module(SRC ${CMAKE_CURRENT_BINARY_DIR}/nsnet2.mlirbc LLVM DST nsnet2st_llvm)

add_executable(NsNet2LLVM main.c)
target_link_libraries(
NsNet2LLVM
add_library(nsnet2_util nsnet2_util.c)
target_link_libraries(nsnet2_util
PRIVATE
samples_util
nsnet2_module
snRuntime
snRuntimeInterface
Quidditch::dispatch::dispatch
)
target_include_directories(nsnet2_util INTERFACE ${CMAKE_CURRENT_LIST_DIR})

macro(create_experiment_variant target_name iree_module query_func)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/${target_name}.c "\
#include <${iree_module}.h>

#include \"nsnet2_util.h\"

int main() {
return run_nsnet2_experiment(${query_func});
}
")
add_executable(${target_name} ${CMAKE_CURRENT_BINARY_DIR}/${target_name}.c)
target_link_libraries(
${target_name}
PRIVATE
nsnet2_util
${iree_module}
snRuntime
)
endmacro()

create_experiment_variant(NsNet2 nsnet2 "quidditch_compiled_ns_net2_linked_quidditch_library_query")
create_experiment_variant(NsNet2LLVM nsnet2_llvm "compiled_ns_net2_linked_llvm_cpu_library_query")
create_experiment_variant(NsNet2ST nsnet2st "quidditch_compiled_ns_net2_linked_quidditch_library_query")
create_experiment_variant(NsNet2STLLVM nsnet2st_llvm "compiled_ns_net2_linked_llvm_cpu_library_query")
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
#include "nsnet2_util.h"

#include <Quidditch/dispatch/dispatch.h>

#include <nsnet2.h>
#include <nsnet2_module.h>
#include <iree/base/alignment.h>

#include <team_decls.h>
#include <util/run_model.h>

int main() {
iree_status_t compiled_ns_net2_create(iree_vm_instance_t *, iree_allocator_t,
iree_vm_module_t **);

int run_nsnet2_experiment(
iree_hal_executable_library_query_fn_t implementation) {
if (!snrt_is_dm_core()) return quidditch_dispatch_enter_worker_loop();

double data[161];
Expand All @@ -15,25 +21,23 @@ int main() {
}

model_config_t config = {
.libraries =
(iree_hal_executable_library_query_fn_t[]){
quidditch_compiled_ns_net2_linked_quidditch_library_query},
.libraries = (iree_hal_executable_library_query_fn_t[]){implementation},
.num_libraries = 1,
.module_constructor = compiled_ns_net2_create,
.main_function = iree_make_cstring_view("compiled_ns_net2.main"),

.element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_64,

.num_inputs = 1,
.input_data = (const void*[]){data, data},
.input_data = (const void *[]){data, data},
.input_sizes = (const iree_host_size_t[]){IREE_ARRAYSIZE(data)},
.input_ranks = (const iree_host_size_t[]){3},
.input_shapes = (const iree_hal_dim_t*[]){(iree_hal_dim_t[]){1, 1, 161}},
.input_shapes = (const iree_hal_dim_t *[]){(iree_hal_dim_t[]){1, 1, 161}},

.num_outputs = 1,
.output_data = (void*[]){data},
.output_data = (void *[]){data},
.output_sizes = (const iree_host_size_t[]){IREE_ARRAYSIZE(data)},
.device_allocator = l3_allocator(),
.device_allocator = l1_allocator(),
};

IREE_CHECK_OK(run_model(&config));
Expand Down
7 changes: 7 additions & 0 deletions runtime/samples/nsnet2/nsnet2_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

#pragma once

#include <iree/hal/local/executable_library.h>

int run_nsnet2_experiment(
iree_hal_executable_library_query_fn_t implementation);
2 changes: 1 addition & 1 deletion runtime/samples/vec_multiply/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ target_link_libraries(
vec_multiply
PRIVATE
samples_util
simple_add_module
simple_add
snRuntime
Quidditch::dispatch::dispatch
)
Expand Down