checkpoint_manager.f90 Source File


This file depends on

sourcefile~~checkpoint_manager.f90~~EfferentGraph sourcefile~checkpoint_manager.f90 checkpoint_manager.f90 sourcefile~common.f90 common.f90 sourcefile~checkpoint_manager.f90->sourcefile~common.f90 sourcefile~config.f90 config.f90 sourcefile~checkpoint_manager.f90->sourcefile~config.f90 sourcefile~field.f90 field.f90 sourcefile~checkpoint_manager.f90->sourcefile~field.f90 sourcefile~io.f90 io.f90 sourcefile~checkpoint_manager.f90->sourcefile~io.f90 sourcefile~io_field_utils.f90 io_field_utils.f90 sourcefile~checkpoint_manager.f90->sourcefile~io_field_utils.f90 sourcefile~io_session.f90 io_session.f90 sourcefile~checkpoint_manager.f90->sourcefile~io_session.f90 sourcefile~solver.f90 solver.f90 sourcefile~checkpoint_manager.f90->sourcefile~solver.f90 sourcefile~config.f90->sourcefile~common.f90 sourcefile~field.f90->sourcefile~common.f90 sourcefile~io.f90->sourcefile~common.f90 sourcefile~io_base.f90 io_base.f90 sourcefile~io.f90->sourcefile~io_base.f90 sourcefile~io_field_utils.f90->sourcefile~common.f90 sourcefile~io_field_utils.f90->sourcefile~field.f90 sourcefile~io_field_utils.f90->sourcefile~solver.f90 sourcefile~io_field_utils.f90->sourcefile~io_base.f90 sourcefile~io_session.f90->sourcefile~common.f90 sourcefile~io_session.f90->sourcefile~io.f90 sourcefile~io_session.f90->sourcefile~io_base.f90 sourcefile~solver.f90->sourcefile~common.f90 sourcefile~solver.f90->sourcefile~config.f90 sourcefile~solver.f90->sourcefile~field.f90 sourcefile~allocator.f90 allocator.f90 sourcefile~solver.f90->sourcefile~allocator.f90 sourcefile~backend.f90~2 backend.f90 sourcefile~solver.f90->sourcefile~backend.f90~2 sourcefile~ibm.f90 ibm.f90 sourcefile~solver.f90->sourcefile~ibm.f90 sourcefile~mesh.f90 mesh.f90 sourcefile~solver.f90->sourcefile~mesh.f90 sourcefile~tdsops.f90 tdsops.f90 sourcefile~solver.f90->sourcefile~tdsops.f90 sourcefile~time_integrator.f90 time_integrator.f90 sourcefile~solver.f90->sourcefile~time_integrator.f90 sourcefile~vector_calculus.f90 vector_calculus.f90 sourcefile~solver.f90->sourcefile~vector_calculus.f90 sourcefile~allocator.f90->sourcefile~common.f90 sourcefile~allocator.f90->sourcefile~field.f90 sourcefile~backend.f90~2->sourcefile~common.f90 sourcefile~backend.f90~2->sourcefile~field.f90 sourcefile~backend.f90~2->sourcefile~allocator.f90 sourcefile~backend.f90~2->sourcefile~mesh.f90 sourcefile~backend.f90~2->sourcefile~tdsops.f90 sourcefile~poisson_fft.f90~3 poisson_fft.f90 sourcefile~backend.f90~2->sourcefile~poisson_fft.f90~3 sourcefile~ibm.f90->sourcefile~common.f90 sourcefile~ibm.f90->sourcefile~field.f90 sourcefile~ibm.f90->sourcefile~io_session.f90 sourcefile~ibm.f90->sourcefile~allocator.f90 sourcefile~ibm.f90->sourcefile~backend.f90~2 sourcefile~ibm.f90->sourcefile~mesh.f90 sourcefile~io_base.f90->sourcefile~common.f90 sourcefile~mesh.f90->sourcefile~common.f90 sourcefile~mesh.f90->sourcefile~field.f90 sourcefile~decomp_dummy.f90 decomp_dummy.f90 sourcefile~mesh.f90->sourcefile~decomp_dummy.f90 sourcefile~mesh_content.f90 mesh_content.f90 sourcefile~mesh.f90->sourcefile~mesh_content.f90 sourcefile~tdsops.f90->sourcefile~common.f90 sourcefile~time_integrator.f90->sourcefile~common.f90 sourcefile~time_integrator.f90->sourcefile~field.f90 sourcefile~time_integrator.f90->sourcefile~allocator.f90 sourcefile~time_integrator.f90->sourcefile~backend.f90~2 sourcefile~vector_calculus.f90->sourcefile~common.f90 sourcefile~vector_calculus.f90->sourcefile~field.f90 sourcefile~vector_calculus.f90->sourcefile~allocator.f90 sourcefile~vector_calculus.f90->sourcefile~backend.f90~2 sourcefile~vector_calculus.f90->sourcefile~tdsops.f90 sourcefile~decomp_dummy.f90->sourcefile~mesh_content.f90 sourcefile~mesh_content.f90->sourcefile~common.f90 sourcefile~poisson_fft.f90~3->sourcefile~common.f90 sourcefile~poisson_fft.f90~3->sourcefile~field.f90 sourcefile~poisson_fft.f90~3->sourcefile~mesh.f90 sourcefile~poisson_fft.f90~3->sourcefile~tdsops.f90

