From 61db6b94f91be2402aadf922d14f7e68505d13e7 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 26 Oct 2023 23:17:43 -0700 Subject: [PATCH] Search temporal reprojection in a quad of pixels --- blade-render/code/blur.wgsl | 40 ++++++++++++++++++++++++------- blade-render/code/camera.inc.wgsl | 10 +++++--- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/blade-render/code/blur.wgsl b/blade-render/code/blur.wgsl index 903c8737..053fba62 100644 --- a/blade-render/code/blur.wgsl +++ b/blade-render/code/blur.wgsl @@ -21,6 +21,16 @@ var input: texture_2d; var prev_input: texture_2d; var output: texture_storage_2d; +fn get_projected_pixel_quad(cp: CameraParams, point: vec3) -> array, 4> { + let pixel = get_projected_pixel_float(cp, point); + return array, 4>( + vec2(vec2(pixel.x - 0.5, pixel.y - 0.5)), + vec2(vec2(pixel.x + 0.5, pixel.y - 0.5)), + vec2(vec2(pixel.x + 0.5, pixel.y + 0.5)), + vec2(vec2(pixel.x - 0.5, pixel.y + 0.5)), + ); +} + fn read_surface(pixel: vec2) -> Surface { var surface = Surface(); surface.flat_normal = normalize(textureLoad(t_flat_normal, pixel, 0).xyz); @@ -44,17 +54,29 @@ fn temporal_accum(@builtin(global_invocation_id) global_id: vec3) { let cur_radiance = textureLoad(input, pixel, 0).xyz; let surface = read_surface(pixel); let pos_world = camera.position + surface.depth * get_ray_direction(camera, pixel); - let prev_pixel = get_projected_pixel(prev_camera, pos_world); + // considering all samples in 2x2 quad, to help with edges + var prev_pixels = get_projected_pixel_quad(prev_camera, pos_world); + var best_index = 0; + var best_weight = 0.0; + for (var i = 0; i < 4; i += 1) { + let prev_pixel = prev_pixels[i]; + if (all(prev_pixel >= vec2(0)) && all(prev_pixel < params.extent)) { + let prev_surface = read_prev_surface(prev_pixel); + let projected_distance = length(pos_world - prev_camera.position); + let weight = compare_flat_normals(surface.flat_normal, prev_surface.flat_normal) + * compare_depths(surface.depth, projected_distance); + if (weight > best_weight) { + best_index = i; + best_weight = weight; + } + } + } + var prev_radiance = cur_radiance; - var history_weight = 1.0 - params.temporal_weight; - if (all(prev_pixel >= vec2(0)) && all(prev_pixel < params.extent)) { - prev_radiance = textureLoad(prev_input, prev_pixel, 0).xyz; - let prev_surface = read_prev_surface(prev_pixel); - let projected_distance = length(pos_world - prev_camera.position); - history_weight *= compare_flat_normals(surface.flat_normal, prev_surface.flat_normal); - history_weight *= compare_depths(surface.depth, projected_distance); + if (best_weight > 0.01) { + prev_radiance = textureLoad(prev_input, prev_pixels[best_index], 0).xyz; } - let radiance = mix(cur_radiance, prev_radiance, history_weight); + let radiance = mix(cur_radiance, prev_radiance, best_weight * (1.0 - params.temporal_weight)); textureStore(output, global_id.xy, vec4(radiance, 0.0)); } diff --git a/blade-render/code/camera.inc.wgsl b/blade-render/code/camera.inc.wgsl index 0c6e3975..bffbbbd9 100644 --- a/blade-render/code/camera.inc.wgsl +++ b/blade-render/code/camera.inc.wgsl @@ -14,12 +14,16 @@ fn get_ray_direction(cp: CameraParams, pixel: vec2) -> vec3 { return normalize(qrot(cp.orientation, local_dir)); } -fn get_projected_pixel(cp: CameraParams, point: vec3) -> vec2 { +fn get_projected_pixel_float(cp: CameraParams, point: vec3) -> vec2 { let local_dir = qrot(qinv(cp.orientation), point - cp.position); if local_dir.z >= 0.0 { - return vec2(-1); + return vec2(-1.0); } let ndc = local_dir.xy / (-local_dir.z * tan(0.5 * cp.fov)); let half_size = 0.5 * vec2(cp.target_size); - return vec2((ndc + vec2(1.0)) * half_size); + return (ndc + vec2(1.0)) * half_size; +} + +fn get_projected_pixel(cp: CameraParams, point: vec3) -> vec2 { + return vec2(get_projected_pixel_float(cp, point)); }