From 0762cc763c25e1f3cd2901e00b3359e8fe962918 Mon Sep 17 00:00:00 2001 From: Raph Levien Date: Sat, 20 Nov 2021 21:52:29 -0800 Subject: [PATCH] Implement clear_buffers on Metal Since clearing functionality is not built-in, use a compute shader. Simplify tests greatly; they don't need the workaround. --- piet-gpu-hal/src/metal.rs | 19 ++++++++- piet-gpu-hal/src/metal/clear.rs | 68 +++++++++++++++++++++++++++++++++ tests/src/clear.rs | 4 -- tests/src/linkedlist.rs | 52 ++++++------------------- tests/src/message_passing.rs | 61 ++++------------------------- tests/src/prefix.rs | 33 +++------------- 6 files changed, 109 insertions(+), 128 deletions(-) create mode 100644 piet-gpu-hal/src/metal/clear.rs diff --git a/piet-gpu-hal/src/metal.rs b/piet-gpu-hal/src/metal.rs index 78c0682..4b8acb8 100644 --- a/piet-gpu-hal/src/metal.rs +++ b/piet-gpu-hal/src/metal.rs @@ -14,6 +14,7 @@ // // Also licensed under MIT license, at your choice. +mod clear; mod util; use std::mem; @@ -39,6 +40,7 @@ pub struct MtlDevice { device: metal::Device, cmd_queue: Arc>, gpu_info: GpuInfo, + helpers: Arc, } pub struct MtlSurface { @@ -78,6 +80,7 @@ pub struct Semaphore; pub struct CmdBuf { cmd_buf: metal::CommandBuffer, + helpers: Arc, } pub struct QueryPool; @@ -93,6 +96,10 @@ pub struct DescriptorSet { images: Vec, } +struct Helpers { + clear_pipeline: metal::ComputePipelineState, +} + impl MtlInstance { pub fn new( window_handle: Option<&dyn HasRawWindowHandle>, @@ -172,10 +179,14 @@ impl MtlInstance { has_memory_model: false, use_staging_buffers, }; + let helpers = Arc::new(Helpers { + clear_pipeline: clear::make_clear_pipeline(&device), + }); Ok(MtlDevice { device, cmd_queue: Arc::new(Mutex::new(cmd_queue)), gpu_info, + helpers, }) } else { Err("can't create system default Metal device".into()) @@ -292,7 +303,8 @@ impl crate::backend::Device for MtlDevice { // consider new_command_buffer_with_unretained_references for performance let cmd_buf = cmd_queue.new_command_buffer(); let cmd_buf = autoreleasepool(|| cmd_buf.to_owned()); - Ok(CmdBuf { cmd_buf }) + let helpers = self.helpers.clone(); + Ok(CmdBuf { cmd_buf, helpers }) } unsafe fn destroy_cmd_buf(&self, _cmd_buf: Self::CmdBuf) -> Result<(), Error> { @@ -467,7 +479,10 @@ impl crate::backend::CmdBuf for CmdBuf { } unsafe fn clear_buffer(&self, buffer: &Buffer, size: Option) { - todo!() + let size = size.unwrap_or(buffer.size); + let encoder = self.cmd_buf.new_compute_command_encoder(); + clear::encode_clear(&encoder, &self.helpers.clear_pipeline, &buffer.buffer, size); + encoder.end_encoding() } unsafe fn copy_buffer(&self, src: &Buffer, dst: &Buffer) { diff --git a/piet-gpu-hal/src/metal/clear.rs b/piet-gpu-hal/src/metal/clear.rs new file mode 100644 index 0000000..2d58a66 --- /dev/null +++ b/piet-gpu-hal/src/metal/clear.rs @@ -0,0 +1,68 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +//! The compute shader and stage for clearing buffers. + +use metal::{ComputePipelineState, Device}; + +const CLEAR_MSL: &str = r#" +using namespace metal; + +struct ConfigBuf +{ + uint size; + uint value; +}; + +kernel void main0(const device ConfigBuf& config [[buffer(0)]], device uint *data [[buffer(1)]], uint3 gid [[thread_position_in_grid]]) +{ + uint ix = gid.x; + if (ix < config.size) + { + data[ix] = config.value; + } +} +"#; + +pub fn make_clear_pipeline(device: &Device) -> ComputePipelineState { + let options = metal::CompileOptions::new(); + let library = device.new_library_with_source(CLEAR_MSL, &options).unwrap(); + let function = library.get_function("main0", None).unwrap(); + device + .new_compute_pipeline_state_with_function(&function).unwrap() + +} + +pub fn encode_clear(encoder: &metal::ComputeCommandEncoderRef, clear_pipeline: &ComputePipelineState, buffer: &metal::Buffer, size: u64) { + // TODO: should be more careful with overflow + let size_in_u32s = (size / 4) as u32; + encoder.set_compute_pipeline_state(&clear_pipeline); + let config = [size_in_u32s, 0]; + encoder.set_bytes(0, std::mem::size_of_val(&config) as u64, config.as_ptr() as *const _); + encoder.set_buffer(1, Some(buffer), 0); + let n_wg = (size_in_u32s + 255) / 256; + let workgroup_count = metal::MTLSize { + width: n_wg as u64, + height: 1, + depth: 1, + }; + let workgroup_size = metal::MTLSize { + width: 256, + height: 1, + depth: 1, + }; + encoder.dispatch_thread_groups(workgroup_count, workgroup_size); +} diff --git a/tests/src/clear.rs b/tests/src/clear.rs index 6e46e1f..7d8bee0 100644 --- a/tests/src/clear.rs +++ b/tests/src/clear.rs @@ -84,10 +84,6 @@ impl ClearCode { } impl ClearStage { - pub unsafe fn new(runner: &mut Runner, n_elements: u64) -> ClearStage { - Self::new_with_value(runner, n_elements, 0) - } - pub unsafe fn new_with_value(runner: &mut Runner, n_elements: u64, value: u32) -> ClearStage { let config = [n_elements as u32, value]; let config_buf = runner diff --git a/tests/src/linkedlist.rs b/tests/src/linkedlist.rs index f8ed826..b3d03ed 100644 --- a/tests/src/linkedlist.rs +++ b/tests/src/linkedlist.rs @@ -14,10 +14,9 @@ // // Also licensed under MIT license, at your choice. -use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet}; +use piet_gpu_hal::{include_shader, BindType, BufferUsage, DescriptorSet}; use piet_gpu_hal::{Buffer, Pipeline}; -use crate::clear::{ClearBinding, ClearCode, ClearStage}; use crate::runner::{Commands, Runner}; use crate::test_result::TestResult; use crate::Config; @@ -27,16 +26,12 @@ const N_BUCKETS: u64 = 65536; struct LinkedListCode { pipeline: Pipeline, - clear_code: Option, } -struct LinkedListStage { - clear_stage: Option, -} +struct LinkedListStage; struct LinkedListBinding { descriptor_set: DescriptorSet, - clear_binding: Option, } pub unsafe fn run_linkedlist_test(runner: &mut Runner, config: &Config) -> TestResult { @@ -77,26 +72,17 @@ impl LinkedListCode { .session .create_compute_pipeline(code, &[BindType::Buffer]) .unwrap(); - let clear_code = if runner.backend_type() == BackendType::Metal { - Some(ClearCode::new(runner)) - } else { - None - }; - LinkedListCode { - pipeline, - clear_code, - } + LinkedListCode { pipeline } } } impl LinkedListStage { - unsafe fn new(runner: &mut Runner, code: &LinkedListCode, n_buckets: u64) -> LinkedListStage { - let clear_stage = if code.clear_code.is_some() { - Some(ClearStage::new(runner, n_buckets)) - } else { - None - }; - LinkedListStage { clear_stage } + unsafe fn new( + _runner: &mut Runner, + _code: &LinkedListCode, + _n_buckets: u64, + ) -> LinkedListStage { + LinkedListStage } unsafe fn bind( @@ -109,15 +95,7 @@ impl LinkedListStage { .session .create_simple_descriptor_set(&code.pipeline, &[mem_buf]) .unwrap(); - let clear_binding = if let Some(stage) = &self.clear_stage { - Some(stage.bind(runner, &code.clear_code.as_ref().unwrap(), mem_buf)) - } else { - None - }; - LinkedListBinding { - descriptor_set, - clear_binding, - } + LinkedListBinding { descriptor_set } } unsafe fn record( @@ -127,15 +105,7 @@ impl LinkedListStage { bindings: &LinkedListBinding, out_buf: &Buffer, ) { - if let Some(stage) = &self.clear_stage { - stage.record( - commands, - code.clear_code.as_ref().unwrap(), - bindings.clear_binding.as_ref().unwrap(), - ); - } else { - commands.cmd_buf.clear_buffer(out_buf, None); - } + commands.cmd_buf.clear_buffer(out_buf, None); commands.cmd_buf.memory_barrier(); let n_workgroups = N_BUCKETS / WG_SIZE; commands.cmd_buf.dispatch( diff --git a/tests/src/message_passing.rs b/tests/src/message_passing.rs index c0f85af..c5d989b 100644 --- a/tests/src/message_passing.rs +++ b/tests/src/message_passing.rs @@ -14,10 +14,9 @@ // // Also licensed under MIT license, at your choice. -use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet, ShaderCode}; +use piet_gpu_hal::{include_shader, BindType, BufferUsage, DescriptorSet, ShaderCode}; use piet_gpu_hal::{Buffer, Pipeline}; -use crate::clear::{ClearBinding, ClearCode, ClearStage}; use crate::config::Config; use crate::runner::{Commands, Runner}; use crate::test_result::TestResult; @@ -27,19 +26,16 @@ const N_ELEMENTS: u64 = 65536; /// The shader code forMessagePassing sum example. struct MessagePassingCode { pipeline: Pipeline, - clear_code: Option, } /// The stage resources for the prefix sum example. struct MessagePassingStage { data_buf: Buffer, - clear_stages: Option<(ClearStage, ClearBinding, ClearStage)>, } /// The binding for the prefix sum example. struct MessagePassingBinding { descriptor_set: DescriptorSet, - clear_binding: Option, } #[derive(Debug)] @@ -56,7 +52,7 @@ pub unsafe fn run_message_passing_test( let mut result = TestResult::new(format!("message passing litmus, {:?}", variant)); let out_buf = runner.buf_down(4, BufferUsage::CLEAR); let code = MessagePassingCode::new(runner, variant); - let stage = MessagePassingStage::new(runner, &code); + let stage = MessagePassingStage::new(runner); let binding = stage.bind(runner, &code, &out_buf.dev_buf); let n_iter = config.n_iter; let mut total_elapsed = 0.0; @@ -92,22 +88,12 @@ impl MessagePassingCode { .session .create_compute_pipeline(code, &[BindType::Buffer, BindType::Buffer]) .unwrap(); - // Currently, Metal backend doesn't support buffer clearing, so use a - // compute shader as a workaround. - let clear_code = if runner.backend_type() == BackendType::Metal { - Some(ClearCode::new(runner)) - } else { - None - }; - MessagePassingCode { - pipeline, - clear_code, - } + MessagePassingCode { pipeline } } } impl MessagePassingStage { - unsafe fn new(runner: &mut Runner, code: &MessagePassingCode) -> MessagePassingStage { + unsafe fn new(runner: &mut Runner) -> MessagePassingStage { let data_buf_size = 8 * N_ELEMENTS; let data_buf = runner .session @@ -116,18 +102,7 @@ impl MessagePassingStage { BufferUsage::STORAGE | BufferUsage::COPY_DST | BufferUsage::CLEAR, ) .unwrap(); - let clear_stages = if let Some(clear_code) = &code.clear_code { - let stage0 = ClearStage::new(runner, N_ELEMENTS * 2); - let binding0 = stage0.bind(runner, clear_code, &data_buf); - let stage1 = ClearStage::new(runner, 1); - Some((stage0, binding0, stage1)) - } else { - None - }; - MessagePassingStage { - data_buf, - clear_stages, - } + MessagePassingStage { data_buf } } unsafe fn bind( @@ -140,21 +115,7 @@ impl MessagePassingStage { .session .create_simple_descriptor_set(&code.pipeline, &[&self.data_buf, out_buf]) .unwrap(); - let clear_binding = if let Some(clear_code) = &code.clear_code { - Some( - self.clear_stages - .as_ref() - .unwrap() - .2 - .bind(runner, clear_code, out_buf), - ) - } else { - None - }; - MessagePassingBinding { - descriptor_set, - clear_binding, - } + MessagePassingBinding { descriptor_set } } unsafe fn record( @@ -164,14 +125,8 @@ impl MessagePassingStage { bindings: &MessagePassingBinding, out_buf: &Buffer, ) { - if let Some((stage0, binding0, stage1)) = &self.clear_stages { - let code = code.clear_code.as_ref().unwrap(); - stage0.record(commands, code, binding0); - stage1.record(commands, code, bindings.clear_binding.as_ref().unwrap()); - } else { - commands.cmd_buf.clear_buffer(&self.data_buf, None); - commands.cmd_buf.clear_buffer(out_buf, None); - } + commands.cmd_buf.clear_buffer(&self.data_buf, None); + commands.cmd_buf.clear_buffer(out_buf, None); commands.cmd_buf.memory_barrier(); commands.cmd_buf.dispatch( &code.pipeline, diff --git a/tests/src/prefix.rs b/tests/src/prefix.rs index bfbc5b6..71be865 100644 --- a/tests/src/prefix.rs +++ b/tests/src/prefix.rs @@ -14,10 +14,9 @@ // // Also licensed under MIT license, at your choice. -use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet, ShaderCode}; +use piet_gpu_hal::{include_shader, BindType, BufferUsage, DescriptorSet, ShaderCode}; use piet_gpu_hal::{Buffer, Pipeline}; -use crate::clear::{ClearBinding, ClearCode, ClearStage}; use crate::config::Config; use crate::runner::{Commands, Runner}; use crate::test_result::TestResult; @@ -31,7 +30,6 @@ const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS; /// A code struct can be created once and reused any number of times. struct PrefixCode { pipeline: Pipeline, - clear_code: Option, } /// The stage resources for the prefix sum example. @@ -43,7 +41,6 @@ struct PrefixStage { // treat it as a capacity. n_elements: u64, state_buf: Buffer, - clear_stage: Option<(ClearStage, ClearBinding)>, } /// The binding for the prefix sum example. @@ -79,7 +76,7 @@ pub unsafe fn run_prefix_test( .unwrap(); let out_buf = runner.buf_down(data_buf.size(), BufferUsage::empty()); let code = PrefixCode::new(runner, variant); - let stage = PrefixStage::new(runner, &code, n_elements); + let stage = PrefixStage::new(runner, n_elements); let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf); let n_iter = config.n_iter; let mut total_elapsed = 0.0; @@ -121,20 +118,12 @@ impl PrefixCode { .unwrap(); // Currently, DX12 and Metal backends don't support buffer clearing, so use a // compute shader as a workaround. - let clear_code = if runner.backend_type() == BackendType::Metal { - Some(ClearCode::new(runner)) - } else { - None - }; - PrefixCode { - pipeline, - clear_code, - } + PrefixCode { pipeline } } } impl PrefixStage { - unsafe fn new(runner: &mut Runner, code: &PrefixCode, n_elements: u64) -> PrefixStage { + unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixStage { let n_workgroups = (n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG; let state_buf_size = 4 + 12 * n_workgroups; let state_buf = runner @@ -144,17 +133,9 @@ impl PrefixStage { BufferUsage::STORAGE | BufferUsage::COPY_DST | BufferUsage::CLEAR, ) .unwrap(); - let clear_stage = if let Some(clear_code) = &code.clear_code { - let stage = ClearStage::new(runner, state_buf_size / 4); - let binding = stage.bind(runner, clear_code, &state_buf); - Some((stage, binding)) - } else { - None - }; PrefixStage { n_elements, state_buf, - clear_stage, } } @@ -174,11 +155,7 @@ impl PrefixStage { unsafe fn record(&self, commands: &mut Commands, code: &PrefixCode, bindings: &PrefixBinding) { let n_workgroups = (self.n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG; - if let Some((stage, binding)) = &self.clear_stage { - stage.record(commands, code.clear_code.as_ref().unwrap(), binding); - } else { - commands.cmd_buf.clear_buffer(&self.state_buf, None); - } + commands.cmd_buf.clear_buffer(&self.state_buf, None); commands.cmd_buf.memory_barrier(); commands.cmd_buf.dispatch( &code.pipeline,