Skip to content

Commit

Permalink
Merge branch 'lwawrzyniak/fix-static-expression-caching' into 'main'
Browse files Browse the repository at this point in the history
Fix hashing of static expressions

See merge request omniverse/warp!790
  • Loading branch information
nvlukasz committed Oct 10, 2024
2 parents ce149b9 + 6ac19b6 commit 419ebe7
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Fix potential out-of-bounds memory access when a `wp.sparse.BsrMatrix` object is reused for storing matrices of different shapes
- Fix robustness to very low desired tolerance in `wp.fem.utils.symmetric_eigenvalues_qr`
- Fix invalid code generation error messages when nesting dynamic and static for-loops
- Fix caching of kernels with static expressions

## [1.4.0] - 2024-10-01

Expand Down
48 changes: 25 additions & 23 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import ast
import builtins
import ctypes
import functools
import hashlib
Expand All @@ -22,7 +21,6 @@
import weakref
from copy import copy as shallowcopy
from pathlib import Path
from struct import pack as struct_pack
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -1488,30 +1486,16 @@ def hash_adjoint(self, adj):
# hash referenced constants
for name, value in constants.items():
ch.update(bytes(name, "utf-8"))
# hash the referenced object
if isinstance(value, builtins.bool):
# This needs to come before the check for `int` since all boolean
# values are also instances of `int`.
ch.update(struct_pack("?", value))
elif isinstance(value, int):
ch.update(struct_pack("<q", value))
elif isinstance(value, float):
ch.update(struct_pack("<d", value))
elif isinstance(value, warp.types.float16):
# float16 is a special case
p = ctypes.pointer(ctypes.c_float(value.value))
ch.update(p.contents)
elif isinstance(value, tuple(warp.types.scalar_types)):
p = ctypes.pointer(value._type_(value.value))
ch.update(p.contents)
elif isinstance(value, ctypes.Array):
ch.update(bytes(value))
else:
raise RuntimeError(f"Invalid constant type: {type(value)}")
ch.update(self.get_constant_bytes(value))

# hash wp.static() expressions that were evaluated at declaration time
for k, v in adj.static_expressions.items():
ch.update(bytes(f"{k} = {v}", "utf-8"))
ch.update(bytes(k, "utf-8"))
if isinstance(v, Function):
if v not in self.functions_in_progress:
ch.update(self.hash_function(v))
else:
ch.update(self.get_constant_bytes(v))

# hash referenced types
for t in types.keys():
Expand All @@ -1524,6 +1508,24 @@ def hash_adjoint(self, adj):

return ch.digest()

def get_constant_bytes(self, value):
if isinstance(value, int):
# this also handles builtins.bool
return bytes(ctypes.c_int(value))
elif isinstance(value, float):
return bytes(ctypes.c_float(value))
elif isinstance(value, warp.types.float16):
# float16 is a special case
return bytes(ctypes.c_float(value.value))
elif isinstance(value, tuple(warp.types.scalar_and_bool_types)):
return bytes(value._type_(value.value))
elif hasattr(value, "_wp_scalar_type_"):
return bytes(value)
elif isinstance(value, warp.codegen.StructInstance):
return bytes(value._ctype)
else:
raise TypeError(f"Invalid constant type: {type(value)}")

def get_module_hash(self):
return self.module_hash

Expand Down
156 changes: 156 additions & 0 deletions warp/tests/test_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import importlib
import tempfile
import unittest
from typing import Dict, List

Expand All @@ -17,6 +19,23 @@
global_variable = 3


def load_code_as_module(code, name):
file, file_path = tempfile.mkstemp(suffix=".py")

try:
with os.fdopen(file, "w") as f:
f.write(code)

