Add proper overlap-add to the StftHelper
Doesn't make much sense without it.
This commit is contained in:
parent
963696cbff
commit
3c62670164
|
@ -1,6 +1,8 @@
|
||||||
use nih_plug::prelude::*;
|
use nih_plug::prelude::*;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
const WINDOW_SIZE: usize = 2048;
|
||||||
|
|
||||||
struct Stft {
|
struct Stft {
|
||||||
params: Pin<Box<StftParams>>,
|
params: Pin<Box<StftParams>>,
|
||||||
|
|
||||||
|
@ -15,7 +17,7 @@ impl Default for Stft {
|
||||||
Self {
|
Self {
|
||||||
params: Box::pin(StftParams::default()),
|
params: Box::pin(StftParams::default()),
|
||||||
|
|
||||||
stft: util::StftHelper::new(2, 512),
|
stft: util::StftHelper::new(2, WINDOW_SIZE),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -56,7 +58,7 @@ impl Plugin for Stft {
|
||||||
) -> bool {
|
) -> bool {
|
||||||
// Normally we'd also initialize the STFT helper for the correct channel count here, but we
|
// Normally we'd also initialize the STFT helper for the correct channel count here, but we
|
||||||
// only do stereo so that's not necessary
|
// only do stereo so that's not necessary
|
||||||
self.stft.set_block_size(512);
|
self.stft.set_block_size(WINDOW_SIZE);
|
||||||
context.set_latency_samples(self.stft.latency_samples());
|
context.set_latency_samples(self.stft.latency_samples());
|
||||||
|
|
||||||
true
|
true
|
||||||
|
|
232
src/util/stft.rs
232
src/util/stft.rs
|
@ -1,7 +1,5 @@
|
||||||
//! Utilities for buffering audio, likely used as part of a short-term Fourier transform.
|
//! Utilities for buffering audio, likely used as part of a short-term Fourier transform.
|
||||||
|
|
||||||
use std::mem;
|
|
||||||
|
|
||||||
use crate::buffer::Buffer;
|
use crate::buffer::Buffer;
|
||||||
|
|
||||||
/// Process the input buffer in equal sized blocks, running a callback on each block to transform
|
/// Process the input buffer in equal sized blocks, running a callback on each block to transform
|
||||||
|
@ -16,16 +14,17 @@ use crate::buffer::Buffer;
|
||||||
/// TODO: We may need something like this purely for analysis, e.g. for showing spectrums in a GUI.
|
/// TODO: We may need something like this purely for analysis, e.g. for showing spectrums in a GUI.
|
||||||
/// Figure out the cleanest way to adapt this for the non-processing use case.
|
/// Figure out the cleanest way to adapt this for the non-processing use case.
|
||||||
pub struct StftHelper<const NUM_SIDECHAIN_INPUTS: usize = 0> {
|
pub struct StftHelper<const NUM_SIDECHAIN_INPUTS: usize = 0> {
|
||||||
// These ring buffers store both the input samples and the already processed output. Whenever we
|
// These ring buffers store the input samples and the already processed output produced by
|
||||||
// wrap around,we'll write the already calculated outputs to the main buffer passed to the
|
// adding overlapping windows. Whenever we reach a new overlapping window, we'll write the
|
||||||
// process function and process a new block.
|
// already calculated outputs to the main buffer passed to the process function and then process
|
||||||
main_ring_buffers: Vec<Vec<f32>>,
|
// a new block.
|
||||||
|
main_input_ring_buffers: Vec<Vec<f32>>,
|
||||||
|
main_output_ring_buffers: Vec<Vec<f32>>,
|
||||||
sidechain_ring_buffers: [Vec<Vec<f32>>; NUM_SIDECHAIN_INPUTS],
|
sidechain_ring_buffers: [Vec<Vec<f32>>; NUM_SIDECHAIN_INPUTS],
|
||||||
|
|
||||||
// To make this more convenient, we'll provide slices into the above buffers to the block
|
/// Results from the ring buffers are copied to this scratch buffer before being passed to the
|
||||||
// process callback
|
/// plugin. Needed to handle overlap.
|
||||||
main_block_buffer: Buffer<'static>,
|
scratch_buffer: Vec<f32>,
|
||||||
sidechain_block_buffers: [Buffer<'static>; NUM_SIDECHAIN_INPUTS],
|
|
||||||
|
|
||||||
/// The current position in our ring buffers. Whenever this wraps around to 0, we'll process
|
/// The current position in our ring buffers. Whenever this wraps around to 0, we'll process
|
||||||
/// a block.
|
/// a block.
|
||||||
|
@ -36,36 +35,25 @@ impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
|
||||||
/// Initialize the [`StftHelper`] for [`Buffer`]s with the specified number of channels and the
|
/// Initialize the [`StftHelper`] for [`Buffer`]s with the specified number of channels and the
|
||||||
/// given maximum block size. Call [`set_block_size()`][`Self::set_block_size()`] afterwards if
|
/// given maximum block size. Call [`set_block_size()`][`Self::set_block_size()`] afterwards if
|
||||||
/// you do not need the full capacity upfront.
|
/// you do not need the full capacity upfront.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if `num_channels == 0 || max_block_size == 0`.
|
||||||
pub fn new(num_channels: usize, max_block_size: usize) -> Self {
|
pub fn new(num_channels: usize, max_block_size: usize) -> Self {
|
||||||
nih_debug_assert_ne!(num_channels, 0);
|
assert_ne!(num_channels, 0);
|
||||||
nih_debug_assert_ne!(max_block_size, 0);
|
assert_ne!(max_block_size, 0);
|
||||||
|
|
||||||
let mut helper = Self {
|
Self {
|
||||||
main_ring_buffers: vec![vec![0.0; max_block_size]; num_channels],
|
main_input_ring_buffers: vec![vec![0.0; max_block_size]; num_channels],
|
||||||
|
main_output_ring_buffers: vec![vec![0.0; max_block_size]; num_channels],
|
||||||
// Kinda hacky way to initialize an array of non-copy types
|
// Kinda hacky way to initialize an array of non-copy types
|
||||||
sidechain_ring_buffers: [(); NUM_SIDECHAIN_INPUTS]
|
sidechain_ring_buffers: [(); NUM_SIDECHAIN_INPUTS]
|
||||||
.map(|_| vec![vec![0.0; max_block_size]; num_channels]),
|
.map(|_| vec![vec![0.0; max_block_size]; num_channels]),
|
||||||
|
|
||||||
main_block_buffer: Buffer::default(),
|
scratch_buffer: vec![0.0; max_block_size],
|
||||||
sidechain_block_buffers: [(); NUM_SIDECHAIN_INPUTS].map(|_| Buffer::default()),
|
|
||||||
|
|
||||||
current_pos: 0,
|
current_pos: 0,
|
||||||
};
|
|
||||||
|
|
||||||
// Preallocate the output slices. We'll point them to the ring buffers at the start of the
|
|
||||||
// process call.
|
|
||||||
unsafe {
|
|
||||||
helper.main_block_buffer.with_raw_vec(|main_block_slices| {
|
|
||||||
main_block_slices.resize_with(num_channels, || &mut [])
|
|
||||||
});
|
|
||||||
for sidechain_block_buffer in &mut helper.sidechain_block_buffers {
|
|
||||||
sidechain_block_buffer.with_raw_vec(|main_block_slices| {
|
|
||||||
main_block_slices.resize_with(num_channels, || &mut [])
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
helper
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Change the current block size. This will clear the buffers, causing the next block to output
|
/// Change the current block size. This will clear the buffers, causing the next block to output
|
||||||
|
@ -75,12 +63,18 @@ impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
|
||||||
///
|
///
|
||||||
/// WIll panic if `block_size > max_block_size`.
|
/// WIll panic if `block_size > max_block_size`.
|
||||||
pub fn set_block_size(&mut self, block_size: usize) {
|
pub fn set_block_size(&mut self, block_size: usize) {
|
||||||
assert!(block_size <= self.main_ring_buffers[0].capacity());
|
assert!(block_size <= self.main_input_ring_buffers[0].capacity());
|
||||||
|
|
||||||
for main_ring_buffer in &mut self.main_ring_buffers {
|
for main_ring_buffer in &mut self.main_input_ring_buffers {
|
||||||
main_ring_buffer.resize(block_size, 0.0);
|
main_ring_buffer.resize(block_size, 0.0);
|
||||||
main_ring_buffer.fill(0.0);
|
main_ring_buffer.fill(0.0);
|
||||||
}
|
}
|
||||||
|
for main_ring_buffer in &mut self.main_output_ring_buffers {
|
||||||
|
main_ring_buffer.resize(block_size, 0.0);
|
||||||
|
main_ring_buffer.fill(0.0);
|
||||||
|
}
|
||||||
|
self.scratch_buffer.resize(block_size, 0.0);
|
||||||
|
self.scratch_buffer.fill(0.0);
|
||||||
for sidechain_ring_buffers in &mut self.sidechain_ring_buffers {
|
for sidechain_ring_buffers in &mut self.sidechain_ring_buffers {
|
||||||
for sidechain_ring_buffer in sidechain_ring_buffers {
|
for sidechain_ring_buffer in sidechain_ring_buffers {
|
||||||
sidechain_ring_buffer.resize(block_size, 0.0);
|
sidechain_ring_buffer.resize(block_size, 0.0);
|
||||||
|
@ -93,74 +87,56 @@ impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
|
||||||
|
|
||||||
/// The amount of latency introduced when processing audio throug hthis [`StftHelper`].
|
/// The amount of latency introduced when processing audio throug hthis [`StftHelper`].
|
||||||
pub fn latency_samples(&self) -> u32 {
|
pub fn latency_samples(&self) -> u32 {
|
||||||
self.main_ring_buffers[0].len() as u32
|
self.main_input_ring_buffers[0].len() as u32
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process the audio in `main_buffer` and in any sidechain buffers in small blocks. Whenever a
|
/// Process the audio in `main_buffer` and in any sidechain buffers in small overlapping blocks
|
||||||
/// new block is available, `process_cb()` gets called with a new audio block of the specified
|
/// with a window function applied, adding up the results for the main buffer so they can be
|
||||||
/// side. The results written to the buffer will then be written back to `main_buffer` exactly
|
/// written back to the host. Whenever a new block is available, `process_cb()` gets called with
|
||||||
/// one block later, which means that this function will introduce one block of latency. This
|
/// a new audio block of the specified size with the windowing function already applied. The
|
||||||
/// can be compensated by calling
|
/// summed reults will then be written back to `main_buffer` exactly one block later, which
|
||||||
/// [`ProcessContext::set_latency()`][`crate::prelude::ProcessContext::set_latency()`] in your
|
/// means that this function will introduce one block of latency. This can be compensated by
|
||||||
/// plugin's initialization function.
|
/// calling [`ProcessContext::set_latency()`][`crate::prelude::ProcessContext::set_latency()`]
|
||||||
|
/// in your plugin's initialization function.
|
||||||
|
///
|
||||||
|
/// For efficiency's sake this function will reuse the same vector for all calls to
|
||||||
|
/// `process_cb`. This means you can only access a single channel's worth of windowed data at a
|
||||||
|
/// time. The arguments to that function are `process_cb(channel_idx, sidechain_buffer_idx,
|
||||||
|
/// data)`, where `sidechain_buffer_idx` will be `None` for the main buffer. If there are any
|
||||||
|
/// sidechain buffers, then they will be processed before the main buffer.
|
||||||
///
|
///
|
||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// Panics if `main_buffer` or the buffers in `sidechain_buffers` do not have the same number of
|
/// Panics if `main_buffer` or the buffers in `sidechain_buffers` do not have the same number of
|
||||||
/// channels as this [`StftHelper`].
|
/// channels as this [`StftHelper`], if the sidechain buffers do not contain the same number of
|
||||||
|
/// samples as the main buffer, or if the window function does not match the block size.
|
||||||
///
|
///
|
||||||
/// TODO: Maybe introduce a trait here so this can be used with things that aren't whole buffers
|
/// TODO: Maybe introduce a trait here so this can be used with things that aren't whole buffers
|
||||||
/// TODO: And also introduce that aforementioned read-only process function (`analyze()?`)
|
/// TODO: And also introduce that aforementioned read-only process function (`analyze()?`)
|
||||||
pub fn process<F>(
|
pub fn process_overlap_add<F>(
|
||||||
&mut self,
|
&mut self,
|
||||||
main_buffer: &mut Buffer,
|
main_buffer: &mut Buffer,
|
||||||
sidechain_buffers: [&Buffer; NUM_SIDECHAIN_INPUTS],
|
sidechain_buffers: [&Buffer; NUM_SIDECHAIN_INPUTS],
|
||||||
|
window_function: &[f32],
|
||||||
|
overlap_times: usize,
|
||||||
mut process_cb: F,
|
mut process_cb: F,
|
||||||
) where
|
) where
|
||||||
F: FnMut(&mut Buffer, &[Buffer; NUM_SIDECHAIN_INPUTS]),
|
F: FnMut(usize, Option<usize>, &mut [f32]),
|
||||||
{
|
{
|
||||||
assert_eq!(main_buffer.channels(), self.main_ring_buffers.len());
|
assert_eq!(main_buffer.channels(), self.main_input_ring_buffers.len());
|
||||||
|
assert_eq!(window_function.len(), self.main_input_ring_buffers[0].len());
|
||||||
// Since the `StftHelper` object may move in between process calls, we need to make sure
|
|
||||||
// that these slices point to our ring buffers at the start of each call
|
|
||||||
unsafe {
|
|
||||||
self.main_block_buffer.with_raw_vec(|main_block_slices| {
|
|
||||||
assert_eq!(main_block_slices.len(), self.main_ring_buffers.len());
|
|
||||||
for (channel_idx, channel_slice) in main_block_slices.iter_mut().enumerate() {
|
|
||||||
// SAFETY: This is equivalent to splitting on each channel, and these block
|
|
||||||
// slices will only be used here as part of the callback when the ring
|
|
||||||
// buffers are not mutably borrwed
|
|
||||||
*channel_slice =
|
|
||||||
&mut *(self.main_ring_buffers[channel_idx].as_mut_slice() as *mut _);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
for (sidechain_block_buffer, sidechain_ring_buffer) in self
|
|
||||||
.sidechain_block_buffers
|
|
||||||
.iter_mut()
|
|
||||||
.zip(self.sidechain_ring_buffers.iter_mut())
|
|
||||||
{
|
|
||||||
sidechain_block_buffer.with_raw_vec(|sidechain_block_slices| {
|
|
||||||
assert_eq!(sidechain_block_slices.len(), sidechain_ring_buffer.len());
|
|
||||||
for (channel_idx, channel_slice) in
|
|
||||||
sidechain_block_slices.iter_mut().enumerate()
|
|
||||||
{
|
|
||||||
*channel_slice =
|
|
||||||
&mut *(sidechain_ring_buffer[channel_idx].as_mut_slice() as *mut _);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// We'll copy samples from `*_buffer` into `*_ring_buffers` while simultaneously copying
|
// We'll copy samples from `*_buffer` into `*_ring_buffers` while simultaneously copying
|
||||||
// already processed samples from `main_ring_buffers` in into `main_buffer`
|
// already processed samples from `main_ring_buffers` in into `main_buffer`
|
||||||
let main_buffer_len = main_buffer.len();
|
let main_buffer_len = main_buffer.len();
|
||||||
let num_channels = main_buffer.channels();
|
let num_channels = main_buffer.channels();
|
||||||
let block_len = self.main_ring_buffers[0].len();
|
let block_size = self.main_input_ring_buffers[0].len();
|
||||||
|
let window_interval = block_size / overlap_times;
|
||||||
let mut already_processed_samples = 0;
|
let mut already_processed_samples = 0;
|
||||||
while already_processed_samples < main_buffer_len {
|
while already_processed_samples < main_buffer_len {
|
||||||
let remaining_samples = main_buffer_len - already_processed_samples;
|
let remaining_samples = main_buffer_len - already_processed_samples;
|
||||||
let samples_until_next_block = block_len - self.current_pos;
|
let samples_until_next_window = (window_interval - self.current_pos) % window_interval;
|
||||||
let samples_to_process = samples_until_next_block.min(remaining_samples);
|
let samples_to_process = samples_until_next_window.min(remaining_samples);
|
||||||
|
|
||||||
// Copy the input from `main_buffer` to the ring buffer while copying last block's
|
// Copy the input from `main_buffer` to the ring buffer while copying last block's
|
||||||
// result from the buffer to `main_buffer`
|
// result from the buffer to `main_buffer`
|
||||||
|
@ -175,12 +151,20 @@ impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
|
||||||
.get_unchecked_mut(channel_idx)
|
.get_unchecked_mut(channel_idx)
|
||||||
.get_unchecked_mut(already_processed_samples + sample_offset)
|
.get_unchecked_mut(already_processed_samples + sample_offset)
|
||||||
};
|
};
|
||||||
let ring_buffer_sample = unsafe {
|
let input_ring_buffer_sample = unsafe {
|
||||||
self.main_ring_buffers
|
self.main_input_ring_buffers
|
||||||
.get_unchecked_mut(channel_idx)
|
.get_unchecked_mut(channel_idx)
|
||||||
.get_unchecked_mut(self.current_pos + sample_offset)
|
.get_unchecked_mut(self.current_pos + sample_offset)
|
||||||
};
|
};
|
||||||
mem::swap(sample, ring_buffer_sample);
|
let output_ring_buffer_sample = unsafe {
|
||||||
|
self.main_output_ring_buffers
|
||||||
|
.get_unchecked_mut(channel_idx)
|
||||||
|
.get_unchecked_mut(self.current_pos + sample_offset)
|
||||||
|
};
|
||||||
|
*input_ring_buffer_sample = *sample;
|
||||||
|
*sample = *output_ring_buffer_sample;
|
||||||
|
// Very important, or else we'll overlap-add ourselves into a feedback hell
|
||||||
|
*output_ring_buffer_sample = 0.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,16 +192,88 @@ impl<const NUM_SIDECHAIN_INPUTS: usize> StftHelper<NUM_SIDECHAIN_INPUTS> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
already_processed_samples += samples_to_process;
|
|
||||||
self.current_pos += samples_to_process;
|
|
||||||
|
|
||||||
// At this point we either have `already_processed_samples == main_buffer_len`, or
|
// At this point we either have `already_processed_samples == main_buffer_len`, or
|
||||||
// `self.current_pos == block_len`. If it's the latter, then we can process a new block.
|
// `self.current_pos % window_interval == 0`. If it's the latter, then we can process a
|
||||||
if self.current_pos == block_len {
|
// new block.
|
||||||
process_cb(&mut self.main_block_buffer, &self.sidechain_block_buffers);
|
if samples_to_process == samples_until_next_window {
|
||||||
|
// Because we're processing in smaller windows, the input ring buffers sadly does
|
||||||
|
// not always contain the full contiguous range we're interested in because they map
|
||||||
|
// wrap around. Because premade FFT algorithms typically can't handle this, we'll
|
||||||
|
// start with copying
|
||||||
|
|
||||||
self.current_pos = 0;
|
// TODO: Sdiechain
|
||||||
|
|
||||||
|
for (channel_idx, (input_ring_buffer, output_ring_buffer)) in self
|
||||||
|
.main_input_ring_buffers
|
||||||
|
.iter()
|
||||||
|
.zip(self.main_output_ring_buffers.iter_mut())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
copy_ring_to_scratch_buffer(
|
||||||
|
&mut self.scratch_buffer,
|
||||||
|
self.current_pos,
|
||||||
|
input_ring_buffer,
|
||||||
|
);
|
||||||
|
multiply_scratch_buffer(&mut self.scratch_buffer, window_function);
|
||||||
|
process_cb(channel_idx, None, &mut self.scratch_buffer);
|
||||||
|
|
||||||
|
// The actual overlap-add part of the equation
|
||||||
|
add_scratch_to_ring_buffer(
|
||||||
|
&self.scratch_buffer,
|
||||||
|
self.current_pos,
|
||||||
|
output_ring_buffer,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do this after handling the block or else we'll copy the wrong samples.
|
||||||
|
already_processed_samples += samples_to_process;
|
||||||
|
self.current_pos = (self.current_pos + samples_to_process) % block_size;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Copy data from the the specified ring buffer (borrowed from `self`) to the scratch buffers at
|
||||||
|
/// the current position. This is a free function because you cannot pass an immutable reference to
|
||||||
|
/// a field from `&self` to a `&mut self` method.
|
||||||
|
#[inline]
|
||||||
|
fn copy_ring_to_scratch_buffer(
|
||||||
|
scratch_buffer: &mut [f32],
|
||||||
|
current_pos: usize,
|
||||||
|
ring_buffer: &[f32],
|
||||||
|
) {
|
||||||
|
let block_size = scratch_buffer.len();
|
||||||
|
let num_copy_before_wrap = block_size - current_pos;
|
||||||
|
scratch_buffer[0..num_copy_before_wrap].copy_from_slice(&ring_buffer[current_pos..block_size]);
|
||||||
|
scratch_buffer[num_copy_before_wrap..block_size].copy_from_slice(&ring_buffer[0..current_pos]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Multiply the scratch buffer by some window function. Also free function because you can't do
|
||||||
|
/// split borrows with methods.
|
||||||
|
#[inline]
|
||||||
|
fn multiply_scratch_buffer(scratch_buffer: &mut [f32], window_function: &[f32]) {
|
||||||
|
for (sample, window_sample) in scratch_buffer.iter_mut().zip(window_function) {
|
||||||
|
*sample *= window_sample;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add data from the scratch buffer to the specified ring buffer. When writing samples from this
|
||||||
|
/// ring buffer back to the host's outputs they must be cleared to prevent infinite feedback.
|
||||||
|
#[inline]
|
||||||
|
fn add_scratch_to_ring_buffer(scratch_buffer: &[f32], current_pos: usize, ring_buffer: &mut [f32]) {
|
||||||
|
// TODO: This could also use some SIMD
|
||||||
|
let block_size = scratch_buffer.len();
|
||||||
|
let num_copy_before_wrap = block_size - current_pos;
|
||||||
|
for (scratch_sample, ring_sample) in scratch_buffer[0..num_copy_before_wrap]
|
||||||
|
.iter()
|
||||||
|
.zip(&mut ring_buffer[current_pos..block_size])
|
||||||
|
{
|
||||||
|
*ring_sample += *scratch_sample;
|
||||||
|
}
|
||||||
|
for (scratch_sample, ring_sample) in scratch_buffer[num_copy_before_wrap..block_size]
|
||||||
|
.iter()
|
||||||
|
.zip(&mut ring_buffer[0..current_pos])
|
||||||
|
{
|
||||||
|
*ring_sample += *scratch_sample;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue