From 9a8bd2d3eae3b5ee57c21a6de60a58430fef4c38 Mon Sep 17 00:00:00 2001
From: "Kurt A. O'Hearn" <ohearnk@msu.edu>
Date: Tue, 17 Apr 2018 15:00:03 -0400
Subject: [PATCH] sPuReMD: add BiCGStab solver.

---
 sPuReMD/src/charges.c |  20 +++++++
 sPuReMD/src/init_md.c |  33 +++++++++--
 sPuReMD/src/lin_alg.c | 126 +++++++++++++++++++++++++++++++++++++++---
 sPuReMD/src/lin_alg.h |   4 ++
 sPuReMD/src/mytypes.h |   4 +-
 5 files changed, 174 insertions(+), 13 deletions(-)

diff --git a/sPuReMD/src/charges.c b/sPuReMD/src/charges.c
index 17c013e6..ca98e585 100644
--- a/sPuReMD/src/charges.c
+++ b/sPuReMD/src/charges.c
@@ -1342,6 +1342,14 @@ static void QEq( reax_system * const system, control_params * const control,
                       workspace->t[0], FALSE ) + 1;
         break;
 
+    case BiCGStab_S:
+        iters = BiCGStab( workspace, control, data, workspace->H, workspace->b_s, control->cm_solver_q_err,
+                workspace->s[0], (control->cm_solver_pre_comp_refactor > 0 &&
+                 (data->step - data->prev_steps) % control->cm_solver_pre_comp_refactor == 0) ? TRUE : FALSE ) + 1;
+        iters += BiCGStab( workspace, control, data, workspace->H, workspace->b_t, control->cm_solver_q_err,
+                workspace->t[0], FALSE ) + 1;
+        break;
+
     default:
         fprintf( stderr, "Unrecognized QEq solver selection. Terminating...\n" );
         exit( INVALID_INPUT );
@@ -1416,6 +1424,12 @@ static void EE( reax_system * const system, control_params * const control,
                  (data->step - data->prev_steps) % control->cm_solver_pre_comp_refactor == 0) ? TRUE : FALSE ) + 1;
         break;
 
+    case BiCGStab_S:
+        iters = BiCGStab( workspace, control, data, workspace->H, workspace->b_s, control->cm_solver_q_err,
+                workspace->s[0], (control->cm_solver_pre_comp_refactor > 0 &&
+                 (data->step - data->prev_steps) % control->cm_solver_pre_comp_refactor == 0) ? TRUE : FALSE ) + 1;
+        break;
+
     default:
         fprintf( stderr, "Unrecognized EE solver selection. Terminating...\n" );
         exit( INVALID_INPUT );
@@ -1499,6 +1513,12 @@ static void ACKS2( reax_system * const system, control_params * const control,
                  (data->step - data->prev_steps) % control->cm_solver_pre_comp_refactor == 0) ? TRUE : FALSE ) + 1;
         break;
 
+    case BiCGStab_S:
+        iters = BiCGStab( workspace, control, data, workspace->H, workspace->b_s, control->cm_solver_q_err,
+                workspace->s[0], (control->cm_solver_pre_comp_refactor > 0 &&
+                 (data->step - data->prev_steps) % control->cm_solver_pre_comp_refactor == 0) ? TRUE : FALSE ) + 1;
+        break;
+
     default:
         fprintf( stderr, "Unrecognized ACKS2 solver selection. Terminating...\n" );
         exit( INVALID_INPUT );
diff --git a/sPuReMD/src/init_md.c b/sPuReMD/src/init_md.c
index d098007c..02a45310 100644
--- a/sPuReMD/src/init_md.c
+++ b/sPuReMD/src/init_md.c
@@ -352,7 +352,7 @@ static void Init_Workspace( reax_system *system, control_params *control,
     }
 
     //TODO: check if unused
