Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump OptiX ABI version to 87 (OptiX 8.0) #117

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions src/optix_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
*/

#define DR_OPTIX_SYM(...) __VA_ARGS__ = nullptr;
#define DR_OPTIX_ABI_VERSION 55
#define DR_OPTIX_FUNCTION_TABLE_SIZE 43
#define DR_OPTIX_ABI_VERSION 87
#define DR_OPTIX_FUNCTION_TABLE_SIZE 48

#include "optix.h"
#include "optix_api.h"
Expand Down Expand Up @@ -43,8 +43,8 @@ static const char *jitc_optix_table_names[DR_OPTIX_FUNCTION_TABLE_SIZE] = {
"optixDeviceContextGetCacheEnabled",
"optixDeviceContextGetCacheLocation",
"optixDeviceContextGetCacheDatabaseSizes",
"optixModuleCreateFromPTX",
"optixModuleCreateFromPTXWithTasks",
"optixModuleCreate",
"optixModuleCreateWithTasks",
"optixModuleGetCompilationState",
"optixModuleDestroy",
"optixBuiltinISModuleGet",
Expand All @@ -61,9 +61,14 @@ static const char *jitc_optix_table_names[DR_OPTIX_FUNCTION_TABLE_SIZE] = {
"optixAccelCheckRelocationCompatibility",
"optixAccelRelocate",
"optixAccelCompact",
"optixAccelEmitProperty",
"optixConvertPointerToTraversableHandle",
"reserved1",
"reserved2",
"optixOpacityMicromapArrayComputeMemoryUsage",
"optixOpacityMicromapArrayBuild",
"optixOpacityMicromapArrayGetRelocationInfo",
"optixOpacityMicromapArrayRelocate",
"optixDisplacementMicromapArrayComputeMemoryUsage",
"optixDisplacementMicromapArrayBuild",
"optixSbtRecordPackHeader",
"optixLaunch",
"optixDenoiserCreate",
Expand All @@ -80,16 +85,6 @@ bool jitc_optix_api_init() {
if (jitc_optix_handle)
return true;

if (jitc_cuda_version_major == 11 && jitc_cuda_version_minor == 5) {
jitc_log(
Warn,
"jit_optix_api_init(): DrJit considers the driver of your graphics "
"card buggy and prone to miscompilation (we explicitly do not "
"support OptiX with CUDA 11.5, which roughly corresponds to driver "
"versions >= 495 and < 510). Please install an older or newer driver.");
return false;
}

if (jitc_cuda_version_major == 12 && jitc_cuda_version_minor == 7) {
jitc_log(
Warn,
Expand Down Expand Up @@ -152,7 +147,7 @@ bool jitc_optix_api_init() {
"jit_optix_api_init(): Failed to load OptiX library! Very likely, "
"your NVIDIA graphics driver is too old and not compatible "
"with the version of OptiX that is being used. In particular, "
"OptiX 7.4 requires driver revision R495.89 or newer.");
"OptiX 8.0 requires driver revision R535 or newer.");
jitc_optix_api_shutdown();
return false;
}
Expand All @@ -165,8 +160,8 @@ bool jitc_optix_api_init() {
LOAD(optixDeviceContextDestroy);
LOAD(optixDeviceContextSetCacheEnabled);
LOAD(optixDeviceContextSetCacheLocation);
LOAD(optixModuleCreateFromPTX);
LOAD(optixModuleCreateFromPTXWithTasks);
LOAD(optixModuleCreate);
LOAD(optixModuleCreateWithTasks);
LOAD(optixModuleGetCompilationState);
LOAD(optixModuleDestroy);
LOAD(optixTaskExecute);
Expand All @@ -181,7 +176,7 @@ bool jitc_optix_api_init() {

#undef LOAD

jitc_log(Info, "jit_optix_api_init(): loaded OptiX (via 7.4 ABI).");
jitc_log(Info, "jit_optix_api_init(): loaded OptiX (via 8.0 ABI).");

return true;
}
Expand All @@ -205,8 +200,9 @@ void jitc_optix_api_shutdown() {
#define Z(x) x = nullptr
Z(optixGetErrorName); Z(optixGetErrorString); Z(optixDeviceContextCreate);
Z(optixDeviceContextDestroy); Z(optixDeviceContextSetCacheEnabled);
Z(optixDeviceContextSetCacheLocation); Z(optixModuleCreateFromPTX);
Z(optixModuleDestroy); Z(optixProgramGroupCreate);
Z(optixDeviceContextSetCacheLocation); Z(optixModuleCreate);
Z(optixModuleCreateWithTasks); Z(optixModuleGetCompilationState);
Z(optixModuleDestroy); Z(optixTaskExecute); Z(optixProgramGroupCreate);
Z(optixProgramGroupDestroy); Z(optixPipelineCreate);
Z(optixPipelineDestroy); Z(optixLaunch); Z(optixSbtRecordPackHeader);
Z(optixPipelineSetStackSize); Z(optixProgramGroupGetStackSize);
Expand Down
10 changes: 5 additions & 5 deletions src/optix_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,11 @@ struct OptixModuleCompileOptions {
const void *boundValues;
unsigned int numBoundValues;
unsigned int numPayloadTypes;
OptixPayloadType *payloadTypes;
const OptixPayloadType *payloadTypes;
};

struct OptixPipelineLinkOptions {
unsigned int maxTraceDepth;
OptixCompileDebugLevel debugLevel;
};

struct OptixProgramGroupSingleModule {
Expand Down Expand Up @@ -140,11 +139,11 @@ DR_OPTIX_SYM(
OptixResult (*optixDeviceContextSetCacheEnabled)(OptixDeviceContext, int));
DR_OPTIX_SYM(OptixResult (*optixDeviceContextSetCacheLocation)(
OptixDeviceContext, const char *));
DR_OPTIX_SYM(OptixResult (*optixModuleCreateFromPTX)(
DR_OPTIX_SYM(OptixResult (*optixModuleCreate)(
OptixDeviceContext, const OptixModuleCompileOptions *,
const OptixPipelineCompileOptions *, const char *, size_t, char *, size_t *,
OptixModule *));
DR_OPTIX_SYM(OptixResult (*optixModuleCreateFromPTXWithTasks)(
DR_OPTIX_SYM(OptixResult (*optixModuleCreateWithTasks)(
OptixDeviceContext, const OptixModuleCompileOptions *,
const OptixPipelineCompileOptions *, const char *, size_t, char *, size_t *,
OptixModule *, OptixTask *));
Expand All @@ -170,4 +169,5 @@ DR_OPTIX_SYM(OptixResult (*optixSbtRecordPackHeader)(OptixProgramGroup,
DR_OPTIX_SYM(OptixResult (*optixPipelineSetStackSize)(
OptixPipeline, unsigned int, unsigned int, unsigned int, unsigned int));
DR_OPTIX_SYM(OptixResult (*optixProgramGroupGetStackSize)(OptixProgramGroup,
OptixStackSizes *));
OptixStackSizes *,
OptixPipeline));
16 changes: 6 additions & 10 deletions src/optix_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ OptixDeviceContext jitc_optix_context() {
size_t log_size = sizeof(log);

OptixModule mod;
jitc_optix_check(optixModuleCreateFromPTX(
jitc_optix_check(optixModuleCreate(
ctx, &mco, &pco, minimal, strlen(minimal), log, &log_size, &mod));

OptixProgramGroupDesc pgd { };
Expand Down Expand Up @@ -229,7 +229,7 @@ bool jitc_optix_compile(ThreadState *ts, const char *buf, size_t buf_size,
const char *kern_name, Kernel &kernel) {
char error_log[16384];

if (!optixModuleCreateFromPTXWithTasks)
if (!optixModuleCreateWithTasks)
jitc_fail("jit_optix_compile(): OptiX not initialized, make sure "
"evaluation happens before Optix shutdown!");

Expand All @@ -253,13 +253,13 @@ bool jitc_optix_compile(ThreadState *ts, const char *buf, size_t buf_size,

OptixTask task;
error_log[0] = '\0';
int rv = optixModuleCreateFromPTXWithTasks(
int rv = optixModuleCreateWithTasks(
optix_context, &mco, &pipeline.compile_options, buf, buf_size,
error_log, &log_size, &kernel.optix.mod, &task);

if (rv) {
jitc_log(Error, "jit_optix_compile(): "
"optixModuleCreateFromPTXWithTasks() failed. Please see the "
"optixModuleCreateWithTasks() failed. Please see the "
"PTX assembly listing and error message below:\n\n%s\n\n%s",
buf, error_log);
jitc_optix_check(rv);
Expand Down Expand Up @@ -362,11 +362,6 @@ bool jitc_optix_compile(ThreadState *ts, const char *buf, size_t buf_size,

OptixPipelineLinkOptions link_options {};
link_options.maxTraceDepth = 1;
#ifndef DRJIT_ENABLE_OPTIX_DEBUG_VALIDATION_ON
link_options.debugLevel = OPTIX_COMPILE_DEBUG_LEVEL_NONE;
#else
link_options.debugLevel = OPTIX_COMPILE_DEBUG_LEVEL_FULL;
#endif

size_t size_before = pipeline.program_groups.size();

Expand Down Expand Up @@ -400,7 +395,8 @@ bool jitc_optix_compile(ThreadState *ts, const char *buf, size_t buf_size,
OptixStackSizes ssp = {};
for (size_t i = 0; i < pipeline.program_groups.size(); ++i) {
OptixStackSizes ss;
rv = optixProgramGroupGetStackSize(pipeline.program_groups[i], &ss);
rv = optixProgramGroupGetStackSize(pipeline.program_groups[i], &ss,
kernel.optix.pipeline);
if (rv) {
jitc_log(Error, "jit_optix_compile(): optixProgramGroupGetStackSize() "
"failed:\n\n%s", error_log);
Expand Down
2 changes: 1 addition & 1 deletion tests/optix_stubs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ void init_optix_api() {
L(optixAccelComputeMemoryUsage);
L(optixAccelBuild);
L(optixAccelCompact);
L(optixModuleCreateFromPTX);
L(optixModuleCreate);
L(optixModuleDestroy)
L(optixProgramGroupCreate);
L(optixProgramGroupDestroy)
Expand Down
5 changes: 4 additions & 1 deletion tests/optix_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ using OptixProgramGroupKind = int;
#define OPTIX_SBT_RECORD_HEADER_SIZE 32

#define OPTIX_COMPILE_DEBUG_LEVEL_NONE 0x2350
#define OPTIX_COMPILE_DEBUG_LEVEL_FULL 0x2352
#define OPTIX_COMPILE_OPTIMIZATION_LEVEL_0 0x2340
#define OPTIX_COMPILE_OPTIMIZATION_LEVEL_3 0x2343

#define OPTIX_BUILD_FLAG_ALLOW_COMPACTION 2
#define OPTIX_BUILD_FLAG_PREFER_FAST_TRACE 4
Expand Down Expand Up @@ -186,7 +189,7 @@ D(optixAccelComputeMemoryUsage, OptixDeviceContext,
D(optixAccelBuild, OptixDeviceContext, CUstream, const OptixAccelBuildOptions *,
const OptixBuildInput *, unsigned int, CUdeviceptr, size_t, CUdeviceptr,
size_t, OptixTraversableHandle *, const OptixAccelEmitDesc *, unsigned int);
D(optixModuleCreateFromPTX, OptixDeviceContext,
D(optixModuleCreate, OptixDeviceContext,
const OptixModuleCompileOptions *, const OptixPipelineCompileOptions *,
const char *, size_t, char *, size_t *, OptixModule *);
D(optixModuleDestroy, OptixModule);
Expand Down
21 changes: 15 additions & 6 deletions tests/triangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ using UInt32 = dr::CUDAArray<uint32_t>;
using UInt64 = dr::CUDAArray<uint64_t>;
using Mask = dr::CUDAArray<bool>;

#if !defined(NDEBUG) || defined(DRJIT_ENABLE_OPTIX_DEBUG_VALIDATION)
#define DRJIT_ENABLE_OPTIX_DEBUG_VALIDATION_ON
#endif

void demo() {
OptixDeviceContext context = jit_optix_context();
jit_cuda_push_context(jit_cuda_context());
Expand All @@ -50,20 +54,20 @@ void demo() {
// =====================================================

const uint32_t triangle_input_flags[1] = { OPTIX_GEOMETRY_FLAG_DISABLE_ANYHIT };
OptixBuildInput triangle_input { };
OptixBuildInput triangle_input {};
triangle_input.type = OPTIX_BUILD_INPUT_TYPE_TRIANGLES;
triangle_input.triangleArray.vertexFormat = OPTIX_VERTEX_FORMAT_FLOAT3;
triangle_input.triangleArray.numVertices = 3;
triangle_input.triangleArray.vertexBuffers = &d_vertices;
triangle_input.triangleArray.flags = triangle_input_flags;
triangle_input.triangleArray.numSbtRecords = 1;

OptixAccelBuildOptions accel_options {};
OptixAccelBuildOptions accel_options{};
accel_options.operation = OPTIX_BUILD_OPERATION_BUILD;
accel_options.buildFlags =
OPTIX_BUILD_FLAG_ALLOW_COMPACTION | OPTIX_BUILD_FLAG_PREFER_FAST_TRACE;

OptixAccelBufferSizes gas_buffer_sizes;
OptixAccelBufferSizes gas_buffer_sizes{};
jit_optix_check(optixAccelComputeMemoryUsage(
context, &accel_options, &triangle_input, 1, &gas_buffer_sizes));

Expand Down Expand Up @@ -110,8 +114,13 @@ void demo() {
// =====================================================

OptixModuleCompileOptions module_compile_options { };
module_compile_options.debugLevel =
OPTIX_COMPILE_DEBUG_LEVEL_NONE;
#ifndef DRJIT_ENABLE_OPTIX_DEBUG_VALIDATION_ON
module_compile_options.debugLevel = OPTIX_COMPILE_DEBUG_LEVEL_NONE;
module_compile_options.optLevel = OPTIX_COMPILE_OPTIMIZATION_LEVEL_3;
#else
module_compile_options.debugLevel = OPTIX_COMPILE_DEBUG_LEVEL_FULL;
module_compile_options.optLevel = OPTIX_COMPILE_OPTIMIZATION_LEVEL_0;
#endif

OptixPipelineCompileOptions pipeline_compile_options { };
pipeline_compile_options.usesMotionBlur = false;
Expand Down Expand Up @@ -139,7 +148,7 @@ void demo() {
size_t log_size = sizeof(log);

OptixModule mod;
int rv = optixModuleCreateFromPTX(
int rv = optixModuleCreate(
context, &module_compile_options, &pipeline_compile_options,
miss_and_closesthit_ptx, strlen(miss_and_closesthit_ptx), log,
&log_size, &mod);
Expand Down