use {crate::utils::*, core::f64::consts::PI, log::error, ndarray::ArrayViewMut2};
mod forward;
mod reverse;
pub use forward::*;
pub use reverse::*;
pub fn initfft(n: usize, factors: &mut [usize; 5], trig: &mut [f64]) {
assert_eq!(2 * n, trig.len());
let fac = [6, 4, 2, 3, 5];
factorisen(n, factors);
let mut ftwopin = 2.0 * PI / (n as f64);
let mut rem = n;
let mut m = 1;
for (i, element) in factors.iter().enumerate() {
for _ in 1..=*element {
rem /= fac[i];
for k in 1..fac[i] {
for l in 0..rem {
trig[m - 1] = ftwopin * ((k * l) as f64);
m += 1;
}
}
ftwopin *= fac[i] as f64;
}
}
for i in 1..n {
trig[i + n - 1] = -(trig[i - 1].sin());
trig[i - 1] = trig[i - 1].cos();
}
}
pub fn factorisen(n: usize, factors: &mut [usize; 5]) {
let mut rem = n;
for elem in factors.iter_mut() {
*elem = 0;
}
while rem % 6 == 0 {
factors[0] += 1;
rem /= 6;
if rem == 1 {
return;
}
}
while rem % 4 == 0 {
factors[1] += 1;
rem /= 4;
if rem == 1 {
return;
}
}
while rem % 2 == 0 {
factors[2] += 1;
rem /= 2;
if rem == 1 {
return;
}
}
while rem % 3 == 0 {
factors[3] += 1;
rem /= 3;
if rem == 1 {
return;
}
}
while rem % 5 == 0 {
factors[4] += 1;
rem /= 5;
if rem == 1 {
return;
}
}
error!("Factorization failed");
quit::with_code(1);
}
pub fn forfft(m: usize, n: usize, mut xs: ArrayViewMut2<f64>, trig: &[f64], factors: &[usize; 5]) {
let xs = xs.as_slice_memory_order_mut().unwrap();
assert_eq!(m * n, xs.len());
assert_eq!(2 * n, trig.len());
let normfac: f64;
let mut wk = vec![0.0; m * n];
let mut rem = n;
let mut cum = 1;
let mut iloc: usize;
let mut orig = true;
for _ in 0..factors[4] {
rem /= 5;
iloc = (rem - 1) * 5 * cum;
let cosine = view2d(&trig[iloc..], cum, 4);
let sine = view2d(&trig[n + iloc..], cum, 4);
if orig {
let a = view3d(xs, m * rem, 5, cum);
let b = viewmut3d(&mut wk, m * rem, cum, 5);
forrdx5(a, b, m * rem, cum, cosine, sine);
} else {
let a = view3d(&wk, m * rem, 5, cum);
let b = viewmut3d(xs, m * rem, cum, 5);
forrdx5(a, b, m * rem, cum, cosine, sine);
}
orig = !orig;
cum *= 5;
}
for _ in 0..factors[3] {
rem /= 3;
iloc = (rem - 1) * 3 * cum;
let cosine = view2d(&trig[iloc..], cum, 2);
let sine = view2d(&trig[n + iloc..], cum, 2);
if orig {
let a = view3d(xs, m * rem, 3, cum);
let b = viewmut3d(&mut wk, m * rem, cum, 3);
forrdx3(a, b, m * rem, cum, cosine, sine);
} else {
let a = view3d(&wk, m * rem, 3, cum);
let b = viewmut3d(xs, m * rem, cum, 3);
forrdx3(a, b, m * rem, cum, cosine, sine);
}
orig = !orig;
cum *= 3;
}
for _ in 0..factors[2] {
rem /= 2;
iloc = (rem - 1) * 2 * cum;
let cosine = view2d(&trig[iloc..], cum, 1);
let sine = view2d(&trig[n + iloc..], cum, 1);
if orig {
let a = view3d(xs, m * rem, 2, cum);
let b = viewmut3d(&mut wk, m * rem, cum, 2);
forrdx2(a, b, m * rem, cum, cosine, sine);
} else {
let a = view3d(&wk, m * rem, 2, cum);
let b = viewmut3d(xs, m * rem, cum, 2);
forrdx2(a, b, m * rem, cum, cosine, sine);
}
orig = !orig;
cum *= 2;
}
for _ in 0..factors[1] {
rem /= 4;
iloc = (rem - 1) * 4 * cum;
let cosine = view2d(&trig[iloc..], cum, 3);
let sine = view2d(&trig[n + iloc..], cum, 3);
if orig {
let a = view3d(xs, m * rem, 4, cum);
let b = viewmut3d(&mut wk, m * rem, cum, 4);
forrdx4(a, b, m * rem, cum, cosine, sine);
} else {
let a = view3d(&wk, m * rem, 4, cum);
let b = viewmut3d(xs, m * rem, cum, 4);
forrdx4(a, b, m * rem, cum, cosine, sine);
}
orig = !orig;
cum *= 4;
}
for _ in 0..factors[0] {
rem /= 6;
iloc = (rem - 1) * 6 * cum;
let cosine = view2d(&trig[iloc..], cum, 5);
let sine = view2d(&trig[n + iloc..], cum, 5);
if orig {
let a = view3d(xs, m * rem, 6, cum);
let b = viewmut3d(&mut wk, m * rem, cum, 6);
forrdx6(a, b, m * rem, cum, cosine, sine);
} else {
let a = view3d(&wk, m * rem, 6, cum);
let b = viewmut3d(xs, m * rem, cum, 6);
forrdx6(a, b, m * rem, cum, cosine, sine);
}
orig = !orig;
cum *= 6;
}
normfac = 1.0 / (n as f64).sqrt();
for (i, x) in xs.iter_mut().enumerate() {
if orig {
*x *= normfac;
} else {
*x = wk[i] * normfac;
}
}
}
pub fn revfft(m: usize, n: usize, mut xs: ArrayViewMut2<f64>, trig: &[f64], factors: &[usize; 5]) {
let xs = xs.as_slice_memory_order_mut().unwrap();
assert_eq!(m * n, xs.len());
assert_eq!(2 * n, trig.len());
let normfac: f64;
let mut wk = vec![0.0; m * n];
let mut rem = n;
let mut cum = 1;
let mut iloc: usize;
let mut orig = true;
for elem in xs.iter_mut().skip((n / 2 + 1) * m) {
*elem = -*elem;
}
for elem in xs.iter_mut().take(m) {
*elem *= 0.5;
}
if n % 2 == 0 {
let k = m * n / 2;
for i in 0..m {
xs[k + i] *= 0.5;
}
}
for _ in 0..factors[0] {
rem /= 6;
iloc = (cum - 1) * 6 * rem;
let cosine = view2d(&trig[iloc..], rem, 5);
let sine = view2d(&trig[n + iloc..], rem, 5);
if orig {
let a = view3d(xs, m * cum, rem, 6);
let b = viewmut3d(&mut wk, m * cum, 6, rem);
revrdx6(a, b, m * cum, rem, cosine, sine);
} else {
let a = view3d(&wk, m * cum, rem, 6);
let b = viewmut3d(xs, m * cum, 6, rem);
revrdx6(a, b, m * cum, rem, cosine, sine);
}
orig = !orig;
cum *= 6;
}
for _ in 0..factors[1] {
rem /= 4;
iloc = (cum - 1) * 4 * rem;
let cosine = view2d(&trig[iloc..], rem, 3);
let sine = view2d(&trig[n + iloc..], rem, 3);
if orig {
let a = view3d(xs, m * cum, rem, 4);
let b = viewmut3d(&mut wk, m * cum, 4, rem);
revrdx4(a, b, m * cum, rem, cosine, sine);
} else {
let a = view3d(&wk, m * cum, rem, 4);
let b = viewmut3d(xs, m * cum, 4, rem);
revrdx4(a, b, m * cum, rem, cosine, sine);
}
orig = !orig;
cum *= 4;
}
for _ in 0..factors[2] {
rem /= 2;
iloc = (cum - 1) * 2 * rem;
let cosine = view2d(&trig[iloc..], rem, 1);
let sine = view2d(&trig[n + iloc..], rem, 1);
if orig {
let a = view3d(xs, m * cum, rem, 2);
let b = viewmut3d(&mut wk, m * cum, 2, rem);
revrdx2(a, b, m * cum, rem, cosine, sine);
} else {
let a = view3d(&wk, m * cum, rem, 2);
let b = viewmut3d(xs, m * cum, 2, rem);
revrdx2(a, b, m * cum, rem, cosine, sine);
}
orig = !orig;
cum *= 2;
}
for _ in 0..factors[3] {
rem /= 3;
iloc = (cum - 1) * 3 * rem;
let cosine = view2d(&trig[iloc..], rem, 2);
let sine = view2d(&trig[n + iloc..], rem, 2);
if orig {
let a = view3d(xs, m * cum, rem, 3);
let b = viewmut3d(&mut wk, m * cum, 3, rem);
revrdx3(a, b, m * cum, rem, cosine, sine);
} else {
let a = view3d(&wk, m * cum, rem, 3);
let b = viewmut3d(xs, m * cum, 3, rem);
revrdx3(a, b, m * cum, rem, cosine, sine);
}
orig = !orig;
cum *= 3;
}
for _ in 0..factors[4] {
rem /= 5;
iloc = (cum - 1) * 5 * rem;
let cosine = view2d(&trig[iloc..], rem, 4);
let sine = view2d(&trig[n + iloc..], rem, 4);
if orig {
let a = view3d(xs, m * cum, rem, 5);
let b = viewmut3d(&mut wk, m * cum, 5, rem);
revrdx5(a, b, m * cum, rem, cosine, sine);
} else {
let a = view3d(&wk, m * cum, rem, 5);
let b = viewmut3d(xs, m * cum, 5, rem);
revrdx5(a, b, m * cum, rem, cosine, sine);
}
orig = !orig;
cum *= 5;
}
normfac = 2.0 / (n as f64).sqrt();
for (i, x) in xs.iter_mut().enumerate() {
if orig {
*x *= normfac;
} else {
*x = wk[i] * normfac;
}
}
}
#[cfg(test)]
mod test {
use {
super::*,
crate::array2_from_file,
byteorder::{ByteOrder, NetworkEndian},
insta::assert_debug_snapshot,
ndarray::{Array2, ShapeBuilder},
};
#[test]
fn factorisen_snapshot_1() {
let n = 16;
let mut factors = [0; 5];
factorisen(n, &mut factors);
assert_debug_snapshot!(factors);
}
#[test]
fn factorisen_snapshot_2() {
let n = 8640;
let mut factors = [0; 5];
factorisen(n, &mut factors);
assert_debug_snapshot!(factors);
}
#[test]
fn initfft_snapshot_1() {
let n = 30;
let mut factors = [0; 5];
let mut trig = [0.0; 60];
let trig2 = include_bytes!("testdata/initfft/30_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
initfft(n, &mut factors, &mut trig);
assert_approx_eq_slice(&trig2, &trig);
}
#[test]
fn initfft_snapshot_2() {
let n = 32;
let mut factors = [0; 5];
let mut trig = [0.0; 64];
let trig2 = include_bytes!("testdata/initfft/32_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
initfft(n, &mut factors, &mut trig);
assert_approx_eq_slice(&trig2, &trig);
}
#[test]
fn initfft_snapshot_3() {
let n = 18;
let mut factors = [0; 5];
let mut trig = [0.0; 36];
let trig2 = include_bytes!("testdata/initfft/18_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
initfft(n, &mut factors, &mut trig);
assert_approx_eq_slice(&trig2, &trig);
}
#[test]
fn forfft_ng12_1() {
let m = 12;
let n = 12;
let mut x = array2_from_file!(12, 12, "testdata/forfft/forfft_ng12_1_x.bin");
let x2 = array2_from_file!(12, 12, "testdata/forfft/forfft_ng12_1_x2.bin");
let trig = include_bytes!("testdata/forfft/forfft_ng12_1_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 0, 1, 0, 0];
forfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn forfft_ng12_2() {
let m = 12;
let n = 12;
let mut x = array2_from_file!(12, 12, "testdata/forfft/forfft_ng12_2_x.bin");
let x2 = array2_from_file!(12, 12, "testdata/forfft/forfft_ng12_2_x2.bin");
let trig = include_bytes!("testdata/forfft/forfft_ng12_2_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 0, 1, 0, 0];
forfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn forfft_ng15_1() {
let m = 15;
let n = 15;
let mut x = array2_from_file!(15, 15, "testdata/forfft/forfft_ng15_1_x.bin");
let x2 = array2_from_file!(15, 15, "testdata/forfft/forfft_ng15_1_x2.bin");
let trig = include_bytes!("testdata/forfft/forfft_ng15_1_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [0, 0, 0, 1, 1];
forfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn forfft_ng15_2() {
let m = 15;
let n = 15;
let mut x = array2_from_file!(15, 15, "testdata/forfft/forfft_ng15_2_x.bin");
let x2 = array2_from_file!(15, 15, "testdata/forfft/forfft_ng15_2_x2.bin");
let trig = include_bytes!("testdata/forfft/forfft_ng15_2_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [0, 0, 0, 1, 1];
forfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn forfft_ng16_1() {
let m = 16;
let n = 16;
let mut x = array2_from_file!(16, 16, "testdata/forfft/forfft_ng16_1_x.bin");
let x2 = array2_from_file!(16, 16, "testdata/forfft/forfft_ng16_1_x2.bin");
let trig = include_bytes!("testdata/forfft/forfft_ng16_1_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [0, 2, 0, 0, 0];
forfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn forfft_ng16_2() {
let m = 16;
let n = 16;
let mut x = array2_from_file!(16, 16, "testdata/forfft/forfft_ng16_2_x.bin");
let x2 = array2_from_file!(16, 16, "testdata/forfft/forfft_ng16_2_x2.bin");
let trig = include_bytes!("testdata/forfft/forfft_ng16_2_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [0, 2, 0, 0, 0];
forfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn forfft_ng18_1() {
let m = 18;
let n = 18;
let mut x = array2_from_file!(18, 18, "testdata/forfft/forfft_ng18_1_x.bin");
let x2 = array2_from_file!(18, 18, "testdata/forfft/forfft_ng18_1_x2.bin");
let trig = include_bytes!("testdata/forfft/forfft_ng18_1_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 0, 0, 1, 0];
forfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn forfft_ng18_2() {
let m = 18;
let n = 18;
let mut x = array2_from_file!(18, 18, "testdata/forfft/forfft_ng18_2_x.bin");
let x2 = array2_from_file!(18, 18, "testdata/forfft/forfft_ng18_2_x2.bin");
let trig = include_bytes!("testdata/forfft/forfft_ng18_2_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 0, 0, 1, 0];
forfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn forfft_ng24_1() {
let m = 24;
let n = 24;
let mut x = array2_from_file!(24, 24, "testdata/forfft/forfft_ng24_1_x.bin");
let x2 = array2_from_file!(24, 24, "testdata/forfft/forfft_ng24_1_x2.bin");
let trig = include_bytes!("testdata/forfft/forfft_ng24_1_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 1, 0, 0, 0];
forfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn forfft_ng24_2() {
let m = 24;
let n = 24;
let mut x = array2_from_file!(24, 24, "testdata/forfft/forfft_ng24_2_x.bin");
let x2 = array2_from_file!(24, 24, "testdata/forfft/forfft_ng24_2_x2.bin");
let trig = include_bytes!("testdata/forfft/forfft_ng24_2_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 1, 0, 0, 0];
forfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn revfft_ng12_1() {
let m = 12;
let n = 12;
let mut x = array2_from_file!(12, 12, "testdata/revfft/revfft_ng12_1_x.bin");
let x2 = array2_from_file!(12, 12, "testdata/revfft/revfft_ng12_1_x2.bin");
let trig = include_bytes!("testdata/revfft/revfft_ng12_1_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 0, 1, 0, 0];
revfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn revfft_ng12_2() {
let m = 12;
let n = 12;
let mut x = array2_from_file!(12, 12, "testdata/revfft/revfft_ng12_2_x.bin");
let x2 = array2_from_file!(12, 12, "testdata/revfft/revfft_ng12_2_x2.bin");
let trig = include_bytes!("testdata/revfft/revfft_ng12_2_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 0, 1, 0, 0];
revfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn revfft_ng15_1() {
let m = 15;
let n = 15;
let mut x = array2_from_file!(15, 15, "testdata/revfft/revfft_ng15_1_x.bin");
let x2 = array2_from_file!(15, 15, "testdata/revfft/revfft_ng15_1_x2.bin");
let trig = include_bytes!("testdata/revfft/revfft_ng15_1_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [0, 0, 0, 1, 1];
revfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn revfft_ng15_2() {
let m = 15;
let n = 15;
let mut x = array2_from_file!(15, 15, "testdata/revfft/revfft_ng15_2_x.bin");
let x2 = array2_from_file!(15, 15, "testdata/revfft/revfft_ng15_2_x2.bin");
let trig = include_bytes!("testdata/revfft/revfft_ng15_2_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [0, 0, 0, 1, 1];
revfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn revfft_ng16_1() {
let m = 16;
let n = 16;
let mut x = array2_from_file!(16, 16, "testdata/revfft/revfft_ng16_1_x.bin");
let x2 = array2_from_file!(16, 16, "testdata/revfft/revfft_ng16_1_x2.bin");
let trig = include_bytes!("testdata/revfft/revfft_ng16_1_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [0, 2, 0, 0, 0];
revfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn revfft_ng16_2() {
let m = 16;
let n = 16;
let mut x = array2_from_file!(16, 16, "testdata/revfft/revfft_ng16_2_x.bin");
let x2 = array2_from_file!(16, 16, "testdata/revfft/revfft_ng16_2_x2.bin");
let trig = include_bytes!("testdata/revfft/revfft_ng16_2_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [0, 2, 0, 0, 0];
revfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn revfft_ng18_1() {
let m = 18;
let n = 18;
let mut x = array2_from_file!(18, 18, "testdata/revfft/revfft_ng18_1_x.bin");
let x2 = array2_from_file!(18, 18, "testdata/revfft/revfft_ng18_1_x2.bin");
let trig = include_bytes!("testdata/revfft/revfft_ng18_1_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 0, 0, 1, 0];
revfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn revfft_ng18_2() {
let m = 18;
let n = 18;
let mut x = array2_from_file!(18, 18, "testdata/revfft/revfft_ng18_2_x.bin");
let x2 = array2_from_file!(18, 18, "testdata/revfft/revfft_ng18_2_x2.bin");
let trig = include_bytes!("testdata/revfft/revfft_ng18_2_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 0, 0, 1, 0];
revfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn revfft_ng24_1() {
let m = 24;
let n = 24;
let mut x = array2_from_file!(24, 24, "testdata/revfft/revfft_ng24_1_x.bin");
let x2 = array2_from_file!(24, 24, "testdata/revfft/revfft_ng24_1_x2.bin");
let trig = include_bytes!("testdata/revfft/revfft_ng24_1_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 1, 0, 0, 0];
revfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
#[test]
fn revfft_ng24_2() {
let m = 24;
let n = 24;
let mut x = array2_from_file!(24, 24, "testdata/revfft/revfft_ng24_2_x.bin");
let x2 = array2_from_file!(24, 24, "testdata/revfft/revfft_ng24_2_x2.bin");
let trig = include_bytes!("testdata/revfft/revfft_ng24_2_trig.bin")
.chunks(8)
.map(NetworkEndian::read_f64)
.collect::<Vec<f64>>();
let factors = [1, 1, 0, 0, 0];
revfft(m, n, x.view_mut(), &trig, &factors);
assert_eq!(x2, x);
}
}