use {
crate::{
constants::*,
nhswps::{coeffs::coeffs, cpsource::cpsource, vertical::vertical, State},
utils::{arr2zero, arr3zero},
},
log::error,
ndarray::{Axis, Zip},
rayon::prelude::*,
std::sync::{Arc, Mutex},
};
pub fn psolve(state: &mut State) {
let toler = 1.0E-9;
let ng = state.spectral.ng;
let nz = state.spectral.nz;
let dz = HBAR / (nz as f64);
let dzi = 1.0 / dz;
let dz2 = dz / 2.0;
let dz6 = dz / 6.0;
let dzisq = (1.0 / dz).powf(2.0);
let hdzi = (1.0 / 2.0) * (1.0 / (HBAR / nz as f64));
let nitmax: usize = 100;
let mut sp0 = arr3zero(ng, nz);
let mut sigx = arr3zero(ng, nz);
let mut sigy = arr3zero(ng, nz);
let mut cpt1 = arr3zero(ng, nz);
let mut cpt2 = arr3zero(ng, nz);
let dpdt = Arc::new(Mutex::new(arr2zero(ng)));
let mut d2pdxt = arr2zero(ng);
let mut d2pdyt = arr2zero(ng);
let d2pdt2 = arr2zero(ng);
let mut wkp = arr2zero(ng);
let sp = arr3zero(ng, nz);
let mut gg = arr3zero(ng, nz);
let mut wka = arr2zero(ng);
let mut wkb = arr2zero(ng);
let mut wkc = arr2zero(ng);
let mut wkd = arr2zero(ng);
let wkq = Arc::new(Mutex::new(arr2zero(ng)));
Zip::from(&mut state.ri)
.and(&state.r)
.apply(|ri, r| *ri = 1.0 / (r + 1.0));
state.spectral.deal3d(state.ri.view_mut());
vertical(state);
coeffs(
state,
sigx.view_mut(),
sigy.view_mut(),
cpt1.view_mut(),
cpt2.view_mut(),
);
cpsource(state, sp0.view_mut());
let mut pna = state.pn.clone();
let sp = Arc::new(Mutex::new(sp));
let d2pdt2 = Arc::new(Mutex::new(d2pdt2));
let mut errp = 1.0;
let mut iter = 0;
while errp > toler && iter < nitmax {
state
.spectral
.ptospc3d(state.pn.view(), state.ps.view_mut(), 0, nz - 1);
state.ps.index_axis_mut(Axis(2), nz).fill(0.0);
Zip::from(&mut wkd)
.and(&state.ps.index_axis(Axis(2), 0))
.and(&state.ps.index_axis(Axis(2), 1))
.and(&state.ps.index_axis(Axis(2), 2))
.and(&state.ps.index_axis(Axis(2), 3))
.apply(|wkd, ps0, ps1, ps2, ps3| {
*wkd = (2.0 * ps0 - 5.0 * ps1 + 4.0 * ps2 - ps3) * dzisq;
});
state
.spectral
.d2fft
.spctop(wkd.view_mut(), d2pdt2.lock().unwrap().view_mut());
Zip::from(&mut wkp)
.and(sp0.index_axis(Axis(2), 0))
.and(cpt2.index_axis(Axis(2), 0))
.and(&(*d2pdt2.lock().unwrap()))
.apply(|wkp, sp0, cpt2, d2pdt2| *wkp = sp0 + cpt2 * d2pdt2);
state.spectral.d2fft.ptospc(wkp.view_mut(), wka.view_mut());
sp.lock().unwrap().index_axis_mut(Axis(2), 0).assign(&wka);
(1..=nz - 1).into_par_iter().for_each(|iz| {
let mut wka = arr2zero(ng);
let mut wkb = arr2zero(ng);
let mut wkc = arr2zero(ng);
let mut wkd = arr2zero(ng);
let mut wkp = arr2zero(ng);
let mut d2pdxt = arr2zero(ng);
let mut d2pdyt = arr2zero(ng);
Zip::from(&mut wka)
.and(&state.ps.index_axis(Axis(2), iz + 1))
.and(&state.ps.index_axis(Axis(2), iz - 1))
.apply(|wka, psp, psm| *wka = (psp - psm) * hdzi);
Zip::from(&mut wkd)
.and(&state.ps.index_axis(Axis(2), iz + 1))
.and(&state.ps.index_axis(Axis(2), iz))
.and(&state.ps.index_axis(Axis(2), iz - 1))
.apply(|wkd, psp, ps, psm| *wkd = (psp - 2.0 * ps + psm) * dzisq);
state
.spectral
.d2fft
.xderiv(&state.spectral.hrkx, wka.view(), wkb.view_mut());
state
.spectral
.d2fft
.yderiv(&state.spectral.hrky, wka.view(), wkc.view_mut());
state
.spectral
.d2fft
.spctop(wkb.view_mut(), d2pdxt.view_mut());
state
.spectral
.d2fft
.spctop(wkc.view_mut(), d2pdyt.view_mut());
Zip::from(&mut wkp)
.and(sp0.index_axis(Axis(2), iz))
.and(sigx.index_axis(Axis(2), iz))
.and(&d2pdxt)
.and(sigy.index_axis(Axis(2), iz))
.and(&d2pdyt)
.apply(|wkp, sp0, sigx, d2pdxt, sigy, d2pdyt| {
*wkp = sp0 + sigx * d2pdxt + sigy * d2pdyt
});
if iz == nz - 1 {
state
.spectral
.d2fft
.spctop(wka.view_mut(), dpdt.lock().unwrap().view_mut());
state
.spectral
.d2fft
.spctop(wkd.view_mut(), d2pdt2.lock().unwrap().view_mut());
Zip::from(&mut wkp)
.and(cpt2.index_axis(Axis(2), iz))
.and(&(*d2pdt2.lock().unwrap()))
.and(cpt1.index_axis(Axis(2), iz))
.and(&(*dpdt.lock().unwrap()))
.apply(|wkp, cpt2, d2pdt2, cpt1, dpdt| *wkp += cpt2 * d2pdt2 + cpt1 * dpdt);
} else {
let mut dpdt_local = arr2zero(ng);
let mut d2pdt2_local = arr2zero(ng);
state
.spectral
.d2fft
.spctop(wka.view_mut(), dpdt_local.view_mut());
state
.spectral
.d2fft
.spctop(wkd.view_mut(), d2pdt2_local.view_mut());
if iz == nz - 2 {
wkq.lock().unwrap().assign(&d2pdt2_local);
}
Zip::from(&mut wkp)
.and(cpt2.index_axis(Axis(2), iz))
.and(&d2pdt2_local)
.and(cpt1.index_axis(Axis(2), iz))
.and(&dpdt_local)
.apply(|wkp, cpt2, d2pdt2, cpt1, dpdt| *wkp += cpt2 * d2pdt2 + cpt1 * dpdt);
}
state.spectral.d2fft.ptospc(wkp.view_mut(), wka.view_mut());
sp.lock().unwrap().index_axis_mut(Axis(2), iz).assign(&wka);
});
Zip::from(&mut *dpdt.lock().unwrap())
.and(&(*d2pdt2.lock().unwrap()))
.and(&(*wkq.lock().unwrap()))
.apply(|dpdt, d2pdt2, wkq| *dpdt += dz2 * (3.0 * d2pdt2 - wkq));
Zip::from(&mut (*d2pdt2.lock().unwrap()))
.and(&(*wkq.lock().unwrap()))
.apply(|d2pdt2, wkq| *d2pdt2 = 2.0 * *d2pdt2 - wkq);
wkp = dpdt.lock().unwrap().clone();
state.spectral.d2fft.ptospc(wkp.view_mut(), wka.view_mut());
state
.spectral
.d2fft
.xderiv(&state.spectral.hrkx, wka.view(), wkb.view_mut());
state
.spectral
.d2fft
.yderiv(&state.spectral.hrky, wka.view(), wkc.view_mut());
state
.spectral
.d2fft
.spctop(wkb.view_mut(), d2pdxt.view_mut());
state
.spectral
.d2fft
.spctop(wkc.view_mut(), d2pdyt.view_mut());
Zip::from(&mut wkp)
.and(sp0.index_axis(Axis(2), nz))
.and(sigx.index_axis(Axis(2), nz))
.and(&d2pdxt)
.and(sigy.index_axis(Axis(2), nz))
.and(&d2pdyt)
.apply(|wkp, sp0, sigx, d2pdxt, sigy, d2pdyt| {
*wkp = sp0 + sigx * d2pdxt + sigy * d2pdyt
});
Zip::from(&mut wkp)
.and(cpt2.index_axis(Axis(2), nz))
.and(&(*d2pdt2.lock().unwrap()))
.and(cpt1.index_axis(Axis(2), nz))
.and(&(*dpdt.lock().unwrap()))
.apply(|wkp, cpt2, d2pdt2, cpt1, dpdt| *wkp += cpt2 * d2pdt2 + cpt1 * dpdt);
state.spectral.d2fft.ptospc(wkp.view_mut(), wka.view_mut());
sp.lock().unwrap().index_axis_mut(Axis(2), nz).assign(&wka);
{
let sp = sp.lock().unwrap();
Zip::from(gg.index_axis_mut(Axis(2), 0))
.and(sp.index_axis(Axis(2), 0))
.and(sp.index_axis(Axis(2), 1))
.apply(|gg, sp0, sp1| *gg = (1.0 / 3.0) * sp0 + (1.0 / 6.0) * sp1);
for iz in 1..nz {
Zip::from(gg.index_axis_mut(Axis(2), iz))
.and(sp.index_axis(Axis(2), iz - 1))
.and(sp.index_axis(Axis(2), iz + 1))
.and(sp.index_axis(Axis(2), iz))
.apply(|gg, spm, spp, sp| *gg = (1.0 / 12.0) * (spm + spp) + (5.0 / 6.0) * sp);
}
}
Zip::from(state.ps.index_axis_mut(Axis(2), 0))
.and(gg.index_axis(Axis(2), 0))
.and(state.spectral.htdv.index_axis(Axis(2), 0))
.apply(|ps, gg, htdv| *ps = gg * htdv);
for iz in 1..nz {
let ps1 = state.ps.index_axis(Axis(2), iz - 1).into_owned();
Zip::from(state.ps.index_axis_mut(Axis(2), iz))
.and(gg.index_axis(Axis(2), iz))
.and(&state.spectral.ap)
.and(&ps1)
.and(state.spectral.htdv.index_axis(Axis(2), iz))
.apply(|ps, gg, ap, ps1, htdv| *ps = (gg - ap * ps1) * htdv);
}
for iz in (0..=nz - 2).rev() {
let ps1 = state.ps.index_axis(Axis(2), iz + 1).into_owned();
Zip::from(state.ps.index_axis_mut(Axis(2), iz))
.and(state.spectral.etdv.index_axis(Axis(2), iz))
.and(&ps1)
.apply(|ps, etdv, ps1| *ps += etdv * ps1);
}
state.ps.index_axis_mut(Axis(2), nz).fill(0.0);
state
.spectral
.spctop3d(state.ps.view(), state.pn.view_mut(), 0, nz - 1);
state.pn.index_axis_mut(Axis(2), nz).fill(0.0);
errp = (state
.pn
.iter()
.zip(&pna)
.map(|(a, b)| (a - b).powf(2.0))
.sum::<f64>()
/ (pna.iter().map(|x| x.powf(2.0)).sum::<f64>() + 1.0E-20))
.sqrt();
if iter > 0 && errp > 1.0 {
error!("Pressure error too large! Final pressure error = {}", errp);
quit::with_code(1);
}
iter += 1;
pna = state.pn.clone();
}
if iter >= nitmax {
error!(
"Exceeded maximum number of iterations to find pressure! Final pressure error = {}",
errp
);
quit::with_code(1);
}
{
for iz in 1..nz {
Zip::from(gg.index_axis_mut(Axis(2), iz))
.and(state.ps.index_axis(Axis(2), iz + 1))
.and(state.ps.index_axis(Axis(2), iz - 1))
.apply(|gg, psp, psm| *gg = (psp - psm) * hdzi);
}
Zip::from(gg.index_axis_mut(Axis(2), nz))
.and(sp.lock().unwrap().index_axis(Axis(2), nz))
.and(state.ps.index_axis(Axis(2), nz - 1))
.apply(|gg, sp, ps| *gg = dz6 * sp - ps * dzi);
Zip::from(gg.index_axis_mut(Axis(2), 1)).apply(|gg| *gg *= state.spectral.htd1[0]);
for iz in 2..nz {
let gg1 = gg.index_axis(Axis(2), iz - 1).into_owned();
Zip::from(gg.index_axis_mut(Axis(2), iz))
.and(&gg1)
.apply(|gg, gg1| *gg = (*gg - (1.0 / 6.0) * gg1) * state.spectral.htd1[iz - 1]);
}
{
let gg1 = gg.index_axis(Axis(2), nz - 1).into_owned();
Zip::from(gg.index_axis_mut(Axis(2), nz))
.and(&gg1)
.apply(|gg, gg1| *gg = (*gg - (1.0 / 3.0) * gg1) * state.spectral.htd1[nz - 1]);
}
for iz in (1..nz).rev() {
let gg1 = gg.index_axis(Axis(2), iz + 1).into_owned();
Zip::from(gg.index_axis_mut(Axis(2), iz))
.and(&gg1)
.apply(|gg, gg1| *gg += state.spectral.etd1[iz - 1] * gg1);
}
}
state
.spectral
.spctop3d(gg.view(), state.dpn.view_mut(), 1, nz);
}
#[cfg(test)]
mod test {
use {
super::*,
crate::{array3_from_file, nhswps::Spectral},
approx::assert_abs_diff_eq,
byteorder::ByteOrder,
lazy_static::lazy_static,
ndarray::{Array3, ShapeBuilder},
};
lazy_static! {
static ref STATE_24_4: State = {
let ng = 24;
let nz = 4;
let ri = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_ri.bin");
let r = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_r.bin");
let u = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_u.bin");
let v = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_v.bin");
let w = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_w.bin");
let zeta = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_zeta.bin");
let z = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_z.bin");
let zx = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_zx.bin");
let zy = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_zy.bin");
let ps = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_ps.bin");
let pn = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_pn.bin");
let dpn = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_dpn.bin");
let aa = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_aa.bin");
let qs = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_qs.bin");
let ds = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_ds.bin");
let gs = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/24_4_gs.bin");
let mut state = State {
spectral: Spectral::new(ng, nz),
u,
v,
w,
z,
zx,
zy,
r,
ri,
aa,
zeta,
pn,
dpn,
ps,
qs,
ds,
gs,
t: 0.0,
ngsave: 0,
itime: 0,
jtime: 0,
ggen: false,
};
psolve(&mut state);
state
};
static ref STATE_32_4: State = {
let ng = 32;
let nz = 4;
let ri = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_ri.bin");
let r = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_r.bin");
let u = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_u.bin");
let v = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_v.bin");
let w = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_w.bin");
let zeta = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_zeta.bin");
let z = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_z.bin");
let zx = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_zx.bin");
let zy = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_zy.bin");
let ps = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_ps.bin");
let pn = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_pn.bin");
let dpn = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_dpn.bin");
let aa = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_aa.bin");
let qs = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_qs.bin");
let ds = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_ds.bin");
let gs = array3_from_file!(ng, ng, nz + 1, "testdata/psolve/32_4_gs.bin");
let mut state = State {
spectral: Spectral::new(ng, nz),
u,
v,
w,
z,
zx,
zy,
r,
ri,
aa,
zeta,
pn,
dpn,
ps,
qs,
ds,
gs,
t: 0.0,
ngsave: 0,
itime: 0,
jtime: 0,
ggen: false,
};
psolve(&mut state);
state
};
}
#[test]
fn _32_4_z() {
let z2 = array3_from_file!(32, 32, 5, "testdata/psolve/32_4_z2.bin");
assert_abs_diff_eq!(z2, STATE_32_4.z, epsilon = 1.0E-10);
}
#[test]
fn _32_4_zx() {
let zx2 = array3_from_file!(32, 32, 5, "testdata/psolve/32_4_zx2.bin");
assert_abs_diff_eq!(zx2, STATE_32_4.zx, epsilon = 1.0E-10);
}
#[test]
fn _32_4_zy() {
let zy2 = array3_from_file!(32, 32, 5, "testdata/psolve/32_4_zy2.bin");
assert_abs_diff_eq!(zy2, STATE_32_4.zy, epsilon = 1.0E-10);
}
#[test]
fn _32_4_w() {
let w2 = array3_from_file!(32, 32, 5, "testdata/psolve/32_4_w2.bin");
assert_abs_diff_eq!(&w2, &STATE_32_4.w, epsilon = 1.0E-10);
}
#[test]
fn _32_4_aa() {
let aa2 = array3_from_file!(32, 32, 5, "testdata/psolve/32_4_aa2.bin");
assert_abs_diff_eq!(&aa2, &STATE_32_4.aa, epsilon = 1.0E-10);
}
#[test]
fn _32_4_ri() {
let ri2 = array3_from_file!(32, 32, 5, "testdata/psolve/32_4_ri2.bin");
assert_abs_diff_eq!(&ri2, &STATE_32_4.ri, epsilon = 1.0E-10, epsilon = 1.0E-10);
}
#[test]
fn _32_4_pn() {
let pn2 = array3_from_file!(32, 32, 5, "testdata/psolve/32_4_pn2.bin");
assert_abs_diff_eq!(&pn2, &STATE_32_4.pn, epsilon = 1.0E-10);
}
#[test]
fn _32_4_ps() {
let ps2 = array3_from_file!(32, 32, 5, "testdata/psolve/32_4_ps2.bin");
assert_abs_diff_eq!(&ps2, &STATE_32_4.ps, epsilon = 1.0E-10);
}
#[test]
fn _32_4_dpn() {
let dpn2 = array3_from_file!(32, 32, 5, "testdata/psolve/32_4_dpn2.bin");
assert_abs_diff_eq!(&dpn2, &STATE_32_4.dpn, epsilon = 1.0E-10);
}
#[test]
fn _24_4_z() {
let z2 = array3_from_file!(24, 24, 5, "testdata/psolve/24_4_z2.bin");
assert_abs_diff_eq!(z2, STATE_24_4.z, epsilon = 1.0E-10);
}
#[test]
fn _24_4_zx() {
let zx2 = array3_from_file!(24, 24, 5, "testdata/psolve/24_4_zx2.bin");
assert_abs_diff_eq!(zx2, STATE_24_4.zx, epsilon = 1.0E-10);
}
#[test]
fn _24_4_zy() {
let zy2 = array3_from_file!(24, 24, 5, "testdata/psolve/24_4_zy2.bin");
assert_abs_diff_eq!(zy2, STATE_24_4.zy, epsilon = 1.0E-10);
}
#[test]
fn _24_4_w() {
let w2 = array3_from_file!(24, 24, 5, "testdata/psolve/24_4_w2.bin");
assert_abs_diff_eq!(&w2, &STATE_24_4.w, epsilon = 1.0E-10);
}
#[test]
fn _24_4_aa() {
let aa2 = array3_from_file!(24, 24, 5, "testdata/psolve/24_4_aa2.bin");
assert_abs_diff_eq!(&aa2, &STATE_24_4.aa, epsilon = 1.0E-10);
}
#[test]
fn _24_4_ri() {
let ri2 = array3_from_file!(24, 24, 5, "testdata/psolve/24_4_ri2.bin");
assert_abs_diff_eq!(&ri2, &STATE_24_4.ri, epsilon = 1.0E-10);
}
#[test]
fn _24_4_pn() {
let pn2 = array3_from_file!(24, 24, 5, "testdata/psolve/24_4_pn2.bin");
assert_abs_diff_eq!(&pn2, &STATE_24_4.pn, epsilon = 1.0E-10);
}
#[test]
fn _24_4_ps() {
let ps2 = array3_from_file!(24, 24, 5, "testdata/psolve/24_4_ps2.bin");
assert_abs_diff_eq!(&ps2, &STATE_24_4.ps, epsilon = 1.0E-10);
}
#[test]
fn _24_4_dpn() {
let dpn2 = array3_from_file!(24, 24, 5, "testdata/psolve/24_4_dpn2.bin");
assert_abs_diff_eq!(&dpn2, &STATE_24_4.dpn, epsilon = 1.0E-10);
}
}