From 252684e9fd83a9e84ed4f4e990db9543cee2b372 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sebastian=20M=C3=BCller?= <mueller.seb@posteo.de>
Date: Wed, 11 Dec 2024 13:54:49 +0100
Subject: [PATCH] mcmc: refactor to use optimizee; update tests

---
 src/mo_mcmc.F90              | 41 +++++++++++++++++-------------------
 src/pf_tests/test_mo_mcmc.pf | 24 ++++++++++++---------
 2 files changed, 33 insertions(+), 32 deletions(-)

diff --git a/src/mo_mcmc.F90 b/src/mo_mcmc.F90
index ab4e12a..5a5c522 100644
--- a/src/mo_mcmc.F90
+++ b/src/mo_mcmc.F90
@@ -15,7 +15,7 @@ MODULE mo_mcmc
   USE mo_append, only : append
   USE mo_moment, only : stddev
   !$ USE omp_lib,    only: OMP_GET_NUM_THREADS
-  use mo_optimization_utils, only : eval_interface, objective_interface
+  use mo_optimization_utils, only : optimizee
   use mo_message, only : error_message
 #ifdef FORCES_WITH_NETCDF
   use mo_ncwrite, only : dump_netcdf
@@ -132,7 +132,7 @@ MODULE mo_mcmc
   !!    \b Example
   !!
   !!    \code{.f90}
