checkpoint_io.f90 Source File


This file depends on

sourcefile~~checkpoint_io.f90~~EfferentGraph sourcefile~checkpoint_io.f90 checkpoint_io.f90 sourcefile~adios2_io.f90 adios2_io.f90 sourcefile~checkpoint_io.f90->sourcefile~adios2_io.f90 sourcefile~common.f90~2 common.f90 sourcefile~checkpoint_io.f90->sourcefile~common.f90~2 sourcefile~config.f90 config.f90 sourcefile~checkpoint_io.f90->sourcefile~config.f90 sourcefile~field.f90 field.f90 sourcefile~checkpoint_io.f90->sourcefile~field.f90 sourcefile~solver.f90 solver.f90 sourcefile~checkpoint_io.f90->sourcefile~solver.f90 sourcefile~adios2_io.f90->sourcefile~common.f90~2 sourcefile~config.f90->sourcefile~common.f90~2 sourcefile~field.f90->sourcefile~common.f90~2 sourcefile~solver.f90->sourcefile~common.f90~2 sourcefile~solver.f90->sourcefile~config.f90 sourcefile~solver.f90->sourcefile~field.f90 sourcefile~allocator.f90~2 allocator.f90 sourcefile~solver.f90->sourcefile~allocator.f90~2 sourcefile~backend.f90~2 backend.f90 sourcefile~solver.f90->sourcefile~backend.f90~2 sourcefile~ibm.f90 ibm.f90 sourcefile~solver.f90->sourcefile~ibm.f90 sourcefile~mesh.f90 mesh.f90 sourcefile~solver.f90->sourcefile~mesh.f90 sourcefile~tdsops.f90 tdsops.f90 sourcefile~solver.f90->sourcefile~tdsops.f90 sourcefile~time_integrator.f90 time_integrator.f90 sourcefile~solver.f90->sourcefile~time_integrator.f90 sourcefile~vector_calculus.f90 vector_calculus.f90 sourcefile~solver.f90->sourcefile~vector_calculus.f90 sourcefile~allocator.f90~2->sourcefile~common.f90~2 sourcefile~allocator.f90~2->sourcefile~field.f90 sourcefile~backend.f90~2->sourcefile~common.f90~2 sourcefile~backend.f90~2->sourcefile~field.f90 sourcefile~backend.f90~2->sourcefile~allocator.f90~2 sourcefile~backend.f90~2->sourcefile~mesh.f90 sourcefile~backend.f90~2->sourcefile~tdsops.f90 sourcefile~poisson_fft.f90 poisson_fft.f90 sourcefile~backend.f90~2->sourcefile~poisson_fft.f90 sourcefile~ibm.f90->sourcefile~adios2_io.f90 sourcefile~ibm.f90->sourcefile~common.f90~2 sourcefile~ibm.f90->sourcefile~field.f90 sourcefile~ibm.f90->sourcefile~allocator.f90~2 sourcefile~ibm.f90->sourcefile~backend.f90~2 sourcefile~ibm.f90->sourcefile~mesh.f90 sourcefile~mesh.f90->sourcefile~common.f90~2 sourcefile~mesh.f90->sourcefile~field.f90 sourcefile~decomp_dummy.f90 decomp_dummy.f90 sourcefile~mesh.f90->sourcefile~decomp_dummy.f90 sourcefile~mesh_content.f90 mesh_content.f90 sourcefile~mesh.f90->sourcefile~mesh_content.f90 sourcefile~tdsops.f90->sourcefile~common.f90~2 sourcefile~time_integrator.f90->sourcefile~common.f90~2 sourcefile~time_integrator.f90->sourcefile~field.f90 sourcefile~time_integrator.f90->sourcefile~allocator.f90~2 sourcefile~time_integrator.f90->sourcefile~backend.f90~2 sourcefile~vector_calculus.f90->sourcefile~common.f90~2 sourcefile~vector_calculus.f90->sourcefile~field.f90 sourcefile~vector_calculus.f90->sourcefile~allocator.f90~2 sourcefile~vector_calculus.f90->sourcefile~backend.f90~2 sourcefile~vector_calculus.f90->sourcefile~tdsops.f90 sourcefile~decomp_dummy.f90->sourcefile~mesh_content.f90 sourcefile~mesh_content.f90->sourcefile~common.f90~2 sourcefile~poisson_fft.f90->sourcefile~common.f90~2 sourcefile~poisson_fft.f90->sourcefile~field.f90 sourcefile~poisson_fft.f90->sourcefile~mesh.f90 sourcefile~poisson_fft.f90->sourcefile~tdsops.f90

Source Code

module m_checkpoint_manager_base
!! Base module defining abstract interface for checkpoint functionality
  use m_solver, only: solver_t
  use m_config, only: checkpoint_config_t

  implicit none

  private
  public :: checkpoint_manager_base_t

  type, abstract :: checkpoint_manager_base_t
    type(checkpoint_config_t) :: checkpoint_cfg
  contains
    procedure(init_interface), deferred :: init
    procedure(handle_restart_interface), deferred :: handle_restart
    procedure(handle_io_step_interface), deferred :: handle_io_step
    procedure(finalise_interface), deferred :: finalise
  end type checkpoint_manager_base_t

  abstract interface
    subroutine init_interface(self, comm)
      import :: checkpoint_manager_base_t
      class(checkpoint_manager_base_t), intent(inout) :: self
      integer, intent(in) :: comm
    end subroutine init_interface

    subroutine handle_restart_interface(self, solver, comm)
      import :: checkpoint_manager_base_t, solver_t
      class(checkpoint_manager_base_t), intent(inout) :: self
      class(solver_t), intent(inout) :: solver
      integer, intent(in), optional :: comm
    end subroutine handle_restart_interface

    subroutine handle_io_step_interface(self, solver, timestep, comm)
      import :: checkpoint_manager_base_t, solver_t
      class(checkpoint_manager_base_t), intent(inout) :: self
      class(solver_t), intent(in) :: solver
      integer, intent(in) :: timestep
      integer, intent(in), optional :: comm
    end subroutine handle_io_step_interface

    subroutine finalise_interface(self)
      import :: checkpoint_manager_base_t
      class(checkpoint_manager_base_t), intent(inout) :: self
    end subroutine finalise_interface
  end interface
