diff --git a/.gitignore b/.gitignore
index 7ca6e7e00f515f5981a7853d48f28d1848a19fa8..c0c8a49b92343851eac1b2395a1efd1e34be1914 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,4 +3,8 @@ raw_sf7_cross_instance/
 matlab/raw_sf7_cross_instance.zip
 *.pkl
 *.png
-*.pyc
\ No newline at end of file
+*.pyc
+*.log
+*.yaml
+*.json
+*wandb
\ No newline at end of file
diff --git a/neural_enhanced_demodulation/matlab/evaluation.m b/neural_enhanced_demodulation/matlab/evaluation.m
index 09901f6626f7a95361ad8c5f998c8614d9f4e77f..5e4bc242e56666915223d67ada43911552fb4dfe 100644
--- a/neural_enhanced_demodulation/matlab/evaluation.m
+++ b/neural_enhanced_demodulation/matlab/evaluation.m
@@ -8,40 +8,71 @@ set(fig,'DefaultAxesFontSize',20);
 set(fig,'DefaultAxesFontWeight','bold');
 
 data_root = '';
-color_list = linspecer(2);
+color_list = linspecer(4);
 BW=125000;
 SF=7;
 
-SNR_list=[-40:15];
 % nelora_file='evaluation/sf7_v1_';
-nelora_file='evaluation/sf7_125k_new_';
+% nelora_file='evaluation/sf7_125k_new_';
+% improved_nelora_file='evaluation/sf7_traineval5e-4_';
+nelora_file='evaluation/sf7_v4_';
+improved_nelora_file='evaluation/sf7_v2.3_maskcnn_no_mask_';
+prune_nelora_file='evaluation/prune_sf7_v1_';
+baseline_file='evaluation/baseline_error_matrix_';
+
 % nelora_file='evaluation/sf7_v3_';
+% nelora_file='evaluation/sf7_traineval5e-4_';
+SNR_list=[-40:15];
+SNR_improved_list=[-40:15];
+SNR_prone_list=[-40:15];
+SNR_list_baseline=[-30:0];
 
-SNR_list_baseline=-30:0;
-baseline_file='evaluation/baseline_error_matrix_';
+% improved nelora
+name_str=[improved_nelora_file,num2str(SF),'_',num2str(BW),'.mat'];
+error_path = [data_root,name_str];
+a2 = load(error_path);
+error_matrix = a2.error_matrix;
+error_rate = 1-error_matrix;
+plot(SNR_improved_list,error_rate,"-.*",'LineWidth',3,'color',color_list(3,:));
+hold on;
 
+% nelora
 name_str=[nelora_file,num2str(SF),'_',num2str(BW),'.mat'];
 error_path = [data_root,name_str];
 a = load(error_path);
 error_matrix = a.error_matrix;
-% error_matrix_info = a.error_matrix_info;
+error_rate = 1-error_matrix;
+plot(SNR_list,error_rate,"-.*",'LineWidth',3,'color',color_list(1,:));
+hold on;
 
-plot(SNR_list,1-error_matrix,"-.*",'LineWidth',3,'color',color_list(1,:));
+% prone
+name_str=[prune_nelora_file,num2str(SF),'_',num2str(BW),'.mat'];
+error_path = [data_root,name_str];
+a = load(error_path);
+error_matrix = a.error_matrix;
+error_rate = 1-error_matrix;
+plot(SNR_prone_list,error_rate,"-.*",'LineWidth',2,'color',color_list(4,:));
 hold on;
 
-name_str=[baseline_file,num2str(SF),'_',num2str(BW),'.mat'];
+% baseline
+name_str=[baseline_file,num2str(SF),'_',num2str(BW),'V1.mat'];
 error_path = [data_root,name_str];
 a = load(error_path);
 error_matrix = a.error_matrix;
 plot(SNR_list_baseline,1-error_matrix,"-.*",'LineWidth',2,'color',color_list(2,:));
 hold on;
 
-legend('NeLoRA','Baseline')
+% legend('Imroved NeLoRA', 'NeLoRA', 'Prone', 'Baseline')
+% legend('NeLoRA', 'Baseline')
+% legend('Imroved NeLoRA', 'NeLoRA', 'Baseline')
+% legend('Imroved NeLoRA', 'NeLoRA', 'Baseline')
+legend('Imroved NeLoRA', 'NeLoRA', 'Prone', 'Baseline')
 
 % legend('abs_baselineNELoRa','Baseline')
 xlabel('SNR (dB)'); % x label
 ylabel('SER'); % y label
 title('Decode SER for SF=7')
-xlim([-30,-10]);
+xlim([-20,-10]);
 set(gcf,'WindowStyle','normal','Position', [200,200,640,360]);
+grid on;
 saveas(gcf,[data_root,'res/',num2str(SF),'_',num2str(BW),'.pdf'])