-  !!    call mcmc(  likelihood, para, rangePar, mcmc_paras, burnin_paras,                  &
+  !!    call mcmc(  objective, para, rangePar, mcmc_paras, burnin_paras,                  &
   !!                seed_in=seeds, printflag_in=printflag, maskpara_in=maskpara,           &
   !!                tmp_file=tmp_file, loglike_in=loglike,                                 &
   !!                ParaSelectMode_in=ParaSelectMode,                                      &
@@ -153,8 +153,7 @@ MODULE mo_mcmc
   !!         90-103. doi:10.1016/j.ecolmodel.2012.03.009.
   !!
 
-  !>    \param[in]  "real(dp) :: likelihood(x)"                    Interface Function which calculates likelihood
-  !!                                                                   of given parameter set x
+  !>    \param[inout] "class(optimizee) :: objective"              Likelihood objective to search the optimum
   !>    \param[in]  "real(dp) :: para(:)"                          Inital parameter set (should be GOOD approximation
   !!                                                                   of best parameter set)
   !>    \param[in]  "real(dp) :: rangePar(size(para),2)"           Min/max range of parameters
@@ -194,7 +193,7 @@ MODULE mo_mcmc
   !>    \param[out]  "real(dp), allocatable :: burnin_paras(:,:)"  Parameter sets sampled during burn-in part of algorithm
 
   !>    \note
-  !>    Likelihood has to be defined as a function interface\n
+  !>    Likelihood objective has to be implented as an extended type of optimizee\n
   !>    The maximal number of parameters is 1000.
 
   !>        \authors Juliane Mai
@@ -306,7 +305,7 @@ MODULE mo_mcmc
   !!    \b Example
   !!
   !!    \code{.f90}
-  !!    call mcmc(  likelihood, para, rangePar, mcmc_paras, burnin_paras,                  &
+  !!    call mcmc(  objective, para, rangePar, mcmc_paras, burnin_paras,                   &
   !!                seed_in=seeds, printflag_in=printflag, maskpara_in=maskpara,           &
   !!                tmp_file=tmp_file, loglike_in=loglike,                                 &
   !!                ParaSelectMode_in=ParaSelectMode,                                      &
@@ -327,8 +326,8 @@ MODULE mo_mcmc
   !!        90-103. doi:10.1016/j.ecolmodel.2012.03.009.
   !!
 
-  !>    \param[in]  "real(dp) :: likelihood(x,sigma,stddev_new,likeli_new)"
-  !>                                                                   Interface Function which calculates likelihood
+  !>    \param[inout] "class(optimizee) :: objective"                  Likelihood objective to search the optimum, with
+  !>                                                                   type-bound procedure 'evaluate' which calculates likelihood
   !>                                                                   of given parameter set x and given standard deviation sigma
   !>                                                                   and returns optionally the standard deviation stddev_new
   !>                                                                   of the errors using x and
@@ -430,7 +429,7 @@ CONTAINS
 
   !-----------------------------------------------------------------------------------------------
 
-  SUBROUTINE mcmc_dp(eval, likelihood, para, rangePar, &   ! obligatory IN
+  SUBROUTINE mcmc_dp(objective, para, rangePar, &   ! obligatory IN
           mcmc_paras, burnin_paras, &   ! obligatory OUT
           seed_in, printflag_in, maskpara_in, &   ! optional IN
           restart, restart_file, &   ! optional IN: if mcmc is restarted and file which contains restart variables
@@ -445,8 +444,7 @@ CONTAINS
 
     IMPLICIT NONE
 
-    procedure(eval_interface), INTENT(IN), POINTER :: eval
-    procedure(objective_interface), intent(in), pointer :: likelihood
+    class(optimizee), intent(inout) :: objective
 
     REAL(DP), DIMENSION(:, :), INTENT(IN) :: rangePar           ! range for each parameter
     REAL(DP), DIMENSION(:), INTENT(IN) :: para               ! initial parameter i
@@ -734,7 +732,7 @@ CONTAINS
       parabest = para
 
       ! initialize likelihood
-      likelibest = likelihood(parabest, eval)
+      likelibest = objective%evaluate(parabest)
 
       !----------------------------------------------------------------------
       ! (1) BURN IN
@@ -798,7 +796,7 @@ CONTAINS
                     paranew, ChangePara)
 
             ! (B) new likelihood
-            likelinew = likelihood(paranew, eval)
+            likelinew = objective%evaluate(paranew)
 
             oddsSwitch1 = .false.
             if (loglike) then
@@ -1039,7 +1037,7 @@ CONTAINS
         else
           paraold = mcmc_paras_3d(Ipos(chain) + Ineg(chain), :, chain)
         end if
-        likeliold = likelihood(paraold, eval)
+        likeliold = objective%evaluate(paraold)
 
         markovchainMCMC : do
 
@@ -1053,7 +1051,7 @@ CONTAINS
                   paranew, ChangePara)
 
           ! (B) new likelihood
-          likelinew = likelihood(paranew, eval)
+          likelinew = objective%evaluate(paranew)
           oddsSwitch1 = .false.
           if (loglike) then
             oddsRatio = likelinew - likeliold
@@ -1286,7 +1284,7 @@ CONTAINS
     RETURN
   END SUBROUTINE mcmc_dp
 
-  SUBROUTINE mcmc_stddev_dp(eval, likelihood, para, rangePar, &   ! obligatory IN
+  SUBROUTINE mcmc_stddev_dp(objective, para, rangePar, &   ! obligatory IN
           mcmc_paras, burnin_paras, &   ! obligatory OUT
           seed_in, printflag_in, maskpara_in, &   ! optional IN
           tmp_file, &   ! optional IN : filename for temporal output of
@@ -1300,8 +1298,7 @@ CONTAINS
 
     IMPLICIT NONE
 
-    procedure(eval_interface), INTENT(IN), POINTER :: eval
-    procedure(objective_interface), intent(in), pointer :: likelihood
+    class(optimizee), intent(inout) :: objective
 
     REAL(DP), DIMENSION(:, :), INTENT(IN) :: rangePar           ! range for each parameter
     REAL(DP), DIMENSION(:), INTENT(IN) :: para               ! initial parameter i
@@ -1533,7 +1530,7 @@ CONTAINS
     ! write(*,*) parabest
 
     ! initialize likelihood and sigma
-    likelibest = likelihood(parabest, eval, 1.0_dp, stddev_new, likeli_new)
+    likelibest = objective%evaluate(parabest, 1.0_dp, stddev_new, likeli_new)
     likelibest = likeli_new
     stddev_data = stddev_new
 
@@ -1612,7 +1609,7 @@ CONTAINS
                     paranew, ChangePara)
 
             ! (B) new likelihood
-            likelinew = likelihood(paranew, eval, stddev_data, stddev_new, likeli_new)
+            likelinew = objective%evaluate(paranew, stddev_data, stddev_new, likeli_new)
 
             oddsSwitch1 = .false.
             if (loglike) then
@@ -1794,7 +1791,7 @@ CONTAINS
         else
           paraold = mcmc_paras_3d(Ipos(chain) + Ineg(chain), :, chain)
         end if
-        likeliold = likelihood(paraold, eval, stddev_data)
+        likeliold = objective%evaluate(paraold, stddev_data)
 
         markovchainMCMC : do
 
@@ -1808,7 +1805,7 @@ CONTAINS
                   paranew, ChangePara)
 
           ! (B) new likelihood
-          likelinew = likelihood(paranew, eval, stddev_data)
+          likelinew = objective%evaluate(paranew, stddev_data)
           oddsSwitch1 = .false.
           if (loglike) then
             oddsRatio = likelinew - likeliold
diff --git a/src/pf_tests/test_mo_mcmc.pf b/src/pf_tests/test_mo_mcmc.pf
index 238cad7..a723f11 100644
--- a/src/pf_tests/test_mo_mcmc.pf
+++ b/src/pf_tests/test_mo_mcmc.pf
@@ -5,7 +5,7 @@ module test_mo_mcmc
   use mo_likelihood, only: setmeas, loglikelihood_dp, loglikelihood_stddev_dp, model_dp
   use mo_mcmc,       only: mcmc, mcmc_stddev
   use mo_moment,     only: mean, stddev
-  use mo_optimization_utils, only: eval_interface, objective_interface
+  use mo_optimization_utils, only: eval_interface, objective_interface, eval_optimizee
   use mo_message, only: error_message
 
   implicit none
@@ -13,6 +13,7 @@ module test_mo_mcmc
   ! for running mcmc
   procedure(eval_interface), pointer :: eval_func
   procedure(objective_interface), pointer :: likelihood
+  type(eval_optimizee) :: objective
 
   real(dp)                              :: p
   real(dp)                              :: likeli_new
@@ -46,8 +47,9 @@ contains
     write(*,*) ' (A1) "real" likelihood  (sigma is an error model_dp or given) --> e.g. loglikelihood of mo_likelihood'
     write(*,*) '      full run '
     write(*,*) '---------------------------------------------------------------------------------------------'
-    eval_func => model_dp
-    p = loglikelihood_dp(parabest, eval_func)
+    objective%eval_pointer => model_dp
+    objective%obj_pointer => loglikelihood_dp
+    p = objective%evaluate(parabest)
     write(*,*) 'log-likelihood = ',p
 
     ! initializing the ranges of parameters
@@ -64,8 +66,7 @@ contains
     !     (1) Burn-in will be performed to optimize settings for MCMC
     !     (2) posterior distribution of the parameters at the minimum (best parameterset)
     !         will be sampled by MCMC
-    likelihood => loglikelihood_dp
-    call mcmc(eval_func, likelihood, parabest, rangePar, mcmc_paras, burnin_paras, &
+    call mcmc(objective, parabest, rangePar, mcmc_paras, burnin_paras, &
         ParaSelectMode_in=2_i4,tmp_file='A_make_check_test_file',              &
         restart=.false., restart_file='restart_make_check_test_file', &
         seed_in=seed, loglike_in=.true., maskpara_in=maskpara, printflag_in=.true.)
@@ -114,7 +115,9 @@ contains
     write(*,*) '---------------------------------------------------------------------------------------------'
     ! starting MCMC:
     !     (1) starting from restart file
-    call mcmc(eval_func, likelihood, parabest, rangePar, mcmc_paras, burnin_paras, &
+    objective%eval_pointer => model_dp
+    objective%obj_pointer => loglikelihood_dp
+    call mcmc(objective, parabest, rangePar, mcmc_paras, burnin_paras, &
         ParaSelectMode_in=2_i4,tmp_file='A_make_check_test_file',              &
         restart=.true., restart_file='restart_make_check_test_file', &
         seed_in=seed, loglike_in=.true., maskpara_in=maskpara, printflag_in=.true.)
@@ -165,7 +168,9 @@ contains
     ! sigma of errors is unknown --> initial guess e.g. 1.0_dp
     !    stddev_new = standard deviation of errors using paraset
     !    likeli_new = likelihood using stddev_new
-    p = loglikelihood_stddev_dp(parabest,eval_func,1.0_dp,likeli_new=likeli_new,stddev_new=stddev_new)
+    objective%eval_pointer => model_dp
+    objective%obj_pointer => loglikelihood_stddev_dp
+    p = objective%evaluate(parabest,1.0_dp,likeli_new=likeli_new,stddev_new=stddev_new)
     write(*,*) 'guessed log-likelihood = ',p
     write(*,*) 'best log-likelihood    = ',likeli_new
 
@@ -183,8 +188,7 @@ contains
     !     (1) Burn-in will be performed to optimize settings for MCMC
     !     (2) posterior distribution of the parameters at the minimum (best parameterset)
     !         will be sampled by MCMC
-    likelihood => loglikelihood_stddev_dp
-    call mcmc_stddev(eval_func, likelihood, parabest, rangePar, mcmc_paras, burnin_paras, &
+    call mcmc_stddev(objective, parabest, rangePar, mcmc_paras, burnin_paras, &
         ParaSelectMode_in=2_i4,tmp_file='B_make_check_test_file',              &
         seed_in=seed, loglike_in=.true., maskpara_in=maskpara, printflag_in=.true.)
 
@@ -221,5 +225,5 @@ contains
     deallocate(burnin_paras)
 
   end subroutine test_mcmc_stddev
-  
+
 end module test_mo_mcmc
\ No newline at end of file
-- 
GitLab