From b177e3114c6e0a3fcaab5633382b9db72d460257 Mon Sep 17 00:00:00 2001
From: Robbert van der Helm <mail@robbertvanderhelm.nl>
Date: Tue, 1 Mar 2022 16:03:13 +0100
Subject: [PATCH] Add per-block iterators to Buffer

---
 src/buffer.rs | 225 ++++++++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 219 insertions(+), 6 deletions(-)

diff --git a/src/buffer.rs b/src/buffer.rs
index 4085ab05..34eac5b9 100644
--- a/src/buffer.rs
+++ b/src/buffer.rs
@@ -16,6 +16,8 @@ pub struct Buffer<'a> {
     output_slices: Vec<&'a mut [f32]>,
 }
 
+// Per-sample per-channel iterators
+
 /// An iterator over all samples in the buffer, yielding iterators over each channel for every
 /// sample. This iteration order offers good cache locality for per-sample access.
 pub struct SamplesIter<'slice, 'sample: 'slice> {
@@ -26,7 +28,9 @@ pub struct SamplesIter<'slice, 'sample: 'slice> {
 }
 
 /// Can construct iterators over actual iterator over the channel data for a sample, yielded by
-/// [Samples].
+/// [Samples]. Can be turned into an iterator, or [Channels::iter_mut()] can be used to iterate over
+/// the channel data multiple times, or more efficiently you can use [Channels::get_unchecked_mut()]
+/// to do the same thing.
 pub struct Channels<'slice, 'sample: 'slice> {
     /// The raw output buffers.
     pub(self) buffers: *mut [&'sample mut [f32]],
@@ -43,6 +47,39 @@ pub struct ChannelsIter<'slice, 'sample: 'slice> {
     pub(self) _marker: PhantomData<&'slice mut [&'sample mut [f32]]>,
 }
 
+// Per-block per-channel per-sample iterators
+
+/// An iterator over all samples in the buffer, slicing over the sample-dimension with a maximum
+/// size of [Self::max_block_size]. See [Buffer::iter_blocks()].
+pub struct BlocksIter<'slice, 'sample: 'slice> {
+    /// The raw output buffers.
+    pub(self) buffers: *mut [&'sample mut [f32]],
+    pub(self) max_block_size: usize,
+    pub(self) current_block_start: usize,
+    pub(self) _marker: PhantomData<&'slice mut [&'sample mut [f32]]>,
+}
+
+/// A block yielded by [BlocksIter]. Can be iterated over once or multiple times, and also supports
+/// direct access to the block's samples if needed.
+pub struct Block<'slice, 'sample: 'slice> {
+    /// The raw output buffers.
+    pub(self) buffers: *mut [&'sample mut [f32]],
+    pub(self) current_block_start: usize,
+    pub(self) current_block_end: usize,
+    pub(self) _marker: PhantomData<&'slice mut [&'sample mut [f32]]>,
+}
+
+/// An iterator over all channels in a block yielded by [Block]. Analogous to [ChannelsIter] but for
+/// blocks.
+pub struct BlockChannelsIter<'slice, 'sample: 'slice> {
+    /// The raw output buffers.
+    pub(self) buffers: *mut [&'sample mut [f32]],
+    pub(self) current_block_start: usize,
+    pub(self) current_block_end: usize,
+    pub(self) current_channel: usize,
+    pub(self) _marker: PhantomData<&'slice mut [&'sample mut [f32]]>,
+}
+
 impl<'slice, 'sample> Iterator for SamplesIter<'slice, 'sample> {
     type Item = Channels<'slice, 'sample>;
 
@@ -68,6 +105,65 @@ impl<'slice, 'sample> Iterator for SamplesIter<'slice, 'sample> {
     }
 }
 
