diff --git a/plugins/examples/stft/src/lib.rs b/plugins/examples/stft/src/lib.rs index 98601ad8..27575a86 100644 --- a/plugins/examples/stft/src/lib.rs +++ b/plugins/examples/stft/src/lib.rs @@ -138,10 +138,9 @@ impl Plugin for Stft { self.stft.process_overlap_add( buffer, - [], &self.window_function, OVERLAP_TIMES, - |_channel_idx, _, real_fft_scratch_buffer| { + |_channel_idx, real_fft_scratch_buffer| { // Forward FFT, the helper has already applied window function self.plan .r2c_plan diff --git a/src/buffer.rs b/src/buffer.rs index d4ba8aa9..ee7b3850 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -240,6 +240,7 @@ impl ExactSizeIterator for BlockChannelsIter<'_, '_> {} impl<'a> Buffer<'a> { /// Return the numer of samples in this buffer. + #[inline] pub fn len(&self) -> usize { if self.output_slices.is_empty() { 0 @@ -331,6 +332,7 @@ impl<'a> Buffer<'a> { impl<'slice, 'sample> Channels<'slice, 'sample> { /// Get the number of channels. #[allow(clippy::len_without_is_empty)] + #[inline] pub fn len(&self) -> usize { unsafe { (*self.buffers).len() } } @@ -461,10 +463,17 @@ impl<'slice, 'sample> Channels<'slice, 'sample> { impl<'slice, 'sample> Block<'slice, 'sample> { /// Get the number of samples (not channels) in the block. #[allow(clippy::len_without_is_empty)] + #[inline] pub fn len(&self) -> usize { self.current_block_end - self.current_block_start } + /// Return the numer of channels in this buffer. + #[inline] + pub fn channels(&self) -> usize { + unsafe { (*self.buffers).len() } + } + /// A resetting iterator. This lets you iterate over the same block multiple times. Otherwise /// you don't need to use this function as [`Block`] already implements [`Iterator`]. You can /// also use the direct accessor functions on this block instead. @@ -481,6 +490,32 @@ impl<'slice, 'sample> Block<'slice, 'sample> { /// Access a channel by index. Useful when you would otherwise iterate over this [`Block`] /// multiple times. #[inline] + pub fn get(&self, channel_index: usize) -> Option<&[f32]> { + // SAFETY: The block bound has already been checked + unsafe { + Some( + (*self.buffers) + .get(channel_index)? + .get_unchecked(self.current_block_start..self.current_block_end), + ) + } + } + + /// The same as [`get()`][Self::get], but without any bounds checking. + /// + /// # Safety + /// + /// `channel_index` must be in the range `0..Self::len()`. + #[inline] + pub unsafe fn get_unchecked(&self, channel_index: usize) -> &[f32] { + (*self.buffers) + .get_unchecked(channel_index) + .get_unchecked(self.current_block_start..self.current_block_end) + } + + /// Access a mutable channel by index. Useful when you would otherwise iterate over this + /// [`Block`] multiple times. + #[inline] pub fn get_mut(&mut self, channel_index: usize) -> Option<&mut [f32]> { // SAFETY: The block bound has already been checked unsafe { diff --git a/src/util/stft.rs b/src/util/stft.rs index 3662a9f4..3c737fac 100644 --- a/src/util/stft.rs +++ b/src/util/stft.rs @@ -1,7 +1,25 @@ //! Utilities for buffering audio, likely used as part of a short-term Fourier transform. use super::window::multiply_with_window; -use crate::buffer::Buffer; +use crate::buffer::{Block, Buffer}; + +/// Some buffer that can be used with the [`StftHelper`]. +pub trait StftInput { + /// The number of samples in this input. + fn num_samples(&self) -> usize; + + /// The number of channels in this input. + fn num_channels(&self) -> usize; + + /// Index the buffer without any bounds checks. + unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32; +} + +/// The same as [`StftInput`], but with support for writing results back to the buffer +pub trait StftInputMut: StftInput { + /// Get a mutable reference to a sample in the buffer without any bounds checks. + unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32; +} /// Process the input buffer in equal sized blocks, running a callback on each block to transform /// the block and then writing back the results from the previous block to the buffer. This @@ -32,6 +50,127 @@ pub struct StftHelper { current_pos: usize, } +/// Marker struct for the version wtihout sidechaining. +struct NoSidechain; + +impl StftInput for Buffer<'_> { + #[inline] + fn num_samples(&self) -> usize { + self.len() + } + + #[inline] + fn num_channels(&self) -> usize { + self.channels() + } + + #[inline] + unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 { + *self + .as_slice_immutable() + .get_unchecked(channel) + .get_unchecked(sample_idx) + } +} + +impl StftInputMut for Buffer<'_> { + #[inline] + unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 { + self.as_slice() + .get_unchecked_mut(channel) + .get_unchecked_mut(sample_idx) + } +} + +impl StftInput for Block<'_, '_> { + #[inline] + fn num_samples(&self) -> usize { + self.len() + } + + #[inline] + fn num_channels(&self) -> usize { + self.channels() + } + + #[inline] + unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 { + *self.get_unchecked(channel).get_unchecked(sample_idx) + } +} + +impl StftInputMut for Block<'_, '_> { + #[inline] + unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 { + self.get_unchecked_mut(channel) + .get_unchecked_mut(sample_idx) + } +} + +impl StftInput for [&[f32]] { + #[inline] + fn num_samples(&self) -> usize { + if self.is_empty() { + 0 + } else { + self[0].len() + } + } + + #[inline] + fn num_channels(&self) -> usize { + self.len() + } + + #[inline] + unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 { + *self.get_unchecked(channel).get_unchecked(sample_idx) + } +} + +impl StftInput for [&mut [f32]] { + #[inline] + fn num_samples(&self) -> usize { + if self.is_empty() { + 0 + } else { + self[0].len() + } + } + + #[inline] + fn num_channels(&self) -> usize { + self.len() + } + + #[inline] + unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 { + *self.get_unchecked(channel).get_unchecked(sample_idx) + } +} + +impl StftInputMut for [&mut [f32]] { + #[inline] + unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 { + self.get_unchecked_mut(channel) + .get_unchecked_mut(sample_idx) + } +} + +impl StftInput for NoSidechain { + fn num_samples(&self) -> usize { + 0 + } + + fn num_channels(&self) -> usize { + 0 + } + + unsafe fn get_sample_unchecked(&self, _channel: usize, _sample_idx: usize) -> f32 { + 0.0 + } +} + impl StftHelper { /// Initialize the [`StftHelper`] for [`Buffer`]s with the specified number of channels and the /// given maximum block size. Call [`set_block_size()`][`Self::set_block_size()`] afterwards if @@ -91,13 +230,13 @@ impl StftHelper { self.main_input_ring_buffers[0].len() as u32 } - /// Process the audio in `main_buffer` and in any sidechain buffers in small overlapping blocks - /// with a window function applied, adding up the results for the main buffer so they can be - /// written back to the host. The window overlap amount is compensated automatically when adding - /// up these samples. Whenever a new block is available, `process_cb()` gets called with a new - /// audio block of the specified size with the windowing function already applied. The summed - /// reults will then be written back to `main_buffer` exactly one block later, which means that - /// this function will introduce one block of latency. This can be compensated by calling + /// Process the audio in `main_buffer` in small overlapping blocks with a window function + /// applied, adding up the results for the main buffer so they can be written back to the host. + /// The window overlap amount is compensated automatically when adding up these samples. + /// Whenever a new block is available, `process_cb()` gets called with a new audio block of the + /// specified size with the windowing function already applied. The summed reults will then be + /// written back to `main_buffer` exactly one block later, which means that this function will + /// introduce one block of latency. This can be compensated by calling /// [`ProcessContext::set_latency()`][`crate::prelude::ProcessContext::set_latency()`] in your /// plugin's initialization function. /// @@ -106,9 +245,7 @@ impl StftHelper { /// /// For efficiency's sake this function will reuse the same vector for all calls to /// `process_cb`. This means you can only access a single channel's worth of windowed data at a - /// time. The arguments to that function are `process_cb(channel_idx, sidechain_buffer_idx, - /// real_fft_buffer)`, where `sidechain_buffer_idx` will be `None` for the main buffer. If there - /// are any sidechain buffers, then they will be processed before the main buffer. + /// time. The arguments to that function are `process_cb(channel_idx, real_fft_buffer)`. /// `real_fft_buffer` will be a slice of `block_size` real valued samples. This can be passed /// directly to an FFT algorithm. /// @@ -118,29 +255,61 @@ impl StftHelper { /// channels as this [`StftHelper`], if the sidechain buffers do not contain the same number of /// samples as the main buffer, or if the window function does not match the block size. /// - /// TODO: Maybe introduce a trait here so this can be used with things that aren't whole buffers /// TODO: And also introduce that aforementioned read-only process function (`analyze()?`) /// TODO: Add more useful ways to do STFT and other buffered operations. I just went with this /// approach because it's what I needed myself, but generic combinators like this could /// also be useful for other operations. - pub fn process_overlap_add( + pub fn process_overlap_add( &mut self, - main_buffer: &mut Buffer, - sidechain_buffers: [&Buffer; NUM_SIDECHAIN_INPUTS], + main_buffer: &mut M, window_function: &[f32], overlap_times: usize, mut process_cb: F, ) where + M: StftInputMut, + F: FnMut(usize, &mut [f32]), + { + self.process_overlap_add_sidechain( + main_buffer, + [&NoSidechain; NUM_SIDECHAIN_INPUTS], + window_function, + overlap_times, + |channel_idx, sidechain_idx, real_fft_scratch_buffer| { + if sidechain_idx.is_none() { + process_cb(channel_idx, real_fft_scratch_buffer); + } + }, + ); + } + + /// The same as [`process_overlap_add()`][Self::process_overlap_add()], but with sidechain + /// inputs that can be analyzed before the main input gets processed. + /// + /// The extra argument in the process function is `sidechain_buffer_idx`, which will be `None` + /// for the main buffer. + pub fn process_overlap_add_sidechain( + &mut self, + main_buffer: &mut M, + sidechain_buffers: [&S; NUM_SIDECHAIN_INPUTS], + window_function: &[f32], + overlap_times: usize, + mut process_cb: F, + ) where + M: StftInputMut, + S: StftInput, F: FnMut(usize, Option, &mut [f32]), { - assert_eq!(main_buffer.channels(), self.main_input_ring_buffers.len()); + assert_eq!( + main_buffer.num_channels(), + self.main_input_ring_buffers.len() + ); assert_eq!(window_function.len(), self.main_input_ring_buffers[0].len()); assert!(overlap_times > 0); // We'll copy samples from `*_buffer` into `*_ring_buffers` while simultaneously copying // already processed samples from `main_ring_buffers` in into `main_buffer` - let main_buffer_len = main_buffer.len(); - let num_channels = main_buffer.channels(); + let main_buffer_len = main_buffer.num_samples(); + let num_channels = main_buffer.num_channels(); let block_size = self.main_input_ring_buffers[0].len(); let window_interval = (block_size / overlap_times) as i32; let mut already_processed_samples = 0; @@ -156,13 +325,14 @@ impl StftHelper { // TODO: This might be able to be sped up a bit with SIMD { // For the main buffer - let main_buffer = main_buffer.as_slice(); for sample_offset in 0..samples_to_process { for channel_idx in 0..num_channels { + // let main_buffer = main_buffer.as_slice(); let sample = unsafe { - main_buffer - .get_unchecked_mut(channel_idx) - .get_unchecked_mut(already_processed_samples + sample_offset) + main_buffer.get_sample_unchecked_mut( + channel_idx, + already_processed_samples + sample_offset, + ) }; let input_ring_buffer_sample = unsafe { self.main_input_ring_buffers @@ -186,20 +356,20 @@ impl StftHelper { .iter() .zip(self.sidechain_ring_buffers.iter_mut()) { - let sidechain_buffer = sidechain_buffer.as_slice_immutable(); for sample_offset in 0..samples_to_process { for channel_idx in 0..num_channels { let sample = unsafe { - sidechain_buffer - .get_unchecked(channel_idx) - .get_unchecked(already_processed_samples + sample_offset) + sidechain_buffer.get_sample_unchecked( + channel_idx, + already_processed_samples + sample_offset, + ) }; let ring_buffer_sample = unsafe { sidechain_ring_buffers .get_unchecked_mut(channel_idx) .get_unchecked_mut(self.current_pos + sample_offset) }; - *ring_buffer_sample = *sample; + *ring_buffer_sample = sample; } } }