Skip to content
Snippets Groups Projects

Main

Closed liyifa11 requested to merge main into dev
1 file
+ 25
5
Compare changes
  • Side-by-side
  • Inline
# end2end.py
# end2end_debug.py
from __future__ import division
import os
@@ -20,7 +20,7 @@ import scipy.io
import cv2
# Local imports
from utils import to_var, to_data, spec_to_network_input
from models.model_components import maskCNNModel, classificationHybridModel
from models.model_components import maskCNNModel, DenoisingNet, Autoencoder, classificationHybridModel
import torch.autograd.profiler as profiler
import time
@@ -60,6 +60,24 @@ def create_model(opts):
return maskCNN, C_XtoY
def create_model2(opts):
"""Builds the generators and discriminators.
"""
denoisenet = Autoencoder()
C_XtoY = classificationHybridModel(conv_dim_in=opts.y_image_channel,
conv_dim_out=opts.n_classes,
conv_dim_lstm=opts.conv_dim_lstm)
if torch.cuda.is_available():
denoisenet.cuda()
C_XtoY.cuda()
print('Models moved to GPU.')
return denoisenet, C_XtoY
def checkpoint(iteration, mask_CNN, C_XtoY, opts):
"""Saves the parameters of both generators G_YtoX, G_XtoY and discriminators D_X, D_Y.
"""
@@ -187,7 +205,8 @@ def training_loop(training_dataloader_X, training_dataloader_Y, testing_dataload
if opts.load:
mask_CNN, C_XtoY = load_checkpoint(opts)
else:
mask_CNN, C_XtoY = create_model(opts)
# mask_CNN, C_XtoY = create_model(opts)
mask_CNN, C_XtoY = create_model2(opts)
g_params = list(mask_CNN.parameters()) + list(C_XtoY.parameters())
g_optimizer = optim.Adam(g_params, opts.lr, [opts.beta1, opts.beta2])
@@ -266,7 +285,8 @@ def training_loop(training_dataloader_X, training_dataloader_Y, testing_dataload
G_Y_loss = G_Image_loss + G_Class_loss
G_Y_loss.backward()
g_optimizer.step()
wandb.log({"G_Y_loss": G_Y_loss, "G_Image_loss": G_Image_loss, "G_class_loss": G_Class_loss})
current_lr = g_optimizer.param_groups[0]['lr']
wandb.log({"G_Y_loss": G_Y_loss, "G_Image_loss": G_Image_loss, "G_class_loss": G_Class_loss, 'lr': current_lr})
# Print the log info
if iteration % opts.log_step == 0:
@@ -385,4 +405,4 @@ def testing(mask_CNN, C_XtoY, testing_dataloader_X,
.format(iteration, iter_per_epoch_test))
error_matrix = np.divide(error_matrix, error_matrix_count+0.0001)
error_matrix_info = np.array(error_matrix_info)
return error_matrix, error_matrix_count, error_matrix_info, saved_data
\ No newline at end of file
return error_matrix, error_matrix_count, error_matrix_info, saved_data
Loading