1
0
Fork 0

Supper non-buffer and non-sidechain inputs in STFT

This commit is contained in:
Robbert van der Helm 2022-03-06 22:26:37 +01:00
parent e61a42e96f
commit b06e67bde7
3 changed files with 233 additions and 29 deletions

View file

@ -138,10 +138,9 @@ impl Plugin for Stft {
self.stft.process_overlap_add(
buffer,
[],
&self.window_function,
OVERLAP_TIMES,
|_channel_idx, _, real_fft_scratch_buffer| {
|_channel_idx, real_fft_scratch_buffer| {
// Forward FFT, the helper has already applied window function
self.plan
.r2c_plan

View file

@ -240,6 +240,7 @@ impl ExactSizeIterator for BlockChannelsIter<'_, '_> {}
impl<'a> Buffer<'a> {
/// Return the numer of samples in this buffer.
#[inline]
pub fn len(&self) -> usize {
if self.output_slices.is_empty() {
0
@ -331,6 +332,7 @@ impl<'a> Buffer<'a> {
impl<'slice, 'sample> Channels<'slice, 'sample> {
/// Get the number of channels.
#[allow(clippy::len_without_is_empty)]
#[inline]
pub fn len(&self) -> usize {
unsafe { (*self.buffers).len() }
}
@ -461,10 +463,17 @@ impl<'slice, 'sample> Channels<'slice, 'sample> {
impl<'slice, 'sample> Block<'slice, 'sample> {
/// Get the number of samples (not channels) in the block.
#[allow(clippy::len_without_is_empty)]
#[inline]
pub fn len(&self) -> usize {
self.current_block_end - self.current_block_start
}
/// Return the numer of channels in this buffer.
#[inline]
pub fn channels(&self) -> usize {
unsafe { (*self.buffers).len() }
}
/// A resetting iterator. This lets you iterate over the same block multiple times. Otherwise
/// you don't need to use this function as [`Block`] already implements [`Iterator`]. You can
/// also use the direct accessor functions on this block instead.
@ -481,6 +490,32 @@ impl<'slice, 'sample> Block<'slice, 'sample> {
/// Access a channel by index. Useful when you would otherwise iterate over this [`Block`]
/// multiple times.
#[inline]
pub fn get(&self, channel_index: usize) -> Option<&[f32]> {
// SAFETY: The block bound has already been checked
unsafe {
Some(
(*self.buffers)
.get(channel_index)?
.get_unchecked(self.current_block_start..self.current_block_end),
)
}
}
/// The same as [`get()`][Self::get], but without any bounds checking.
///
/// # Safety
///
/// `channel_index` must be in the range `0..Self::len()`.
#[inline]
pub unsafe fn get_unchecked(&self, channel_index: usize) -> &[f32] {
(*self.buffers)
.get_unchecked(channel_index)
.get_unchecked(self.current_block_start..self.current_block_end)
}
/// Access a mutable channel by index. Useful when you would otherwise iterate over this
/// [`Block`] multiple times.
#[inline]
pub fn get_mut(&mut self, channel_index: usize) -> Option<&mut [f32]> {
// SAFETY: The block bound has already been checked
unsafe {

View file

@ -1,7 +1,25 @@
//! Utilities for buffering audio, likely used as part of a short-term Fourier transform.
use super::window::multiply_with_window;
use crate::buffer::Buffer;
use crate::buffer::{Block, Buffer};
/// Some buffer that can be used with the [`StftHelper`].
pub trait StftInput {
/// The number of samples in this input.
fn num_samples(&self) -> usize;
/// The number of channels in this input.
fn num_channels(&self) -> usize;
/// Index the buffer without any bounds checks.
unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32;
}
/// The same as [`StftInput`], but with support for writing results back to the buffer
pub trait StftInputMut: StftInput {
/// Get a mutable reference to a sample in the buffer without any bounds checks.
unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32;
}
/// Process the input buffer in equal sized blocks, running a callback on each block to transform
/// the block and then writing back the results from the previous block to the buffer. This
@ -32,6 +50,127 @@ pub struct StftHelper<const NUM_SIDECHAIN_INPUTS: usize = 0> {
current_pos: usize,
}
/// Marker struct for the version wtihout sidechaining.
struct NoSidechain;
impl StftInput for Buffer<'_> {
#[inline]
fn num_samples(&self) -> usize {
self.len()
}
#[inline]
fn num_channels(&self) -> usize {
self.channels()
}
#[inline]
unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
*self
.as_slice_immutable()
.get_unchecked(channel)
.get_unchecked(sample_idx)
}
}
impl StftInputMut for Buffer<'_> {
#[inline]
unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 {
self.as_slice()
.get_unchecked_mut(channel)
.get_unchecked_mut(sample_idx)
}
}
impl StftInput for Block<'_, '_> {
#[inline]
fn num_samples(&self) -> usize {
self.len()
}
#[inline]
fn num_channels(&self) -> usize {
self.channels()
}
#[inline]
unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
*self.get_unchecked(channel).get_unchecked(sample_idx)
}
}
impl StftInputMut for Block<'_, '_> {
#[inline]
unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 {
self.get_unchecked_mut(channel)
.get_unchecked_mut(sample_idx)
}
}
impl StftInput for [&[f32]] {
#[inline]
fn num_samples(&self) -> usize {
if self.is_empty() {
0
} else {
self[0].len()
}
}
#[inline]
fn num_channels(&self) -> usize {
self.len()
}
#[inline]
unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
*self.get_unchecked(channel).get_unchecked(sample_idx)
}
}
impl StftInput for [&mut [f32]] {
#[inline]
fn num_samples(&self) -> usize {
if self.is_empty() {
0
} else {
self[0].len()
}
}
#[inline]
fn num_channels(&self) -> usize {
self.len()
}
#[inline]
unsafe fn get_sample_unchecked(&self, channel: usize, sample_idx: usize) -> f32 {
*self.get_unchecked(channel).get_unchecked(sample_idx)
}
}
impl StftInputMut for [&mut [f32]] {
#[inline]
unsafe fn get_sample_unchecked_mut(&mut self, channel: usize, sample_idx: usize) -> &mut f32 {
self.get_unchecked_mut(channel)
.get_unchecked_mut(sample_idx)
}
}
impl StftInput for NoSidechain {
fn num_samples(&self) -> usize {
0
}
fn num_channels(&self) -> usize {
0
}
unsafe fn get_sample_unchecked(&self, _channel: usize, _sample_idx: usize) -> f32 {
0.0
}
}
impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
/// Initialize the [`StftHelper`] for [`Buffer`]s with the specified number of channels and the
/// given maximum block size. Call [`set_block_size()`][`Self::set_block_size()`] afterwards if
@ -91,13 +230,13 @@ impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
self.main_input_ring_buffers[0].len() as u32
}
/// Process the audio in `main_buffer` and in any sidechain buffers in small overlapping blocks
/// with a window function applied, adding up the results for the main buffer so they can be
/// written back to the host. The window overlap amount is compensated automatically when adding
/// up these samples. Whenever a new block is available, `process_cb()` gets called with a new
/// audio block of the specified size with the windowing function already applied. The summed
/// reults will then be written back to `main_buffer` exactly one block later, which means that
/// this function will introduce one block of latency. This can be compensated by calling
/// Process the audio in `main_buffer` in small overlapping blocks with a window function
/// applied, adding up the results for the main buffer so they can be written back to the host.
/// The window overlap amount is compensated automatically when adding up these samples.
/// Whenever a new block is available, `process_cb()` gets called with a new audio block of the
/// specified size with the windowing function already applied. The summed reults will then be
/// written back to `main_buffer` exactly one block later, which means that this function will
/// introduce one block of latency. This can be compensated by calling
/// [`ProcessContext::set_latency()`][`crate::prelude::ProcessContext::set_latency()`] in your
/// plugin's initialization function.
///
@ -106,9 +245,7 @@ impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
///
/// For efficiency's sake this function will reuse the same vector for all calls to
/// `process_cb`. This means you can only access a single channel's worth of windowed data at a
/// time. The arguments to that function are `process_cb(channel_idx, sidechain_buffer_idx,
/// real_fft_buffer)`, where `sidechain_buffer_idx` will be `None` for the main buffer. If there
/// are any sidechain buffers, then they will be processed before the main buffer.
/// time. The arguments to that function are `process_cb(channel_idx, real_fft_buffer)`.
/// `real_fft_buffer` will be a slice of `block_size` real valued samples. This can be passed
/// directly to an FFT algorithm.
///
@ -118,29 +255,61 @@ impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
/// channels as this [`StftHelper`], if the sidechain buffers do not contain the same number of
/// samples as the main buffer, or if the window function does not match the block size.
///
/// TODO: Maybe introduce a trait here so this can be used with things that aren't whole buffers
/// TODO: And also introduce that aforementioned read-only process function (`analyze()?`)
/// TODO: Add more useful ways to do STFT and other buffered operations. I just went with this
/// approach because it's what I needed myself, but generic combinators like this could
/// also be useful for other operations.
pub fn process_overlap_add<F>(
pub fn process_overlap_add<M, F>(
&mut self,
main_buffer: &mut Buffer,
sidechain_buffers: [&Buffer; NUM_SIDECHAIN_INPUTS],
main_buffer: &mut M,
window_function: &[f32],
overlap_times: usize,
mut process_cb: F,
) where
M: StftInputMut,
F: FnMut(usize, &mut [f32]),
{
self.process_overlap_add_sidechain(
main_buffer,
[&NoSidechain; NUM_SIDECHAIN_INPUTS],
window_function,
overlap_times,
|channel_idx, sidechain_idx, real_fft_scratch_buffer| {
if sidechain_idx.is_none() {
process_cb(channel_idx, real_fft_scratch_buffer);
}
},
);
}
/// The same as [`process_overlap_add()`][Self::process_overlap_add()], but with sidechain
/// inputs that can be analyzed before the main input gets processed.
///
/// The extra argument in the process function is `sidechain_buffer_idx`, which will be `None`
/// for the main buffer.
pub fn process_overlap_add_sidechain<M, S, F>(
&mut self,
main_buffer: &mut M,
sidechain_buffers: [&S; NUM_SIDECHAIN_INPUTS],
window_function: &[f32],
overlap_times: usize,
mut process_cb: F,
) where
M: StftInputMut,
S: StftInput,
F: FnMut(usize, Option<usize>, &mut [f32]),
{
assert_eq!(main_buffer.channels(), self.main_input_ring_buffers.len());
assert_eq!(
main_buffer.num_channels(),
self.main_input_ring_buffers.len()
);
assert_eq!(window_function.len(), self.main_input_ring_buffers[0].len());
assert!(overlap_times > 0);
// We'll copy samples from `*_buffer` into `*_ring_buffers` while simultaneously copying
// already processed samples from `main_ring_buffers` in into `main_buffer`
let main_buffer_len = main_buffer.len();
let num_channels = main_buffer.channels();
let main_buffer_len = main_buffer.num_samples();
let num_channels = main_buffer.num_channels();
let block_size = self.main_input_ring_buffers[0].len();
let window_interval = (block_size / overlap_times) as i32;
let mut already_processed_samples = 0;
@ -156,13 +325,14 @@ impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
// TODO: This might be able to be sped up a bit with SIMD
{
// For the main buffer
let main_buffer = main_buffer.as_slice();
for sample_offset in 0..samples_to_process {
for channel_idx in 0..num_channels {
// let main_buffer = main_buffer.as_slice();
let sample = unsafe {
main_buffer
.get_unchecked_mut(channel_idx)
.get_unchecked_mut(already_processed_samples + sample_offset)
main_buffer.get_sample_unchecked_mut(
channel_idx,
already_processed_samples + sample_offset,
)
};
let input_ring_buffer_sample = unsafe {
self.main_input_ring_buffers
@ -186,20 +356,20 @@ impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
.iter()
.zip(self.sidechain_ring_buffers.iter_mut())
{
let sidechain_buffer = sidechain_buffer.as_slice_immutable();
for sample_offset in 0..samples_to_process {
for channel_idx in 0..num_channels {
let sample = unsafe {
sidechain_buffer
.get_unchecked(channel_idx)
.get_unchecked(already_processed_samples + sample_offset)
sidechain_buffer.get_sample_unchecked(
channel_idx,
already_processed_samples + sample_offset,
)
};
let ring_buffer_sample = unsafe {
sidechain_ring_buffers
.get_unchecked_mut(channel_idx)
.get_unchecked_mut(self.current_pos + sample_offset)
};
*ring_buffer_sample = *sample;
*ring_buffer_sample = sample;
}
}
}