diff --git a/neural_enhanced_demodulation/matlab/evaluation/sf7_v2.3_maskcnn_no_mask_7_125000.mat b/neural_enhanced_demodulation/matlab/evaluation/sf7_v2.3_maskcnn_no_mask_7_125000.mat
new file mode 100644
index 0000000000000000000000000000000000000000..b9500550d89e86ea0512af9e8ff6815d2b8d935d
Binary files /dev/null and b/neural_enhanced_demodulation/matlab/evaluation/sf7_v2.3_maskcnn_no_mask_7_125000.mat differ
diff --git a/neural_enhanced_demodulation/pytorch/None_7_125000.mat b/neural_enhanced_demodulation/pytorch/None_7_125000.mat
new file mode 100644
index 0000000000000000000000000000000000000000..4fb75fbec0c34a24a5e73980f1b6cf270e33dcd6
Binary files /dev/null and b/neural_enhanced_demodulation/pytorch/None_7_125000.mat differ
diff --git a/neural_enhanced_demodulation/pytorch/config.py b/neural_enhanced_demodulation/pytorch/config.py
index d4abb27fac5d559e5a8ba25e54efad2fff72b9bb..f387d928511de0d75e4bf3782c9ed6d73dcd8ff5 100644
--- a/neural_enhanced_demodulation/pytorch/config.py
+++ b/neural_enhanced_demodulation/pytorch/config.py
@@ -144,7 +144,7 @@ def create_parser():
         help='Choose the root path to rf signals.',
     )
 
-    parser.add_argument('--network', type=str, default='end2end', choices=['end2end', 'end2end_fig4', 'end2end_real'])
+    parser.add_argument('--network', type=str, default='end2end_debug', choices=['end2end', 'end2end_debug', 'end2end_fig4', 'end2end_real'])
 
     parser.add_argument(
         '--feature_name',
@@ -196,6 +196,11 @@ def create_parser():
         default='checkpoints'
     )
     parser.add_argument('--dir_comment', type=str, default='None')
+    parser.add_argument('--checkpoint_dir',
+                        type=str,
+                        default='checkpoints')
+    parser.add_argument('--dir_comment', type=str, default='add cos anealing')
+    parser.add_argument('--wandb_entity', type=str, default='jackthewizard')
     parser.add_argument('--sample_dir', type=str, default='samples')
     parser.add_argument('--testing_dir', type=str, default='testing')
     # parser.add_argument('--load', type=str, default='pre_trained')