spec = importlib.util.spec_from_file_location(name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
finally:
os.remove(file_path)

# return Warp module
return wp.get_module(module.__name__)


@wp.func
def static_global_variable_func():
static_var = warp.static(global_variable + 2)
Expand Down Expand Up @@ -383,6 +402,140 @@ def static_condition3(results: wp.array(dtype=int)):
assert_np_equal(counts["else"], 0)


static_builtin_constant_template = """
import warp as wp
# Python builtin literal like 17, 42.0, or True
C = {value}
@wp.kernel
def k():
print(wp.static(C))
"""

static_warp_constant_template = """
import warp as wp
# Warp scalar value like wp.uint8(17)
C = wp.{dtype}({value})
@wp.kernel
def k():
print(wp.static(C))
"""

static_struct_constant_template = """
import warp as wp
@wp.struct
class SimpleStruct:
x: float
C = SimpleStruct()
C.x = {value}
@wp.kernel
def k():
print(wp.static(C))
"""

static_func_template = """
import warp as wp
@wp.func
def f():
# modify the function to verify hashing
return {value}
@wp.kernel
def k():
print(wp.static(f)())
"""


def test_static_constant_hash(test, _):
# Python literals
# (type, value1, value2)
literals = [
(int, 17, 42),
(float, 17.5, 42.5),
(bool, True, False),
]

for builtin_type, value1, value2 in literals:
type_name = builtin_type.__name__
with test.subTest(msg=f"{type_name}"):
source1 = static_builtin_constant_template.format(value=value1)
source2 = static_builtin_constant_template.format(value=value2)
source3 = static_builtin_constant_template.format(value=value1)

module1 = load_code_as_module(source1, f"aux_static_constant_builtin_{type_name}_1")
module2 = load_code_as_module(source2, f"aux_static_constant_builtin_{type_name}_2")
module3 = load_code_as_module(source3, f"aux_static_constant_builtin_{type_name}_3")

hash1 = module1.hash_module()
hash2 = module2.hash_module()
hash3 = module3.hash_module()

test.assertNotEqual(hash1, hash2)
test.assertEqual(hash1, hash3)

# Warp types (scalars, vectors, matrices)
for warp_type in [*wp.types.scalar_types, *wp.types.vector_types]:
type_name = warp_type.__name__
with test.subTest(msg=f"wp.{type_name}"):
value1 = ", ".join([str(17)] * warp_type._length_)
value2 = ", ".join([str(42)] * warp_type._length_)
source1 = static_warp_constant_template.format(dtype=type_name, value=value1)
source2 = static_warp_constant_template.format(dtype=type_name, value=value2)
source3 = static_warp_constant_template.format(dtype=type_name, value=value1)

module1 = load_code_as_module(source1, f"aux_static_constant_wp_{type_name}_1")
module2 = load_code_as_module(source2, f"aux_static_constant_wp_{type_name}_2")
module3 = load_code_as_module(source3, f"aux_static_constant_wp_{type_name}_3")

hash1 = module1.hash_module()
hash2 = module2.hash_module()
hash3 = module3.hash_module()

test.assertNotEqual(hash1, hash2)
test.assertEqual(hash1, hash3)

# structs
with test.subTest(msg="struct"):
source1 = static_struct_constant_template.format(value=17)
source2 = static_struct_constant_template.format(value=42)
source3 = static_struct_constant_template.format(value=17)

module1 = load_code_as_module(source1, "aux_static_constant_struct_1")
module2 = load_code_as_module(source2, "aux_static_constant_struct_2")
module3 = load_code_as_module(source3, "aux_static_constant_struct_3")

hash1 = module1.hash_module()
hash2 = module2.hash_module()
hash3 = module3.hash_module()

test.assertNotEqual(hash1, hash2)
test.assertEqual(hash1, hash3)


def test_static_function_hash(test, _):
source1 = static_func_template.format(value=17)
source2 = static_func_template.format(value=42)
source3 = static_func_template.format(value=17)

module1 = load_code_as_module(source1, "aux_static_func1")
module2 = load_code_as_module(source2, "aux_static_func2")
module3 = load_code_as_module(source3, "aux_static_func3")

hash1 = module1.hash_module()
hash2 = module2.hash_module()
hash3 = module3.hash_module()

test.assertNotEqual(hash1, hash2)
test.assertEqual(hash1, hash3)


devices = get_test_devices()


Expand All @@ -406,6 +559,9 @@ def test_static_python_call(self):
add_function_test(TestStatic, "test_static_for_loop", test_static_for_loop, devices=devices)
add_function_test(TestStatic, "test_static_if_else_elif", test_static_if_else_elif, devices=devices)

add_function_test(TestStatic, "test_static_constant_hash", test_static_constant_hash, devices=None)
add_function_test(TestStatic, "test_static_function_hash", test_static_function_hash, devices=None)


if __name__ == "__main__":
wp.clear_kernel_cache()
Expand Down

0 comments on commit 419ebe7

Please sign in to comment.