From 7c57974da6bf21244a26adc08c013c6a214697dd Mon Sep 17 00:00:00 2001
From: "Kurt A. O'Hearn" <ohearnku@msu.edu>
Date: Wed, 23 Jan 2019 09:24:44 -0800
Subject: [PATCH] PuReMD-old: fix convergence checks for PIPECG and PIPECR.
 Update CG to used preconditioned residual in it's check.

---
 PuReMD/src/linear_solvers.c | 51 +++++++++++++++++++++++++------------
 1 file changed, 35 insertions(+), 16 deletions(-)

diff --git a/PuReMD/src/linear_solvers.c b/PuReMD/src/linear_solvers.c
index 276daed2..cda691ea 100644
--- a/PuReMD/src/linear_solvers.c
+++ b/PuReMD/src/linear_solvers.c
@@ -1794,11 +1794,11 @@ int CG( reax_system *system, control_params *control, simulation_data *data,
         storage *workspace, sparse_matrix *H, real *b,
         real tol, real *x, mpi_datatypes* mpi_data )
 {
-    int  i, j;
-    real tmp, alpha, beta, b_norm;
+    int i, j;
+    real tmp, alpha, beta, norm, b_norm;
     real sig_old, sig_new;
     real t_start, t_pa, t_spmv, t_vops, t_comm, t_allreduce;
-    real timings[5];
+    real timings[5], redux[3];
 
     t_pa = 0.0;
     t_spmv = 0.0;
@@ -1863,11 +1863,19 @@ int CG( reax_system *system, control_params *control, simulation_data *data,
     }
 
     t_start = MPI_Wtime( );
-    b_norm = Parallel_Norm( b, system->n, mpi_data->world );
-    sig_new = Parallel_Dot( workspace->r, workspace->d, system->n, mpi_data->world );
+    redux[0] = Dot_local( workspace->r, workspace->d, system->n );
+    redux[1] = Dot_local( workspace->d, workspace->d, system->n );
+    redux[2] = Dot_local( b, b, system->n );
+    t_vops += MPI_Wtime( ) - t_start;
+
+    t_start = MPI_Wtime( );
+    MPI_Allreduce( MPI_IN_PLACE, redux, 3, MPI_DOUBLE, MPI_SUM, mpi_data->world );
     t_allreduce += MPI_Wtime( ) - t_start;
+    sig_new = redux[0];
+    norm = sqrt( redux[1] );
+    b_norm = sqrt( redux[2] );
 
-    for ( i = 0; i < control->cm_solver_max_iters && sqrt(sig_new) / b_norm > tol; ++i )
+    for ( i = 0; i < control->cm_solver_max_iters && norm / b_norm > tol; ++i )
     {
         t_start = MPI_Wtime( );
         Dist( system, mpi_data, workspace->d, REAL_PTR_TYPE, MPI_DOUBLE );
@@ -1932,9 +1940,16 @@ int CG( reax_system *system, control_params *control, simulation_data *data,
         }
 
         t_start = MPI_Wtime( );
+        redux[0] = Dot_local( workspace->r, workspace->p, system->n );
+        redux[1] = Dot_local( workspace->p, workspace->p, system->n );
+        t_vops += MPI_Wtime( ) - t_start;
+
+        t_start = MPI_Wtime( );
+        MPI_Allreduce( MPI_IN_PLACE, redux, 2, MPI_DOUBLE, MPI_SUM, mpi_data->world );
+        t_allreduce += MPI_Wtime( ) - t_start;
         sig_old = sig_new;
-        sig_new = Parallel_Dot( workspace->r, workspace->p, system->n, mpi_data->world );
-        t_allreduce += MPI_Wtime() - t_start;
+        sig_new = redux[0];
+        norm = sqrt( redux[1] );
 
         t_start = MPI_Wtime( );
         beta = sig_new / sig_old;
@@ -1987,9 +2002,9 @@ int PIPECG( reax_system *system, control_params *control, simulation_data *data,
         real tol, real *x, mpi_datatypes* mpi_data )
 {
     int i, j;
-    real alpha, beta, delta, gamma_old, gamma_new, norm;
+    real alpha, beta, delta, gamma_old, gamma_new, norm, b_norm;
     real t_start, t_pa, t_spmv, t_vops, t_comm, t_allreduce;
-    real timings[5], redux[3];
+    real timings[5], redux[4];
     MPI_Request req;
 
     t_pa = 0.0;
@@ -2089,9 +2104,10 @@ int PIPECG( reax_system *system, control_params *control, simulation_data *data,
     redux[0] = Dot_local( workspace->w, workspace->u, system->n );
     redux[1] = Dot_local( workspace->r, workspace->u, system->n );
     redux[2] = Dot_local( workspace->u, workspace->u, system->n );
+    redux[3] = Dot_local( b, b, system->n );
     t_vops += MPI_Wtime( ) - t_start;
 
-    MPI_Iallreduce( MPI_IN_PLACE, redux, 3, MPI_DOUBLE, MPI_SUM, mpi_data->world, &req );
+    MPI_Iallreduce( MPI_IN_PLACE, redux, 4, MPI_DOUBLE, MPI_SUM, mpi_data->world, &req );
 
     /* pre-conditioning */
     if ( control->cm_solver_pre_comp_type == NONE_PC )
@@ -2155,8 +2171,9 @@ int PIPECG( reax_system *system, control_params *control, simulation_data *data,
     delta = redux[0];
     gamma_new = redux[1];
     norm = sqrt( redux[2] );
+    b_norm = sqrt( redux[3] );
 
-    for ( i = 0; i < control->cm_solver_max_iters && norm > tol; ++i )
+    for ( i = 0; i < control->cm_solver_max_iters && norm / b_norm > tol; ++i )
     {
         if ( i > 0 )
         {
@@ -2293,9 +2310,9 @@ int PIPECR( reax_system *system, control_params *control, simulation_data *data,
         real tol, real *x, mpi_datatypes* mpi_data )
 {
     int i, j;
-    real alpha, beta, delta, gamma_old, gamma_new, norm;
+    real alpha, beta, delta, gamma_old, gamma_new, norm, b_norm;
     real t_start, t_pa, t_spmv, t_vops, t_comm, t_allreduce;
-    real timings[5], redux[3];
+    real timings[5], redux[4];
     MPI_Request req;
 
     t_pa = 0.0;
@@ -2394,7 +2411,7 @@ int PIPECR( reax_system *system, control_params *control, simulation_data *data,
     //TODO: better loop unrolling and termination condition check
     norm = tol + 1.0;
 
-    for ( i = 0; i < control->cm_solver_max_iters && norm > tol; ++i )
+    for ( i = 0; i < control->cm_solver_max_iters && norm / b_norm > tol; ++i )
     {
         /* pre-conditioning */
         if ( control->cm_solver_pre_comp_type == NONE_PC )
@@ -2429,9 +2446,10 @@ int PIPECR( reax_system *system, control_params *control, simulation_data *data,
         redux[0] = Dot_local( workspace->w, workspace->u, system->n );
         redux[1] = Dot_local( workspace->m, workspace->w, system->n );
         redux[2] = Dot_local( workspace->u, workspace->u, system->n );
+        redux[3] = Dot_local( b, b, system->n );
         t_vops += MPI_Wtime( ) - t_start;
 
-        MPI_Iallreduce( MPI_IN_PLACE, redux, 3, MPI_DOUBLE, MPI_SUM, mpi_data->world, &req );
+        MPI_Iallreduce( MPI_IN_PLACE, redux, 4, MPI_DOUBLE, MPI_SUM, mpi_data->world, &req );
 
         t_start = MPI_Wtime( );
         Dist( system, mpi_data, workspace->m, REAL_PTR_TYPE, MPI_DOUBLE );
@@ -2466,6 +2484,7 @@ int PIPECR( reax_system *system, control_params *control, simulation_data *data,
         gamma_new = redux[0];
         delta = redux[1];
         norm = sqrt( redux[2] );
+        b_norm = sqrt( redux[3] );
 
         if ( i > 0 )
         {
-- 
GitLab