Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bounds checks on pointers #2451

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 202 additions & 76 deletions src/back/msl/writer.rs

Large diffs are not rendered by default.

96 changes: 95 additions & 1 deletion src/proc/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Definitions for index bounds checking.
*/

use crate::{valid, Handle, UniqueArena};
use crate::{valid, FastHashSet, Handle, UniqueArena};
use bit_set::BitSet;
use std::iter::{self, zip};

/// How should code generated by Naga do bounds checks?
///
Expand Down Expand Up @@ -339,6 +340,99 @@ pub fn access_needs_check(
Some(length)
}

/// Returns an iterator over the chain of `Access` and `AccessIndex`
/// expressions starting from `chain`.
///
/// They're yielded as `(base, index)` pairs, where `base` is the expression
/// being indexed into and `index` is the index being used.
///
/// The index is `None` if `base` is a struct, since you never need bounds
/// checks for accessing struct fields.
///
/// If `chain` isn't an `Access` or `AccessIndex` expression, this just
/// yields nothing.
pub fn access_chain<'a>(
mut chain: Handle<crate::Expression>,
module: &'a crate::Module,
function: &'a crate::Function,
info: &'a valid::FunctionInfo,
) -> impl Iterator<Item = (Handle<crate::Expression>, Option<GuardedIndex>)> + 'a {
iter::from_fn(move || {
let (base, index) = match function.expressions[chain] {
crate::Expression::Access { base, index } => {
(base, Some(GuardedIndex::Expression(index)))
}
crate::Expression::AccessIndex { base, index } => {
// Don't try to check indices into structs. Validation already took
// care of them, and needs_guard doesn't handle that case.
let mut base_inner = info[base].ty.inner_with(&module.types);
if let crate::TypeInner::Pointer { base, .. } = *base_inner {
base_inner = &module.types[base].inner;
}
match *base_inner {
crate::TypeInner::Struct { .. } => (base, None),
_ => (base, Some(GuardedIndex::Known(index))),
}
}
_ => return None,
};
chain = base;
Some((base, index))
})
}

/// Returns all the types which we need out-of-bounds locals for; that is,
/// all of the types which the code might attempt to get an out-of-bounds
/// pointer to, in which case we yield a pointer to the out-of-bounds local
/// of the correct type.
pub fn oob_locals(
module: &crate::Module,
function: &crate::Function,
info: &valid::FunctionInfo,
) -> FastHashSet<Handle<crate::Type>> {
let mut result = FastHashSet::default();
for statement in &function.body {
// The only situation in which we end up actually needing to create an
// out-of-bounds pointer is when passing one to a function.
//
// This is because pointers are never baked; so they're just inlined everywhere
// they're used. That means that loads can just return 0, and stores can just do
// nothing; functions are the only case where you actually *have* to produce a
// pointer.
if let crate::Statement::Call {
function: callee,
ref arguments,
..
} = *statement
{
// Now go through the arguments of the function looking for pointers which need bounds checks.
for (arg_info, &arg) in zip(&module.functions[callee].arguments, arguments) {
match module.types[arg_info.ty].inner {
crate::TypeInner::ValuePointer { .. } => {
// `ValuePointer`s should only ever be used when resolving the types of
// expressions, since the arena can no longer be modified at that point; things
// in the arena should always use proper `Pointer`s.
unreachable!("`ValuePointer` found in arena")
}
crate::TypeInner::Pointer { base, .. } => {
if access_chain(arg, module, function, info).any(|(base, index)| {
index
.and_then(|index| {
access_needs_check(base, index, module, function, info)
})
.is_some()
}) {
result.insert(base);
}
}
_ => continue,
};
}
}
}
result
}

impl GuardedIndex {
/// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible.
///
Expand Down
2 changes: 2 additions & 0 deletions src/proc/namer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ pub enum NameKey {
Function(Handle<crate::Function>),
FunctionArgument(Handle<crate::Function>, u32),
FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>),
FunctionOobLocal(Handle<crate::Function>, Handle<crate::Type>),
EntryPoint(EntryPointIndex),
EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>),
EntryPointArgument(EntryPointIndex, u32),
EntryPointOobLocal(EntryPointIndex, Handle<crate::Type>),
}