end module m_checkpoint_manager_base

module m_checkpoint_manager_impl
!! Implementation of checkpoint manager when ADIOS2 is enabled
  use mpi, only: MPI_COMM_WORLD, MPI_Comm_rank, MPI_Abort
  use m_common, only: dp, i8, DIR_C, VERT, get_argument
  use m_field, only: field_t
  use m_solver, only: solver_t
  use m_adios2_io, only: adios2_writer_t, adios2_reader_t, adios2_file_t, &
                         adios2_mode_write, adios2_mode_read
  use m_config, only: checkpoint_config_t
  use m_checkpoint_manager_base, only: checkpoint_manager_base_t

  implicit none

  private
  public :: checkpoint_manager_adios2_t

  ! type for dynamic field buffer mapping
  type :: field_buffer_map_t
    character(len=32) :: field_name
    real(dp), dimension(:, :, :), allocatable :: buffer
  end type field_buffer_map_t

  type, extends(checkpoint_manager_base_t) :: checkpoint_manager_adios2_t
    type(adios2_writer_t) :: adios2_writer                          !! Writer for checkpoints
    type(adios2_writer_t) :: snapshot_writer                        !! Writer for snapshots
    integer :: last_checkpoint_step = -1
    integer, dimension(3) :: output_stride = [1, 1, 1]              !! Stride factors for snapshots (default: full resolution)
    integer, dimension(3) :: full_resolution = [1, 1, 1]           !! Full resolution factors for checkpoints (no downsampling)
    real(dp), dimension(:, :, :), allocatable :: output_buffer     !! Fallback buffer for extra fields
    type(field_buffer_map_t), allocatable :: field_buffers(:)       !! Dynamic field buffer mapping for true async I/O
    real(dp), dimension(:, :, :), allocatable :: coords_x, coords_y, coords_z
    integer(i8), dimension(3) :: last_shape_dims = 0
    integer, dimension(3) :: last_stride_factors = 0
    integer(i8), dimension(3) :: last_output_shape = 0
    character(len=4096) :: vtk_xml = ""                             !! VTK XML string for ParaView compatibility
    type(adios2_file_t) :: snapshot_file
  contains
    procedure :: init
    procedure :: handle_restart
    procedure :: handle_io_step
    procedure :: finalise
    procedure, private :: write_checkpoint
    procedure, private :: write_snapshot
    procedure, private :: restart_checkpoint
    procedure, private :: configure
    procedure, private :: write_fields
    procedure, private :: stride_data
    procedure, private :: stride_data_to_buffer
    procedure, private :: cleanup_output_buffers
    procedure, private :: get_output_dimensions
    procedure, private :: generate_coordinates
    procedure, private :: generate_vtk_xml
  end type checkpoint_manager_adios2_t

  type :: field_ptr_t
    class(field_t), pointer :: ptr => null()
  end type field_ptr_t

