1
0
Fork 0

Use more typical convolution FFT in STFT example

This commit is contained in:
Robbert van der Helm 2022-05-08 02:21:48 +02:00
parent 55eeb689dd
commit 3fe24e7dc6
3 changed files with 28 additions and 29 deletions

View file

@ -86,7 +86,6 @@ impl SpectrumInput {
|channel_idx, real_fft_scratch_buffer| {
multiply_with_window(real_fft_scratch_buffer, &self.compensated_window_function);
// Forward FFT, the helper has already applied window function
self.plan
.process_with_scratch(
real_fft_scratch_buffer,

View file

@ -4,19 +4,25 @@ use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
use std::f32;
use std::sync::Arc;
const WINDOW_SIZE: usize = 2048;
const OVERLAP_TIMES: usize = 4;
/// The size of the windows we'll process at a time.
const WINDOW_SIZE: usize = 64;
/// The length of the filter's impulse response.
const FILTER_WINDOW_SIZE: usize = 33;
/// The length of the FFT window we will use to perform FFT convolution. This includes padding to
/// prevent time domain aliasing as a result of cyclic convolution.
const FFT_WINDOW_SIZE: usize = WINDOW_SIZE + FILTER_WINDOW_SIZE - 1;
/// The gain compensation we need to apply for the STFT process.
const GAIN_COMPENSATION: f32 = 1.0 / WINDOW_SIZE as f32;
struct Stft {
params: Arc<StftParams>,
/// An adapter that performs most of the overlap-add algorithm for us.
stft: util::StftHelper,
/// A Hann window function, applied after the IDFT operation to minimize time domain aliasing.
window_function: Vec<f32>,
/// The FFT of a simple low-pass FIR filter.
lp_filter_kernel: Vec<Complex32>,
lp_filter_spectrum: Vec<Complex32>,
/// The algorithm for the FFT operation.
r2c_plan: Arc<dyn RealToComplex<f32>>,
@ -32,21 +38,19 @@ struct StftParams {}
impl Default for Stft {
fn default() -> Self {
let mut planner = RealFftPlanner::new();
let r2c_plan = planner.plan_fft_forward(WINDOW_SIZE);
let c2r_plan = planner.plan_fft_inverse(WINDOW_SIZE);
let r2c_plan = planner.plan_fft_forward(FFT_WINDOW_SIZE);
let c2r_plan = planner.plan_fft_inverse(FFT_WINDOW_SIZE);
let mut real_fft_buffer = r2c_plan.make_input_vec();
let mut complex_fft_buffer = r2c_plan.make_output_vec();
// Build a super simple low-pass filter from one of the built in window function
const FILTER_WINDOW_SIZE: usize = 33;
let filter_window = util::window::hann(FILTER_WINDOW_SIZE);
real_fft_buffer[0..FILTER_WINDOW_SIZE].copy_from_slice(&filter_window);
// Build a super simple low-pass filter from one of the built in window functions
let mut filter_window = util::window::hann(FILTER_WINDOW_SIZE);
// And make sure to normalize this so convolution sums to 1
let filter_normalization_factor = real_fft_buffer.iter().sum::<f32>().recip();
for sample in &mut real_fft_buffer {
let filter_normalization_factor = filter_window.iter().sum::<f32>().recip();
for sample in &mut filter_window {
*sample *= filter_normalization_factor;
}
real_fft_buffer[0..FILTER_WINDOW_SIZE].copy_from_slice(&filter_window);
// RustFFT doesn't actually need a scratch buffer here, so we'll pass an empty buffer
// instead
@ -57,10 +61,12 @@ impl Default for Stft {
Self {
params: Arc::new(StftParams::default()),
stft: util::StftHelper::new(2, WINDOW_SIZE, 0),
window_function: util::window::hann(WINDOW_SIZE),
// We'll process the input in `WINDOW_SIZE` chunks, but our FFT window is slightly
// larger to account for time domain aliasing so we'll need to add some padding ot each
// block.
stft: util::StftHelper::new(2, WINDOW_SIZE, FFT_WINDOW_SIZE - WINDOW_SIZE),
lp_filter_kernel: complex_fft_buffer.clone(),
lp_filter_spectrum: complex_fft_buffer.clone(),
r2c_plan,
c2r_plan,
@ -122,13 +128,11 @@ impl Plugin for Stft {
buffer: &mut Buffer,
_context: &mut impl ProcessContext,
) -> ProcessStatus {
// Compensate for the window function, the overlap, and the extra gain introduced by the
// IDFT operation
const GAIN_COMPENSATION: f32 = 1.0 / (OVERLAP_TIMES as f32 / 2.0) / WINDOW_SIZE as f32;
self.stft
.process_overlap_add(buffer, OVERLAP_TIMES, |_channel_idx, real_fft_buffer| {
// Forward FFT, the helper has already applied window function
.process_overlap_add(buffer, 1, |_channel_idx, real_fft_buffer| {
// Forward FFT, `real_fft_buffer` already is already padded with zeroes, and the
// padding from the last iteration will have already been added back to the start of
// the buffer
self.r2c_plan
.process_with_scratch(real_fft_buffer, &mut self.complex_fft_buffer, &mut [])
.unwrap();
@ -138,7 +142,7 @@ impl Plugin for Stft {
for (fft_bin, kernel_bin) in self
.complex_fft_buffer
.iter_mut()
.zip(&self.lp_filter_kernel)
.zip(&self.lp_filter_spectrum)
{
*fft_bin *= *kernel_bin * GAIN_COMPENSATION;
}
@ -148,9 +152,6 @@ impl Plugin for Stft {
self.c2r_plan
.process_with_scratch(&mut self.complex_fft_buffer, real_fft_buffer, &mut [])
.unwrap();
// Apply the window function. We can do this either before the DFT or after the IDFT
util::window::multiply_with_window(real_fft_buffer, &self.window_function);
});
ProcessStatus::Normal

View file

@ -236,7 +236,6 @@ impl Plugin for PubertySimulator {
// Negated because pitching down should cause us to take values from higher frequency bins
let frequency_multiplier = 2.0f32.powf(-smoothed_pitch_value);
// Forward FFT, the helper has already applied window function
// RustFFT doesn't actually need a scratch buffer here, so we'll pass an empty
// buffer instead
fft_plan