spectral_processing.f90 Source File


This file depends on

sourcefile~~spectral_processing.f90~~EfferentGraph sourcefile~spectral_processing.f90 spectral_processing.f90 sourcefile~common.f90~3 common.f90 sourcefile~spectral_processing.f90->sourcefile~common.f90~3

Files dependent on this one

sourcefile~~spectral_processing.f90~~AfferentGraph sourcefile~spectral_processing.f90 spectral_processing.f90 sourcefile~poisson_fft.f90~2 poisson_fft.f90 sourcefile~poisson_fft.f90~2->sourcefile~spectral_processing.f90 sourcefile~backend.f90 backend.f90 sourcefile~backend.f90->sourcefile~poisson_fft.f90~2 sourcefile~xcompact.f90 xcompact.f90 sourcefile~xcompact.f90->sourcefile~backend.f90

Source Code

module m_cuda_spectral
  use cudafor

  use m_common, only: dp

  implicit none

contains

  attributes(global) subroutine memcpy3D(dst, src, nx, ny, nz)
    !! Copy data between x3d2 padded arrays and cuFFTMp descriptors
    implicit none

    real(dp), device, intent(inout), dimension(:, :, :) :: dst
    real(dp), device, intent(in), dimension(:, :, :) :: src
    integer, value, intent(in) :: nx, ny, nz

    integer :: i, j, k

    j = threadIdx%x + (blockIdx%x - 1)*blockDim%x !ny
    k = blockIdx%y !nz

    if (j <= ny) then
      do i = 1, nx
        dst(i, j, k) = src(i, j, k)
      end do
    end if
  end subroutine memcpy3D

  attributes(global) subroutine memcpy3D_with_transpose(dst, src, nx, ny, nz)
  !! Copy with transpose: src(nx, ny, nz) -> dst(ny, nx, nz)
  !! Used for 100 case forward FFT
    implicit none

    real(dp), device, intent(out), dimension(:, :, :) :: dst  ! (ny+2, nx, nz) but we only write (ny, nx, nz)
    real(dp), device, intent(in), dimension(:, :, :) :: src   ! (nx, ny, nz)
    integer, value, intent(in) :: nx, ny, nz

    integer :: i, j, k

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x  ! iterates over nx
    k = blockIdx%y  ! nz

    if (i <= nx) then
      do j = 1, ny
        ! Transpose: dst(j, i, k) = src(i, j, k)
        dst(j, i, k) = src(i, j, k)
      end do
    end if

  end subroutine memcpy3D_with_transpose

  attributes(global) subroutine memcpy3D_with_transpose_back( &
    dst, src, nx, ny, nz &
    )
  !! Copy with transpose back: src(ny, nx, nz) -> dst(nx, ny, nz)
  !! Used for 100 case backward FFT
    implicit none

    real(dp), device, intent(out), dimension(:, :, :) :: dst  ! (nx, ny, nz)
    real(dp), device, intent(in), dimension(:, :, :) :: src   ! (ny+2, nx, nz) but we only read (ny, nx, nz)
    integer, value, intent(in) :: nx, ny, nz

    integer :: i, j, k

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x  ! iterates over nx
    k = blockIdx%y  ! nz

    if (i <= nx) then
      do j = 1, ny
        ! Transpose back: dst(i, j, k) = src(j, i, k)
        dst(i, j, k) = src(j, i, k)
      end do
    end if

  end subroutine memcpy3D_with_transpose_back

  attributes(global) subroutine transpose_xyz_to_zxy(dst, src, nx, ny, nz)
    !! Transpose: src(nx, ny, nz) → dst(nz, nx, ny)
    !! Used before R2C FFT for 110 case to put Z (periodic) in fast dim.
    !!
    !! Launch: blocks=dim3(nz, (ny-1)/tpb+1, 1), threads=dim3(tpb,1,1)
    !! Each thread handles one (k,j) pair, loops over i.
    implicit none

    real(dp), device, intent(out), dimension(:, :, :) :: dst  ! (nz, nx, ny)
    real(dp), device, intent(in), dimension(:, :, :) :: src   ! (nx, ny, nz)
    integer, value, intent(in) :: nx, ny, nz

    integer :: i, j, k

    k = blockIdx%x                                          ! nz
    j = (blockIdx%y - 1)*blockDim%x + threadIdx%x           ! ny

    if (j <= ny) then
      do i = 1, nx
        dst(k, i, j) = src(i, j, k)
      end do
    end if

  end subroutine transpose_xyz_to_zxy

  attributes(global) subroutine transpose_zxy_to_xyz(dst, src, nx, ny, nz)
    !! Transpose back: src(nz, nx, ny) -> dst(nx, ny, nz)
    !! Used after C2R FFT for 110 case.
    !!
    !! Launch: blocks=dim3(nz, (ny-1)/tpb+1, 1), threads=dim3(tpb,1,1)
    implicit none

    real(dp), device, intent(out), dimension(:, :, :) :: dst  ! (nx, ny, nz)
    real(dp), device, intent(in), dimension(:, :, :) :: src   ! (nz, nx, ny)
    integer, value, intent(in) :: nx, ny, nz

    integer :: i, j, k

    k = blockIdx%x                                          ! nz
    j = (blockIdx%y - 1)*blockDim%x + threadIdx%x           ! ny

    if (j <= ny) then
      do i = 1, nx
        dst(i, j, k) = src(k, i, j)
      end do
    end if

  end subroutine transpose_zxy_to_xyz

  attributes(global) subroutine process_spectral_000( &
    div_u, waves, nx_spec, ny_spec, y_sp_st, nx, ny, nz, &
    ax, bx, ay, by, az, bz &
    )
    !! Post-processes the divergence of velocity in spectral space, including
    !! scaling w.r.t. grid size.
    !!
    !! Ref. JCP 228 (2009), 5989–6015, Sec 4
    implicit none

    !> Divergence of velocity in spectral space
    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u
    !> Spectral equivalence constants
    complex(dp), device, intent(in), dimension(:, :, :) :: waves
    real(dp), device, intent(in), dimension(:) :: ax, bx, ay, by, az, bz
    !> Grid size in spectral space
    integer, value, intent(in) :: nx_spec, ny_spec
    !> Offset in y direction in the permuted slabs in spectral space
    integer, value, intent(in) :: y_sp_st
    !> Grid size
    integer, value, intent(in) :: nx, ny, nz

    integer :: i, j, k, ix, iy, iz
    real(dp) :: tmp_r, tmp_c, div_r, div_c

    j = threadIdx%x + (blockIdx%x - 1)*blockDim%x
    k = blockIdx%y ! nz_spec

    if (j <= ny_spec) then
      do i = 1, nx_spec
        ! normalisation
        div_r = real(div_u(i, j, k), kind=dp)/nx/ny/nz
        div_c = aimag(div_u(i, j, k))/nx/ny/nz

        ix = i; iy = j + y_sp_st; iz = k

        ! post-process forward
        ! post-process in z
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bz(iz) + tmp_c*az(iz)
        div_c = tmp_c*bz(iz) - tmp_r*az(iz)
        if (iz > nz/2 + 1) div_r = -div_r
        if (iz > nz/2 + 1) div_c = -div_c

        ! post-process in y
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*by(iy) + tmp_c*ay(iy)
        div_c = tmp_c*by(iy) - tmp_r*ay(iy)
        if (iy > ny/2 + 1) div_r = -div_r
        if (iy > ny/2 + 1) div_c = -div_c

        ! post-process in x
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bx(ix) + tmp_c*ax(ix)
        div_c = tmp_c*bx(ix) - tmp_r*ax(ix)

        ! Solve Poisson
        tmp_r = real(waves(i, j, k), kind=dp)
        tmp_c = aimag(waves(i, j, k))
        if ((tmp_r < 1.e-16_dp) .or. (tmp_c < 1.e-16_dp)) then
          div_r = 0._dp; div_c = 0._dp
        else
          div_r = -div_r/tmp_r
          div_c = -div_c/tmp_c
        end if

        ! post-process backward
        ! post-process in z
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bz(iz) - tmp_c*az(iz)
        div_c = -tmp_c*bz(iz) - tmp_r*az(iz)
        if (iz > nz/2 + 1) div_r = -div_r
        if (iz > nz/2 + 1) div_c = -div_c

        ! post-process in y
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*by(iy) + tmp_c*ay(iy)
        div_c = tmp_c*by(iy) - tmp_r*ay(iy)
        if (iy > ny/2 + 1) div_r = -div_r
        if (iy > ny/2 + 1) div_c = -div_c

        ! post-process in x
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bx(ix) + tmp_c*ax(ix)
        div_c = -tmp_c*bx(ix) + tmp_r*ax(ix)

        ! update the entry
        div_u(i, j, k) = cmplx(div_r, div_c, kind=dp)
      end do
    end if

  end subroutine process_spectral_000

  attributes(global) subroutine process_spectral_010( &
    div_u, waves, nx_spec, ny_spec, y_sp_st, nx, ny, nz, &
    ax, bx, ay, by, az, bz &
    )
    !! Post-processes the divergence of velocity in spectral space, including
    !! scaling w.r.t. grid size.
    !!
    !! Ref. JCP 228 (2009), 5989–6015, Sec 4
    implicit none

    !> Divergence of velocity in spectral space
    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u
    !> Spectral equivalence constants
    complex(dp), device, intent(in), dimension(:, :, :) :: waves
    real(dp), device, intent(in), dimension(:) :: ax, bx, ay, by, az, bz
    !> Grid size in spectral space
    integer, value, intent(in) :: nx_spec, ny_spec
    !> Offset in y direction in the permuted slabs in spectral space
    integer, value, intent(in) :: y_sp_st
    !> Grid size
    integer, value, intent(in) :: nx, ny, nz

    integer :: i, j, k, ix, iy, iz, iy_rev
    real(dp) :: tmp_r, tmp_c, div_r, div_c, l_r, l_c, r_r, r_c

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x
    k = blockIdx%y ! nz_spec

    if (i <= nx_spec) then
      do j = 1, ny_spec
        ix = i; iy = j + y_sp_st; iz = k

        ! normalisation
        div_r = real(div_u(i, j, k), kind=dp)/nx/ny/nz
        div_c = aimag(div_u(i, j, k))/nx/ny/nz

        ! postprocess in z
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bz(iz) + tmp_c*az(iz)
        div_c = tmp_c*bz(iz) - tmp_r*az(iz)
        if (iz > nz/2 + 1) div_r = -div_r
        if (iz > nz/2 + 1) div_c = -div_c

        ! postprocess in x
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bx(ix) + tmp_c*ax(ix)
        div_c = tmp_c*bx(ix) - tmp_r*ax(ix)
        if (ix > nx/2 + 1) div_r = -div_r
        if (ix > nx/2 + 1) div_c = -div_c

        ! update the entry
        div_u(i, j, k) = cmplx(div_r, div_c, kind=dp)
      end do
    end if

    if (i <= nx_spec) then
      do j = 2, ny_spec/2 + 1
        ix = i; iy = j + y_sp_st; iz = k
        iy_rev = ny_spec - j + 2 + y_sp_st

        l_r = real(div_u(i, j, k), kind=dp)
        l_c = aimag(div_u(i, j, k))
        r_r = real(div_u(i, ny_spec - j + 2, k), kind=dp)
        r_c = aimag(div_u(i, ny_spec - j + 2, k))

        ! update the entry
        div_u(i, j, k) = 0.5_dp*cmplx( & !&
         l_r*by(iy) + l_c*ay(iy) + r_r*by(iy) - r_c*ay(iy), &
         -l_r*ay(iy) + l_c*by(iy) + r_r*ay(iy) + r_c*by(iy), kind=dp &
         )
        div_u(i, ny_spec - j + 2, k) = 0.5_dp*cmplx( & !&
         r_r*by(iy_rev) + r_c*ay(iy_rev) + l_r*by(iy_rev) - l_c*ay(iy_rev), &
         -r_r*ay(iy_rev) + r_c*by(iy_rev) + l_r*ay(iy_rev) + l_c*by(iy_rev), &
         kind=dp &
         )
      end do
    end if

    ! Solve Poisson
    if (i <= nx_spec) then
      do j = 1, ny_spec
        div_r = real(div_u(i, j, k), kind=dp)
        div_c = aimag(div_u(i, j, k))

        tmp_r = real(waves(i, j, k), kind=dp)
        tmp_c = aimag(waves(i, j, k))
        if (abs(tmp_r) < 1.e-16_dp) then
          div_r = 0._dp
        else
          div_r = -div_r/tmp_r
        end if
        if (abs(tmp_c) < 1.e-16_dp) then
          div_c = 0._dp
        else
          div_c = -div_c/tmp_c
        end if

        ! update the entry
        div_u(i, j, k) = cmplx(div_r, div_c, kind=dp)
        if (i == nx/2 + 1 .and. k == nz/2 + 1) div_u(i, j, k) = 0._dp
      end do
    end if

    ! post-process backward
    if (i <= nx_spec) then
      do j = 2, ny_spec/2 + 1
        ix = i; iy = j + y_sp_st; iz = k
        iy_rev = ny_spec - j + 2 + y_sp_st

        l_r = real(div_u(i, j, k), kind=dp)
        l_c = aimag(div_u(i, j, k))
        r_r = real(div_u(i, ny_spec - j + 2, k), kind=dp)
        r_c = aimag(div_u(i, ny_spec - j + 2, k))

        ! update the entry
        div_u(i, j, k) = cmplx( & !&
          l_r*by(iy) - l_c*ay(iy) + r_r*ay(iy) + r_c*by(iy), &
          l_r*ay(iy) + l_c*by(iy) - r_r*by(iy) + r_c*ay(iy), kind=dp &
          )
        div_u(i, ny_spec - j + 2, k) = cmplx( & !&
          r_r*by(iy_rev) - r_c*ay(iy_rev) + l_r*ay(iy_rev) + l_c*by(iy_rev), &
          r_r*ay(iy_rev) + r_c*by(iy_rev) - l_r*by(iy_rev) + l_c*ay(iy_rev), &
          kind=dp &
          )
      end do
    end if

    if (i <= nx_spec) then
      do j = 1, ny_spec
        ix = i; iy = j + y_sp_st; iz = k

        div_r = real(div_u(i, j, k), kind=dp)
        div_c = aimag(div_u(i, j, k))

        ! post-process in z
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bz(iz) - tmp_c*az(iz)
        div_c = tmp_c*bz(iz) + tmp_r*az(iz)
        if (iz > nz/2 + 1) div_r = -div_r
        if (iz > nz/2 + 1) div_c = -div_c

        ! post-process in x
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bx(ix) - tmp_c*ax(ix)
        div_c = tmp_c*bx(ix) + tmp_r*ax(ix)
        if (ix > nx/2 + 1) div_r = -div_r
        if (ix > nx/2 + 1) div_c = -div_c

        ! update the entry
        div_u(i, j, k) = cmplx(div_r, div_c, kind=dp)
      end do
    end if

  end subroutine process_spectral_010

  attributes(global) subroutine process_spectral_010_fw( &
    div_u, nx_spec, ny_spec, y_sp_st, nx, ny, nz, ax, bx, ay, by, az, bz &
    )
    !! Post-processes the divergence of velocity in spectral space, including
    !! scaling w.r.t. grid size.
    !!
    !! Ref. JCP 228 (2009), 5989–6015, Sec 4
    implicit none

    !> Divergence of velocity in spectral space
    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u
    !> Spectral equivalence constants
    real(dp), device, intent(in), dimension(:) :: ax, bx, ay, by, az, bz
    !> Grid size in spectral space
    integer, value, intent(in) :: nx_spec, ny_spec
    !> Offset in y direction in the permuted slabs in spectral space
    integer, value, intent(in) :: y_sp_st
    !> Grid size
    integer, value, intent(in) :: nx, ny, nz

    integer :: i, j, k, ix, iy, iz, iy_rev
    real(dp) :: tmp_r, tmp_c, div_r, div_c, l_r, l_c, r_r, r_c

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x
    k = blockIdx%y ! nz_spec

    if (i <= nx_spec) then
      do j = 1, ny_spec
        ix = i; iy = j + y_sp_st; iz = k

        ! normalisation
        div_r = real(div_u(i, j, k), kind=dp)/nx/ny/nz
        div_c = aimag(div_u(i, j, k))/nx/ny/nz

        ! postprocess in z
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bz(iz) + tmp_c*az(iz)
        div_c = tmp_c*bz(iz) - tmp_r*az(iz)
        if (iz > nz/2 + 1) div_r = -div_r
        if (iz > nz/2 + 1) div_c = -div_c

        ! postprocess in x
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bx(ix) + tmp_c*ax(ix)
        div_c = tmp_c*bx(ix) - tmp_r*ax(ix)
        if (ix > nx/2 + 1) div_r = -div_r
        if (ix > nx/2 + 1) div_c = -div_c

        ! update the entry
        div_u(i, j, k) = cmplx(div_r, div_c, kind=dp)
      end do
    end if

    if (i <= nx_spec) then
      do j = 2, ny_spec/2 + 1
        ix = i; iy = j + y_sp_st; iz = k
        iy_rev = ny_spec - j + 2 + y_sp_st

        l_r = real(div_u(i, j, k), kind=dp)
        l_c = aimag(div_u(i, j, k))
        r_r = real(div_u(i, ny_spec - j + 2, k), kind=dp)
        r_c = aimag(div_u(i, ny_spec - j + 2, k))

        ! update the entry
        div_u(i, j, k) = 0.5_dp*cmplx( & !&
         l_r*by(iy) + l_c*ay(iy) + r_r*by(iy) - r_c*ay(iy), &
         -l_r*ay(iy) + l_c*by(iy) + r_r*ay(iy) + r_c*by(iy), kind=dp &
         )
        div_u(i, ny_spec - j + 2, k) = 0.5_dp*cmplx( & !&
         r_r*by(iy_rev) + r_c*ay(iy_rev) + l_r*by(iy_rev) - l_c*ay(iy_rev), &
         -r_r*ay(iy_rev) + r_c*by(iy_rev) + l_r*ay(iy_rev) + l_c*by(iy_rev), &
         kind=dp &
         )
      end do
    end if

  end subroutine process_spectral_010_fw

  attributes(global) subroutine process_spectral_010_poisson( &
    div_u, a_re, a_im, off, inc, nx_spec, n, nx, ny, nz &
    )
    !! Solve the Poisson equation at cell centres with non-perioic BC along y
    !!
    !! Ref. JCP 228 (2009), 5989–6015, Sec 4
    implicit none

    !> Divergence of velocity in spectral space
    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u
    !> Spectral equivalence constants
    real(dp), device, intent(inout), dimension(:, :, :, :) :: a_re, a_im
    !> offset and increment. increment is 2 when considering only odd or even
    integer, value, intent(in) :: off, inc
    !> Grid size in spectral space
    integer, value, intent(in) :: nx_spec, n, nx, ny, nz

    integer :: i, j, k, jm, nm
    real(dp) :: tmp_r, tmp_c, div_r, div_c, epsilon

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x
    k = blockIdx%y ! nz_spec

    epsilon = 1.e-16_dp

    ! Solve Poisson
    if (i <= nx_spec) then
      ! Forward pass for the pentadiagonal matrix
      do j = 1, n - 2
        ! j mapping based on odd/even
        ! inc=2, off=0 ---> j => 2j - 1
        ! inc=2, off=1 ---> j => 2j
        ! inc=1, off=0 ---> j => j
        jm = inc*j + off - inc/2
        ! eliminate diag-1
        tmp_r = 0._dp
        if (abs(a_re(i, j, k, 3)) > epsilon) then
          tmp_r = a_re(i, j + 1, k, 2)/a_re(i, j, k, 3)
        end if
        tmp_c = 0._dp
        if (abs(a_im(i, j, k, 3)) > epsilon) then
          tmp_c = a_im(i, j + 1, k, 2)/a_im(i, j, k, 3)
        end if
        div_r = real(div_u(i, jm + inc, k) - tmp_r*div_u(i, jm, k), kind=dp)
        div_c = aimag(div_u(i, jm + inc, k) - tmp_c*div_u(i, jm, k))
        div_u(i, jm + inc, k) = cmplx(div_r, div_c, kind=dp)
        ! modify pentadiagonal coefficients in-place
        a_re(i, j + 1, k, 3) = a_re(i, j + 1, k, 3) - tmp_r*a_re(i, j, k, 4)
        a_im(i, j + 1, k, 3) = a_im(i, j + 1, k, 3) - tmp_c*a_im(i, j, k, 4)
        a_re(i, j + 1, k, 4) = a_re(i, j + 1, k, 4) - tmp_r*a_re(i, j, k, 5)
        a_im(i, j + 1, k, 4) = a_im(i, j + 1, k, 4) - tmp_c*a_im(i, j, k, 5)

        ! eliminate diag-2
        tmp_r = 0._dp
        if (abs(a_re(i, j, k, 3)) > epsilon) then
          tmp_r = a_re(i, j + 2, k, 1)/a_re(i, j, k, 3)
        end if
        tmp_c = 0._dp
        if (abs(a_im(i, j, k, 3)) > epsilon) then
          tmp_c = a_im(i, j + 2, k, 1)/a_im(i, j, k, 3)
        end if
        div_r = real(div_u(i, jm + 2*inc, k) - tmp_r*div_u(i, jm, k), kind=dp)
        div_c = aimag(div_u(i, jm + 2*inc, k) - tmp_c*div_u(i, jm, k))
        div_u(i, jm + 2*inc, k) = cmplx(div_r, div_c, kind=dp)
        ! modify pentadiagonal coefficients in-place
        a_re(i, j + 2, k, 2) = a_re(i, j + 2, k, 2) - tmp_r*a_re(i, j, k, 4)
        a_im(i, j + 2, k, 2) = a_im(i, j + 2, k, 2) - tmp_c*a_im(i, j, k, 4)
        a_re(i, j + 2, k, 3) = a_re(i, j + 2, k, 3) - tmp_r*a_re(i, j, k, 5)
        a_im(i, j + 2, k, 3) = a_im(i, j + 2, k, 3) - tmp_c*a_im(i, j, k, 5)
      end do

      ! handle the last row
      if (abs(a_re(i, n - 1, k, 3)) > epsilon) then
        tmp_r = a_re(i, n, k, 2)/a_re(i, n - 1, k, 3)
      else
        tmp_r = 0._dp
      end if
      if (abs(a_im(i, n - 1, k, 3)) > epsilon) then
        tmp_c = a_im(i, n, k, 2)/a_im(i, n - 1, k, 3)
      else
        tmp_c = 0._dp
      end if
      div_r = a_re(i, n, k, 3) - tmp_r*a_re(i, n - 1, k, 4)
      div_c = a_im(i, n, k, 3) - tmp_c*a_im(i, n - 1, k, 4)

      ! j mapping based on odd/even for last point j=n
      nm = inc*n + off - inc/2
      if (abs(div_r) > epsilon) then
        tmp_r = tmp_r/div_r
        div_r = real(div_u(i, nm, k), kind=dp)/div_r &
                - tmp_r*real(div_u(i, nm - inc, k), kind=dp)
      else
        tmp_r = 0._dp
        div_r = 0._dp
      end if
      if (abs(div_c) > epsilon) then
        tmp_c = tmp_c/div_c
        div_c = aimag(div_u(i, nm, k))/div_c &
                - tmp_c*aimag(div_u(i, nm - inc, k))
      else
        tmp_c = 0._dp
        div_c = 0._dp
      end if
      div_u(i, nm, k) = cmplx(div_r, div_c, kind=dp)

      if (abs(a_re(i, n - 1, k, 3)) > epsilon) then
        tmp_r = 1._dp/a_re(i, n - 1, k, 3)
      else
        tmp_r = 0._dp
      end if
      if (abs(a_im(i, n - 1, k, 3)) > epsilon) then
        tmp_c = 1._dp/a_im(i, n - 1, k, 3)
      else
        tmp_c = 0._dp
      end if
      div_r = a_re(i, n - 1, k, 4)*tmp_r
      div_c = a_im(i, n - 1, k, 4)*tmp_c
      div_u(i, nm - inc, k) = cmplx( & !&
        real(div_u(i, nm - inc, k), kind=dp)*tmp_r &
        - real(div_u(i, nm, k), kind=dp)*div_r, &
        aimag(div_u(i, nm - inc, k))*tmp_c &
        - aimag(div_u(i, nm, k))*div_c, &
        kind=dp &
        )

      if (i == nx/2 + 1 .and. k == nz/2 + 1) then
        div_u(i, nm, k) = 0._dp
        div_u(i, nm - inc, k) = 0._dp
      end if

      ! backward pass
      do j = n - 2, 1, -1
        ! j mapping based on odd/even
        jm = inc*j + off - inc/2
        if (abs(a_re(i, j, k, 3)) > epsilon) then
          tmp_r = 1._dp/a_re(i, j, k, 3)
        else
          tmp_r = 0._dp
        end if
        if (abs(a_im(i, j, k, 3)) > epsilon) then
          tmp_c = 1._dp/a_im(i, j, k, 3)
        else
          tmp_c = 0._dp
        end if
        div_u(i, jm, k) = cmplx( & !&
          tmp_r*(real(div_u(i, jm, k), kind=dp) &
                 - a_re(i, j, k, 4)*real(div_u(i, jm + inc, k), kind=dp) &
                 - a_re(i, j, k, 5)*real(div_u(i, jm + 2*inc, k), kind=dp)), &
          tmp_c*(aimag(div_u(i, jm, k)) &
                 - a_im(i, j, k, 4)*aimag(div_u(i, jm + inc, k)) &
                 - a_im(i, j, k, 5)*aimag(div_u(i, jm + 2*inc, k))), &
          kind=dp &
          )
        if (i == nx/2 + 1 .and. k == nz/2 + 1) div_u(i, jm, k) = 0._dp
      end do
    end if

  end subroutine process_spectral_010_poisson

  attributes(global) subroutine process_spectral_010_bw( &
    div_u, nx_spec, ny_spec, y_sp_st, nx, ny, nz, ax, bx, ay, by, az, bz &
    )
    !! Post-processes the divergence of velocity in spectral space, including
    !! scaling w.r.t. grid size.
    !!
    !! Ref. JCP 228 (2009), 5989–6015, Sec 4
    implicit none

    !> Divergence of velocity in spectral space
    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u
    !> Spectral equivalence constants
    real(dp), device, intent(in), dimension(:) :: ax, bx, ay, by, az, bz
    !> Grid size in spectral space
    integer, value, intent(in) :: nx_spec, ny_spec
    !> Offset in y direction in the permuted slabs in spectral space
    integer, value, intent(in) :: y_sp_st
    !> Grid size
    integer, value, intent(in) :: nx, ny, nz

    integer :: i, j, k, ix, iy, iz, iy_rev
    real(dp) :: tmp_r, tmp_c, div_r, div_c, l_r, l_c, r_r, r_c

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x
    k = blockIdx%y ! nz_spec

    ! post-process backward
    if (i <= nx_spec) then
      do j = 2, ny_spec/2 + 1
        ix = i; iy = j + y_sp_st; iz = k
        iy_rev = ny_spec - j + 2 + y_sp_st

        l_r = real(div_u(i, j, k), kind=dp)
        l_c = aimag(div_u(i, j, k))
        r_r = real(div_u(i, ny_spec - j + 2, k), kind=dp)
        r_c = aimag(div_u(i, ny_spec - j + 2, k))

        ! update the entry
        div_u(i, j, k) = cmplx( & !&
          l_r*by(iy) - l_c*ay(iy) + r_r*ay(iy) + r_c*by(iy), &
          l_r*ay(iy) + l_c*by(iy) - r_r*by(iy) + r_c*ay(iy), kind=dp &
          )
        div_u(i, ny_spec - j + 2, k) = cmplx( & !&
          r_r*by(iy_rev) - r_c*ay(iy_rev) + l_r*ay(iy_rev) + l_c*by(iy_rev), &
          r_r*ay(iy_rev) + r_c*by(iy_rev) - l_r*by(iy_rev) + l_c*ay(iy_rev), &
          kind=dp &
          )
      end do
    end if

    if (i <= nx_spec) then
      do j = 1, ny_spec
        ix = i; iy = j + y_sp_st; iz = k

        div_r = real(div_u(i, j, k), kind=dp)
        div_c = aimag(div_u(i, j, k))

        ! post-process in z
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bz(iz) - tmp_c*az(iz)
        div_c = tmp_c*bz(iz) + tmp_r*az(iz)
        if (iz > nz/2 + 1) div_r = -div_r
        if (iz > nz/2 + 1) div_c = -div_c

        ! post-process in x
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bx(ix) - tmp_c*ax(ix)
        div_c = tmp_c*bx(ix) + tmp_r*ax(ix)
        if (ix > nx/2 + 1) div_r = -div_r
        if (ix > nx/2 + 1) div_c = -div_c

        ! update the entry
        div_u(i, j, k) = cmplx(div_r, div_c, kind=dp)
      end do
    end if

  end subroutine process_spectral_010_bw