contains

  subroutine init(self, comm)
    !! Initialise checkpoint manager with a reference to an ADIOS2 writer
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    integer, intent(in) :: comm

    call self%adios2_writer%init(comm, "checkpoint_writer")
    call self%snapshot_writer%init(comm, "snapshot_writer")

    self%checkpoint_cfg = checkpoint_config_t()
    call self%checkpoint_cfg%read(nml_file=get_argument(1))

    call self%configure( &
      self%checkpoint_cfg%checkpoint_freq, &
      self%checkpoint_cfg%snapshot_freq, &
      self%checkpoint_cfg%keep_checkpoint, &
      self%checkpoint_cfg%checkpoint_prefix, &
      self%checkpoint_cfg%snapshot_prefix, &
      self%checkpoint_cfg%output_stride, &
      comm &
      )
  end subroutine init

  subroutine handle_restart(self, solver, comm)
    !! Check if a restart is needed and handle it
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    class(solver_t), intent(inout) :: solver
    integer, intent(in), optional :: comm

    character(len=256) :: restart_file
    integer :: restart_timestep
    real(dp) :: restart_time

    restart_file = trim(self%checkpoint_cfg%restart_file)
    if (solver%mesh%par%is_root()) then
      print *, 'Restarting from checkpoint: ', restart_file
    end if
    call self%restart_checkpoint(solver, restart_file, restart_timestep, &
                                 restart_time, comm)

    solver%current_iter = restart_timestep

    if (solver%mesh%par%is_root()) then
      print *, 'Successfully restarted from checkpoint at iteration ', &
        restart_timestep, ' with time ', restart_time
    end if
  end subroutine handle_restart

  subroutine handle_io_step(self, solver, timestep, comm)
    !! Method to handle checkpoint and snapshot writing at a given timestep
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    class(solver_t), intent(in) :: solver
    integer, intent(in) :: timestep
    integer, intent(in), optional :: comm

    integer :: comm_to_use

    comm_to_use = MPI_COMM_WORLD
    if (present(comm)) comm_to_use = comm

    call self%write_checkpoint(solver, timestep, comm_to_use)
    call self%write_snapshot(solver, timestep, comm_to_use)
  end subroutine handle_io_step

  subroutine configure( &
    self, checkpoint_freq, snapshot_freq, keep_checkpoint, &
    checkpoint_prefix, snapshot_prefix, output_stride, comm &
    )
    !! Configure checkpoint and snapshot settings
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    integer, intent(in), optional :: checkpoint_freq, snapshot_freq
    logical, intent(in), optional :: keep_checkpoint
    character(len=*), intent(in), optional :: checkpoint_prefix, &
                                              snapshot_prefix
    integer, dimension(3), intent(in), optional :: output_stride
    integer, intent(in), optional :: comm

    integer :: comm_to_use, myrank, ierr

    comm_to_use = MPI_COMM_WORLD
    if (present(comm)) comm_to_use = comm
    call MPI_Comm_rank(comm_to_use, myrank, ierr)

    if (present(checkpoint_freq)) &
      self%checkpoint_cfg%checkpoint_freq = checkpoint_freq
    if (present(snapshot_freq)) &
      self%checkpoint_cfg%snapshot_freq = snapshot_freq
    if (present(keep_checkpoint)) &
      self%checkpoint_cfg%keep_checkpoint = keep_checkpoint
    if (present(checkpoint_prefix)) &
      self%checkpoint_cfg%checkpoint_prefix = checkpoint_prefix
    if (present(snapshot_prefix)) &
      self%checkpoint_cfg%snapshot_prefix = snapshot_prefix
    if (present(output_stride)) &
      self%output_stride = self%checkpoint_cfg%output_stride

    if (myrank == 0) then
      if (self%checkpoint_cfg%checkpoint_freq > 0) then
        print *, 'Checkpoint frequency: ', self%checkpoint_cfg%checkpoint_freq
        print *, 'Keep all checkpoints: ', self%checkpoint_cfg%keep_checkpoint
        print *, 'Checkpoint prefix: ', &
          trim(self%checkpoint_cfg%checkpoint_prefix)
      end if

      if (self%checkpoint_cfg%snapshot_freq > 0) then
        print *, 'Snapshot frequency: ', self%checkpoint_cfg%snapshot_freq
        print *, 'Snapshot prefix: ', trim(self%checkpoint_cfg%snapshot_prefix)
        print *, 'Output stride: ', self%output_stride
      end if
    end if
  end subroutine configure

  subroutine write_checkpoint(self, solver, timestep, comm)
    !! Write a checkpoint file for simulation restart
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    class(solver_t), intent(in) :: solver
    integer, intent(in) :: timestep
    integer, intent(in), optional :: comm

    character(len=256) :: filename, temp_filename, old_filename
    type(adios2_file_t) :: file
    integer :: ierr, myrank
    integer :: comm_to_use, i
    character(len=*), parameter :: field_names(*) = ["u", "v", "w"]
    real(dp) :: simulation_time
    logical :: file_exists
    type(field_ptr_t), allocatable :: field_ptrs(:), host_fields(:)
    integer, parameter :: num_fields = size(field_names)
    integer :: data_loc

    if (self%checkpoint_cfg%checkpoint_freq <= 0) return
    if (mod(timestep, self%checkpoint_cfg%checkpoint_freq) /= 0) return

    comm_to_use = MPI_COMM_WORLD
    if (present(comm)) comm_to_use = comm
    call MPI_Comm_rank(comm_to_use, myrank, ierr)

    write (filename, '(A,A,I0.6,A)') &
      trim(self%checkpoint_cfg%checkpoint_prefix), '_', timestep, '.bp'
    write (temp_filename, '(A,A)') &
      trim(self%checkpoint_cfg%checkpoint_prefix), '_temp.bp'
    if (myrank == 0) print *, 'Writing checkpoint: ', trim(filename)

    file = self%adios2_writer%open(temp_filename, adios2_mode_write, &
                                   comm_to_use)
    call self%adios2_writer%begin_step(file)

    simulation_time = timestep*solver%dt
    data_loc = solver%u%data_loc
    call self%adios2_writer%write_data("timestep", timestep, file)
    call self%adios2_writer%write_data("time", real(simulation_time, dp), file)
    call self%adios2_writer%write_data("dt", real(solver%dt, dp), file)
    call self%adios2_writer%write_data("data_loc", data_loc, file)

    allocate (field_ptrs(num_fields))
    allocate (host_fields(num_fields))
    do i = 1, num_fields
      select case (trim(field_names(i)))
      case ("u")
        field_ptrs(i)%ptr => solver%u
      case ("v")
        field_ptrs(i)%ptr => solver%v
      case ("w")
        field_ptrs(i)%ptr => solver%w
      case default
        if (solver%mesh%par%is_root()) then
          print *, 'ERROR: Unknown field name in checkpoint list: ', &
            trim(field_names(i))
        end if
        error stop 1
      end select
    end do

    do i = 1, num_fields
      host_fields(i)%ptr => solver%host_allocator%get_block( &
                            DIR_C, field_ptrs(i)%ptr%data_loc)
      call solver%backend%get_field_data( &
        host_fields(i)%ptr%data, field_ptrs(i)%ptr)
    end do

    call self%write_fields( &
      field_names, host_fields, &
      solver, file, data_loc, use_stride=.false., writer=self%adios2_writer &
      )
    deallocate (field_ptrs)

    call self%adios2_writer%close(file)

    do i = 1, num_fields
      call solver%host_allocator%release_block(host_fields(i)%ptr)
    end do

    deallocate (host_fields)

    if (myrank == 0) then
      ! Move temporary file to final checkpoint filename
      call execute_command_line('mv '//trim(temp_filename)//' '// &
                                trim(filename))

      inquire (file=trim(filename), exist=file_exists)
      if (.not. file_exists) then
        print *, 'ERROR: Checkpoint file not created: ', trim(filename)
      end if

      ! Remove old checkpoint if configured to keep only the latest
      if (.not. self%checkpoint_cfg%keep_checkpoint &
          .and. self%last_checkpoint_step > 0) then
        write (old_filename, '(A,A,I0.6,A)') &
          trim(self%checkpoint_cfg%checkpoint_prefix), '_', &
          self%last_checkpoint_step, '.bp'
        inquire (file=trim(old_filename), exist=file_exists)
        if (file_exists) then
          call execute_command_line('rm -rf '//trim(old_filename), &
                                    exitstat=ierr)
          if (ierr /= 0) then
            print *, 'WARNING: failed to remove old checkpoint: ', &
              trim(old_filename)
          end if
        end if
      end if

    end if

    self%last_checkpoint_step = timestep
  end subroutine write_checkpoint

  subroutine write_snapshot(self, solver, timestep, comm)
   !! Write a snapshot file for visualisation
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    class(solver_t), intent(in) :: solver
    integer, intent(in) :: timestep
    integer, intent(in), optional :: comm

    character(len=*), parameter :: field_names(*) = ["u", "v", "w"]
    integer :: myrank, ierr, comm_to_use, i
    character(len=256) :: filename
    integer(i8), dimension(3) :: output_shape_dims
    integer, dimension(3) :: global_dims, output_dims
    type(field_ptr_t), allocatable :: field_ptrs(:), host_fields(:)
    integer, parameter :: num_fields = size(field_names)
    real(dp), dimension(3) :: origin, original_spacing, output_spacing
    real(dp) :: simulation_time
    logical :: snapshot_uses_stride = .true. ! snapshots always use striding

    if (self%checkpoint_cfg%snapshot_freq <= 0) return
    if (mod(timestep, self%checkpoint_cfg%snapshot_freq) /= 0) return

    comm_to_use = MPI_COMM_WORLD
    if (present(comm)) comm_to_use = comm
    call MPI_Comm_rank(comm_to_use, myrank, ierr)

    write (filename, '(A,A,I0.6,A)') &
      trim(self%checkpoint_cfg%snapshot_prefix), '_', timestep, '.bp'

    global_dims = solver%mesh%get_global_dims(VERT)
    origin = solver%mesh%get_coordinates(1, 1, 1)
    original_spacing = solver%mesh%geo%d

    if (snapshot_uses_stride) then
      output_spacing = original_spacing*real(self%output_stride, dp)
      do i = 1, size(global_dims)
        output_dims(i) = (global_dims(i) + self%output_stride(i) - 1)/ &
                         self%output_stride(i)
      end do
    else
      output_spacing = original_spacing
      output_dims = global_dims
    end if
    output_shape_dims = int(output_dims, i8)

    call self%generate_vtk_xml( &
      output_shape_dims, field_names, origin, output_spacing &
      )

    self%snapshot_file = self%snapshot_writer%open( &
                         filename, adios2_mode_write, comm_to_use)
    if (myrank == 0) print *, 'Creating snapshot file: ', trim(filename)

    ! Write VTK XML attributes for ParaView compatibility
    if (myrank == 0) then
      call self%snapshot_writer%write_attribute( &
        "vtk.xml", self%vtk_xml, self%snapshot_file)
    end if

    call self%snapshot_writer%begin_step(self%snapshot_file)

    simulation_time = timestep*solver%dt
    if (myrank == 0) then
      print *, 'Writing snapshot for time =', simulation_time, &
        ' iteration =', timestep
    end if

    call self%snapshot_writer%write_data( &
      "time", real(simulation_time, dp), self%snapshot_file)

    allocate (field_ptrs(num_fields))
    allocate (host_fields(num_fields))

    ! point field_ptrs to actual solver data
    do i = 1, num_fields
      select case (trim(field_names(i)))
      case ("u")
        field_ptrs(i)%ptr => solver%u
      case ("v")
        field_ptrs(i)%ptr => solver%v
      case ("w")
        field_ptrs(i)%ptr => solver%w
      case default
        if (solver%mesh%par%is_root()) then
          print *, 'ERROR: Unknown field name in snapshot list: ', &
            trim(field_names(i))
        end if
        error stop 1
      end select
    end do

    ! allocate a unique host buffer
    do i = 1, num_fields
      host_fields(i)%ptr => solver%host_allocator%get_block( &
                            DIR_C, field_ptrs(i)%ptr%data_loc)
      call solver%backend%get_field_data( &
        host_fields(i)%ptr%data, field_ptrs(i)%ptr)
    end do

    call self%write_fields( &
      field_names, host_fields, &
      solver, self%snapshot_file, solver%u%data_loc, &
      use_stride=snapshot_uses_stride, &
      writer=self%snapshot_writer &
      )

    call self%snapshot_writer%end_step(self%snapshot_file)
    call self%snapshot_writer%close(self%snapshot_file)

    do i = 1, num_fields
      call solver%host_allocator%release_block(host_fields(i)%ptr)
    end do

    deallocate (field_ptrs)
    deallocate (host_fields)
  end subroutine write_snapshot

  subroutine generate_vtk_xml(self, dims, fields, origin, spacing)
    !! Generate VTK XML string for ImageData format for ParaView's ADIOS2VTXReader
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    integer(i8), dimension(3), intent(in) :: dims
    character(len=*), dimension(:), intent(in) :: fields
    real(dp), dimension(3), intent(in) :: origin, spacing

    character(len=4096) :: xml
    character(len=96) :: extent_str, origin_str, spacing_str
    integer :: i

    ! VTK uses (x,y,z) order, extent defines grid size from 0 to N-1
    write (extent_str, '(A,I0,A,I0,A,I0,A,I0,A,I0,A,I0,A)') &
      '0 ', dims(3) - 1, ' 0 ', dims(2) - 1, ' 0 ', dims(1) - 1

    write (origin_str, '(G0, 1X, G0, 1X, G0)') origin
    write (spacing_str, '(G0, 1X, G0, 1X, G0)') spacing

    xml = '<?xml version="1.0"?>'//new_line('a')// &
          '<VTKFile type="ImageData" version="0.1">'//new_line('a')// &
          '  <ImageData WholeExtent=" '//trim(adjustl(extent_str))//'" ' &
          //'Origin="'//trim(adjustl(origin_str))//'" '// &
          'Spacing="'//trim(adjustl(spacing_str))//'">'//new_line('a')// &
          '    <Piece Extent="'//trim(adjustl(extent_str))//'">' &
          //new_line('a')//'      <PointData>'//new_line('a')

    do i = 1, size(fields)
      xml = trim(xml)//'      <DataArray Name="'//trim(fields(i))//'">'// &
            trim(fields(i))//'</DataArray>'//new_line('a')
    end do

    xml = trim(xml)//'        <DataArray Name="TIME">time</DataArray>' &
          //new_line('a')

    xml = trim(xml)//'      </PointData>'//new_line('a')// &
          '    </Piece>'//new_line('a')// &
          '  </ImageData>'//new_line('a')// &
          '</VTKFile>'

    self%vtk_xml = xml
  end subroutine generate_vtk_xml

  subroutine generate_coordinates( &
    self, solver, file, shape_dims, start_dims, count_dims, data_loc &
    )
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    class(solver_t), intent(in) :: solver
    type(adios2_file_t), intent(inout) :: file
    integer(i8), dimension(3), intent(in) :: shape_dims, start_dims, count_dims
    integer, intent(in) :: data_loc

    integer :: i, nx, ny, nz
    real(dp), dimension(3) :: coords
    integer(i8), dimension(3) :: x_shape, y_shape, z_shape
    integer(i8), dimension(3) :: x_start, y_start, z_start
    integer(i8), dimension(3) :: x_count, y_count, z_count

    nx = int(count_dims(1))
    ny = int(count_dims(2))
    nz = int(count_dims(3))

    ! coordinates are structured as 3D arrays for ParaView ADIOS2 reader compatibility
    if (.not. allocated(self%coords_x) .or. size(self%coords_x) /= nx) then
      if (allocated(self%coords_x)) deallocate (self%coords_x)
      allocate (self%coords_x(nx, 1, 1))
    end if

    if (.not. allocated(self%coords_y) .or. size(self%coords_y) /= ny) then
      if (allocated(self%coords_y)) deallocate (self%coords_y)
      allocate (self%coords_y(1, ny, 1))
    end if

    if (.not. allocated(self%coords_z) .or. size(self%coords_z) /= nz) then
      if (allocated(self%coords_z)) deallocate (self%coords_z)
      allocate (self%coords_z(1, 1, nz))
    end if

    do i = 1, nx
      coords = solver%mesh%get_coordinates(i, 1, 1, data_loc)
      self%coords_x(i, 1, 1) = coords(1)
    end do

    do i = 1, ny
      coords = solver%mesh%get_coordinates(1, i, 1, data_loc)
      self%coords_y(1, i, 1) = coords(2)
    end do

    do i = 1, nz
      coords = solver%mesh%get_coordinates(1, 1, i, data_loc)
      self%coords_z(1, 1, i) = coords(3)
    end do

    x_shape = [1_i8, 1_i8, shape_dims(1)]
    x_start = [0_i8, 0_i8, start_dims(1)]
    x_count = [1_i8, 1_i8, count_dims(1)]

    y_shape = [1_i8, shape_dims(2), 1_i8]
    y_start = [0_i8, start_dims(2), 0_i8]
    y_count = [1_i8, count_dims(2), 1_i8]

    z_shape = [shape_dims(3), 1_i8, 1_i8]
    z_start = [start_dims(3), 0_i8, 0_i8]
    z_count = [count_dims(3), 1_i8, 1_i8]

    call self%adios2_writer%write_data( &
      "coordinates/x", self%coords_x, file, x_shape, x_start, x_count &
      )
    call self%adios2_writer%write_data( &
      "coordinates/y", self%coords_y, file, y_shape, y_start, y_count &
      )
    call self%adios2_writer%write_data( &
      "coordinates/z", self%coords_z, file, z_shape, z_start, z_count &
      )
  end subroutine generate_coordinates

  function stride_data( &
    self, input_data, dims, stride, output_dims_out) &
    result(output_data)
    !! stride the input data based on the specified stride
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    real(dp), dimension(:, :, :), intent(in) :: input_data
    integer, dimension(3), intent(in) :: dims
    integer, dimension(3), intent(in) :: stride
    integer, dimension(3), intent(out) :: output_dims_out

    real(dp), dimension(:, :, :), allocatable :: output_data
    integer :: i_stride, j_stride, k_stride
    integer :: i_max, j_max, k_max

    if (all(stride == 1)) then
      allocate (output_data(dims(1), dims(2), dims(3)))
      output_data = input_data
      output_dims_out = dims
      return
    end if

    i_stride = stride(1); j_stride = stride(2); k_stride = stride(3)

    i_max = (dims(1) - 1)/i_stride + 1
    j_max = (dims(2) - 1)/j_stride + 1
    k_max = (dims(3) - 1)/k_stride + 1

    output_dims_out = [i_max, j_max, k_max]
    allocate (output_data(i_max, j_max, k_max))

    output_data = input_data(1:dims(1):i_stride, &
                             1:dims(2):j_stride, 1:dims(3):k_stride)
  end function stride_data

  subroutine stride_data_to_buffer( &
    self, input_data, dims, stride, out_buffer, output_dims_out &
    )
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    real(dp), dimension(:, :, :), intent(in) :: input_data
    integer, dimension(3), intent(in) :: dims
    integer, dimension(3), intent(in) :: stride
    real(dp), dimension(:, :, :), allocatable, intent(inout) :: out_buffer
    integer, dimension(3), intent(out) :: output_dims_out

    integer :: i_stride, j_stride, k_stride
    integer :: i_max, j_max, k_max

    if (all(stride == 1)) then
      if (allocated(out_buffer)) then
        if (size(out_buffer, 1) /= dims(1) &
            .or. size(out_buffer, 2) /= dims(2) .or. &
            size(out_buffer, 3) /= dims(3)) then
          deallocate (out_buffer)
          allocate (out_buffer(dims(1), dims(2), dims(3)))
        end if
      else
        allocate (out_buffer(dims(1), dims(2), dims(3)))
      end if
      out_buffer = input_data
      output_dims_out = dims
      return
    end if

    i_stride = stride(1); j_stride = stride(2); k_stride = stride(3)

    i_max = (dims(1) + i_stride - 1)/i_stride
    j_max = (dims(2) + j_stride - 1)/j_stride
    k_max = (dims(3) + k_stride - 1)/k_stride

    output_dims_out = [i_max, j_max, k_max]

    if (allocated(out_buffer)) then
      if (size(out_buffer, 1) /= i_max &
          .or. size(out_buffer, 2) /= j_max .or. &
          size(out_buffer, 3) /= k_max) then
        deallocate (out_buffer)
        allocate (out_buffer(i_max, j_max, k_max))
      end if
    else
      allocate (out_buffer(i_max, j_max, k_max))
    end if

    out_buffer = input_data(1:dims(1):i_stride, &
                            1:dims(2):j_stride, 1:dims(3):k_stride)
  end subroutine stride_data_to_buffer

  subroutine get_output_dimensions( &
    self, shape_dims, start_dims, count_dims, stride_factors, &
    output_shape, output_start, output_count, output_dims_local)

    class(checkpoint_manager_adios2_t), intent(inout) :: self
    integer(i8), dimension(3), intent(in) :: shape_dims, start_dims, count_dims
    integer, dimension(3), intent(in) :: stride_factors
    integer(i8), dimension(3), intent(out) :: output_shape, output_start
    integer(i8), dimension(3), intent(out) :: output_count
    integer, dimension(3), intent(out) :: output_dims_local

    if (all(stride_factors == 1)) then
      output_shape = shape_dims
      output_start = start_dims
      output_count = count_dims
      output_dims_local = int(count_dims)
      return
    end if

    if (all(shape_dims == self%last_shape_dims) .and. &
        all(stride_factors == self%last_stride_factors) .and. &
        all(self%last_output_shape > 0)) then
      output_shape = self%last_output_shape
    else
      output_shape = [(shape_dims(1) + stride_factors(1) - 1_i8) &
                      /int(stride_factors(1), i8), &
                      (shape_dims(2) + stride_factors(2) - 1_i8) &
                      /int(stride_factors(2), i8), &
                      (shape_dims(3) + stride_factors(3) - 1_i8) &
                      /int(stride_factors(3), i8)]

      self%last_shape_dims = shape_dims
      self%last_stride_factors = stride_factors
      self%last_output_shape = output_shape
    end if

    output_start = [start_dims(1)/int(stride_factors(1), i8), &
                    start_dims(2)/int(stride_factors(2), i8), &
                    start_dims(3)/int(stride_factors(3), i8)]

    output_dims_local = [(int(count_dims(1)) + stride_factors(1) - 1) &
                         /stride_factors(1), &
                         (int(count_dims(2)) + stride_factors(2) - 1) &
                         /stride_factors(2), &
                         (int(count_dims(3)) + stride_factors(3) - 1) &
                         /stride_factors(3)]

    output_count = [int(output_dims_local(1), i8), &
                    int(output_dims_local(2), i8), &
                    int(output_dims_local(3), i8)]
  end subroutine get_output_dimensions

  subroutine restart_checkpoint( &
    self, solver, filename, timestep, restart_time, comm &
    )
    !! Restart simulation state from checkpoint file
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    class(solver_t), intent(inout) :: solver
    character(len=*), intent(in) :: filename
    integer, intent(out) :: timestep
    real(dp), intent(out) :: restart_time
    integer, intent(in) :: comm

    type(adios2_reader_t) :: reader
    type(adios2_file_t) :: file
    integer :: i, ierr
    integer :: dims(3)
    integer(i8), dimension(3) :: start_dims, count_dims
    character(len=*), parameter :: field_names(3) = ["u", "v", "w"]
    logical :: file_exists
    integer :: data_loc

    inquire (file=filename, exist=file_exists)
    if (.not. file_exists) then
      if (solver%mesh%par%is_root()) then
        print *, 'ERROR: Checkpoint file not found: ', trim(filename)
      end if
      call MPI_Abort(comm, 1, ierr)
      return
    end if

    call reader%init(comm, "checkpoint_reader")

    file = reader%open(filename, adios2_mode_read, comm)
    call reader%begin_step(file)
    call reader%read_data("timestep", timestep, file)
    call reader%read_data("time", restart_time, file)
    call reader%read_data("data_loc", data_loc, file)

    dims = solver%mesh%get_dims(data_loc)
    start_dims = int(solver%mesh%par%n_offset, i8)
    count_dims = int(dims, i8)

    call solver%u%set_data_loc(data_loc)
    call solver%v%set_data_loc(data_loc)
    call solver%w%set_data_loc(data_loc)

    ! use a separate reader for each variable to avoid data corruption.
    block
      real(dp), allocatable, target :: field_data_u(:, :, :)
      real(dp), allocatable, target :: field_data_v(:, :, :)
      real(dp), allocatable, target :: field_data_w(:, :, :)
      real(dp), pointer :: current_field_data(:, :, :)
      type(adios2_reader_t) :: isolated_reader
      type(adios2_file_t) :: isolated_file

      call isolated_reader%init(comm, "u_reader")
      isolated_file = isolated_reader%open(filename, adios2_mode_read, comm)
      call isolated_reader%begin_step(isolated_file)
      call isolated_reader%read_data( &
        "u", field_data_u, isolated_file, start_dims, count_dims)
      call isolated_reader%close(isolated_file)
      call isolated_reader%finalise()

      call isolated_reader%init(comm, "v_reader")
      isolated_file = isolated_reader%open(filename, adios2_mode_read, comm)
      call isolated_reader%begin_step(isolated_file)
      call isolated_reader%read_data( &
        "v", field_data_v, isolated_file, start_dims, count_dims)
      call isolated_reader%close(isolated_file)
      call isolated_reader%finalise()

      call isolated_reader%init(comm, "w_reader")
      isolated_file = isolated_reader%open(filename, adios2_mode_read, comm)
      call isolated_reader%begin_step(isolated_file)
      call isolated_reader%read_data( &
        "w", field_data_w, isolated_file, start_dims, count_dims)
      call isolated_reader%close(isolated_file)
      call isolated_reader%finalise()

      do i = 1, size(field_names)
        select case (trim(field_names(i)))
        case ("u")
          current_field_data => field_data_u
        case ("v")
          current_field_data => field_data_v
        case ("w")
          current_field_data => field_data_w
        end select

        select case (trim(field_names(i)))
        case ("u")
          call solver%backend%set_field_data(solver%u, current_field_data)
        case ("v")
          call solver%backend%set_field_data(solver%v, current_field_data)
        case ("w")
          call solver%backend%set_field_data(solver%w, current_field_data)
        case default
          call self%adios2_writer%handle_error( &
            1, "Invalid field name"//trim(field_names(i)))
        end select
      end do
      deallocate (field_data_u, field_data_v, field_data_w)
    end block

    call reader%close(file)
    call reader%finalise()
  end subroutine restart_checkpoint

  subroutine write_fields( &
    self, field_names, host_fields, solver, file, data_loc, &
    use_stride, writer &
    )
    !! Write field data, optionally with striding, using separate buffers for true async I/O
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    character(len=*), dimension(:), intent(in) :: field_names
    class(field_ptr_t), dimension(:), target, intent(in) :: host_fields
    class(solver_t), intent(in) :: solver
    type(adios2_file_t), intent(inout) :: file
    integer, intent(in) :: data_loc
    logical, intent(in), optional :: use_stride
    type(adios2_writer_t), intent(inout) :: writer

    integer :: i_field
    logical :: apply_stride

    apply_stride = .false.
    if (present(use_stride)) apply_stride = use_stride

    ! pre-allocate separate buffers for each field to enable race-free async I/O
    if (apply_stride) then
      call prepare_field_buffers(solver, self%output_stride, field_names, &
                                 data_loc)
    else
      call prepare_field_buffers(solver, self%full_resolution, field_names, &
                                 data_loc)
    end if

    do i_field = 1, size(field_names)
      call write_single_field( &
        trim(field_names(i_field)), host_fields(i_field)%ptr, &
        data_loc)
    end do

  contains

    subroutine prepare_field_buffers( &
      solver, stride_factors, field_names, data_loc &
      )
      class(solver_t), intent(in) :: solver
      integer, dimension(3), intent(in) :: stride_factors
      character(len=*), dimension(:), intent(in) :: field_names
      integer, intent(in) :: data_loc
      integer :: dims(3), output_dims_local(3), i
      integer(i8), dimension(3) :: shape_dims, start_dims, count_dims
      integer(i8), dimension(3) :: output_shape, output_start, output_count

      dims = solver%mesh%get_dims(data_loc)
      shape_dims = int(solver%mesh%get_global_dims(data_loc), i8)
      start_dims = int(solver%mesh%par%n_offset, i8)
      count_dims = int(dims, i8)

      call self%get_output_dimensions( &
        shape_dims, start_dims, count_dims, stride_factors, &
        output_shape, output_start, output_count, &
        output_dims_local &
        )

      if (allocated(self%field_buffers)) deallocate (self%field_buffers)
      allocate (self%field_buffers(size(field_names)))

      do i = 1, size(field_names)
        self%field_buffers(i)%field_name = trim(field_names(i))
        allocate ( &
          self%field_buffers(i)%buffer( &
          output_dims_local(1), &
          output_dims_local(2), &
          output_dims_local(3)))
      end do
    end subroutine prepare_field_buffers

    subroutine write_single_field(field_name, host_field, data_loc)
      character(len=*), intent(in) :: field_name
      class(field_t), pointer :: host_field
      integer, intent(in) :: data_loc

      integer, dimension(3) :: output_dims_local, stride_factors
      integer(i8), dimension(3) :: shape_dims, start_dims, count_dims
      integer(i8), dimension(3) :: output_shape, output_start, output_count
      integer :: dims(3)
      integer :: buffer_idx
      logical :: buffer_found

      dims = solver%mesh%get_dims(data_loc)
      shape_dims = int(solver%mesh%get_global_dims(data_loc), i8)
      start_dims = int(solver%mesh%par%n_offset, i8)
      count_dims = int(dims, i8)

      if (apply_stride) then
        stride_factors = self%output_stride
      else
        !no striding for checkpointing
        stride_factors = self%full_resolution
      end if

      call self%get_output_dimensions( &
        shape_dims, start_dims, count_dims, stride_factors, &
        output_shape, output_start, output_count, &
        output_dims_local &
        )

      ! find the matching buffer for this field
      buffer_found = .false.
      do buffer_idx = 1, size(self%field_buffers)
        if (trim(self%field_buffers(buffer_idx)%field_name) &
            == trim(field_name)) then
          buffer_found = .true.
          exit
        end if
      end do

      if (buffer_found) then
        ! use the dedicated buffer for this field for true async I/O
        call self%stride_data_to_buffer( &
          host_field%data(1:dims(1), 1:dims(2), 1:dims(3)), dims, &
          stride_factors, self%field_buffers(buffer_idx)%buffer, &
          output_dims_local &
          )

        call writer%write_data( &
          field_name, self%field_buffers(buffer_idx)%buffer, &
          file, output_shape, output_start, output_count &
          )
      else
        ! fallback to shared buffer for fields not in the map
        if (.not. allocated(self%output_buffer)) then
          allocate ( &
            self%output_buffer(output_dims_local(1), &
                               output_dims_local(2), &
                               output_dims_local(3)))
        end if
        call self%stride_data_to_buffer( &
          host_field%data(1:dims(1), 1:dims(2), 1:dims(3)), dims, &
          stride_factors, self%output_buffer, output_dims_local)

        call writer%write_data( &
          field_name, self%output_buffer, &
          file, output_shape, output_start, output_count &
          )
      end if

    end subroutine write_single_field

  end subroutine write_fields

  subroutine cleanup_output_buffers(self)
    !! Clean up dynamic field buffers
    class(checkpoint_manager_adios2_t), intent(inout) :: self
    integer :: i

    if (allocated(self%output_buffer)) deallocate (self%output_buffer)

    if (allocated(self%field_buffers)) then
      do i = 1, size(self%field_buffers)
        if (allocated(self%field_buffers(i)%buffer)) then
          deallocate (self%field_buffers(i)%buffer)
        end if
      end do
      deallocate (self%field_buffers)
    end if
  end subroutine cleanup_output_buffers

  subroutine finalise(self)
    class(checkpoint_manager_adios2_t), intent(inout) :: self

    call self%cleanup_output_buffers()
    call self%adios2_writer%finalise()
    call self%snapshot_writer%finalise()
  end subroutine finalise

end module m_checkpoint_manager_impl

module m_checkpoint_manager
  !! Public facade and factory function for checkpoint manager
  use m_checkpoint_manager_impl, only: checkpoint_manager_adios2_t
  use m_solver, only: solver_t

  implicit none

  private
  public :: checkpoint_manager_t, create_checkpoint_manager

  type :: checkpoint_manager_t
    type(checkpoint_manager_adios2_t) :: impl
  contains
    procedure :: init => cm_init
    procedure :: handle_restart => cm_handle_restart
    procedure :: handle_io_step => cm_handle_io_step
    procedure :: finalise => cm_finalise
    procedure :: is_restart => cm_is_restart
  end type checkpoint_manager_t

  interface checkpoint_manager_t
    module procedure create_checkpoint_manager
  end interface checkpoint_manager_t

contains
  function create_checkpoint_manager(comm) result(mgr)
    integer, intent(in) :: comm
    type(checkpoint_manager_t) :: mgr

    call mgr%init(comm)
  end function create_checkpoint_manager

  subroutine cm_init(self, comm)
    class(checkpoint_manager_t), intent(inout) :: self
    integer, intent(in) :: comm

    call self%impl%init(comm)
  end subroutine cm_init

  subroutine cm_handle_restart(self, solver, comm)
    class(checkpoint_manager_t), intent(inout) :: self
    class(solver_t), intent(inout) :: solver
    integer, intent(in), optional :: comm

    call self%impl%handle_restart(solver, comm)
  end subroutine cm_handle_restart

  subroutine cm_handle_io_step(self, solver, timestep, comm)
    class(checkpoint_manager_t), intent(inout) :: self
    class(solver_t), intent(in) :: solver
    integer, intent(in) :: timestep
    integer, intent(in), optional :: comm

    call self%impl%handle_io_step(solver, timestep, comm)
  end subroutine cm_handle_io_step

  subroutine cm_finalise(self)
    class(checkpoint_manager_t), intent(inout) :: self

    call self%impl%finalise()
  end subroutine cm_finalise

  function cm_is_restart(self) result(is_restart)
    class(checkpoint_manager_t), intent(in) :: self
    logical :: is_restart

    is_restart = self%impl%checkpoint_cfg%restart_from_checkpoint
  end function cm_is_restart

end module m_checkpoint_manager