/// This processor assigns names to all the things in a module
Expand Down
7 changes: 7 additions & 0 deletions tests/in/access.param.ron
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
(
bounds_check_policies: (
index: ReadZeroSkipWrite,
buffer: ReadZeroSkipWrite,
image_load: ReadZeroSkipWrite,
image_store: ReadZeroSkipWrite,
binding_array: ReadZeroSkipWrite,
),
spv: (
version: (1, 1),
debug: true,
Expand Down
21 changes: 21 additions & 0 deletions tests/in/access.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ fn test_matrix_within_struct_accesses() {
t.m[0][idx] = 20.0;
t.m[idx][1] = 30.0;
t.m[idx][idx] = 40.0;

// passing pointers to a function
// FIXME: these are currently commented out because getting pointers to
// vector/matrix elements is broken in Metal and HLSL.
// let pl0 = read_from_private(&t.m[0][1]);
// let pl1 = read_from_private(&t.m[0][idx]);
// let pl2 = read_from_private(&t.m[idx][1]);
// let pl3 = read_from_private(&t.m[idx][idx]);
}

struct MatCx2InArray {
Expand Down Expand Up @@ -97,12 +105,24 @@ fn test_matrix_within_array_within_struct_accesses() {
t.am[0][0][idx] = 20.0;
t.am[0][idx][1] = 30.0;
t.am[0][idx][idx] = 40.0;

// passing pointers to a function
// FIXME: these are currently commented out because getting pointers to
// vector/matrix elements is broken in Metal and HLSL.
// let pl0 = read_from_private(&t.am[0][0][1]);
// let pl1 = read_from_private(&t.am[0][0][idx]);
// let pl2 = read_from_private(&t.am[0][idx][1]);
// let pl3 = read_from_private(&t.am[0][idx][idx]);
}

fn read_from_private(foo: ptr<function, f32>) -> f32 {
return *foo;
}

fn read_i32_from_private(foo: ptr<function, i32>) -> i32 {
return *foo;
}

fn test_arr_as_arg(a: array<array<f32, 10>, 5>) -> f32 {
return a[4][9];
}
Expand Down Expand Up @@ -133,6 +153,7 @@ fn foo_vert(@builtin(vertex_index) vi: u32) -> @builtin(position) vec4<f32> {
var c2 = array<i32, 5>(a, i32(b), 3, 4, 5);
c2[vi + 1u] = 42;
let value = c2[vi];
let value_again = read_i32_from_private(&c2[vi]);

test_arr_as_arg(array<array<f32, 10>, 5>());

Expand Down
87 changes: 74 additions & 13 deletions tests/out/analysis/access.info.ron
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
("SIZED | COPY | ARGUMENT"),
("SIZED | COPY | ARGUMENT"),
("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
Expand Down Expand Up @@ -2873,7 +2874,46 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(25),
ty: Handle(24),
),
(
uniformity: (
non_uniform_result: Some(1),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(3),
),
],
sampling: [],
),
(
flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"),
available_stages: ("VERTEX | FRAGMENT | COMPUTE"),
uniformity: (
non_uniform_result: Some(1),
requirements: (""),
),
may_kill: false,
sampling_set: [],
global_uses: [
(""),
(""),
(""),
(""),
(""),
(""),
],
expressions: [
(
uniformity: (
non_uniform_result: Some(1),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(26),
),
(
uniformity: (
Expand All @@ -2894,7 +2934,7 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(24),
ty: Handle(25),
),
(
uniformity: (
Expand Down Expand Up @@ -2945,7 +2985,7 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(31),
ty: Handle(32),
),
(
uniformity: (
Expand Down Expand Up @@ -2987,7 +3027,7 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(33),
ty: Handle(34),
),
(
uniformity: (
Expand Down Expand Up @@ -3046,7 +3086,7 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(32),
ty: Handle(33),
),
],
sampling: [],
Expand Down Expand Up @@ -3076,7 +3116,7 @@
non_uniform_result: Some(1),
requirements: (""),
),
ref_count: 2,
ref_count: 3,
assignable_global: None,
ty: Handle(1),
),
Expand Down Expand Up @@ -3567,17 +3607,17 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(28),
ty: Handle(29),
),
(
uniformity: (
non_uniform_result: Some(41),
requirements: (""),
),
ref_count: 3,
ref_count: 4,
assignable_global: None,
ty: Value(Pointer(
base: 28,
base: 29,
space: Function,
)),
),
Expand Down Expand Up @@ -3647,14 +3687,35 @@
assignable_global: None,
ty: Handle(3),
),
(
uniformity: (
non_uniform_result: Some(41),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 3,
space: Function,
)),
),
(
uniformity: (
non_uniform_result: Some(1),
requirements: (""),
),
ref_count: 0,
assignable_global: None,
ty: Handle(3),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(25),
ty: Handle(26),
),
(
uniformity: (
Expand Down Expand Up @@ -3723,7 +3784,7 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(26),
ty: Handle(27),
),
],
sampling: [],
Expand Down Expand Up @@ -4260,7 +4321,7 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(32),
ty: Handle(33),
),
(
uniformity: (
Expand All @@ -4270,7 +4331,7 @@
ref_count: 2,
assignable_global: None,
ty: Value(Pointer(
base: 32,
base: 33,
space: Function,
)),
),
Expand Down
9 changes: 7 additions & 2 deletions tests/out/glsl/access.assign_through_ptr.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ float read_from_private(inout float foo_1) {
return _e1;
}

int read_i32_from_private(inout int foo_2) {
int _e1 = foo_2;
return _e1;
}

float test_arr_as_arg(float a[5][10]) {
return a[4][9];
}
Expand All @@ -36,8 +41,8 @@ void assign_through_ptr_fn(inout uint p) {
return;
}

void assign_array_through_ptr_fn(inout vec4 foo_2[2]) {
foo_2 = vec4[2](vec4(1.0), vec4(2.0));
void assign_array_through_ptr_fn(inout vec4 foo_3[2]) {
foo_3 = vec4[2](vec4(1.0), vec4(2.0));
return;
}

Expand Down
Loading
Loading