diff --git a/tools/run_sim_puremd.py b/tools/run_sim_puremd.py
index c6c68500d34fdd34534b26c4946cc09f01506459..c5fafffa13e5b0f60128d70ad431e10673406121 100644
--- a/tools/run_sim_puremd.py
+++ b/tools/run_sim_puremd.py
@@ -93,15 +93,34 @@ class TestCase():
         fp_temp.write(lines)
         fp_temp.close()
 
-    def run(self, binary, process_results=False):
+    def run(self, binary, process_results=False, mpi_cmd=['mpirun']):
         from operator import mul
         from functools import reduce
 
-        # command to run as subprocess
-        args = [
-                'mpirun',
-                '-np',
-                '0', # placeholder, substituted below
+        # add MPI execution command to subprocess argument list
+        if mpi_cmd[0] == 'mpirun':
+            args = [
+                    'mpirun',
+                    '-np',
+                    '0', # placeholder, substituted below
+            ]
+        elif mpi_cmd[0] == 'srun':
+            # slurm scheduler wraps MPI commands (e.g., NERSC)
+            args = [
+                    'srun',
+                    '-N',
+                    '0', # placeholder, substituted below
+                    '-n',
+                    '0', # placeholder, substituted below
+                    '-c',
+                    '0', # placeholder, substituted below
+            ]
+        else:
+            print("[ERROR] Invalid MPI application type ({0}). Terminating...".format(mpi_cmd[0]))
+            exit(-1)
+
+        # command to run as subprocess (as list of command and arguments)
+        args += [
                 binary,
                 self.__geo_file,
                 self.__ffield_file,
@@ -117,9 +136,6 @@ class TestCase():
             fout.write(self.__result_header_fmt.format(*self.__result_header))
             fout.flush()
 
-        temp_dir = mkdtemp()
-        temp_file = path.join(temp_dir, path.basename(self.__control_file))
-
         for p in product(*[self.__params[k] for k in self.__param_names]):
             param_dict = dict((k, v) for (k, v) in zip(self.__param_names, p))
             param_dict['name'] = path.basename(self.__geo_file).split('.')[0] \
@@ -138,15 +154,32 @@ class TestCase():
                 + '_pcsai' + param_dict['cm_solver_pre_comp_sai_thres'] \
                 + '_pa' + param_dict['cm_solver_pre_app_type'] \
                 + '_paji'+ str(param_dict['cm_solver_pre_app_jacobi_iters'])
+
+            temp_dir = mkdtemp()
+            temp_file = path.join(temp_dir, path.basename(self.__control_file))
         
-            args[2] = str(reduce(mul,
-                map(int, param_dict['proc_by_dim'].split(':')), 1))
+            # add MPI execution arguments to subprocess argument list
+            if mpi_cmd[0] == 'mpirun':
+                # number of MPI processes
+                args[2] = str(reduce(mul,
+                    map(int, param_dict['proc_by_dim'].split(':')), 1))
+            elif mpi_cmd[0] == 'srun':
+                # number of nodes
+                args[2] = mpi_cmd[1]
+                # number of tasks
+                args[4] = mpi_cmd[2]
+                # number of cores per task
+                args[6] = mpi_cmd[3]
             
             if not process_results:
                 self._setup(param_dict, temp_file)
                 #env['OMP_NUM_THREADS'] = param_dict['threads']
                 start = time()
-                args[6] = temp_file;
+                # add generated control file to subprocess argument list
+                if mpi_cmd[0] == 'mpirun':
+                    args[6] = temp_file;
+                elif mpi_cmd[0] == 'srun':
+                    args[10] = temp_file;
                 proc_handle = Popen(args, stdout=PIPE, stderr=PIPE, env=env, universal_newlines=True)
                 stdout, stderr = proc_handle.communicate()
                 stop = time()
@@ -161,11 +194,12 @@ class TestCase():
             else:
                 self._process_result(fout, param_dict, self.__min_step, self.__max_step)
 
+            if path.exists(temp_file):
+                remove(temp_file)
+            if path.exists(temp_dir):
+                rmdir(temp_dir)
+
         fout.close()
-        if path.exists(temp_file):
-            remove(temp_file)
-        if path.exists(temp_dir):
-            rmdir(temp_dir)
 
     def _process_result(self, fout, param, min_step, max_step):
         from operator import mul
@@ -177,6 +211,7 @@ class TestCase():
         init = 0.
         bonded = 0.
         nonbonded = 0.
+        time = 0.
         cm = 0.
         cm_sort = 0.
         s_iters = 0.
@@ -187,6 +222,7 @@ class TestCase():
         s_spmv = 0.
         s_vec_ops = 0.
         cnt = 0
+        cnt_valid = 0
         line_cnt = 0
         log_file = param['name'] + '.log'
 
@@ -198,28 +234,33 @@ class TestCase():
                 line = line.split()
                 try:
                     if (not min_step and not max_step) or \
-                    (min_step and not max_step and cnt >= min_step) or \
-                    (not min_step and max_step and cnt <= max_step) or \
-                    (cnt >= min_step and cnt <= max_step):
-                        if ( int(line[0]) < int(param['nsteps']) ):
-                            total = total + float(line[1])
-                            comm = comm + float(line[2])
-                            neighbors = neighbors + float(line[3])
-                            init = init + float(line[4])
-                            bonded = bonded + float(line[5])
-                            nonbonded = nonbonded + float(line[6])
-                            cm = cm + float(line[7])
-                            cm_sort = cm_sort + float(line[8])
-                            s_iters = s_iters + float(line[9])
-                            pre_comp = pre_comp + float(line[10])
-                            pre_app = pre_app + float(line[11])
-                            s_comm = s_comm + float(line[12])
-                            s_allr = s_allr + float(line[13])
-                            s_spmv = s_spmv + float(line[14])
-                            s_vec_ops = s_vec_ops + float(line[15])
-                            cnt = cnt + 1
+                    (min_step and not max_step and cnt_valid >= min_step) or \
+                    (not min_step and max_step and cnt_valid <= max_step) or \
+                    (cnt_valid >= min_step and cnt_valid <= max_step):
+                        total = total + float(line[1])
+                        comm = comm + float(line[2])
+                        neighbors = neighbors + float(line[3])
+                        init = init + float(line[4])
+                        bonded = bonded + float(line[5])
+                        nonbonded = nonbonded + float(line[6])
+                        cm = cm + float(line[7])
+                        cm_sort = cm_sort + float(line[8])
+                        s_iters = s_iters + float(line[9])
+                        pre_comp = pre_comp + float(line[10])
+                        pre_app = pre_app + float(line[11])
+                        s_comm = s_comm + float(line[12])
+                        s_allr = s_allr + float(line[13])
+                        s_spmv = s_spmv + float(line[14])
+                        s_vec_ops = s_vec_ops + float(line[15])
+                        cnt = cnt + 1
+                    cnt_valid = cnt_valid + 1
                 except Exception:
                     pass
+                if line[0] == 'total:':
+                    try:
+                        time = float(line[1])
+                    except Exception:
+                        pass
                 line_cnt = line_cnt + 1
             if cnt > 0:
                 comm = comm / cnt
@@ -281,6 +322,8 @@ if __name__ == '__main__':
             help='Process simulation results only (do not perform simulations).')
     parser.add_argument('-p', '--params', metavar='params', action='append', default=None, nargs=2,
             help='Paramater name and value pairs for the simulation, which multiple values comma delimited.')
+    parser.add_argument('-m', '--mpi_cmd', metavar='mpi_cmd', default='mpirun', nargs=1,
+            help='MPI command type and arguments. Examples: \'mpirun\', \'srun:1:32:1\'.')
     parser.add_argument('-n', '--min_step', metavar='min_step', default=None, nargs=1,
             help='Minimum simulation step to begin aggregating results.')
     parser.add_argument('-x', '--max_step', metavar='max_step', default=None, nargs=1,
@@ -333,7 +376,7 @@ if __name__ == '__main__':
     if args.binary:
         binary = args.binary[0]
     else:
-        binary  = path.join(base_dir, 'PuReMD/bin/puremd')
+        binary = path.join(base_dir, 'PuReMD/bin/puremd')
 
     # overwrite default params, if supplied via command line args
     if args.params:
@@ -447,4 +490,5 @@ if __name__ == '__main__':
                 min_step=min_step, max_step=max_step))
 
     for test in test_cases:
-        test.run(binary, process_results=args.process_results)
+        test.run(binary, process_results=args.process_results,
+                mpi_cmd=args.mpi_cmd.split(':'))