diff --git a/neural_enhanced_demodulation/pytorch/end2end.py b/neural_enhanced_demodulation/pytorch/end2end.py
index 936cc2ce1da5145d8db455815b98bde8ba2ccf37..7d954b25e52d7675fe0a1fc03d6a7abaae782d58 100644
--- a/neural_enhanced_demodulation/pytorch/end2end.py
+++ b/neural_enhanced_demodulation/pytorch/end2end.py
@@ -194,6 +194,9 @@ def training_loop(
         * Saves checkpoint every opts.checkpoint_every iterations
         * Saves generated samples every opts.sample_every iterations
     """
+    import wandb;
+    wandb.init(project="CSE891", entity=opts.wandb_entity)
+    
     loss_spec = torch.nn.MSELoss(reduction='mean')
     loss_class = nn.CrossEntropyLoss()
     # Create generators and discriminators
@@ -285,6 +288,7 @@ def training_loop(
         G_Image_loss = opts.scaling_for_imaging_loss * g_y_pix_loss
         G_Class_loss = opts.scaling_for_classification_loss * g_y_class_loss
         G_Y_loss = G_Image_loss + G_Class_loss
+        wandb.log({"G_Y_loss": G_Y_loss, "G_Image_loss": G_Image_loss, "G_class_loss": G_Class_loss})
         G_Y_loss.backward()
         g_optimizer.step()
 
@@ -312,6 +316,7 @@ def training_loop(
         if iteration % opts.checkpoint_every == 0:
             checkpoint(iteration, mask_CNN, C_XtoY, opts)
 
+    wandb.finish()
     test_iter_X = iter(testing_dataloader_X)
     test_iter_Y = iter(testing_dataloader_Y)
     iter_per_epoch_test = min(len(test_iter_X), len(test_iter_Y))
diff --git a/neural_enhanced_demodulation/pytorch/end2end_improve.py b/neural_enhanced_demodulation/pytorch/end2end_improve.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc63ef27f7952c1e8a73789e86ab9f03775127b5
--- /dev/null
+++ b/neural_enhanced_demodulation/pytorch/end2end_improve.py
@@ -0,0 +1,408 @@
+# end2end_improve.py
+
+from __future__ import division
+import os
+
+import warnings
+
+warnings.filterwarnings("ignore")
+
+# Torch imports
+import torch
+import torch.fft
+import torch.nn as nn
+import torch.optim as optim
+
+# Numpy & Scipy imports
+import numpy as np
+import scipy.io
+
+import cv2
+# Local imports
+from utils import to_var, to_data, spec_to_network_input
+from models.model_components import maskCNNModel, DenoisingNet, Autoencoder, classificationHybridModel
+import torch.autograd.profiler as profiler
+import time
+
+SEED = 11
+
+# Set the random seed manually for reproducibility.
+np.random.seed(SEED)
+torch.manual_seed(SEED)
+if torch.cuda.is_available():
+    torch.cuda.manual_seed(SEED)
+
+
+def print_models(Model):
+    """Prints model information for the generators and discriminators.
+    """
+    print("                 Model                ")
+    print("---------------------------------------")
+    print(Model)
+    print("---------------------------------------")
+
+
+def create_model(opts):
+    """Builds the generators and discriminators.
+    """
+
+    maskCNN = maskCNNModel(opts)
+
+    C_XtoY = classificationHybridModel(conv_dim_in=opts.y_image_channel,
+                                       conv_dim_out=opts.n_classes,
+                                       conv_dim_lstm=opts.conv_dim_lstm)
+
+    if torch.cuda.is_available():
+        maskCNN.cuda()
+        C_XtoY.cuda()
+        print('Models moved to GPU.')
+
+    return maskCNN, C_XtoY
+
+
+def create_model2(opts):
+    """Builds the generators and discriminators.
+    """
+
+    denoisenet = Autoencoder()
+
+    C_XtoY = classificationHybridModel(conv_dim_in=opts.y_image_channel,
+                                       conv_dim_out=opts.n_classes,
+                                       conv_dim_lstm=opts.conv_dim_lstm)
+
+    if torch.cuda.is_available():
+        denoisenet.cuda()
+        C_XtoY.cuda()
+        print('Models moved to GPU.')
+
+    return denoisenet, C_XtoY
+
+
+def checkpoint(iteration, mask_CNN, C_XtoY, opts):
+    """Saves the parameters of both generators G_YtoX, G_XtoY and discriminators D_X, D_Y.
+    """
+
+    mask_CNN_path = os.path.join(opts.checkpoint_dir, str(iteration) + '_maskCNN.pkl')
+    torch.save(mask_CNN.state_dict(), mask_CNN_path)
+
+    C_XtoY_path = os.path.join(opts.checkpoint_dir, str(iteration) + '_C_XtoY.pkl')
+    torch.save(C_XtoY.state_dict(), C_XtoY_path)
+
+
+def load_checkpoint(opts):
+    """Loads the generator and discriminator models from checkpoints.
+    """
+
+    maskCNN_path = os.path.join(opts.checkpoint_dir, str(opts.load_iters) + '_maskCNN.pkl')
+
+    maskCNN = maskCNNModel(opts)
+
+    maskCNN.load_state_dict(torch.load(
+        maskCNN_path, map_location=lambda storage, loc: storage),
+        strict=False)
+
+    C_XtoY_path = os.path.join(opts.checkpoint_dir, str(opts.load_iters) + '_C_XtoY.pkl')
+
+    C_XtoY = classificationHybridModel(conv_dim_in=opts.x_image_channel,
+                                       conv_dim_out=opts.n_classes,
+                                       conv_dim_lstm=opts.conv_dim_lstm)
+
+    C_XtoY.load_state_dict(torch.load(
+        C_XtoY_path, map_location=lambda storage, loc: storage),
+        strict=False)
+
+    if torch.cuda.is_available():
+        maskCNN.cuda()
+        C_XtoY.cuda()
+        print('Models moved to GPU.')
+
+    return maskCNN, C_XtoY
+
+
+def merge_images(sources, targets, batch_size, image_channel):
+    """Creates a grid consisting of pairs of columns, where the first column in
+    each pair contains images source images and the second column in each pair
+    contains images generated by the CycleGAN from the corresponding images in
+    the first column.
+    """
+    _, _, h, w = sources.shape
+    row = int(np.sqrt(batch_size))
+    column = int(batch_size / row)
+    merged = np.zeros([image_channel, row * h, column * w * 2])
+    for idx, (s, t) in enumerate(zip(sources, targets)):
+        i = idx // column
+        j = idx % column
+        merged[:, i * h:(i + 1) * h, (j * 2) * w:(j * 2 + 1) * w] = s
+        merged[:, i * h:(i + 1) * h, (j * 2 + 1) * w:(j * 2 + 2) * w] = t
+    return merged.transpose(1, 2, 0)
+
+
+def save_samples(iteration, fixed_Y, fixed_X, mask_CNN, opts):
+    """Saves samples from both generators X->Y and Y->X.
+    """
+    fake_Y = mask_CNN(fixed_X)
+    fixed_X = to_data(fixed_X)
+
+    Y, fake_Y = to_data(fixed_Y), to_data(fake_Y)
+
+    merged = merge_images(fixed_X, fake_Y, opts.batch_size, opts.y_image_channel)
+
+    path = os.path.join(opts.sample_dir,
+                        'sample-{:06d}-Y.png'.format(iteration))
+    merged = np.abs(merged[:, :, 0] + 1j * merged[:, :, 1])
+    merged = (merged - np.amin(merged)) / (np.amax(merged) - np.amin(merged)) * 255
+    merged = cv2.flip(merged, 0)
+    cv2.imwrite(path, merged)
+    print('Saved {}'.format(path))
+
+
+def save_samples_separate(iteration, fixed_Y, fixed_X, mask_CNN, opts,
+                          name_X_test, labels_Y_test, saved_dir):
+    """Saves samples from both generators X->Y and Y->X.
+    """
+    fake_Y = mask_CNN(fixed_X)
+
+    fixed_Y, fake_Y, fixed_X = to_data(fixed_Y), to_data(fake_Y), to_data(fixed_X)
+
+    for batch_index in range(opts.batch_size):
+        if batch_index < len(name_X_test):
+            path_src = os.path.join(saved_dir, name_X_test[batch_index])
+            groundtruth_image = (
+                np.squeeze(fixed_Y[batch_index, :, :, :]).transpose(1, 2, 0))
+
+            groundtruth_image = np.abs(groundtruth_image[:, :, 0] + 1j * groundtruth_image[:, :, 1])
+            groundtruth_image = (groundtruth_image - np.amin(groundtruth_image)) / (
+                    np.amax(groundtruth_image) - np.amin(groundtruth_image)) * 255
+            cv2.imwrite(path_src + '_groundtruth_' + str(iteration) + '.png', groundtruth_image)
+
+            fake_image = (
+                np.squeeze(fake_Y[batch_index, :, :, :]).transpose(1, 2, 0))
+            fake_image = np.abs(fake_image[:, :, 0] + 1j * fake_image[:, :, 1])
+            fake_image = (fake_image - np.amin(fake_image)) / (np.amax(fake_image) - np.amin(fake_image)) * 255
+            cv2.imwrite(path_src + '_fake_' + str(iteration) + '.png', fake_image)
+
+            raw_image = (
+                np.squeeze(fixed_X[batch_index, :, :, :]).transpose(1, 2, 0))
+            raw_image = np.abs(raw_image[:, :, 0] + 1j * raw_image[:, :, 1])
+            raw_image = (raw_image - np.amin(raw_image)) / (np.amax(raw_image) - np.amin(raw_image)) * 255
+            cv2.imwrite(path_src + '_raw_' + str(iteration) + '.png', raw_image)
+
+    # print('Saved {}'.format(path))
+
+
+def training_loop(training_dataloader_X, training_dataloader_Y, testing_dataloader_X,
+                  testing_dataloader_Y, opts):
+    """Runs the training loop.
+        * Saves checkpoint every opts.checkpoint_every iterations
+        * Saves generated samples every opts.sample_every iterations
+    """
+    import wandb;
+    wandb.init(project="CSE891")
+    
+    loss_spec = torch.nn.MSELoss(reduction='mean')
+    loss_class = nn.CrossEntropyLoss()
+    # Create generators and discriminators
+    if opts.load:
+        mask_CNN, C_XtoY = load_checkpoint(opts)
+    else:
+        # mask_CNN, C_XtoY = create_model(opts)
+        mask_CNN, C_XtoY = create_model2(opts)
+
+    g_params = list(mask_CNN.parameters()) + list(C_XtoY.parameters())
+    g_optimizer = optim.Adam(g_params, opts.lr, [opts.beta1, opts.beta2])
+
+    iter_X = iter(training_dataloader_X)
+    iter_Y = iter(training_dataloader_Y)
+
+    test_iter_X = iter(testing_dataloader_X)
+    test_iter_Y = iter(testing_dataloader_Y)
+
+    # Get some fixed data from domains X and Y for sampling. These are images that are held
+    # constant throughout training, that allow us to inspect the model's performance.
+    fixed_X, name_X_fixed = next(test_iter_X)
+    fixed_X = to_var(fixed_X)
+
+    fixed_Y, name_Y_fixed = next(test_iter_Y)
+    fixed_Y = to_var(fixed_Y)
+    # print("Fixed_X {}".format(fixed_X.shape))
+    fixed_X_spectrum_raw = torch.stft(input=fixed_X, n_fft=opts.stft_nfft, hop_length=opts.stft_overlap,
+                                      win_length=opts.stft_window, pad_mode='constant');
+    fixed_X_spectrum = spec_to_network_input(fixed_X_spectrum_raw, opts)
+    # print("Fixed {}".format(fixed_X_spectrum.shape))
+
+    fixed_Y_spectrum_raw = torch.stft(input=fixed_Y, n_fft=opts.stft_nfft, hop_length=opts.stft_overlap,
+                                      win_length=opts.stft_window, pad_mode='constant');
+    fixed_Y_spectrum = spec_to_network_input(fixed_Y_spectrum_raw, opts)
+
+    iter_per_epoch = min(len(iter_X), len(iter_Y))
+
+    for iteration in range(1, opts.train_iters + 1):
+        if iteration % iter_per_epoch == 0:
+            iter_X = iter(training_dataloader_X)
+            iter_Y = iter(training_dataloader_Y)
+
+        # images_X, name_X = iter_X.next()
+        images_X, name_X = next(iter_X)
+        labels_X_mapping = list(
+            map(lambda x: int(x.split('_')[5]), name_X))
+        images_X, labels_X = to_var(images_X), to_var(
+            torch.tensor(labels_X_mapping))
+        # images_Y, name_Y = iter_Y.next()
+        images_Y, name_Y = next(iter_Y)
+        labels_Y_mapping = list(
+            map(lambda x: int(x.split('_')[5]), name_Y))
+        images_Y, labels_Y = to_var(images_Y), to_var(
+            torch.tensor(labels_Y_mapping))
+
+        # ============================================
+        #            TRAIN THE GENERATOR
+        # ============================================
+
+        images_X_spectrum_raw = torch.stft(input=images_X, n_fft=opts.stft_nfft, hop_length=opts.stft_overlap,
+
+                                           win_length=opts.stft_window, pad_mode='constant');
+        images_X_spectrum = spec_to_network_input(images_X_spectrum_raw, opts)
+
+        images_Y_spectrum_raw = torch.stft(input=images_Y, n_fft=opts.stft_nfft, hop_length=opts.stft_overlap,
+                                           win_length=opts.stft_window, pad_mode='constant');
+        images_Y_spectrum = spec_to_network_input(images_Y_spectrum_raw, opts)
+        
+        #########################################
+        ##    FILL THIS IN: X--Y               ##
+        #########################################
+        if iteration % 50 == 0:
+            print("Iteration: {}/{}".format(iteration, opts.train_iters))
+        mask_CNN.train()
+        C_XtoY.train()
+        fake_Y_spectrum = mask_CNN(images_X_spectrum) # B, 2, 128, 33
+        # 2. Compute the generator loss based on domain Y
+        g_y_pix_loss = loss_spec(fake_Y_spectrum, images_Y_spectrum) # Try other loss functions
+        labels_X_estimated = C_XtoY(fake_Y_spectrum) # B, 128
+        g_y_class_loss = loss_class(labels_X_estimated, labels_X) # Try other classification loss functions
+        g_optimizer.zero_grad()
+        G_Image_loss = opts.scaling_for_imaging_loss * g_y_pix_loss
+        G_Class_loss = opts.scaling_for_classification_loss * g_y_class_loss
+        G_Y_loss = G_Image_loss + G_Class_loss
+        G_Y_loss.backward()
+        g_optimizer.step()
+        current_lr = g_optimizer.param_groups[0]['lr']
+        wandb.log({"G_Y_loss": G_Y_loss, "G_Image_loss": G_Image_loss, "G_class_loss": G_Class_loss, 'lr': current_lr})
+
+        # Print the log info
+        if iteration % opts.log_step == 0:
+            print(
+                'Iteration [{:5d}/{:5d}] | G_Y_loss: {:6.4f}| G_Image_loss: {:6.4f}| G_Class_loss: {:6.4f}'
+                    .format(iteration, opts.train_iters,
+                            G_Y_loss.item(),
+                            G_Image_loss.item(),
+                            G_Class_loss.item()))
+
+        # Save the generated samples
+        if (iteration % opts.sample_every == 0) and (not opts.server):
+            # save_samples(iteration, fixed_Y_spectrum, fixed_X_spectrum, mask_CNN, opts)
+            save_samples_separate(iteration, fixed_Y_spectrum, fixed_X_spectrum,
+                                  mask_CNN, opts, name_X_fixed, name_Y_fixed, opts.sample_dir)
+
+        # Save the model parameters
+        if iteration % opts.checkpoint_every == 0:
+            checkpoint(iteration, mask_CNN, C_XtoY, opts)
+
+        ###########################################
+        ################## Testing ################
+        ###########################################
+        if iteration % opts.log_step == 0:
+            error_matrix, error_matrix_count, error_matrix_info, saved_data = testing(mask_CNN, C_XtoY, testing_dataloader_X, testing_dataloader_Y, opts)
+            wandb.log({"error_matrix": error_matrix[11:41].mean(), 'error_matrix_count': error_matrix_count.mean()})
+    wandb.finish()
+    
+    scipy.io.savemat(
+        opts.root_path + '/' + opts.dir_comment + '_' + str(opts.sf) + '_' + str(opts.bw) + '.mat',
+        dict(error_matrix=error_matrix,
+            error_matrix_count=error_matrix_count,
+            error_matrix_info=error_matrix_info))
+
+
+    with open('test.npy', 'wb') as f:
+        np.save(f, saved_data)
+        f.close()
+
+
+def testing(mask_CNN, C_XtoY, testing_dataloader_X,
+                  testing_dataloader_Y, opts):
+    mask_CNN.eval()
+    C_XtoY.eval()
+    test_iter_X = iter(testing_dataloader_X)
+    test_iter_Y = iter(testing_dataloader_Y)
+    # iter_per_epoch_test = min(len(test_iter_X), len(test_iter_Y))
+    iter_per_epoch_test = 3000
+
+    error_matrix = np.zeros([len(opts.snr_list), 1], dtype=float)
+    error_matrix_count = np.zeros([len(opts.snr_list), 1], dtype=int)
+
+    error_matrix_info = []
+
+    saved_data = {}
+    for iteration in range(iter_per_epoch_test):
+        # images_X_test, name_X_test = test_iter_X.next()
+        images_X_test, name_X_test = next(test_iter_X)
+
+        code_X_test_mapping = list(
+            map(lambda x: float(x.split('_')[0]), name_X_test))
+
+        snr_X_test_mapping = list(
+            map(lambda x: int(x.split('_')[1]), name_X_test))
+
+        instance_X_test_mapping = list(
+            map(lambda x: int(x.split('_')[4]), name_X_test))
+
+        labels_X_test_mapping = list(
+            map(lambda x: int(x.split('_')[5]), name_X_test))
+
+        images_X_test, labels_X_test = to_var(images_X_test), to_var(
+            torch.tensor(labels_X_test_mapping))
+
+        # images_Y_test, labels_Y_test = test_iter_Y.next()
+        images_Y_test, labels_Y_test = next(test_iter_Y)
+        images_Y_test = to_var(images_Y_test)
+
+        images_X_test_spectrum_raw = torch.stft(input=images_X_test, n_fft=opts.stft_nfft,
+                                                hop_length=opts.stft_overlap, win_length=opts.stft_window,
+                                                pad_mode='constant');
+        images_X_test_spectrum = spec_to_network_input(images_X_test_spectrum_raw, opts)
+
+        images_Y_test_spectrum_raw = torch.stft(input=images_Y_test, n_fft=opts.stft_nfft,
+                                                hop_length=opts.stft_overlap, win_length=opts.stft_window,
+                                                pad_mode='constant');
+        images_Y_test_spectrum = spec_to_network_input(images_Y_test_spectrum_raw, opts)
+        fake_Y_test_spectrum = mask_CNN(images_X_test_spectrum)
+        labels_X_estimated = C_XtoY(fake_Y_test_spectrum)
+        saved_sample = to_data(labels_X_estimated)
+
+        for i, label in enumerate(to_data(labels_X_test)):
+            if label not in saved_data.keys():
+                saved_data[label] = []
+                saved_data[label].append(saved_sample[i])
+            else:
+                saved_data[label].append(saved_sample[i])
+        _, labels_X_test_estimated = torch.max(labels_X_estimated, 1)
+
+        test_right_case = (labels_X_test_estimated == labels_X_test)
+        test_right_case = to_data(test_right_case)
+
+        for batch_index in range(opts.batch_size):
+            try:
+                snr_index = opts.snr_list.index(snr_X_test_mapping[batch_index])
+                error_matrix[snr_index] += test_right_case[batch_index]
+                error_matrix_count[snr_index] += 1
+                error_matrix_info.append([instance_X_test_mapping[batch_index], code_X_test_mapping[batch_index],
+                                          snr_X_test_mapping[batch_index],
+                                          labels_X_test_estimated[batch_index].cpu().data.int(),
+                                          labels_X_test[batch_index].cpu().data.int()])
+            except:
+                print("Something else went wrong")
+        if iteration % opts.log_step == 0:
+            print('Testing Iteration [{:5d}/{:5d}]'
+                  .format(iteration, iter_per_epoch_test))
+    error_matrix = np.divide(error_matrix, error_matrix_count+0.0001)
+    error_matrix_info = np.array(error_matrix_info)
+    return error_matrix, error_matrix_count, error_matrix_info, saved_data
diff --git a/neural_enhanced_demodulation/pytorch/main.py b/neural_enhanced_demodulation/pytorch/main.py
index 42528b85271740cb6f8a87da5e9e376ab05b63ed..d2608fc0d3df12d784f7115840ca960b7dae3a87 100644
--- a/neural_enhanced_demodulation/pytorch/main.py
+++ b/neural_enhanced_demodulation/pytorch/main.py
@@ -1,32 +1,26 @@
 """Main script for project."""
 from __future__ import print_function
-
-import os
-
+from utils import generate_dataset, create_dir, set_gpu, print_opts
 import config
 import datasets.data_loader as data_loader
-import end2end
-from prune import go_prune
-from utils import create_dir, generate_dataset, print_opts, set_gpu
+import end2end, end2end_improve
+import end2end, end2end_improve
+import os
 
 
 def main(opts):
     """Loads the data, creates checkpoint and sample directories, and starts the training loop.
     """
     [files_train, files_test
-     ] = generate_dataset(
-        opts.root_path, opts.data_dir, opts.ratio_bt_train_and_test,
-        opts.code_list, opts.snr_list, opts.bw_list, opts.sf_list,
-        opts.instance_list, opts.sorting_type
-    )
+     ] = generate_dataset(opts.root_path, opts.data_dir, opts.ratio_bt_train_and_test,
+                          opts.code_list, opts.snr_list, opts.bw_list, opts.sf_list,
+                          opts.instance_list, opts.sorting_type)
     # Create train and test dataloaders for images from the two domains X and Y
 
     training_dataloader_X, testing_dataloader_X = data_loader.lora_loader(
-        opts, files_train, files_test, False
-    )
+        opts, files_train, files_test, False)
     training_dataloader_Y, testing_dataloader_Y = data_loader.lora_loader(
-        opts, files_train, files_test, True
-    )
+        opts, files_train, files_test, True)
 
     # Create checkpoint and sample directories
     create_dir(opts.checkpoint_dir)
@@ -37,18 +31,13 @@ def main(opts):
     # Start training
     set_gpu(opts.free_gpu_id)
 
-    # if prune mode
-    if opts.prune:
-        go_prune(opts, testing_dataloader_X, testing_dataloader_Y)
-        return
-
     # select the model
     if opts.network == 'end2end':
-        end2end.training_loop(
-            training_dataloader_X, training_dataloader_Y, testing_dataloader_X,
-            testing_dataloader_Y, opts
-        )
-
+        end2end.training_loop(training_dataloader_X, training_dataloader_Y, testing_dataloader_X,
+                              testing_dataloader_Y, opts)
+    elif opts.network == 'end2end_improve':
+        end2end_debug.training_loop(training_dataloader_X, training_dataloader_Y, testing_dataloader_X,
+                              testing_dataloader_Y, opts)
 
 if __name__ == "__main__":
     parser = config.create_parser()
diff --git a/neural_enhanced_demodulation/pytorch/models/model_components.py b/neural_enhanced_demodulation/pytorch/models/model_components.py
index b1a963d17417a9e85981e00911d772300daffca0..9e9e4882f9c0ef0afbc0a0e819c51f04dcc89969 100644
--- a/neural_enhanced_demodulation/pytorch/models/model_components.py
+++ b/neural_enhanced_demodulation/pytorch/models/model_components.py
@@ -24,11 +24,20 @@ class classificationHybridModel(nn.Module):
         self.drop1 = nn.Dropout(0.2)
         self.drop2 = nn.Dropout(0.5)
         self.act = nn.ReLU()
-
-    def forward(self, x):
-        out = self.act(self.conv1(x))
-        out = self.pool1(out)
-        out = out.view(out.size(0), -1)
+        self._initialize_weights()
+
+    def _initialize_weights(self):
+        # import ipdb; ipdb.set_trace()
+        for module in self.modules():
+            if isinstance(module, nn.Linear):
+                nn.init.kaiming_normal_(module.weight, mode='fan_in')
+                if module.bias is not None:
+                    nn.init.constant_(module.bias, 0)
+                    
+    def forward(self, x): # 8, 2, 128, 33
+        out = self.act(self.conv1(x)) # 8, 16, 64, 17
+        out = self.pool1(out) # 8, 16, 32, 8
+        out = out.reshape(out.size(0), -1) # 
 
         out = self.act(self.dense(out))
         out = self.drop2(out)
@@ -85,7 +94,7 @@ class maskCNNModel(nn.Module):
             nn.BatchNorm2d(8), nn.ReLU(),
 
         )
-
+        
         self.lstm = nn.LSTM(
             opts.conv_dim_lstm,
             opts.lstm_dim,
@@ -95,21 +104,136 @@ class maskCNNModel(nn.Module):
 
         self.fc1 = nn.Linear(2 * opts.lstm_dim, opts.fc1_dim)
         self.fc2 = nn.Linear(opts.fc1_dim, opts.freq_size * opts.y_image_channel)
-
-    def forward(self, x):
+        self._initialize_weights() #initialize with kaiming initialization for all the linear layers
+
+    def _initialize_weights(self):
+        # import ipdb; ipdb.set_trace()
+        for module in self.modules():
+            if isinstance(module, nn.Linear):
+                nn.init.kaiming_normal_(module.weight, mode='fan_in')
+                if module.bias is not None:
+                    nn.init.constant_(module.bias, 0)
+                    
+    def forward(self, x): # x: 8, 2, 128, 33
         out = x.transpose(2, 3).contiguous()
-        out = self.conv(out)
-        out = out.transpose(1, 2).contiguous()
-        out = out.view(out.size(0), out.size(1), -1)
-        out, _ = self.lstm(out)
+        out = self.conv(out) # out: 8, 8, 33, 128
+        out = out.transpose(1, 2).contiguous()  # out: 8, 33, 8, 128
+        out = out.view(out.size(0), out.size(1), -1) # out: 8, 33, 1024
+        out, _ = self.lstm(out) # out: 8, 33, 800
         out = F.relu(out)
-        out = self.fc1(out)
+        out = self.fc1(out) # out: 8, 33, 600
         out = F.relu(out)
-        out = self.fc2(out)
-
-        out = out.view(out.size(0), out.size(1), self.opts.y_image_channel, -1)
-        out = torch.sigmoid(out)
-        out = out.transpose(1, 2).contiguous()
-        out = out.transpose(2, 3).contiguous()
-        masked = out * x  # out is mask, masked is denoised
+        out = self.fc2(out) # out: 8, 33, 256
+
+        out = out.view(out.size(0), out.size(1), self.opts.y_image_channel, -1) # out: 8, 33, 2, 128
+        # out = torch.sigmoid(out) # 
+        out = out.transpose(1, 2).contiguous() # out: 8, 2, 33, 128
+        out = out.transpose(2, 3).contiguous() # out: 8, 2, 128, 33
+        masked = out
+        # masked = out * x  # out is mask, masked is denoised
         return masked
+
+
+class ResidualBlock(nn.Module):
+    def __init__(self, channels):
+        super(ResidualBlock, self).__init__()
+        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
+        self.bn1 = nn.BatchNorm2d(channels)
+        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
+        self.bn2 = nn.BatchNorm2d(channels)
+
+    def forward(self, x):
+        identity = x
+
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.bn2(self.conv2(out))
+
+        out += identity
+        out = F.relu(out)
+
+        return out
+
+
+class AEResidualBlock(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
+        super(AEResidualBlock, self).__init__()
+        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
+        self.bn1 = nn.BatchNorm2d(out_channels)
+        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)
+        self.bn2 = nn.BatchNorm2d(out_channels)
+        self.relu = nn.ReLU(inplace=True)
+
+        self.skip = nn.Sequential()
+        if stride != 1 or in_channels != out_channels:
+            self.skip = nn.Sequential(
+                nn.Conv2d(in_channels, out_channels, 1, stride),
+                nn.BatchNorm2d(out_channels)
+            )
+
+    def forward(self, x):
+        identity = self.skip(x)
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+class DenoisingNet(nn.Module):
+    def __init__(self, opts=None):
+        super(DenoisingNet, self).__init__()
+        self.initial_conv = nn.Conv2d(2, 64, kernel_size=3, padding=1)
+        self.res_block1 = ResidualBlock(64)
+        self.res_block2 = ResidualBlock(64)
+        # self.res_block3 = ResidualBlock(64)
+        # self.res_block4 = ResidualBlock(64)
+        self.final_conv = nn.Conv2d(64, 2, kernel_size=3, padding=1)
+
+    def forward(self, x):
+        output = F.relu(self.initial_conv(x))
+        output = self.res_block1(output)
+        output = self.res_block2(output)
+        # output = self.res_block3(output)
+        # output = self.res_block4(output)
+        output = self.final_conv(output)
+        mask = output.sigmoid()
+        output = mask * x
+        return output
+
+class Autoencoder(nn.Module):
+    def __init__(self):
+        super(Autoencoder, self).__init__()
+
+        self.encoder = nn.Sequential(
+            AEResidualBlock(2, 32),
+            AEResidualBlock(64, 64),
+            AEResidualBlock(64, 128)
+        )
+        self.decoder = nn.Sequential(
+            AEResidualBlock(128, 64),
+            AEResidualBlock(64, 32),
+            AEResidualBlock(32, 2)
+        )
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.decoder(x)
+        return x
+
+if __name__ == '__main__':
+    # create mo
+    # model = DenoisingNet()
+    model = Autoencoder()
+
+    # Input data
+    batch_size = 8
+    dummy_input = torch.randn(batch_size, 2, 128, 33)
+
+    output = model(dummy_input)
+    print(output.shape)  # [5, 2, 128, 33]
diff --git a/neural_enhanced_demodulation/pytorch/prune_sf7_v1_7_125000.mat b/neural_enhanced_demodulation/pytorch/prune_sf7_v1_7_125000.mat
new file mode 100644
index 0000000000000000000000000000000000000000..2d4f0afd4582a668945848e1c3e8889f2398d549
Binary files /dev/null and b/neural_enhanced_demodulation/pytorch/prune_sf7_v1_7_125000.mat differ
diff --git a/neural_enhanced_demodulation/pytorch/sf7_v2.1_7_125000.mat b/neural_enhanced_demodulation/pytorch/sf7_v2.1_7_125000.mat
new file mode 100644
index 0000000000000000000000000000000000000000..383f170471ddfacbe21e109752c3461597ed4d4a
Binary files /dev/null and b/neural_enhanced_demodulation/pytorch/sf7_v2.1_7_125000.mat differ
diff --git a/neural_enhanced_demodulation/pytorch/sf7_v2.2_2resblks_mask_7_125000.mat b/neural_enhanced_demodulation/pytorch/sf7_v2.2_2resblks_mask_7_125000.mat
new file mode 100644
index 0000000000000000000000000000000000000000..2be0f9c41c4615a9775b289a27283ddb0ab1fd70
Binary files /dev/null and b/neural_enhanced_demodulation/pytorch/sf7_v2.2_2resblks_mask_7_125000.mat differ
diff --git a/neural_enhanced_demodulation/pytorch/sf7_v2.2_7_125000.mat b/neural_enhanced_demodulation/pytorch/sf7_v2.2_7_125000.mat
new file mode 100644
index 0000000000000000000000000000000000000000..99ae6d9830f9486e65b7ad7dff225a52b032e054
Binary files /dev/null and b/neural_enhanced_demodulation/pytorch/sf7_v2.2_7_125000.mat differ
diff --git a/neural_enhanced_demodulation/pytorch/sf7_v2.3_maskcnn_no_mask_7_125000.mat b/neural_enhanced_demodulation/pytorch/sf7_v2.3_maskcnn_no_mask_7_125000.mat
new file mode 100644
index 0000000000000000000000000000000000000000..b9500550d89e86ea0512af9e8ff6815d2b8d935d
Binary files /dev/null and b/neural_enhanced_demodulation/pytorch/sf7_v2.3_maskcnn_no_mask_7_125000.mat differ
diff --git a/neural_enhanced_demodulation/pytorch/sf7_v2.4_AE_7_125000.mat b/neural_enhanced_demodulation/pytorch/sf7_v2.4_AE_7_125000.mat
new file mode 100644
index 0000000000000000000000000000000000000000..6ed4cae766f5d0c058a05c209536a046272f6f50
Binary files /dev/null and b/neural_enhanced_demodulation/pytorch/sf7_v2.4_AE_7_125000.mat differ
diff --git a/neural_enhanced_demodulation/pytorch/sf7_v2_7_125000.mat b/neural_enhanced_demodulation/pytorch/sf7_v2_7_125000.mat
new file mode 100644
index 0000000000000000000000000000000000000000..6512a078d03897feb2f11bd76fe14a2c2bb1a661
Binary files /dev/null and b/neural_enhanced_demodulation/pytorch/sf7_v2_7_125000.mat differ
diff --git a/neural_enhanced_demodulation/pytorch/shs/train.sh b/neural_enhanced_demodulation/pytorch/shs/train.sh
index 9a5f5dba05d537fde9703c50934eccf7a1d4619b..de27d711883d449a2913d265dd4b1ee36424bf1a 100644
--- a/neural_enhanced_demodulation/pytorch/shs/train.sh
+++ b/neural_enhanced_demodulation/pytorch/shs/train.sh
@@ -1,10 +1,14 @@
-python main.py --dir_comment sf7_v1 \
+
+python main.py --dir_comment sf7_v2.4_AE \
                --batch_size 16 \
+               --lr 0.0005 \
                --root_path . \
                --data_dir /home/liyifa11/MyCodes/AIoT891/project1/sf7_125k \
                --groundtruth_code 35 \
                --normalization \
                --train_iter 100000 \
+               --log_step 2000 \
                --ratio_bt_train_and_test 0.8 \
-               --network end2end
+               --network end2end_debug \
+               --free_gpu_id 1
 
diff --git a/neural_enhanced_demodulation/pytorch/test.npy b/neural_enhanced_demodulation/pytorch/test.npy
index e29e468f9d2d89f14c5067ea8fd694fcc3733207..304b1bf5603868822917c944e05b8e2a6d068391 100644
Binary files a/neural_enhanced_demodulation/pytorch/test.npy and b/neural_enhanced_demodulation/pytorch/test.npy differ