mirror of
https://github.com/italicsjenga/vello.git
synced 2025-01-25 18:56:35 +11:00
Implement clear_buffers on Metal
Since clearing functionality is not built-in, use a compute shader. Simplify tests greatly; they don't need the workaround.
This commit is contained in:
parent
657f219ce8
commit
0762cc763c
6 changed files with 109 additions and 128 deletions
|
@ -14,6 +14,7 @@
|
||||||
//
|
//
|
||||||
// Also licensed under MIT license, at your choice.
|
// Also licensed under MIT license, at your choice.
|
||||||
|
|
||||||
|
mod clear;
|
||||||
mod util;
|
mod util;
|
||||||
|
|
||||||
use std::mem;
|
use std::mem;
|
||||||
|
@ -39,6 +40,7 @@ pub struct MtlDevice {
|
||||||
device: metal::Device,
|
device: metal::Device,
|
||||||
cmd_queue: Arc<Mutex<metal::CommandQueue>>,
|
cmd_queue: Arc<Mutex<metal::CommandQueue>>,
|
||||||
gpu_info: GpuInfo,
|
gpu_info: GpuInfo,
|
||||||
|
helpers: Arc<Helpers>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct MtlSurface {
|
pub struct MtlSurface {
|
||||||
|
@ -78,6 +80,7 @@ pub struct Semaphore;
|
||||||
|
|
||||||
pub struct CmdBuf {
|
pub struct CmdBuf {
|
||||||
cmd_buf: metal::CommandBuffer,
|
cmd_buf: metal::CommandBuffer,
|
||||||
|
helpers: Arc<Helpers>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct QueryPool;
|
pub struct QueryPool;
|
||||||
|
@ -93,6 +96,10 @@ pub struct DescriptorSet {
|
||||||
images: Vec<Image>,
|
images: Vec<Image>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Helpers {
|
||||||
|
clear_pipeline: metal::ComputePipelineState,
|
||||||
|
}
|
||||||
|
|
||||||
impl MtlInstance {
|
impl MtlInstance {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
window_handle: Option<&dyn HasRawWindowHandle>,
|
window_handle: Option<&dyn HasRawWindowHandle>,
|
||||||
|
@ -172,10 +179,14 @@ impl MtlInstance {
|
||||||
has_memory_model: false,
|
has_memory_model: false,
|
||||||
use_staging_buffers,
|
use_staging_buffers,
|
||||||
};
|
};
|
||||||
|
let helpers = Arc::new(Helpers {
|
||||||
|
clear_pipeline: clear::make_clear_pipeline(&device),
|
||||||
|
});
|
||||||
Ok(MtlDevice {
|
Ok(MtlDevice {
|
||||||
device,
|
device,
|
||||||
cmd_queue: Arc::new(Mutex::new(cmd_queue)),
|
cmd_queue: Arc::new(Mutex::new(cmd_queue)),
|
||||||
gpu_info,
|
gpu_info,
|
||||||
|
helpers,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
Err("can't create system default Metal device".into())
|
Err("can't create system default Metal device".into())
|
||||||
|
@ -292,7 +303,8 @@ impl crate::backend::Device for MtlDevice {
|
||||||
// consider new_command_buffer_with_unretained_references for performance
|
// consider new_command_buffer_with_unretained_references for performance
|
||||||
let cmd_buf = cmd_queue.new_command_buffer();
|
let cmd_buf = cmd_queue.new_command_buffer();
|
||||||
let cmd_buf = autoreleasepool(|| cmd_buf.to_owned());
|
let cmd_buf = autoreleasepool(|| cmd_buf.to_owned());
|
||||||
Ok(CmdBuf { cmd_buf })
|
let helpers = self.helpers.clone();
|
||||||
|
Ok(CmdBuf { cmd_buf, helpers })
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn destroy_cmd_buf(&self, _cmd_buf: Self::CmdBuf) -> Result<(), Error> {
|
unsafe fn destroy_cmd_buf(&self, _cmd_buf: Self::CmdBuf) -> Result<(), Error> {
|
||||||
|
@ -467,7 +479,10 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn clear_buffer(&self, buffer: &Buffer, size: Option<u64>) {
|
unsafe fn clear_buffer(&self, buffer: &Buffer, size: Option<u64>) {
|
||||||
todo!()
|
let size = size.unwrap_or(buffer.size);
|
||||||
|
let encoder = self.cmd_buf.new_compute_command_encoder();
|
||||||
|
clear::encode_clear(&encoder, &self.helpers.clear_pipeline, &buffer.buffer, size);
|
||||||
|
encoder.end_encoding()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn copy_buffer(&self, src: &Buffer, dst: &Buffer) {
|
unsafe fn copy_buffer(&self, src: &Buffer, dst: &Buffer) {
|
||||||
|
|
68
piet-gpu-hal/src/metal/clear.rs
Normal file
68
piet-gpu-hal/src/metal/clear.rs
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
// Copyright 2021 The piet-gpu authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Also licensed under MIT license, at your choice.
|
||||||
|
|
||||||
|
//! The compute shader and stage for clearing buffers.
|
||||||
|
|
||||||
|
use metal::{ComputePipelineState, Device};
|
||||||
|
|
||||||
|
const CLEAR_MSL: &str = r#"
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
struct ConfigBuf
|
||||||
|
{
|
||||||
|
uint size;
|
||||||
|
uint value;
|
||||||
|
};
|
||||||
|
|
||||||
|
kernel void main0(const device ConfigBuf& config [[buffer(0)]], device uint *data [[buffer(1)]], uint3 gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
uint ix = gid.x;
|
||||||
|
if (ix < config.size)
|
||||||
|
{
|
||||||
|
data[ix] = config.value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"#;
|
||||||
|
|
||||||
|
pub fn make_clear_pipeline(device: &Device) -> ComputePipelineState {
|
||||||
|
let options = metal::CompileOptions::new();
|
||||||
|
let library = device.new_library_with_source(CLEAR_MSL, &options).unwrap();
|
||||||
|
let function = library.get_function("main0", None).unwrap();
|
||||||
|
device
|
||||||
|
.new_compute_pipeline_state_with_function(&function).unwrap()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode_clear(encoder: &metal::ComputeCommandEncoderRef, clear_pipeline: &ComputePipelineState, buffer: &metal::Buffer, size: u64) {
|
||||||
|
// TODO: should be more careful with overflow
|
||||||
|
let size_in_u32s = (size / 4) as u32;
|
||||||
|
encoder.set_compute_pipeline_state(&clear_pipeline);
|
||||||
|
let config = [size_in_u32s, 0];
|
||||||
|
encoder.set_bytes(0, std::mem::size_of_val(&config) as u64, config.as_ptr() as *const _);
|
||||||
|
encoder.set_buffer(1, Some(buffer), 0);
|
||||||
|
let n_wg = (size_in_u32s + 255) / 256;
|
||||||
|
let workgroup_count = metal::MTLSize {
|
||||||
|
width: n_wg as u64,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
let workgroup_size = metal::MTLSize {
|
||||||
|
width: 256,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
encoder.dispatch_thread_groups(workgroup_count, workgroup_size);
|
||||||
|
}
|
|
@ -84,10 +84,6 @@ impl ClearCode {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClearStage {
|
impl ClearStage {
|
||||||
pub unsafe fn new(runner: &mut Runner, n_elements: u64) -> ClearStage {
|
|
||||||
Self::new_with_value(runner, n_elements, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub unsafe fn new_with_value(runner: &mut Runner, n_elements: u64, value: u32) -> ClearStage {
|
pub unsafe fn new_with_value(runner: &mut Runner, n_elements: u64, value: u32) -> ClearStage {
|
||||||
let config = [n_elements as u32, value];
|
let config = [n_elements as u32, value];
|
||||||
let config_buf = runner
|
let config_buf = runner
|
||||||
|
|
|
@ -14,10 +14,9 @@
|
||||||
//
|
//
|
||||||
// Also licensed under MIT license, at your choice.
|
// Also licensed under MIT license, at your choice.
|
||||||
|
|
||||||
use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet};
|
use piet_gpu_hal::{include_shader, BindType, BufferUsage, DescriptorSet};
|
||||||
use piet_gpu_hal::{Buffer, Pipeline};
|
use piet_gpu_hal::{Buffer, Pipeline};
|
||||||
|
|
||||||
use crate::clear::{ClearBinding, ClearCode, ClearStage};
|
|
||||||
use crate::runner::{Commands, Runner};
|
use crate::runner::{Commands, Runner};
|
||||||
use crate::test_result::TestResult;
|
use crate::test_result::TestResult;
|
||||||
use crate::Config;
|
use crate::Config;
|
||||||
|
@ -27,16 +26,12 @@ const N_BUCKETS: u64 = 65536;
|
||||||
|
|
||||||
struct LinkedListCode {
|
struct LinkedListCode {
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
clear_code: Option<ClearCode>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LinkedListStage {
|
struct LinkedListStage;
|
||||||
clear_stage: Option<ClearStage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct LinkedListBinding {
|
struct LinkedListBinding {
|
||||||
descriptor_set: DescriptorSet,
|
descriptor_set: DescriptorSet,
|
||||||
clear_binding: Option<ClearBinding>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn run_linkedlist_test(runner: &mut Runner, config: &Config) -> TestResult {
|
pub unsafe fn run_linkedlist_test(runner: &mut Runner, config: &Config) -> TestResult {
|
||||||
|
@ -77,26 +72,17 @@ impl LinkedListCode {
|
||||||
.session
|
.session
|
||||||
.create_compute_pipeline(code, &[BindType::Buffer])
|
.create_compute_pipeline(code, &[BindType::Buffer])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let clear_code = if runner.backend_type() == BackendType::Metal {
|
LinkedListCode { pipeline }
|
||||||
Some(ClearCode::new(runner))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
LinkedListCode {
|
|
||||||
pipeline,
|
|
||||||
clear_code,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LinkedListStage {
|
impl LinkedListStage {
|
||||||
unsafe fn new(runner: &mut Runner, code: &LinkedListCode, n_buckets: u64) -> LinkedListStage {
|
unsafe fn new(
|
||||||
let clear_stage = if code.clear_code.is_some() {
|
_runner: &mut Runner,
|
||||||
Some(ClearStage::new(runner, n_buckets))
|
_code: &LinkedListCode,
|
||||||
} else {
|
_n_buckets: u64,
|
||||||
None
|
) -> LinkedListStage {
|
||||||
};
|
LinkedListStage
|
||||||
LinkedListStage { clear_stage }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn bind(
|
unsafe fn bind(
|
||||||
|
@ -109,15 +95,7 @@ impl LinkedListStage {
|
||||||
.session
|
.session
|
||||||
.create_simple_descriptor_set(&code.pipeline, &[mem_buf])
|
.create_simple_descriptor_set(&code.pipeline, &[mem_buf])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let clear_binding = if let Some(stage) = &self.clear_stage {
|
LinkedListBinding { descriptor_set }
|
||||||
Some(stage.bind(runner, &code.clear_code.as_ref().unwrap(), mem_buf))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
LinkedListBinding {
|
|
||||||
descriptor_set,
|
|
||||||
clear_binding,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn record(
|
unsafe fn record(
|
||||||
|
@ -127,15 +105,7 @@ impl LinkedListStage {
|
||||||
bindings: &LinkedListBinding,
|
bindings: &LinkedListBinding,
|
||||||
out_buf: &Buffer,
|
out_buf: &Buffer,
|
||||||
) {
|
) {
|
||||||
if let Some(stage) = &self.clear_stage {
|
|
||||||
stage.record(
|
|
||||||
commands,
|
|
||||||
code.clear_code.as_ref().unwrap(),
|
|
||||||
bindings.clear_binding.as_ref().unwrap(),
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
commands.cmd_buf.clear_buffer(out_buf, None);
|
commands.cmd_buf.clear_buffer(out_buf, None);
|
||||||
}
|
|
||||||
commands.cmd_buf.memory_barrier();
|
commands.cmd_buf.memory_barrier();
|
||||||
let n_workgroups = N_BUCKETS / WG_SIZE;
|
let n_workgroups = N_BUCKETS / WG_SIZE;
|
||||||
commands.cmd_buf.dispatch(
|
commands.cmd_buf.dispatch(
|
||||||
|
|
|
@ -14,10 +14,9 @@
|
||||||
//
|
//
|
||||||
// Also licensed under MIT license, at your choice.
|
// Also licensed under MIT license, at your choice.
|
||||||
|
|
||||||
use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet, ShaderCode};
|
use piet_gpu_hal::{include_shader, BindType, BufferUsage, DescriptorSet, ShaderCode};
|
||||||
use piet_gpu_hal::{Buffer, Pipeline};
|
use piet_gpu_hal::{Buffer, Pipeline};
|
||||||
|
|
||||||
use crate::clear::{ClearBinding, ClearCode, ClearStage};
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::runner::{Commands, Runner};
|
use crate::runner::{Commands, Runner};
|
||||||
use crate::test_result::TestResult;
|
use crate::test_result::TestResult;
|
||||||
|
@ -27,19 +26,16 @@ const N_ELEMENTS: u64 = 65536;
|
||||||
/// The shader code forMessagePassing sum example.
|
/// The shader code forMessagePassing sum example.
|
||||||
struct MessagePassingCode {
|
struct MessagePassingCode {
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
clear_code: Option<ClearCode>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The stage resources for the prefix sum example.
|
/// The stage resources for the prefix sum example.
|
||||||
struct MessagePassingStage {
|
struct MessagePassingStage {
|
||||||
data_buf: Buffer,
|
data_buf: Buffer,
|
||||||
clear_stages: Option<(ClearStage, ClearBinding, ClearStage)>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The binding for the prefix sum example.
|
/// The binding for the prefix sum example.
|
||||||
struct MessagePassingBinding {
|
struct MessagePassingBinding {
|
||||||
descriptor_set: DescriptorSet,
|
descriptor_set: DescriptorSet,
|
||||||
clear_binding: Option<ClearBinding>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -56,7 +52,7 @@ pub unsafe fn run_message_passing_test(
|
||||||
let mut result = TestResult::new(format!("message passing litmus, {:?}", variant));
|
let mut result = TestResult::new(format!("message passing litmus, {:?}", variant));
|
||||||
let out_buf = runner.buf_down(4, BufferUsage::CLEAR);
|
let out_buf = runner.buf_down(4, BufferUsage::CLEAR);
|
||||||
let code = MessagePassingCode::new(runner, variant);
|
let code = MessagePassingCode::new(runner, variant);
|
||||||
let stage = MessagePassingStage::new(runner, &code);
|
let stage = MessagePassingStage::new(runner);
|
||||||
let binding = stage.bind(runner, &code, &out_buf.dev_buf);
|
let binding = stage.bind(runner, &code, &out_buf.dev_buf);
|
||||||
let n_iter = config.n_iter;
|
let n_iter = config.n_iter;
|
||||||
let mut total_elapsed = 0.0;
|
let mut total_elapsed = 0.0;
|
||||||
|
@ -92,22 +88,12 @@ impl MessagePassingCode {
|
||||||
.session
|
.session
|
||||||
.create_compute_pipeline(code, &[BindType::Buffer, BindType::Buffer])
|
.create_compute_pipeline(code, &[BindType::Buffer, BindType::Buffer])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// Currently, Metal backend doesn't support buffer clearing, so use a
|
MessagePassingCode { pipeline }
|
||||||
// compute shader as a workaround.
|
|
||||||
let clear_code = if runner.backend_type() == BackendType::Metal {
|
|
||||||
Some(ClearCode::new(runner))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
MessagePassingCode {
|
|
||||||
pipeline,
|
|
||||||
clear_code,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MessagePassingStage {
|
impl MessagePassingStage {
|
||||||
unsafe fn new(runner: &mut Runner, code: &MessagePassingCode) -> MessagePassingStage {
|
unsafe fn new(runner: &mut Runner) -> MessagePassingStage {
|
||||||
let data_buf_size = 8 * N_ELEMENTS;
|
let data_buf_size = 8 * N_ELEMENTS;
|
||||||
let data_buf = runner
|
let data_buf = runner
|
||||||
.session
|
.session
|
||||||
|
@ -116,18 +102,7 @@ impl MessagePassingStage {
|
||||||
BufferUsage::STORAGE | BufferUsage::COPY_DST | BufferUsage::CLEAR,
|
BufferUsage::STORAGE | BufferUsage::COPY_DST | BufferUsage::CLEAR,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let clear_stages = if let Some(clear_code) = &code.clear_code {
|
MessagePassingStage { data_buf }
|
||||||
let stage0 = ClearStage::new(runner, N_ELEMENTS * 2);
|
|
||||||
let binding0 = stage0.bind(runner, clear_code, &data_buf);
|
|
||||||
let stage1 = ClearStage::new(runner, 1);
|
|
||||||
Some((stage0, binding0, stage1))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
MessagePassingStage {
|
|
||||||
data_buf,
|
|
||||||
clear_stages,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn bind(
|
unsafe fn bind(
|
||||||
|
@ -140,21 +115,7 @@ impl MessagePassingStage {
|
||||||
.session
|
.session
|
||||||
.create_simple_descriptor_set(&code.pipeline, &[&self.data_buf, out_buf])
|
.create_simple_descriptor_set(&code.pipeline, &[&self.data_buf, out_buf])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let clear_binding = if let Some(clear_code) = &code.clear_code {
|
MessagePassingBinding { descriptor_set }
|
||||||
Some(
|
|
||||||
self.clear_stages
|
|
||||||
.as_ref()
|
|
||||||
.unwrap()
|
|
||||||
.2
|
|
||||||
.bind(runner, clear_code, out_buf),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
MessagePassingBinding {
|
|
||||||
descriptor_set,
|
|
||||||
clear_binding,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn record(
|
unsafe fn record(
|
||||||
|
@ -164,14 +125,8 @@ impl MessagePassingStage {
|
||||||
bindings: &MessagePassingBinding,
|
bindings: &MessagePassingBinding,
|
||||||
out_buf: &Buffer,
|
out_buf: &Buffer,
|
||||||
) {
|
) {
|
||||||
if let Some((stage0, binding0, stage1)) = &self.clear_stages {
|
|
||||||
let code = code.clear_code.as_ref().unwrap();
|
|
||||||
stage0.record(commands, code, binding0);
|
|
||||||
stage1.record(commands, code, bindings.clear_binding.as_ref().unwrap());
|
|
||||||
} else {
|
|
||||||
commands.cmd_buf.clear_buffer(&self.data_buf, None);
|
commands.cmd_buf.clear_buffer(&self.data_buf, None);
|
||||||
commands.cmd_buf.clear_buffer(out_buf, None);
|
commands.cmd_buf.clear_buffer(out_buf, None);
|
||||||
}
|
|
||||||
commands.cmd_buf.memory_barrier();
|
commands.cmd_buf.memory_barrier();
|
||||||
commands.cmd_buf.dispatch(
|
commands.cmd_buf.dispatch(
|
||||||
&code.pipeline,
|
&code.pipeline,
|
||||||
|
|
|
@ -14,10 +14,9 @@
|
||||||
//
|
//
|
||||||
// Also licensed under MIT license, at your choice.
|
// Also licensed under MIT license, at your choice.
|
||||||
|
|
||||||
use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet, ShaderCode};
|
use piet_gpu_hal::{include_shader, BindType, BufferUsage, DescriptorSet, ShaderCode};
|
||||||
use piet_gpu_hal::{Buffer, Pipeline};
|
use piet_gpu_hal::{Buffer, Pipeline};
|
||||||
|
|
||||||
use crate::clear::{ClearBinding, ClearCode, ClearStage};
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::runner::{Commands, Runner};
|
use crate::runner::{Commands, Runner};
|
||||||
use crate::test_result::TestResult;
|
use crate::test_result::TestResult;
|
||||||
|
@ -31,7 +30,6 @@ const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS;
|
||||||
/// A code struct can be created once and reused any number of times.
|
/// A code struct can be created once and reused any number of times.
|
||||||
struct PrefixCode {
|
struct PrefixCode {
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
clear_code: Option<ClearCode>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The stage resources for the prefix sum example.
|
/// The stage resources for the prefix sum example.
|
||||||
|
@ -43,7 +41,6 @@ struct PrefixStage {
|
||||||
// treat it as a capacity.
|
// treat it as a capacity.
|
||||||
n_elements: u64,
|
n_elements: u64,
|
||||||
state_buf: Buffer,
|
state_buf: Buffer,
|
||||||
clear_stage: Option<(ClearStage, ClearBinding)>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The binding for the prefix sum example.
|
/// The binding for the prefix sum example.
|
||||||
|
@ -79,7 +76,7 @@ pub unsafe fn run_prefix_test(
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let out_buf = runner.buf_down(data_buf.size(), BufferUsage::empty());
|
let out_buf = runner.buf_down(data_buf.size(), BufferUsage::empty());
|
||||||
let code = PrefixCode::new(runner, variant);
|
let code = PrefixCode::new(runner, variant);
|
||||||
let stage = PrefixStage::new(runner, &code, n_elements);
|
let stage = PrefixStage::new(runner, n_elements);
|
||||||
let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf);
|
let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf);
|
||||||
let n_iter = config.n_iter;
|
let n_iter = config.n_iter;
|
||||||
let mut total_elapsed = 0.0;
|
let mut total_elapsed = 0.0;
|
||||||
|
@ -121,20 +118,12 @@ impl PrefixCode {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// Currently, DX12 and Metal backends don't support buffer clearing, so use a
|
// Currently, DX12 and Metal backends don't support buffer clearing, so use a
|
||||||
// compute shader as a workaround.
|
// compute shader as a workaround.
|
||||||
let clear_code = if runner.backend_type() == BackendType::Metal {
|
PrefixCode { pipeline }
|
||||||
Some(ClearCode::new(runner))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
PrefixCode {
|
|
||||||
pipeline,
|
|
||||||
clear_code,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PrefixStage {
|
impl PrefixStage {
|
||||||
unsafe fn new(runner: &mut Runner, code: &PrefixCode, n_elements: u64) -> PrefixStage {
|
unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixStage {
|
||||||
let n_workgroups = (n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
|
let n_workgroups = (n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
|
||||||
let state_buf_size = 4 + 12 * n_workgroups;
|
let state_buf_size = 4 + 12 * n_workgroups;
|
||||||
let state_buf = runner
|
let state_buf = runner
|
||||||
|
@ -144,17 +133,9 @@ impl PrefixStage {
|
||||||
BufferUsage::STORAGE | BufferUsage::COPY_DST | BufferUsage::CLEAR,
|
BufferUsage::STORAGE | BufferUsage::COPY_DST | BufferUsage::CLEAR,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let clear_stage = if let Some(clear_code) = &code.clear_code {
|
|
||||||
let stage = ClearStage::new(runner, state_buf_size / 4);
|
|
||||||
let binding = stage.bind(runner, clear_code, &state_buf);
|
|
||||||
Some((stage, binding))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
PrefixStage {
|
PrefixStage {
|
||||||
n_elements,
|
n_elements,
|
||||||
state_buf,
|
state_buf,
|
||||||
clear_stage,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,11 +155,7 @@ impl PrefixStage {
|
||||||
|
|
||||||
unsafe fn record(&self, commands: &mut Commands, code: &PrefixCode, bindings: &PrefixBinding) {
|
unsafe fn record(&self, commands: &mut Commands, code: &PrefixCode, bindings: &PrefixBinding) {
|
||||||
let n_workgroups = (self.n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
|
let n_workgroups = (self.n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
|
||||||
if let Some((stage, binding)) = &self.clear_stage {
|
|
||||||
stage.record(commands, code.clear_code.as_ref().unwrap(), binding);
|
|
||||||
} else {
|
|
||||||
commands.cmd_buf.clear_buffer(&self.state_buf, None);
|
commands.cmd_buf.clear_buffer(&self.state_buf, None);
|
||||||
}
|
|
||||||
commands.cmd_buf.memory_barrier();
|
commands.cmd_buf.memory_barrier();
|
||||||
commands.cmd_buf.dispatch(
|
commands.cmd_buf.dispatch(
|
||||||
&code.pipeline,
|
&code.pipeline,
|
||||||
|
|
Loading…
Add table
Reference in a new issue