-    //workspace->w        = (real *) scalloc( cm_lin_sys_size, sizeof( real ),
+    //workspace->w = (real *) scalloc( cm_lin_sys_size, sizeof( real ),
     //"Init_Workspace::workspace->droptol" );
     //TODO: check if unused
     workspace->b = (real *) scalloc( system->N_cm * 2, sizeof( real ),
@@ -433,7 +433,6 @@ static void Init_Workspace( reax_system *system, control_params *control,
 
     switch ( control->cm_solver_type )
     {
-        /* GMRES storage */
         case GMRES_S:
         case GMRES_H_S:
             workspace->y = (real *) scalloc( control->cm_solver_restart + 1, sizeof( real ),
@@ -473,7 +472,6 @@ static void Init_Workspace( reax_system *system, control_params *control,
                     "Init_Workspace::workspace->p" );
             break;
 
-        /* CG storage */
         case CG_S:
             workspace->r = (real *) scalloc( system->N_cm, sizeof( real ),
                     "Init_Workspace::workspace->r" );
@@ -494,6 +492,23 @@ static void Init_Workspace( reax_system *system, control_params *control,
                     "Init_Workspace::workspace->q" );
             break;
 
+        case BiCGStab_S:
+            workspace->r = (real *) scalloc( system->N_cm, sizeof( real ),
+                    "Init_Workspace::workspace->r" );
+            workspace->r_hat = (real *) scalloc( system->N_cm, sizeof( real ),
+                    "Init_Workspace::workspace->r_hat" );
+            workspace->d = (real *) scalloc( system->N_cm, sizeof( real ),
+                    "Init_Workspace::workspace->d" );
+            workspace->q = (real *) scalloc( system->N_cm, sizeof( real ),
+                    "Init_Workspace::workspace->q" );
+            workspace->p = (real *) scalloc( system->N_cm, sizeof( real ),
+                    "Init_Workspace::workspace->p" );
+            workspace->y = (real *) scalloc( system->N_cm, sizeof( real ),
+                    "Init_Workspace::workspace->y" );
+            workspace->z = (real *) scalloc( system->N_cm, sizeof( real ),
+                    "Init_Workspace::workspace->z" );
+            break;
+
         default:
             fprintf( stderr, "Unknown charge method linear solver type. Terminating...\n" );
             exit( INVALID_INPUT );
@@ -1217,7 +1232,6 @@ static void Finalize_Workspace( reax_system *system, control_params *control,
 
     switch ( control->cm_solver_type )
     {
-        /* GMRES storage */
         case GMRES_S:
         case GMRES_H_S:
             for ( i = 0; i < control->cm_solver_restart + 1; ++i )
@@ -1242,7 +1256,6 @@ static void Finalize_Workspace( reax_system *system, control_params *control,
             sfree( workspace->p, "Finalize_Workspace::workspace->p" );
             break;
 
-        /* CG storage */
         case CG_S:
             sfree( workspace->r, "Finalize_Workspace::workspace->r" );
             sfree( workspace->d, "Finalize_Workspace::workspace->d" );
@@ -1256,6 +1269,16 @@ static void Finalize_Workspace( reax_system *system, control_params *control,
             sfree( workspace->q, "Finalize_Workspace::workspace->q" );
             break;
 
+        case BiCGStab_S:
+            sfree( workspace->r, "Finalize_Workspace::workspace->r" );
+            sfree( workspace->r_hat, "Finalize_Workspace::workspace->r_hat" );
+            sfree( workspace->d, "Finalize_Workspace::workspace->d" );
+            sfree( workspace->q, "Finalize_Workspace::workspace->q" );
+            sfree( workspace->p, "Finalize_Workspace::workspace->p" );
+            sfree( workspace->y, "Finalize_Workspace::workspace->y" );
+            sfree( workspace->z, "Finalize_Workspace::workspace->z" );
+            break;
+
         default:
             fprintf( stderr, "Unknown charge method linear solver type. Terminating...\n" );
             exit( INVALID_INPUT );
diff --git a/sPuReMD/src/lin_alg.c b/sPuReMD/src/lin_alg.c
index bcf78513..ac3c2166 100644
--- a/sPuReMD/src/lin_alg.c
+++ b/sPuReMD/src/lin_alg.c
@@ -3162,6 +3162,7 @@ int GMRES( const static_storage * const workspace, const control_params * const
     if ( g_itr >= control->cm_solver_max_iters )
     {
         fprintf( stderr, "[WARNING] GMRES convergence failed (%d outer iters)\n", g_itr );
+        fprintf( stderr, "  [INFO] Rel. residual error: %f\n", FABS(workspace->g[g_j]) / bnorm );
         return g_itr * (control->cm_solver_restart + 1) + g_j + 1;
     }
 
@@ -3403,6 +3404,7 @@ int GMRES_HouseHolder( const static_storage * const workspace,
     if ( g_itr >= control->cm_solver_max_iters )
     {
         fprintf( stderr, "[WARNING] GMRES convergence failed (%d outer iters)\n", g_itr );
+        fprintf( stderr, "  [INFO] Rel. residual error: %f\n", FABS(w[g_j]) / bnorm );
         return g_itr * (control->cm_solver_restart + 1) + j + 1;
     }
 
@@ -3461,7 +3463,7 @@ int CG( const static_storage * const workspace, const control_params * const con
 
         t_start = Get_Time( );
         Vector_Copy( p, z, N );
-        sig_new = Dot( r, z, N );
+        sig_new = Dot( r, p, N );
         t_vops += Get_Timing_Info( t_start );
 
         for ( i = 0; i < control->cm_solver_max_iters && r_norm / b_norm > tol; ++i )
@@ -3474,7 +3476,7 @@ int CG( const static_storage * const workspace, const control_params * const con
             tmp = Dot( d, p, N );
             alpha = sig_new / tmp;
             Vector_Add( x, alpha, p, N );
-            Vector_Add( r, -alpha, d, N );
+            Vector_Add( r, -1.0 * alpha, d, N );
             r_norm = Norm( r, N );
             t_vops += Get_Timing_Info( t_start );
 
@@ -3487,7 +3489,7 @@ int CG( const static_storage * const workspace, const control_params * const con
             sig_old = sig_new;
             sig_new = Dot( r, z, N );
             beta = sig_new / sig_old;
-            Vector_Sum( p, 1., z, beta, p, N );
+            Vector_Sum( p, 1.0, z, beta, p, N );
             t_vops += Get_Timing_Info( t_start );
         }
 
@@ -3504,6 +3506,115 @@ int CG( const static_storage * const workspace, const control_params * const con
     if ( g_itr >= control->cm_solver_max_iters )
     {
         fprintf( stderr, "[WARNING] CG convergence failed (%d iters)\n", g_itr );
+        fprintf( stderr, "  [INFO] Rel. residual error: %f\n", r_norm / b_norm );
+        return g_itr;
+    }
+
+    return g_itr;
+}
+
+
+/* Bi-conjugate gradient stabalized method with left preconditioning for
+ * solving nonsymmetric linear systems
+ *
+ * workspace: struct containing storage for workspace for the linear solver
+ * control: struct containing parameters governing the simulation and numeric methods
+ * data: struct containing simulation data (e.g., atom info)
+ * H: sparse, symmetric matrix, lower half stored in CSR format
+ * b: right-hand side of the linear system
+ * tol: tolerence compared against the relative residual for determining convergence
+ * x: inital guess
+ * fresh_pre: flag for determining if preconditioners should be recomputed
+ * */
+int BiCGStab( const static_storage * const workspace, const control_params * const control,
+        simulation_data * const data, const sparse_matrix * const H, const real * const b,
+        const real tol, real * const x, const int fresh_pre )
+{
+    int i, g_itr, N;
+    real tmp, alpha, beta, omega, rho, rho_old, sigma, r_norm, b_norm;
+    real t_start, t_pa, t_spmv, t_vops;
+
+    N = H->n;
+    t_pa = 0.0;
+    t_spmv = 0.0;
+    t_vops = 0.0;
+
+#ifdef _OPENMP
+    #pragma omp parallel default(none) \
+    private(i, tmp, alpha, beta, omega, rho, rho_old, sigma, r_norm, b_norm, t_start) \
+    reduction(+: t_pa, t_spmv, t_vops) \
+    shared(g_itr, N)
+#endif
+    {
+        t_pa = 0.0;
+        t_spmv = 0.0;
+        t_vops = 0.0;
+
+        t_start = Get_Time( );
+        Sparse_MatVec( workspace, H, x, workspace->d );
+        t_spmv += Get_Timing_Info( t_start );
+
+        t_start = Get_Time( );
+        b_norm = Norm( b, N );
+        Vector_Sum( workspace->r, 1.0,  b, -1.0, workspace->d, N );
+        t_vops += Get_Timing_Info( t_start );
+
+        t_start = Get_Time( );
+        Vector_Copy( workspace->r_hat, workspace->r, N );
+        r_norm = Norm( workspace->r, N );
+        Vector_Copy( workspace->p, workspace->r, N );
+        rho_old = Dot( workspace->r, workspace->r_hat, N );
+        t_vops += Get_Timing_Info( t_start );
+
+        for ( i = 0; i < control->cm_solver_max_iters && r_norm / b_norm > tol; ++i )
+        {
+            t_start = Get_Time( );
+            Sparse_MatVec( workspace, H, workspace->p, workspace->d );
+            t_spmv += Get_Timing_Info( t_start );
+
+            t_start = Get_Time( );
+            tmp = Dot( workspace->d, workspace->r_hat, N );
+            alpha = rho_old / tmp;
+            Vector_Sum( workspace->q, 1.0, workspace->r, -1.0 * alpha, workspace->d, N );
+            t_vops += Get_Timing_Info( t_start );
+
+            t_start = Get_Time( );
+            Sparse_MatVec( workspace, H, workspace->q, workspace->y );
+            t_spmv += Get_Timing_Info( t_start );
+
+            t_start = Get_Time( );
+            sigma = Dot( workspace->y, workspace->q, N );
+            tmp = Dot( workspace->y, workspace->y, N );
+            omega = sigma / tmp;
+            Vector_Sum( workspace->z, alpha, workspace->p, omega, workspace->q, N );
+            Vector_Add( x, 1.0, workspace->z, N );
+            Vector_Sum( workspace->r, 1.0, workspace->q, -1.0 * omega, workspace->y, N );
+            r_norm = Norm( workspace->r, N );
+            t_vops += Get_Timing_Info( t_start );
+
+            t_start = Get_Time( );
+            rho = Dot( workspace->r, workspace->r_hat, N );
+            beta = (rho / rho_old) * (alpha / omega);
+            Vector_Sum( workspace->z, 1.0, workspace->p, -1.0 * omega, workspace->d, N );
+            Vector_Sum( workspace->p, 1.0, workspace->r, beta, workspace->z, N );
+            rho_old = rho;
+            t_vops += Get_Timing_Info( t_start );
+        }
+
+#ifdef _OPENMP
+        #pragma omp single
+#endif
+        g_itr = i;
+    }
+
+    data->timing.cm_solver_pre_app += t_pa / control->num_threads;
+    data->timing.cm_solver_spmv += t_spmv / control->num_threads;
+    data->timing.cm_solver_vector_ops += t_vops / control->num_threads;
+
+    if ( g_itr >= control->cm_solver_max_iters )
+    {
+        fprintf( stderr, "[WARNING] BiCGStab convergence failed (%d iters)\n", g_itr );
+        fprintf( stderr, "  [INFO] Rel. residual error: %f\n", r_norm / b_norm );
         return g_itr;
     }
 
@@ -3550,8 +3661,8 @@ int SDM( const static_storage * const workspace, const control_params * const co
         t_vops += Get_Timing_Info( t_start );
 
         t_start = Get_Time( );
-        apply_preconditioner( workspace, control, workspace->r, workspace->d, fresh_pre, LEFT );
-        apply_preconditioner( workspace, control, workspace->r, workspace->d, fresh_pre, RIGHT );
+        apply_preconditioner( workspace, control, workspace->r, workspace->q, fresh_pre, LEFT );
+        apply_preconditioner( workspace, control, workspace->q, workspace->d, fresh_pre, RIGHT );
         t_pa += Get_Timing_Info( t_start );
 
         t_start = Get_Time( );
@@ -3583,8 +3694,8 @@ int SDM( const static_storage * const workspace, const control_params * const co
             t_vops += Get_Timing_Info( t_start );
 
             t_start = Get_Time( );
-            apply_preconditioner( workspace, control, workspace->r, workspace->d, FALSE, LEFT );
-            apply_preconditioner( workspace, control, workspace->r, workspace->d, FALSE, RIGHT );
+            apply_preconditioner( workspace, control, workspace->r, workspace->q, FALSE, LEFT );
+            apply_preconditioner( workspace, control, workspace->q, workspace->d, FALSE, RIGHT );
             t_pa += Get_Timing_Info( t_start );
         }
 
@@ -3601,6 +3712,7 @@ int SDM( const static_storage * const workspace, const control_params * const co
     if ( g_itr >= control->cm_solver_max_iters  )
     {
         fprintf( stderr, "[WARNING] SDM convergence failed (%d iters)\n", g_itr );
+        fprintf( stderr, "  [INFO] Rel. residual error: %f\n", SQRT(sig) / b_norm );
         return g_itr;
     }
 
diff --git a/sPuReMD/src/lin_alg.h b/sPuReMD/src/lin_alg.h
index ab0237c1..696fdc9b 100644
--- a/sPuReMD/src/lin_alg.h
+++ b/sPuReMD/src/lin_alg.h
@@ -101,6 +101,10 @@ int CG( const static_storage * const, const control_params * const,
         simulation_data * const, const sparse_matrix * const, const real * const,
         const real, real * const, const int );
 
+int BiCGStab( const static_storage * const, const control_params * const,
+        simulation_data * const, const sparse_matrix * const, const real * const,
+        const real, real * const, const int );
+
 int SDM( const static_storage * const, const control_params * const,
         simulation_data * const, const sparse_matrix * const, const real * const, const real,
         real * const, const int );
diff --git a/sPuReMD/src/mytypes.h b/sPuReMD/src/mytypes.h
index 0abe07f2..c2712e83 100644
--- a/sPuReMD/src/mytypes.h
+++ b/sPuReMD/src/mytypes.h
@@ -219,6 +219,7 @@ enum solver
     GMRES_H_S = 1,
     CG_S = 2,
     SDM_S = 3,
+    BiCGStab_S = 4,
 };
 
 enum pre_comp
@@ -980,8 +981,9 @@ typedef struct
     real **rn;
     real **v;
 
-    /* CG related storage */
+    /* CG, SDM, BiCGStab related storage */
     real *r;
+    real *r_hat;
     real *d;
     real *q;
     real *p;
-- 
GitLab