1
0
Fork 0

Use more typical convolution FFT in STFT example

This commit is contained in:
Robbert van der Helm 2022-05-08 02:21:48 +02:00
parent 55eeb689dd
commit 3fe24e7dc6
3 changed files with 28 additions and 29 deletions

View file

@ -86,7 +86,6 @@ impl SpectrumInput {
|channel_idx, real_fft_scratch_buffer| { |channel_idx, real_fft_scratch_buffer| {
multiply_with_window(real_fft_scratch_buffer, &self.compensated_window_function); multiply_with_window(real_fft_scratch_buffer, &self.compensated_window_function);
// Forward FFT, the helper has already applied window function
self.plan self.plan
.process_with_scratch( .process_with_scratch(
real_fft_scratch_buffer, real_fft_scratch_buffer,

View file

@ -4,19 +4,25 @@ use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
use std::f32; use std::f32;
use std::sync::Arc; use std::sync::Arc;
const WINDOW_SIZE: usize = 2048; /// The size of the windows we'll process at a time.
const OVERLAP_TIMES: usize = 4; const WINDOW_SIZE: usize = 64;
/// The length of the filter's impulse response.
const FILTER_WINDOW_SIZE: usize = 33;
/// The length of the FFT window we will use to perform FFT convolution. This includes padding to
/// prevent time domain aliasing as a result of cyclic convolution.
const FFT_WINDOW_SIZE: usize = WINDOW_SIZE + FILTER_WINDOW_SIZE - 1;
/// The gain compensation we need to apply for the STFT process.
const GAIN_COMPENSATION: f32 = 1.0 / WINDOW_SIZE as f32;
struct Stft { struct Stft {
params: Arc<StftParams>, params: Arc<StftParams>,
/// An adapter that performs most of the overlap-add algorithm for us. /// An adapter that performs most of the overlap-add algorithm for us.
stft: util::StftHelper, stft: util::StftHelper,
/// A Hann window function, applied after the IDFT operation to minimize time domain aliasing.
window_function: Vec<f32>,
/// The FFT of a simple low-pass FIR filter. /// The FFT of a simple low-pass FIR filter.
lp_filter_kernel: Vec<Complex32>, lp_filter_spectrum: Vec<Complex32>,
/// The algorithm for the FFT operation. /// The algorithm for the FFT operation.
r2c_plan: Arc<dyn RealToComplex<f32>>, r2c_plan: Arc<dyn RealToComplex<f32>>,
@ -32,21 +38,19 @@ struct StftParams {}
impl Default for Stft { impl Default for Stft {
fn default() -> Self { fn default() -> Self {
let mut planner = RealFftPlanner::new(); let mut planner = RealFftPlanner::new();
let r2c_plan = planner.plan_fft_forward(WINDOW_SIZE); let r2c_plan = planner.plan_fft_forward(FFT_WINDOW_SIZE);
let c2r_plan = planner.plan_fft_inverse(WINDOW_SIZE); let c2r_plan = planner.plan_fft_inverse(FFT_WINDOW_SIZE);
let mut real_fft_buffer = r2c_plan.make_input_vec(); let mut real_fft_buffer = r2c_plan.make_input_vec();
let mut complex_fft_buffer = r2c_plan.make_output_vec(); let mut complex_fft_buffer = r2c_plan.make_output_vec();
// Build a super simple low-pass filter from one of the built in window function // Build a super simple low-pass filter from one of the built in window functions
const FILTER_WINDOW_SIZE: usize = 33; let mut filter_window = util::window::hann(FILTER_WINDOW_SIZE);
let filter_window = util::window::hann(FILTER_WINDOW_SIZE);
real_fft_buffer[0..FILTER_WINDOW_SIZE].copy_from_slice(&filter_window);
// And make sure to normalize this so convolution sums to 1 // And make sure to normalize this so convolution sums to 1
let filter_normalization_factor = real_fft_buffer.iter().sum::<f32>().recip(); let filter_normalization_factor = filter_window.iter().sum::<f32>().recip();
for sample in &mut real_fft_buffer { for sample in &mut filter_window {
*sample *= filter_normalization_factor; *sample *= filter_normalization_factor;
} }
real_fft_buffer[0..FILTER_WINDOW_SIZE].copy_from_slice(&filter_window);
// RustFFT doesn't actually need a scratch buffer here, so we'll pass an empty buffer // RustFFT doesn't actually need a scratch buffer here, so we'll pass an empty buffer
// instead // instead
@ -57,10 +61,12 @@ impl Default for Stft {
Self { Self {
params: Arc::new(StftParams::default()), params: Arc::new(StftParams::default()),
stft: util::StftHelper::new(2, WINDOW_SIZE, 0), // We'll process the input in `WINDOW_SIZE` chunks, but our FFT window is slightly
window_function: util::window::hann(WINDOW_SIZE), // larger to account for time domain aliasing so we'll need to add some padding ot each
// block.
stft: util::StftHelper::new(2, WINDOW_SIZE, FFT_WINDOW_SIZE - WINDOW_SIZE),
lp_filter_kernel: complex_fft_buffer.clone(), lp_filter_spectrum: complex_fft_buffer.clone(),
r2c_plan, r2c_plan,
c2r_plan, c2r_plan,
@ -122,13 +128,11 @@ impl Plugin for Stft {
buffer: &mut Buffer, buffer: &mut Buffer,
_context: &mut impl ProcessContext, _context: &mut impl ProcessContext,
) -> ProcessStatus { ) -> ProcessStatus {
// Compensate for the window function, the overlap, and the extra gain introduced by the
// IDFT operation
const GAIN_COMPENSATION: f32 = 1.0 / (OVERLAP_TIMES as f32 / 2.0) / WINDOW_SIZE as f32;
self.stft self.stft
.process_overlap_add(buffer, OVERLAP_TIMES, |_channel_idx, real_fft_buffer| { .process_overlap_add(buffer, 1, |_channel_idx, real_fft_buffer| {
// Forward FFT, the helper has already applied window function // Forward FFT, `real_fft_buffer` already is already padded with zeroes, and the
// padding from the last iteration will have already been added back to the start of
// the buffer
self.r2c_plan self.r2c_plan
.process_with_scratch(real_fft_buffer, &mut self.complex_fft_buffer, &mut []) .process_with_scratch(real_fft_buffer, &mut self.complex_fft_buffer, &mut [])
.unwrap(); .unwrap();
@ -138,7 +142,7 @@ impl Plugin for Stft {
for (fft_bin, kernel_bin) in self for (fft_bin, kernel_bin) in self
.complex_fft_buffer .complex_fft_buffer
.iter_mut() .iter_mut()
.zip(&self.lp_filter_kernel) .zip(&self.lp_filter_spectrum)
{ {
*fft_bin *= *kernel_bin * GAIN_COMPENSATION; *fft_bin *= *kernel_bin * GAIN_COMPENSATION;
} }
@ -148,9 +152,6 @@ impl Plugin for Stft {
self.c2r_plan self.c2r_plan
.process_with_scratch(&mut self.complex_fft_buffer, real_fft_buffer, &mut []) .process_with_scratch(&mut self.complex_fft_buffer, real_fft_buffer, &mut [])
.unwrap(); .unwrap();
// Apply the window function. We can do this either before the DFT or after the IDFT
util::window::multiply_with_window(real_fft_buffer, &self.window_function);
}); });
ProcessStatus::Normal ProcessStatus::Normal

View file

@ -236,7 +236,6 @@ impl Plugin for PubertySimulator {
// Negated because pitching down should cause us to take values from higher frequency bins // Negated because pitching down should cause us to take values from higher frequency bins
let frequency_multiplier = 2.0f32.powf(-smoothed_pitch_value); let frequency_multiplier = 2.0f32.powf(-smoothed_pitch_value);
// Forward FFT, the helper has already applied window function
// RustFFT doesn't actually need a scratch buffer here, so we'll pass an empty // RustFFT doesn't actually need a scratch buffer here, so we'll pass an empty
// buffer instead // buffer instead
fft_plan fft_plan