1
0
Fork 0

Fix duplicate mutable borrow soundness in buffer

This gets rid of multiple simultaneous &mut references to the
vector (which should be fine, I think), and also replaces the
`.iter_mut()` that just resets the index (which definitely leads to
soundness issues) with an `.iter_mut()` and a `.into_iter()` that don't
let you have concurrent mutable borrows to the same sample data.
This commit is contained in:
Robbert van der Helm 2022-02-13 00:59:25 +01:00
parent eea05cc748
commit eac3fdf612

View file

@ -26,9 +26,9 @@ impl<'a> Buffer<'a> {
} }
/// Iterate over the samples, returning a channel iterator for each sample. /// Iterate over the samples, returning a channel iterator for each sample.
pub fn iter_mut(&mut self) -> Samples<'_, 'a> { pub fn iter_mut(&mut self) -> Samples<'a> {
Samples { Samples {
buffers: &mut self.output_slices, buffers: self.output_slices.as_mut_slice(),
current_sample: 0, current_sample: 0,
} }
} }
@ -49,24 +49,20 @@ impl<'a> Buffer<'a> {
/// An iterator over all samples in the buffer, yielding iterators over each channel for every /// An iterator over all samples in the buffer, yielding iterators over each channel for every
/// sample. This iteration order offers good cache locality for per-sample access. /// sample. This iteration order offers good cache locality for per-sample access.
pub struct Samples<'outer, 'inner> { pub struct Samples<'a> {
/// The raw output buffers. /// The raw output buffers.
pub(self) buffers: &'outer mut [&'inner mut [f32]], pub(self) buffers: *mut [&'a mut [f32]],
pub(self) current_sample: usize, pub(self) current_sample: usize,
} }
impl<'outer, 'inner> Iterator for Samples<'outer, 'inner> { impl<'a> Iterator for Samples<'a> {
type Item = Channels<'outer, 'inner>; type Item = Channels<'a>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if self.current_sample < self.buffers[0].len() { if self.current_sample < self.len() {
// SAFETY: We guarantee that each sample is only mutably borrowed once in the channels
// iterator
let buffers: &'outer mut _ = unsafe { &mut *(self.buffers as *mut _) };
let channels = Channels { let channels = Channels {
buffers, buffers: self.buffers,
current_sample: self.current_sample, current_sample: self.current_sample,
current_channel: 0,
}; };
self.current_sample += 1; self.current_sample += 1;
@ -78,35 +74,56 @@ impl<'outer, 'inner> Iterator for Samples<'outer, 'inner> {
} }
fn size_hint(&self) -> (usize, Option<usize>) { fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.buffers[0].len() - self.current_sample; let remaining = unsafe { (*self.buffers)[0].len() } - self.current_sample;
(remaining, Some(remaining)) (remaining, Some(remaining))
} }
} }
impl<'outer, 'inner> ExactSizeIterator for Samples<'outer, 'inner> {} impl<'a> ExactSizeIterator for Samples<'a> {}
/// An iterator over the channel data for a sample, yielded by [Samples]. /// Can construct iterators over actual iterator over the channel data for a sample, yielded by
pub struct Channels<'outer, 'inner> { /// [Samples].
pub struct Channels<'a> {
/// The raw output buffers. /// The raw output buffers.
pub(self) buffers: &'outer mut [&'inner mut [f32]], pub(self) buffers: *mut [&'a mut [f32]],
pub(self) current_sample: usize,
}
/// The actual iterator over the channel data for a sample, yielded by [Channels].
pub struct ChannelsIter<'a> {
/// The raw output buffers.
pub(self) buffers: *mut [&'a mut [f32]],
pub(self) current_sample: usize, pub(self) current_sample: usize,
pub(self) current_channel: usize, pub(self) current_channel: usize,
} }
impl<'outer, 'inner> Iterator for Channels<'outer, 'inner> { impl<'a> IntoIterator for Channels<'a> {
type Item = &'inner mut f32; type Item = &'a mut f32;
type IntoIter = ChannelsIter<'a>;
fn into_iter(self) -> Self::IntoIter {
ChannelsIter {
buffers: self.buffers,
current_sample: self.current_sample,
current_channel: 0,
}
}
}
impl<'a> Iterator for ChannelsIter<'a> {
type Item = &'a mut f32;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if self.current_channel < self.buffers.len() { if self.current_channel < self.len() {
// SAFETY: These bounds have already been checked // SAFETY: These bounds have already been checked
let sample = unsafe { let sample = unsafe {
self.buffers (*self.buffers)
.get_unchecked_mut(self.current_channel) .get_unchecked_mut(self.current_channel)
.get_unchecked_mut(self.current_sample) .get_unchecked_mut(self.current_sample)
}; };
// SAFETY: It is not possible to have multiple mutable references to the same sample at // SAFETY: It is not possible to have multiple mutable references to the same sample at
// the same time // the same time
let sample: &'inner mut f32 = unsafe { &mut *(sample as *mut f32) }; let sample: &'a mut f32 = unsafe { &mut *(sample as *mut f32) };
self.current_channel += 1; self.current_channel += 1;
@ -117,19 +134,29 @@ impl<'outer, 'inner> Iterator for Channels<'outer, 'inner> {
} }
fn size_hint(&self) -> (usize, Option<usize>) { fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.buffers.len() - self.current_channel; let remaining = unsafe { (*self.buffers).len() } - self.current_channel;
(remaining, Some(remaining)) (remaining, Some(remaining))
} }
} }
impl<'outer, 'inner> ExactSizeIterator for Channels<'outer, 'inner> {} impl<'a> ExactSizeIterator for ChannelsIter<'a> {}
impl Channels<'_> {
/// Get the number of channels.
pub fn len(&self) -> usize {
unsafe { (*self.buffers).len() }
}
impl<'outer, 'inner> Channels<'outer, 'inner> {
/// A resetting iterator. This lets you iterate over the same channels multiple times. Otherwise /// A resetting iterator. This lets you iterate over the same channels multiple times. Otherwise
/// you don't need to use this function as [Channels] already implements [Iterator]. /// you don't need to use this function as [Channels] already implements [Iterator].
pub fn iter_mut(&mut self) -> &mut Self { pub fn iter_mut<'a: 'b, 'b>(&'a mut self) -> ChannelsIter<'b> {
self.current_channel = 0; // SAFETY: No two [ChannelIters] can exist at a time
self let buffers: *mut [&'b mut [f32]] = unsafe { std::mem::transmute(self.buffers) };
ChannelsIter {
buffers,
current_sample: self.current_sample,
current_channel: 0,
}
} }
/// Access a sample by index. Useful when you would otehrwise iterate over this 'Channels' /// Access a sample by index. Useful when you would otehrwise iterate over this 'Channels'
@ -139,7 +166,7 @@ impl<'outer, 'inner> Channels<'outer, 'inner> {
// SAFETY: The channel bound has already been checked // SAFETY: The channel bound has already been checked
unsafe { unsafe {
Some( Some(
self.buffers (*self.buffers)
.get_mut(channel_index)? .get_mut(channel_index)?
.get_unchecked_mut(self.current_sample), .get_unchecked_mut(self.current_sample),
) )
@ -153,7 +180,7 @@ impl<'outer, 'inner> Channels<'outer, 'inner> {
/// `channel_index` must be in the range `0..Self::len()`. /// `channel_index` must be in the range `0..Self::len()`.
#[inline] #[inline]
pub unsafe fn get_unchecked_mut(&mut self, channel_index: usize) -> &mut f32 { pub unsafe fn get_unchecked_mut(&mut self, channel_index: usize) -> &mut f32 {
self.buffers (*self.buffers)
.get_unchecked_mut(channel_index) .get_unchecked_mut(channel_index)
.get_unchecked_mut(self.current_sample) .get_unchecked_mut(self.current_sample)
} }