! ------------------------------------------------------------------
  ! 110 SPECTRAL KERNELS (7 separate launches)
  !
  ! Spectral array layout: (nz/2+1, nx, ny) — Z R2C in dim1
  ! Thread layout:
  !   i = threadIdx → X (dim2, up to nx)
  !   k = blockIdx%y → Y (dim3, up to ny)
  !   serial loop j → Z R2C modes (dim1, nz/2+1)
  !
  ! Launch config for all:
  !   blocks = dim3((nx-1)/tsize+1, ny, 1)
  !   threads = dim3(min(nx, tsize), 1, 1)
  !
  ! Arguments use PHYSICAL names (nx,ny,nz) not spectral-dim names,
  ! to avoid confusion with the transposed layout.
  ! ------------------------------------------------------------------

  attributes(global) subroutine process_spectral_110_norm_z( &
    div_u, nz_h, nx, ny, nz, az, bz &
    )
    !! Step 1 (forward): normalise + Z periodic post-process
    !! Z is dim1 (serial j loop), periodic R2C — no sign flip needed
    !! since j only goes to nz/2+1.
    implicit none

    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u ! (nz/2+1, nx, ny)
    real(dp), device, intent(in), dimension(:) :: az, bz
    integer, value, intent(in) :: nz_h  ! nz/2+1
    integer, value, intent(in) :: nx, ny, nz

    integer :: i, j, k
    real(dp) :: tmp_r, tmp_c, div_r, div_c

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x  ! X index
    k = blockIdx%y                                   ! Y index

    if (i <= nx .and. k <= ny) then
      do j = 1, nz_h
        div_r = real(div_u(j, i, k), kind=dp)/(nx*ny*nz)
        div_c = aimag(div_u(j, i, k))/(nx*ny*nz)

        ! Z periodic post-process (forward)
        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bz(j) + tmp_c*az(j)
        div_c = tmp_c*bz(j) - tmp_r*az(j)
        ! No sign flip: j <= nz/2+1 always, so j > nz/2+1 is never true

        div_u(j, i, k) = cmplx(div_r, div_c, kind=dp)
      end do
    end if

  end subroutine process_spectral_110_norm_z

  attributes(global) subroutine process_spectral_110_x_pair_fw( &
    div_u, nz_h, nx, ny, x_sp_st, ax, bx &
    )
    !! Step 2 (forward): X paired even/odd split
    !! X is dim2 (thread i). Only i in [2, nx/2+1] executes.
    !! Writes to i and nx-i+2. Race-free (pair doesn't enter).
    implicit none

    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u ! (nz/2+1, nx, ny)
    real(dp), device, intent(in), dimension(:) :: ax, bx
    integer, value, intent(in) :: nz_h, nx, ny
    integer, value, intent(in) :: x_sp_st

    integer :: i, j, k, ix, ix_pair
    real(dp) :: l_r, l_c, r_r, r_c

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x  ! X index
    k = blockIdx%y                                   ! Y index
    ! Note: when i == nx/2+1 (even nx), ix_pair == i (Nyquist self-pair).
    ! Both writes produce the same result, second is a harmless overwrite.
    if (i >= 2 .and. i <= nx/2 + 1 .and. k <= ny) then
      ix = i + x_sp_st
      ix_pair = nx - i + 2
      do j = 1, nz_h
        l_r = real(div_u(j, i, k), kind=dp)
        l_c = aimag(div_u(j, i, k))
        r_r = real(div_u(j, ix_pair, k), kind=dp)
        r_c = aimag(div_u(j, ix_pair, k))

        div_u(j, i, k) = 0.5_dp*cmplx( & !&
          l_r*bx(ix) + l_c*ax(ix) + r_r*bx(ix) - r_c*ax(ix), &
          -l_r*ax(ix) + l_c*bx(ix) + r_r*ax(ix) + r_c*bx(ix), kind=dp &
          )
        div_u(j, ix_pair, k) = 0.5_dp*cmplx( & !&
          r_r*bx(ix_pair + x_sp_st) + r_c*ax(ix_pair + x_sp_st) &
            + l_r*bx(ix_pair + x_sp_st) - l_c*ax(ix_pair + x_sp_st), &
          -r_r*ax(ix_pair + x_sp_st) + r_c*bx(ix_pair + x_sp_st) &
            + l_r*ax(ix_pair + x_sp_st) + l_c*bx(ix_pair + x_sp_st), &
          kind=dp &
          )
      end do
    end if

  end subroutine process_spectral_110_x_pair_fw

  attributes(global) subroutine process_spectral_110_y_pair_fw( &
    div_u, nz_h, nx, ny, y_sp_st, ay, by &
    )
    !! Step 3 (forward): Y paired even/odd split
    !! Y is dim3 (blockIdx%y = k). Only k in [2, ny/2+1] executes.
    !! Writes to k and ny-k+2. Race-free (pair block doesn't enter).
    implicit none

    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u ! (nz/2+1, nx, ny)
    real(dp), device, intent(in), dimension(:) :: ay, by
    integer, value, intent(in) :: nz_h, nx, ny
    integer, value, intent(in) :: y_sp_st

    integer :: i, j, k, iy, iy_pair
    real(dp) :: l_r, l_c, r_r, r_c

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x  ! X index
    k = blockIdx%y                                   ! Y index

    if (i <= nx .and. k >= 2 .and. k <= ny/2 + 1) then
      iy = k + y_sp_st
      iy_pair = ny - k + 2
      do j = 1, nz_h
        l_r = real(div_u(j, i, k), kind=dp)
        l_c = aimag(div_u(j, i, k))
        r_r = real(div_u(j, i, iy_pair), kind=dp)
        r_c = aimag(div_u(j, i, iy_pair))

        div_u(j, i, k) = 0.5_dp*cmplx( & !&
          l_r*by(iy) + l_c*ay(iy) + r_r*by(iy) - r_c*ay(iy), &
          -l_r*ay(iy) + l_c*by(iy) + r_r*ay(iy) + r_c*by(iy), kind=dp &
          )
        div_u(j, i, iy_pair) = 0.5_dp*cmplx( & !&
          r_r*by(iy_pair + y_sp_st) + r_c*ay(iy_pair + y_sp_st) &
            + l_r*by(iy_pair + y_sp_st) - l_c*ay(iy_pair + y_sp_st), &
          -r_r*ay(iy_pair + y_sp_st) + r_c*by(iy_pair + y_sp_st) &
            + l_r*ay(iy_pair + y_sp_st) + l_c*by(iy_pair + y_sp_st), &
          kind=dp &
          )
      end do
    end if

  end subroutine process_spectral_110_y_pair_fw

  attributes(global) subroutine process_spectral_110_poisson( &
    div_u, waves, nz_h, nx, ny, nz, x_sp_st &
    )
    !! Step 4: Poisson solve — divide by waves
    implicit none

    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u ! (nz/2+1, nx, ny)
    complex(dp), device, intent(in), dimension(:, :, :) :: waves    ! (nz/2+1, nx, ny)
    integer, value, intent(in) :: nz_h, nx, ny, nz
    integer, value, intent(in) :: x_sp_st

    integer :: i, j, k, ix
    real(dp) :: div_r, div_c, tmp_r, tmp_c

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x  ! X index
    k = blockIdx%y                                   ! Y index

    if (i <= nx .and. k <= ny) then
      ix = i + x_sp_st
      do j = 1, nz_h
        div_r = real(div_u(j, i, k), kind=dp)
        div_c = aimag(div_u(j, i, k))

        tmp_r = real(waves(j, i, k), kind=dp)
        tmp_c = aimag(waves(j, i, k))
        if (abs(tmp_r) < 1.e-16_dp) then
          div_r = 0._dp
        else
          div_r = -div_r/tmp_r
        end if
        if (abs(tmp_c) < 1.e-16_dp) then
          div_c = 0._dp
        else
          div_c = -div_c/tmp_c
        end if

        div_u(j, i, k) = cmplx(div_r, div_c, kind=dp)
        ! Zero Nyquist modes
        if (ix == nx/2 + 1 .and. j == nz/2 + 1) div_u(j, i, k) = 0._dp
      end do
    end if

  end subroutine process_spectral_110_poisson

  attributes(global) subroutine process_spectral_110_y_pair_bw( &
    div_u, nz_h, nx, ny, y_sp_st, ay, by &
    )
    !! Step 5 (backward): Y paired even/odd recombine
    implicit none

    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u ! (nz/2+1, nx, ny)
    real(dp), device, intent(in), dimension(:) :: ay, by
    integer, value, intent(in) :: nz_h, nx, ny
    integer, value, intent(in) :: y_sp_st

    integer :: i, j, k, iy, iy_pair
    real(dp) :: l_r, l_c, r_r, r_c

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x
    k = blockIdx%y

    if (i <= nx .and. k >= 2 .and. k <= ny/2 + 1) then
      iy = k + y_sp_st
      iy_pair = ny - k + 2
      do j = 1, nz_h
        l_r = real(div_u(j, i, k), kind=dp)
        l_c = aimag(div_u(j, i, k))
        r_r = real(div_u(j, i, iy_pair), kind=dp)
        r_c = aimag(div_u(j, i, iy_pair))

        div_u(j, i, k) = cmplx( & !&
          l_r*by(iy) - l_c*ay(iy) + r_r*ay(iy) + r_c*by(iy), &
          l_r*ay(iy) + l_c*by(iy) - r_r*by(iy) + r_c*ay(iy), kind=dp &
          )
        div_u(j, i, iy_pair) = cmplx( & !&
          r_r*by(iy_pair + y_sp_st) - r_c*ay(iy_pair + y_sp_st) &
            + l_r*ay(iy_pair + y_sp_st) + l_c*by(iy_pair + y_sp_st), &
          r_r*ay(iy_pair + y_sp_st) + r_c*by(iy_pair + y_sp_st) &
            - l_r*by(iy_pair + y_sp_st) + l_c*ay(iy_pair + y_sp_st), &
          kind=dp &
          )
      end do
    end if

  end subroutine process_spectral_110_y_pair_bw

  attributes(global) subroutine process_spectral_110_x_pair_bw( &
    div_u, nz_h, nx, ny, x_sp_st, ax, bx &
    )
    !! Step 6 (backward): X paired even/odd recombine
    implicit none

    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u ! (nz/2+1, nx, ny)
    real(dp), device, intent(in), dimension(:) :: ax, bx
    integer, value, intent(in) :: nz_h, nx, ny
    integer, value, intent(in) :: x_sp_st

    integer :: i, j, k, ix, ix_pair
    real(dp) :: l_r, l_c, r_r, r_c

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x
    k = blockIdx%y

    if (i >= 2 .and. i <= nx/2 + 1 .and. k <= ny) then
      ix = i + x_sp_st
      ix_pair = nx - i + 2
      do j = 1, nz_h
        l_r = real(div_u(j, i, k), kind=dp)
        l_c = aimag(div_u(j, i, k))
        r_r = real(div_u(j, ix_pair, k), kind=dp)
        r_c = aimag(div_u(j, ix_pair, k))

        div_u(j, i, k) = cmplx( & !&
          l_r*bx(ix) - l_c*ax(ix) + r_r*ax(ix) + r_c*bx(ix), &
          l_r*ax(ix) + l_c*bx(ix) - r_r*bx(ix) + r_c*ax(ix), kind=dp &
          )
        div_u(j, ix_pair, k) = cmplx( & !&
          r_r*bx(ix_pair + x_sp_st) - r_c*ax(ix_pair + x_sp_st) &
            + l_r*ax(ix_pair + x_sp_st) + l_c*bx(ix_pair + x_sp_st), &
          r_r*ax(ix_pair + x_sp_st) + r_c*bx(ix_pair + x_sp_st) &
            - l_r*bx(ix_pair + x_sp_st) + l_c*ax(ix_pair + x_sp_st), &
          kind=dp &
          )
      end do
    end if

  end subroutine process_spectral_110_x_pair_bw

  attributes(global) subroutine process_spectral_110_z_bw( &
    div_u, nz_h, nx, ny, az, bz &
    )
    !! Step 7 (backward): Z periodic undo
    implicit none

    complex(dp), device, intent(inout), dimension(:, :, :) :: div_u ! (nz/2+1, nx, ny)
    real(dp), device, intent(in), dimension(:) :: az, bz
    integer, value, intent(in) :: nz_h, nx, ny

    integer :: i, j, k
    real(dp) :: tmp_r, tmp_c, div_r, div_c

    i = threadIdx%x + (blockIdx%x - 1)*blockDim%x
    k = blockIdx%y

    if (i <= nx .and. k <= ny) then
      do j = 1, nz_h
        div_r = real(div_u(j, i, k), kind=dp)
        div_c = aimag(div_u(j, i, k))

        tmp_r = div_r
        tmp_c = div_c
        div_r = tmp_r*bz(j) - tmp_c*az(j)
        div_c = tmp_c*bz(j) + tmp_r*az(j)
        ! No sign flip: j <= nz/2+1 always

        div_u(j, i, k) = cmplx(div_r, div_c, kind=dp)
      end do
    end if

  end subroutine process_spectral_110_z_bw

  attributes(global) subroutine enforce_periodicity_x(f_out, f_in, nx)
    implicit none

    real(dp), device, intent(out), dimension(:, :, :) :: f_out
    real(dp), device, intent(in), dimension(:, :, :) :: f_in
    integer, value, intent(in) :: nx

    integer :: i, j, k, n2

    j = threadIdx%x
    k = blockIdx%x
    n2 = nx/2

    do i = 1, n2
      f_out(i, j, k) = f_in(2*i - 1, j, k)
    end do
    if (mod(nx, 2) == 1) then
      ! odd-size center entry
      f_out(n2 + 1, j, k) = f_in(nx, j, k)
      do i = n2 + 2, nx
        f_out(i, j, k) = f_in(2*nx - 2*i + 2, j, k)
      end do
    else
      do i = n2 + 1, nx
        f_out(i, j, k) = f_in(2*nx - 2*i + 2, j, k)
      end do
    end if

  end subroutine enforce_periodicity_x

  attributes(global) subroutine undo_periodicity_x(f_out, f_in, nx)
    implicit none

    real(dp), device, intent(out), dimension(:, :, :) :: f_out
    real(dp), device, intent(in), dimension(:, :, :) :: f_in
    integer, value, intent(in) :: nx

    integer :: i, j, k, n2

    j = threadIdx%x
    k = blockIdx%x
    n2 = nx/2

    do i = 1, n2
      f_out(2*i - 1, j, k) = f_in(i, j, k)
      f_out(2*i, j, k) = f_in(nx - i + 1, j, k)
    end do
    if (mod(nx, 2) == 1) then
      ! odd-size center entry
      f_out(nx, j, k) = f_in(n2 + 1, j, k)
    end if

  end subroutine undo_periodicity_x

  attributes(global) subroutine enforce_periodicity_y(f_out, f_in, ny)
    implicit none

    real(dp), device, intent(out), dimension(:, :, :) :: f_out
    real(dp), device, intent(in), dimension(:, :, :) :: f_in
    integer, value, intent(in) :: ny

    integer :: i, j, k, n2

    i = threadIdx%x
    k = blockIdx%x
    n2 = ny/2

    do j = 1, n2
      f_out(i, j, k) = f_in(i, 2*j - 1, k)
    end do
    if (mod(ny, 2) == 1) then
      ! odd-size center entry
      f_out(i, n2 + 1, k) = f_in(i, ny, k)
      do j = n2 + 2, ny
        f_out(i, j, k) = f_in(i, 2*ny - 2*j + 2, k)
      end do
    else
      do j = n2 + 1, ny
        f_out(i, j, k) = f_in(i, 2*ny - 2*j + 2, k)
      end do
    end if

  end subroutine enforce_periodicity_y

  attributes(global) subroutine undo_periodicity_y(f_out, f_in, ny)
    implicit none

    real(dp), device, intent(out), dimension(:, :, :) :: f_out
    real(dp), device, intent(in), dimension(:, :, :) :: f_in
    integer, value, intent(in) :: ny

    integer :: i, j, k, n2

    i = threadIdx%x
    k = blockIdx%x
    n2 = ny/2

    do j = 1, n2
      f_out(i, 2*j - 1, k) = f_in(i, j, k)
      f_out(i, 2*j, k) = f_in(i, ny - j + 1, k)
    end do
    if (mod(ny, 2) == 1) then
      ! odd-size center entry
      f_out(i, ny, k) = f_in(i, n2 + 1, k)
    end if

  end subroutine undo_periodicity_y

  attributes(global) subroutine enforce_periodicity_xy(f_out, f_in, nx, ny)
    !! Combined X and Y periodicity enforcement (interleave shuffle).
    implicit none

    real(dp), device, intent(out), dimension(:, :, :) :: f_out
    real(dp), device, intent(in), dimension(:, :, :) :: f_in
    integer, value, intent(in) :: nx, ny

    integer :: i, j, k
    integer :: n2x, n2y
    integer :: src_i, src_j

    n2x = nx/2
    n2y = ny/2

    k = blockIdx%x
    j = (blockIdx%y - 1)*blockDim%x + threadIdx%x
    if (j > ny) return

    if (j <= n2y) then
      src_j = 2*j - 1
    else if (mod(ny, 2) == 1 .and. j == n2y + 1) then
      src_j = ny
    else
      src_j = 2*ny - 2*j + 2
    end if

    do i = 1, nx
      if (i <= n2x) then
        src_i = 2*i - 1
      else if (mod(nx, 2) == 1 .and. i == n2x + 1) then
        src_i = nx
      else
        src_i = 2*nx - 2*i + 2
      end if

      f_out(i, j, k) = f_in(src_i, src_j, k)
    end do

  end subroutine enforce_periodicity_xy

  attributes(global) subroutine undo_periodicity_xy(f_out, f_in, nx, ny)
    !! Combined X and Y periodicity undo (reverse interleave shuffle).
    implicit none

    real(dp), device, intent(out), dimension(:, :, :) :: f_out
    real(dp), device, intent(in), dimension(:, :, :) :: f_in
    integer, value, intent(in) :: nx, ny

    integer :: i, j, k
    integer :: n2x, n2y
    integer :: src_i, src_j

    n2x = nx/2
    n2y = ny/2

    k = blockIdx%x
    j = (blockIdx%y - 1)*blockDim%x + threadIdx%x
    if (j > ny) return

    if (mod(ny, 2) == 1 .and. j == ny) then
      src_j = n2y + 1
    else if (mod(j, 2) == 1) then
      src_j = (j + 1)/2
    else
      src_j = ny - j/2 + 1
    end if

    do i = 1, nx
      if (mod(nx, 2) == 1 .and. i == nx) then
        src_i = n2x + 1
      else if (mod(i, 2) == 1) then
        src_i = (i + 1)/2
      else
        src_i = nx - i/2 + 1
      end if

      f_out(i, j, k) = f_in(src_i, src_j, k)
    end do

  end subroutine undo_periodicity_xy

end module m_cuda_spectral