Files dependent on this one

sourcefile~~checkpoint_manager.f90~~AfferentGraph sourcefile~checkpoint_manager.f90 checkpoint_manager.f90 sourcefile~io_manager.f90 io_manager.f90 sourcefile~io_manager.f90->sourcefile~checkpoint_manager.f90 sourcefile~base_case.f90 base_case.f90 sourcefile~base_case.f90->sourcefile~io_manager.f90 sourcefile~channel.f90 channel.f90 sourcefile~channel.f90->sourcefile~base_case.f90 sourcefile~generic.f90 generic.f90 sourcefile~generic.f90->sourcefile~base_case.f90 sourcefile~tgv.f90 tgv.f90 sourcefile~tgv.f90->sourcefile~base_case.f90 sourcefile~xcompact.f90 xcompact.f90 sourcefile~xcompact.f90->sourcefile~base_case.f90 sourcefile~xcompact.f90->sourcefile~channel.f90 sourcefile~xcompact.f90->sourcefile~generic.f90 sourcefile~xcompact.f90->sourcefile~tgv.f90

Source Code

module m_checkpoint_manager
! @brief Manages the creation and restoration of simulation checkpoints
!! for restart capabilities.
!!
!! @details This module is responsible for periodically saving the full, unstrided
!! simulation state to a file. This allows a simulation to be stopped and resumed
!! from the exact state it was in.
!!
!! Key features include:
!! - Reading all checkpoint settings from a configuration file
!! - Periodically writing the full-resolution simulation state
!! - Handling the full logic for restarting a simulation from
!! a specified checkpoint file.
!! - A safe-write strategy that writes to a temporary file first,
!!   then atomically renames it to the final filename to
!! prevent corrupted checkpoints.
!! - Optional cleanup of old checkpoint files to conserve disk space.
  use mpi, only: MPI_COMM_WORLD, MPI_Comm_rank, MPI_Abort
  use m_common, only: dp, i8, DIR_C, get_argument
  use m_field, only: field_t
  use m_solver, only: solver_t
  use m_io_session, only: reader_session_t, writer_session_t
  use m_config, only: checkpoint_config_t
  use m_io_field_utils, only: field_buffer_map_t, field_ptr_t, &
                              setup_field_arrays, cleanup_field_arrays, &
                              stride_data_to_buffer, get_output_dimensions, &
                              prepare_field_buffers, cleanup_field_buffers, &
                              write_single_field_to_buffer

  implicit none

  private
  public :: checkpoint_manager_t

  type :: checkpoint_manager_t
    type(checkpoint_config_t) :: config
    integer :: last_checkpoint_step = -1
    integer, dimension(3) :: full_resolution = [1, 1, 1]
    type(field_buffer_map_t), allocatable :: field_buffers(:)
    integer(i8), dimension(3) :: last_shape_dims = 0
    integer, dimension(3) :: last_stride_factors = 0
    integer(i8), dimension(3) :: last_output_shape = 0
  contains
    procedure :: init
    procedure :: handle_restart
    procedure :: handle_checkpoint_step
    procedure :: is_restart
    procedure :: finalise
    procedure, private :: write_checkpoint
    procedure, private :: restart_checkpoint
    procedure, private :: write_fields
    procedure, private :: cleanup_output_buffers
  end type checkpoint_manager_t