+impl<'slice, 'sample> Iterator for BlockChannelsIter<'slice, 'sample> {
+    type Item = &'sample mut [f32];
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if self.current_channel < unsafe { (*self.buffers).len() } {
+            // SAFETY: These bounds have already been checked
+            // SAFETY: It is also not possible to have multiple mutable references to the same
+            //         sample at the same time
+            let slice = unsafe {
+                (*self.buffers)
+                    .get_unchecked_mut(self.current_channel)
+                    .get_unchecked_mut(self.current_block_start..self.current_block_end)
+            };
+
+            self.current_channel += 1;
+
+            Some(slice)
+        } else {
+            None
+        }
+    }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        let remaining = unsafe { (*self.buffers).len() } - self.current_channel;
+        (remaining, Some(remaining))
+    }
+}
+
+impl<'slice, 'sample> Iterator for BlocksIter<'slice, 'sample> {
+    type Item = Block<'slice, 'sample>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        let buffer_len = unsafe { (*self.buffers)[0].len() };
+        if self.current_block_start < buffer_len {
+            let current_block_end =
+                (self.current_block_start + self.max_block_size).min(buffer_len);
+            let block = Block {
+                buffers: self.buffers,
+                current_block_start: self.current_block_start,
+                current_block_end,
+                _marker: self._marker,
+            };
+
+            self.current_block_start += self.max_block_size;
+
+            Some(block)
+        } else {
+            None
+        }
+    }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        let remaining = ((unsafe { (*self.buffers)[0].len() } - self.current_block_start) as f32
+            / self.max_block_size as f32)
+            .ceil() as usize;
+        (remaining, Some(remaining))
+    }
+}
+
 impl<'slice, 'sample> Iterator for ChannelsIter<'slice, 'sample> {
     type Item = &'sample mut f32;
 
@@ -110,9 +206,25 @@ impl<'slice, 'sample> IntoIterator for Channels<'slice, 'sample> {
     }
 }
 
-impl ExactSizeIterator for SamplesIter<'_, '_> {}
+impl<'slice, 'sample> IntoIterator for Block<'slice, 'sample> {
+    type Item = &'sample mut [f32];
+    type IntoIter = BlockChannelsIter<'slice, 'sample>;
 
+    fn into_iter(self) -> Self::IntoIter {
+        BlockChannelsIter {
+            buffers: self.buffers,
+            current_block_start: self.current_block_start,
+            current_block_end: self.current_block_end,
+            current_channel: 0,
+            _marker: self._marker,
+        }
+    }
+}
+
+impl ExactSizeIterator for SamplesIter<'_, '_> {}
 impl ExactSizeIterator for ChannelsIter<'_, '_> {}
+impl ExactSizeIterator for BlocksIter<'_, '_> {}
+impl ExactSizeIterator for BlockChannelsIter<'_, '_> {}
 
 impl<'a> Buffer<'a> {
     /// Returns true if this buffer does not contain any samples.
@@ -134,7 +246,28 @@ impl<'a> Buffer<'a> {
         }
     }
 
-    /// Access the raw output slice vector. This neds to be resized to match the number of output
+    /// Iterate over the buffer in blocks with the specified maximum size. The ideal maximum block
+    /// size depends on the plugin in question, but 64 or 128 samples works for most plugins. Since
+    /// the buffer's total size may not be cleanly divisble by the maximum size, the returned
+    /// buffers may have any size in `[1, max_block_size]`. This is useful when using algorithms
+    /// that work on entire blocks of audio, like those that would otherwise need to perform
+    /// expensive per-sample branching or that can use per-sample SIMD as opposed to per-channel
+    /// SIMD.
+    ///
+    /// The parameter smoothers can also produce smoothed values for an entire block using
+    /// [crate::Smoother::next_block()]. Before using this, you will need to call
+    /// [crate::Plugin::initialize_block_smoothers()] with the same `max_block_size` in your
+    /// initialization function first.
+    pub fn iter_blocks<'slice>(&'slice mut self, max_block_size: usize) -> BlocksIter<'slice, 'a> {
+        BlocksIter {
+            buffers: self.output_slices.as_mut_slice(),
+            max_block_size,
+            current_block_start: 0,
+            _marker: PhantomData,
+        }
+    }
+
+    /// Access the raw output slice vector. This needs to be resized to match the number of output
     /// channels during the plugin's initialization. Then during audio processing, these slices
     /// should be updated to point to the plugin's audio buffers.
     ///
