module m_cuda_tdsops use iso_fortran_env, only: stderr => error_unit use m_common, only: dp use m_tdsops, only: tdsops_t, tdsops_init implicit none type, extends(tdsops_t) :: cuda_tdsops_t !! CUDA extension of the Tridiagonal Solver Operators class. !! !! Regular tdsops_t class is initiated and the coefficient arrays are !! copied into device arrays so that cuda kernels can use them. real(dp), device, allocatable :: dist_fw_dev(:), dist_bw_dev(:), & dist_sa_dev(:), dist_sc_dev(:), & dist_af_dev(:) real(dp), device, allocatable :: thom_f_dev(:), thom_s_dev(:), & thom_w_dev(:), thom_p_dev(:) real(dp), device, allocatable :: coeffs_dev(:), & coeffs_s_dev(:, :), coeffs_e_dev(:, :) contains end type cuda_tdsops_t interface cuda_tdsops_t module procedure cuda_tdsops_init end interface cuda_tdsops_t contains function cuda_tdsops_init(n, delta, operation, scheme, bc_start, bc_end, & n_halo, from_to, sym, c_nu, nu0_nu) result(tdsops) !! Constructor function for the cuda_tdsops_t class. !! See tdsops_t for details. implicit none type(cuda_tdsops_t) :: tdsops !! return value of the function integer, intent(in) :: n real(dp), intent(in) :: delta character(*), intent(in) :: operation, scheme integer, intent(in) :: bc_start, bc_end integer, optional, intent(in) :: n_halo character(*), optional, intent(in) :: from_to logical, optional, intent(in) :: sym real(dp), optional, intent(in) :: c_nu, nu0_nu integer :: n_stencil tdsops%tdsops_t = tdsops_init(n, delta, operation, scheme, bc_start, & bc_end, n_halo, from_to, sym, c_nu, nu0_nu) n_stencil = 2*tdsops%n_halo + 1 allocate (tdsops%dist_fw_dev(n), tdsops%dist_bw_dev(n)) allocate (tdsops%dist_sa_dev(n), tdsops%dist_sc_dev(n)) allocate (tdsops%dist_af_dev(n)) allocate (tdsops%thom_f_dev(n), tdsops%thom_s_dev(n)) allocate (tdsops%thom_w_dev(n), tdsops%thom_p_dev(n)) allocate (tdsops%coeffs_dev(n_stencil)) allocate (tdsops%coeffs_s_dev(n_stencil, tdsops%n_halo)) allocate (tdsops%coeffs_e_dev(n_stencil, tdsops%n_halo)) tdsops%dist_fw_dev(:) = tdsops%dist_fw(:) tdsops%dist_bw_dev(:) = tdsops%dist_bw(:) tdsops%dist_sa_dev(:) = tdsops%dist_sa(:) tdsops%dist_sc_dev(:) = tdsops%dist_sc(:) tdsops%dist_af_dev(:) = tdsops%dist_af(:) tdsops%thom_f_dev(:) = tdsops%thom_f(:) tdsops%thom_s_dev(:) = tdsops%thom_s(:) tdsops%thom_w_dev(:) = tdsops%thom_w(:) tdsops%thom_p_dev(:) = tdsops%thom_p(:) tdsops%coeffs_dev(:) = tdsops%coeffs(:) tdsops%coeffs_s_dev(:, :) = tdsops%coeffs_s(:, :) tdsops%coeffs_e_dev(:, :) = tdsops%coeffs_e(:, :) end function cuda_tdsops_init end module m_cuda_tdsops