contains

  subroutine init(self, comm)
    !! Initialise checkpoint manager
    class(checkpoint_manager_t), intent(inout) :: self
    integer, intent(in) :: comm

    self%config = checkpoint_config_t()
    call self%config%read(nml_file=get_argument(1))

    if (self%config%checkpoint_freq > 0) then
      call configure_output(self, comm)
    end if
  end subroutine init

  subroutine configure_output(self, comm)
    !! Configure checkpoint output settings
    use m_io_backend, only: get_default_backend, IO_BACKEND_DUMMY
    class(checkpoint_manager_t), intent(inout) :: self
    integer, intent(in) :: comm

    integer :: myrank, ierr

    call MPI_Comm_rank(comm, myrank, ierr)

    if (myrank == 0 .and. get_default_backend() /= IO_BACKEND_DUMMY) then
      print *, 'Checkpoint frequency: ', self%config%checkpoint_freq
      print *, 'Keep all checkpoints: ', self%config%keep_checkpoint
      print *, 'Checkpoint prefix: ', trim(self%config%checkpoint_prefix)
    end if
  end subroutine configure_output

  function is_restart(self) result(restart)
    !! Check if this is a restart run
    class(checkpoint_manager_t), intent(in) :: self
    logical :: restart

    restart = self%config%restart_from_checkpoint
  end function is_restart

  subroutine handle_restart(self, solver, comm)
    !! Handle restart from checkpoint
    class(checkpoint_manager_t), intent(inout) :: self
    class(solver_t), intent(inout) :: solver
    integer, intent(in), optional :: comm

    character(len=256) :: restart_file
    integer :: restart_timestep
    real(dp) :: restart_time

    restart_file = trim(self%config%restart_file)
    if (solver%mesh%par%is_root()) then
      print *, 'Restarting from checkpoint: ', restart_file
    end if

    call self%restart_checkpoint(solver, restart_file, restart_timestep, &
                                 restart_time, comm)

    solver%current_iter = restart_timestep

    if (solver%mesh%par%is_root()) then
      print *, 'Successfully restarted from checkpoint at iteration ', &
        restart_timestep, ' with time ', restart_time
    end if
  end subroutine handle_restart

  subroutine handle_checkpoint_step(self, solver, timestep, comm)
    !! Handle checkpoint writing at a given timestep
    class(checkpoint_manager_t), intent(inout) :: self
    class(solver_t), intent(in) :: solver
    integer, intent(in) :: timestep
    integer, intent(in), optional :: comm

    integer :: comm_to_use

    comm_to_use = MPI_COMM_WORLD
    if (present(comm)) comm_to_use = comm

    call self%write_checkpoint(solver, timestep, comm_to_use)
  end subroutine handle_checkpoint_step

  subroutine write_checkpoint(self, solver, timestep, comm)
    !! Write a checkpoint file for simulation restart
    class(checkpoint_manager_t), intent(inout) :: self
    class(solver_t), intent(in) :: solver
    integer, intent(in) :: timestep
    integer, intent(in) :: comm

    character(len=256) :: filename, temp_filename, old_filename
    integer :: ierr, myrank
    character(len=*), parameter :: field_names(*) = ["u", "v", "w"]
    real(dp) :: simulation_time
    logical :: file_exists
    type(field_ptr_t), allocatable :: field_ptrs(:), host_fields(:)
    integer :: data_loc
    type(writer_session_t) :: writer_session

    if (self%config%checkpoint_freq <= 0) return
    if (mod(timestep, self%config%checkpoint_freq) /= 0) return

    call MPI_Comm_rank(comm, myrank, ierr)

    write (filename, '(A,A,I0.6,A)') &
      trim(self%config%checkpoint_prefix), '_', timestep, '.bp'
    write (temp_filename, '(A,A)') &
      trim(self%config%checkpoint_prefix), '_temp.bp'

    call writer_session%open(temp_filename, comm)
    if (writer_session%is_session_functional() .and. myrank == 0) then
      print *, 'Writing checkpoint: ', trim(filename)
    end if

    simulation_time = timestep*solver%dt
    data_loc = solver%u%data_loc
    call writer_session%write_data("timestep", timestep)
    call writer_session%write_data("time", real(simulation_time, dp))
    call writer_session%write_data("dt", real(solver%dt, dp))
    call writer_session%write_data("data_loc", data_loc)

    call setup_field_arrays(solver, field_names, field_ptrs, host_fields)

    call self%write_fields( &
      field_names, host_fields, &
      solver, writer_session, data_loc &
      )

    call writer_session%close()

    call cleanup_field_arrays(solver, field_ptrs, host_fields)

    if (myrank == 0) then
      inquire (file=trim(temp_filename), exist=file_exists)
      if (file_exists) then
        ! Move temporary file to final checkpoint filename
        call execute_command_line('mv '//trim(temp_filename)//' '// &
                                  trim(filename))

        inquire (file=trim(filename), exist=file_exists)
        if (.not. file_exists) then
          print *, 'ERROR: Checkpoint file not created: ', trim(filename)
        end if
      else
        ! temp file doesn't exist - skip file operations silently
      end if

      ! Remove old checkpoint if configured to keep only the latest
      if (.not. self%config%keep_checkpoint &
          .and. self%last_checkpoint_step > 0) then
        write (old_filename, '(A,A,I0.6,A)') &
          trim(self%config%checkpoint_prefix), '_', &
          self%last_checkpoint_step, '.bp'
        inquire (file=trim(old_filename), exist=file_exists)
        if (file_exists) then
          call execute_command_line('rm -rf '//trim(old_filename), &
                                    exitstat=ierr)
          if (ierr /= 0) then
            print *, 'WARNING: failed to remove old checkpoint: ', &
              trim(old_filename)
          end if
        end if
      end if
    end if

    self%last_checkpoint_step = timestep
  end subroutine write_checkpoint

  subroutine restart_checkpoint( &
    self, solver, filename, timestep, restart_time, comm &
    )
    !! Restart simulation state from checkpoint file
    class(checkpoint_manager_t), intent(inout) :: self
    class(solver_t), intent(inout) :: solver
    character(len=*), intent(in) :: filename
    integer, intent(out) :: timestep
    real(dp), intent(out) :: restart_time
    integer, intent(in) :: comm

    type(reader_session_t) :: reader_session
    integer :: ierr
    integer :: dims(3)
    integer(i8), dimension(3) :: start_dims, count_dims
    character(len=*), parameter :: field_names(3) = ["u", "v", "w"]
    logical :: file_exists
    integer :: data_loc

    inquire (file=filename, exist=file_exists)
    if (.not. file_exists) then
      if (solver%mesh%par%is_root()) then
        print *, 'ERROR: Checkpoint file not found: ', trim(filename)
      end if
      call MPI_Abort(comm, 1, ierr)
      return
    end if

    call reader_session%open(filename, comm)
    call reader_session%read_data("timestep", timestep)
    call reader_session%read_data("time", restart_time)
    call reader_session%read_data("data_loc", data_loc)

    dims = solver%mesh%get_dims(data_loc)
    start_dims = int(solver%mesh%par%n_offset, i8)
    count_dims = int(dims, i8)

    call solver%u%set_data_loc(data_loc)
    call solver%v%set_data_loc(data_loc)
    call solver%w%set_data_loc(data_loc)

    block
      real(dp), allocatable, target :: field_data_u(:, :, :)
      real(dp), allocatable, target :: field_data_v(:, :, :)
      real(dp), allocatable, target :: field_data_w(:, :, :)

      allocate (field_data_u(count_dims(1), count_dims(2), count_dims(3)))
      allocate (field_data_v(count_dims(1), count_dims(2), count_dims(3)))
      allocate (field_data_w(count_dims(1), count_dims(2), count_dims(3)))
      call reader_session%read_data("u", field_data_u)
      call reader_session%read_data("v", field_data_v)
      call reader_session%read_data("w", field_data_w)
      call solver%backend%set_field_data(solver%u, field_data_u)
      call solver%backend%set_field_data(solver%v, field_data_v)
      call solver%backend%set_field_data(solver%w, field_data_w)
    end block

    call reader_session%close()
  end subroutine restart_checkpoint

  subroutine write_fields( &
    self, field_names, host_fields, solver, writer_session, data_loc &
    )
    !! Write field data for checkpoints (no striding)
    class(checkpoint_manager_t), intent(inout) :: self
    character(len=*), dimension(:), intent(in) :: field_names
    class(field_ptr_t), dimension(:), target, intent(in) :: host_fields
    class(solver_t), intent(in) :: solver
    type(writer_session_t), intent(inout) :: writer_session
    integer, intent(in) :: data_loc

    integer :: i_field
    integer(i8), dimension(3) :: output_start, output_count
    integer, dimension(3) :: output_dims_local

    ! Prepare buffers for full resolution (no striding for checkpoints)
    call prepare_field_buffers( &
      solver, self%full_resolution, field_names, data_loc, &
      self%field_buffers, self%last_shape_dims, self%last_stride_factors, &
      self%last_output_shape &
      )

    ! Calculate output dimensions for writing
    call get_output_dimensions( &
      int(solver%mesh%get_global_dims(data_loc), i8), &
      int(solver%mesh%par%n_offset, i8), &
      int(solver%mesh%get_dims(data_loc), i8), &
      self%full_resolution, &
      self%last_output_shape, output_start, output_count, &
      output_dims_local, &
      self%last_shape_dims, self%last_stride_factors, &
      self%last_output_shape &
      )

    do i_field = 1, size(field_names)
      call write_single_field_to_buffer( &
        trim(field_names(i_field)), host_fields(i_field)%ptr, &
        solver, self%full_resolution, data_loc, &
        self%field_buffers, self%last_shape_dims, self%last_stride_factors, &
        self%last_output_shape &
        )

      call writer_session%write_data( &
        trim(field_names(i_field)), &
        self%field_buffers(i_field)%buffer, &
        start_dims=output_start, count_dims=output_count &
        )
    end do
  end subroutine write_fields

  subroutine cleanup_output_buffers(self)
    !! Clean up dynamic field buffers
    class(checkpoint_manager_t), intent(inout) :: self

    call cleanup_field_buffers(self%field_buffers)
  end subroutine cleanup_output_buffers

  subroutine finalise(self)
    !! Clean up checkpoint manager
    class(checkpoint_manager_t), intent(inout) :: self

    call self%cleanup_output_buffers()
  end subroutine finalise

end module m_checkpoint_manager