@@ -161,15 +294,15 @@ impl<'slice, 'sample> Channels<'slice, 'sample> {
             buffers: self.buffers,
             current_sample: self.current_sample,
             current_channel: 0,
-            _marker: PhantomData,
+            _marker: self._marker,
         }
     }
 
-    /// Access a sample by index. Useful when you would otehrwise iterate over this 'Channels'
+    /// Access a sample by index. Useful when you would otherwise iterate over this 'Channels'
     /// iterator multiple times.
     #[inline]
     pub fn get_mut(&mut self, channel_index: usize) -> Option<&mut f32> {
-        // SAFETY: The channel bound has already been checked
+        // SAFETY: The sample bound has already been checked
         unsafe {
             Some(
                 (*self.buffers)
@@ -192,6 +325,52 @@ impl<'slice, 'sample> Channels<'slice, 'sample> {
     }
 }
 
+impl<'slice, 'sample> Block<'slice, 'sample> {
+    /// Get the number of channels in the block.
+    pub fn len(&self) -> usize {
+        unsafe { (*self.buffers).len() }
+    }
+
+    /// A resetting iterator. This lets you iterate over the same block multiple times. Otherwise
+    /// you don't need to use this function as [Block] already implements [Iterator]. You can also
+    /// use the direct accessor functions on this block instead.
+    pub fn iter_mut(&mut self) -> BlockChannelsIter<'slice, 'sample> {
+        BlockChannelsIter {
+            buffers: self.buffers,
+            current_block_start: self.current_block_start,
+            current_block_end: self.current_block_end,
+            current_channel: 0,
+            _marker: self._marker,
+        }
+    }
+
+    /// Access a channel by index. Useful when you would otherwise iterate over this [Block]
+    /// multiple times.
+    #[inline]
+    pub fn get_mut(&mut self, channel_index: usize) -> Option<&mut [f32]> {
+        // SAFETY: The block bound has already been checked
+        unsafe {
+            Some(
+                (*self.buffers)
+                    .get_mut(channel_index)?
+                    .get_unchecked_mut(self.current_block_start..self.current_block_end),
+            )
+        }
+    }
+
+    /// The same as [Self::get_mut], but without any bounds checking.
+    ///
+    /// # Safety
+    ///
+    /// `channel_index` must be in the range `0..Self::len()`.
+    #[inline]
+    pub unsafe fn get_unchecked_mut(&mut self, channel_index: usize) -> &mut [f32] {
+        (*self.buffers)
+            .get_unchecked_mut(channel_index)
+            .get_unchecked_mut(self.current_block_start..self.current_block_end)
+    }
+}
+
 #[cfg(miri)]
 mod miri {
     use super::*;
@@ -223,4 +402,38 @@ mod miri {
 
         assert_eq!(real_buffers[0][0], 0.003);
     }
+
+    #[test]
+    fn repeated_slices() {
+        let mut real_buffers = vec![vec![0.0; 512]; 2];
+        let mut buffer = Buffer::default();
+        unsafe {
+            buffer.with_raw_vec(|output_slices| {
+                let (first_channel, other_channels) = real_buffers.split_at_mut(1);
+                *output_slices = vec![&mut first_channel[0], &mut other_channels[0]];
+            })
+        };
+
+        // These iterators should not alias
+        let mut blocks = buffer.iter_blocks(16);
+        let block1 = blocks.next().unwrap();
+        let block2 = blocks.next().unwrap();
+        for channel in block1 {
+            for sample in channel.iter_mut() {
+                *sample += 0.001;
+            }
+        }
+        for channel in block2 {
+            for sample in channel.iter_mut() {
+                *sample += 0.001;
+            }
+        }
+
+        for i in 0..32 {
+            assert_eq!(real_buffers[0][i], 0.001);
+        }
+        for i in 32..48 {
+            assert_eq!(real_buffers[0][i], 0.0);
+        }
+    }
 }