From 5e1188f9687173d8039598b8043e0986b1f027d9 Mon Sep 17 00:00:00 2001 From: Chad Brokaw Date: Thu, 11 May 2023 12:37:36 -0400 Subject: [PATCH] replace branches with chained selects This exchanges the per-pixel branching with additional ALU + selects. My expectation is that this will be faster, but that may be hardware/driver dependent and likely requires profiling and examination of generated code. The original code is kept in a comment with notes to explain the more obfuscated select version. --- shader/fine.wgsl | 125 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 89 insertions(+), 36 deletions(-) diff --git a/shader/fine.wgsl b/shader/fine.wgsl index 4047ebd..bab1b63 100644 --- a/shader/fine.wgsl +++ b/shader/fine.wgsl @@ -115,22 +115,33 @@ fn read_end_clip(cmd_ix: u32) -> CmdEndClip { } fn extend_mode(t: f32, mode: u32) -> f32 { - // This can be replaced with two selects, exchanging the cost - // of a branch for additional ALU - switch mode { - // PAD - case 0u: { - return clamp(t, 0.0, 1.0); - } - // REPEAT - case 1u: { - return fract(t); - } - // REFLECT (2) - default: { - return abs(t - 2.0 * round(0.5 * t)); - } - } + let EXTEND_PAD = 0u; + let EXTEND_REPEAT = 1u; + let EXTEND_REFLECT = 2u; + // Branching version of the code below: + // + // switch mode { + // // EXTEND_PAD + // case 0u: { + // return clamp(t, 0.0, 1.0); + // } + // // EXTEND_REPEAT + // case 1u: { + // return fract(t); + // } + // // EXTEND_REFLECT + // default: { + // return abs(t - 2.0 * round(0.5 * t)); + // } + // } + let pad = clamp(t, 0.0, 1.0); + let repeat = fract(t); + let reflect = abs(t - 2.0 * round(0.5 * t)); + return select( + select(pad, repeat, mode == EXTEND_REPEAT), + reflect, + mode == EXTEND_REFLECT + ); } #else @@ -304,9 +315,9 @@ fn main( let is_circular = rad.kind == RAD_GRAD_KIND_CIRCULAR; let is_focal_on_circle = rad.kind == RAD_GRAD_KIND_FOCAL_ON_CIRCLE; let is_swapped = (rad.flags & RAD_GRAD_SWAPPED) != 0u; + let is_greater = radius > 1.0; let inv_r1 = select(1.0 / radius, 0.0, is_circular); - let root_f = select(1.0, -1.0, is_swapped || one_minus_focal_x < 0.0); - let t_base_scale = select(vec2(0.0, -1.0), vec2(1.0, 1.0), is_swapped); + let less_scale = select(1.0, -1.0, is_swapped || one_minus_focal_x < 0.0); let t_sign = sign(one_minus_focal_x); for (var i = 0u; i < PIXELS_PER_THREAD; i += 1u) { let my_xy = vec2(xy.x + f32(i), xy.y); @@ -316,25 +327,67 @@ fn main( let xx = x * x; let yy = y * y; let x_inv_r1 = x * inv_r1; - var t = 0.0; - var valid = true; - if is_strip { - let a = radius - yy; - t = sqrt(a) + x; - valid = a >= 0.0; - } else if is_focal_on_circle { - t = (xx + yy) / x; - valid = t >= 0.0; - } else if radius > 1.0 { - t = sqrt(xx + yy) - x_inv_r1; - } else { - let a = xx - yy; - t = root_f * sqrt(a) - x_inv_r1; - valid = a >= 0.0 && t >= 0.0; - } - if valid { + // This is the branching version of the code implemented + // by the chained selects below: + // + // var t = 0.0; + // var is_valid = true; + // if is_strip { + // let a = radius - yy; + // t = sqrt(a) + x; + // is_valid = a >= 0.0; + // } else if is_focal_on_circle { + // t = (xx + yy) / x; + // is_valid = t >= 0.0; + // } else if radius > 1.0 { + // t = sqrt(xx + yy) - x_inv_r1; + // } else { + // let a = xx - yy; + // t = root_f * sqrt(a) - x_inv_r1; + // is_valid = a >= 0.0 && t >= 0.0; + // } + // + // The pattern is that these can all be computed with + // the expression: a * sqrt(b) + c + // + // The parameters to the expression are computed up front + // and chosen with chained selects based on their + // respective conditions. The same process is done + // for determining the validity of the resulting value. + var strip_params = vec3(1.0, radius - yy, x); + var foc_params = vec3(1.0, 0.0, (xx + yy) / x); + var greater_params = vec3(1.0, xx + yy, -x_inv_r1); + var less_params = vec3(less_scale, xx - yy, -x_inv_r1); + var params = select( + select( + select( + less_params, + greater_params, + is_greater, + ), + foc_params, + is_focal_on_circle, + ), + strip_params, + is_strip, + ); + var t = params.x * sqrt(params.y) + params.z; + let is_valid = select( + select( + select( + params.y >= 0.0 && t >= 0.0, + true, + is_greater + ), + t >= 0.0 && x != 0.0, + is_focal_on_circle, + ), + params.y >= 0.0, + is_strip, + ); + if is_valid { t = extend_mode(focal_x + t_sign * t, rad.extend_mode); - t = (t_base_scale.x - t) * t_base_scale.y; + t = select(t, 1.0 - t, is_swapped); let x = i32(round(t * f32(GRADIENT_WIDTH - 1))); let fg_rgba = textureLoad(gradients, vec2(x, i32(rad.index)), 0); let fg_i = fg_rgba * area[i];