Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
CSE891 NELoRA WHL
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
liyifa11
CSE891 NELoRA WHL
Merge requests
!5
Main
Code
Review changes
Check out branch
Download
Patches
Plain diff
Closed
Main
main
into
dev
Overview
0
Commits
17
Pipelines
0
Changes
1
Closed
liyifa11
requested to merge
main
into
dev
1 year ago
Overview
0
Commits
17
Pipelines
0
Changes
1
Expand
Merge dev to main
0
0
Merge request reports
Viewing commit
74d4be14
Prev
Next
Show latest version
1 file
+
408
−
0
Inline
Compare changes
Side-by-side
Inline
Show whitespace changes
Show one file at a time
74d4be14
Add new file
· 74d4be14
liyifa11
authored
1 year ago
neural_enhanced_demodulation/pytorch/end2end_improve.py
0 → 100644
+
408
−
0
Options
# end2end_improve.py
from
__future__
import
division
import
os
import
warnings
warnings
.
filterwarnings
(
"
ignore
"
)
# Torch imports
import
torch
import
torch.fft
import
torch.nn
as
nn
import
torch.optim
as
optim
# Numpy & Scipy imports
import
numpy
as
np
import
scipy.io
import
cv2
# Local imports
from
utils
import
to_var
,
to_data
,
spec_to_network_input
from
models.model_components
import
maskCNNModel
,
DenoisingNet
,
Autoencoder
,
classificationHybridModel
import
torch.autograd.profiler
as
profiler
import
time
SEED
=
11
# Set the random seed manually for reproducibility.
np
.
random
.
seed
(
SEED
)
torch
.
manual_seed
(
SEED
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
SEED
)
def
print_models
(
Model
):
"""
Prints model information for the generators and discriminators.
"""
print
(
"
Model
"
)
print
(
"
---------------------------------------
"
)
print
(
Model
)
print
(
"
---------------------------------------
"
)
def
create_model
(
opts
):
"""
Builds the generators and discriminators.
"""
maskCNN
=
maskCNNModel
(
opts
)
C_XtoY
=
classificationHybridModel
(
conv_dim_in
=
opts
.
y_image_channel
,
conv_dim_out
=
opts
.
n_classes
,
conv_dim_lstm
=
opts
.
conv_dim_lstm
)
if
torch
.
cuda
.
is_available
():
maskCNN
.
cuda
()
C_XtoY
.
cuda
()
print
(
'
Models moved to GPU.
'
)
return
maskCNN
,
C_XtoY
def
create_model2
(
opts
):
"""
Builds the generators and discriminators.
"""
denoisenet
=
Autoencoder
()
C_XtoY
=
classificationHybridModel
(
conv_dim_in
=
opts
.
y_image_channel
,
conv_dim_out
=
opts
.
n_classes
,
conv_dim_lstm
=
opts
.
conv_dim_lstm
)
if
torch
.
cuda
.
is_available
():
denoisenet
.
cuda
()
C_XtoY
.
cuda
()
print
(
'
Models moved to GPU.
'
)
return
denoisenet
,
C_XtoY
def
checkpoint
(
iteration
,
mask_CNN
,
C_XtoY
,
opts
):
"""
Saves the parameters of both generators G_YtoX, G_XtoY and discriminators D_X, D_Y.
"""
mask_CNN_path
=
os
.
path
.
join
(
opts
.
checkpoint_dir
,
str
(
iteration
)
+
'
_maskCNN.pkl
'
)
torch
.
save
(
mask_CNN
.
state_dict
(),
mask_CNN_path
)
C_XtoY_path
=
os
.
path
.
join
(
opts
.
checkpoint_dir
,
str
(
iteration
)
+
'
_C_XtoY.pkl
'
)
torch
.
save
(
C_XtoY
.
state_dict
(),
C_XtoY_path
)
def
load_checkpoint
(
opts
):
"""
Loads the generator and discriminator models from checkpoints.
"""
maskCNN_path
=
os
.
path
.
join
(
opts
.
checkpoint_dir
,
str
(
opts
.
load_iters
)
+
'
_maskCNN.pkl
'
)
maskCNN
=
maskCNNModel
(
opts
)
maskCNN
.
load_state_dict
(
torch
.
load
(
maskCNN_path
,
map_location
=
lambda
storage
,
loc
:
storage
),
strict
=
False
)
C_XtoY_path
=
os
.
path
.
join
(
opts
.
checkpoint_dir
,
str
(
opts
.
load_iters
)
+
'
_C_XtoY.pkl
'
)
C_XtoY
=
classificationHybridModel
(
conv_dim_in
=
opts
.
x_image_channel
,
conv_dim_out
=
opts
.
n_classes
,
conv_dim_lstm
=
opts
.
conv_dim_lstm
)
C_XtoY
.
load_state_dict
(
torch
.
load
(
C_XtoY_path
,
map_location
=
lambda
storage
,
loc
:
storage
),
strict
=
False
)
if
torch
.
cuda
.
is_available
():
maskCNN
.
cuda
()
C_XtoY
.
cuda
()
print
(
'
Models moved to GPU.
'
)
return
maskCNN
,
C_XtoY
def
merge_images
(
sources
,
targets
,
batch_size
,
image_channel
):
"""
Creates a grid consisting of pairs of columns, where the first column in
each pair contains images source images and the second column in each pair
contains images generated by the CycleGAN from the corresponding images in
the first column.
"""
_
,
_
,
h
,
w
=
sources
.
shape
row
=
int
(
np
.
sqrt
(
batch_size
))
column
=
int
(
batch_size
/
row
)
merged
=
np
.
zeros
([
image_channel
,
row
*
h
,
column
*
w
*
2
])
for
idx
,
(
s
,
t
)
in
enumerate
(
zip
(
sources
,
targets
)):
i
=
idx
//
column
j
=
idx
%
column
merged
[:,
i
*
h
:(
i
+
1
)
*
h
,
(
j
*
2
)
*
w
:(
j
*
2
+
1
)
*
w
]
=
s
merged
[:,
i
*
h
:(
i
+
1
)
*
h
,
(
j
*
2
+
1
)
*
w
:(
j
*
2
+
2
)
*
w
]
=
t
return
merged
.
transpose
(
1
,
2
,
0
)
def
save_samples
(
iteration
,
fixed_Y
,
fixed_X
,
mask_CNN
,
opts
):
"""
Saves samples from both generators X->Y and Y->X.
"""
fake_Y
=
mask_CNN
(
fixed_X
)
fixed_X
=
to_data
(
fixed_X
)
Y
,
fake_Y
=
to_data
(
fixed_Y
),
to_data
(
fake_Y
)
merged
=
merge_images
(
fixed_X
,
fake_Y
,
opts
.
batch_size
,
opts
.
y_image_channel
)
path
=
os
.
path
.
join
(
opts
.
sample_dir
,
'
sample-{:06d}-Y.png
'
.
format
(
iteration
))
merged
=
np
.
abs
(
merged
[:,
:,
0
]
+
1j
*
merged
[:,
:,
1
])
merged
=
(
merged
-
np
.
amin
(
merged
))
/
(
np
.
amax
(
merged
)
-
np
.
amin
(
merged
))
*
255
merged
=
cv2
.
flip
(
merged
,
0
)
cv2
.
imwrite
(
path
,
merged
)
print
(
'
Saved {}
'
.
format
(
path
))
def
save_samples_separate
(
iteration
,
fixed_Y
,
fixed_X
,
mask_CNN
,
opts
,
name_X_test
,
labels_Y_test
,
saved_dir
):
"""
Saves samples from both generators X->Y and Y->X.
"""
fake_Y
=
mask_CNN
(
fixed_X
)
fixed_Y
,
fake_Y
,
fixed_X
=
to_data
(
fixed_Y
),
to_data
(
fake_Y
),
to_data
(
fixed_X
)
for
batch_index
in
range
(
opts
.
batch_size
):
if
batch_index
<
len
(
name_X_test
):
path_src
=
os
.
path
.
join
(
saved_dir
,
name_X_test
[
batch_index
])
groundtruth_image
=
(
np
.
squeeze
(
fixed_Y
[
batch_index
,
:,
:,
:]).
transpose
(
1
,
2
,
0
))
groundtruth_image
=
np
.
abs
(
groundtruth_image
[:,
:,
0
]
+
1j
*
groundtruth_image
[:,
:,
1
])
groundtruth_image
=
(
groundtruth_image
-
np
.
amin
(
groundtruth_image
))
/
(
np
.
amax
(
groundtruth_image
)
-
np
.
amin
(
groundtruth_image
))
*
255
cv2
.
imwrite
(
path_src
+
'
_groundtruth_
'
+
str
(
iteration
)
+
'
.png
'
,
groundtruth_image
)
fake_image
=
(
np
.
squeeze
(
fake_Y
[
batch_index
,
:,
:,
:]).
transpose
(
1
,
2
,
0
))
fake_image
=
np
.
abs
(
fake_image
[:,
:,
0
]
+
1j
*
fake_image
[:,
:,
1
])
fake_image
=
(
fake_image
-
np
.
amin
(
fake_image
))
/
(
np
.
amax
(
fake_image
)
-
np
.
amin
(
fake_image
))
*
255
cv2
.
imwrite
(
path_src
+
'
_fake_
'
+
str
(
iteration
)
+
'
.png
'
,
fake_image
)
raw_image
=
(
np
.
squeeze
(
fixed_X
[
batch_index
,
:,
:,
:]).
transpose
(
1
,
2
,
0
))
raw_image
=
np
.
abs
(
raw_image
[:,
:,
0
]
+
1j
*
raw_image
[:,
:,
1
])
raw_image
=
(
raw_image
-
np
.
amin
(
raw_image
))
/
(
np
.
amax
(
raw_image
)
-
np
.
amin
(
raw_image
))
*
255
cv2
.
imwrite
(
path_src
+
'
_raw_
'
+
str
(
iteration
)
+
'
.png
'
,
raw_image
)
# print('Saved {}'.format(path))
def
training_loop
(
training_dataloader_X
,
training_dataloader_Y
,
testing_dataloader_X
,
testing_dataloader_Y
,
opts
):
"""
Runs the training loop.
* Saves checkpoint every opts.checkpoint_every iterations
* Saves generated samples every opts.sample_every iterations
"""
import
wandb
;
wandb
.
init
(
project
=
"
CSE891
"
)
loss_spec
=
torch
.
nn
.
MSELoss
(
reduction
=
'
mean
'
)
loss_class
=
nn
.
CrossEntropyLoss
()
# Create generators and discriminators
if
opts
.
load
:
mask_CNN
,
C_XtoY
=
load_checkpoint
(
opts
)
else
:
# mask_CNN, C_XtoY = create_model(opts)
mask_CNN
,
C_XtoY
=
create_model2
(
opts
)
g_params
=
list
(
mask_CNN
.
parameters
())
+
list
(
C_XtoY
.
parameters
())
g_optimizer
=
optim
.
Adam
(
g_params
,
opts
.
lr
,
[
opts
.
beta1
,
opts
.
beta2
])
iter_X
=
iter
(
training_dataloader_X
)
iter_Y
=
iter
(
training_dataloader_Y
)
test_iter_X
=
iter
(
testing_dataloader_X
)
test_iter_Y
=
iter
(
testing_dataloader_Y
)
# Get some fixed data from domains X and Y for sampling. These are images that are held
# constant throughout training, that allow us to inspect the model's performance.
fixed_X
,
name_X_fixed
=
next
(
test_iter_X
)
fixed_X
=
to_var
(
fixed_X
)
fixed_Y
,
name_Y_fixed
=
next
(
test_iter_Y
)
fixed_Y
=
to_var
(
fixed_Y
)
# print("Fixed_X {}".format(fixed_X.shape))
fixed_X_spectrum_raw
=
torch
.
stft
(
input
=
fixed_X
,
n_fft
=
opts
.
stft_nfft
,
hop_length
=
opts
.
stft_overlap
,
win_length
=
opts
.
stft_window
,
pad_mode
=
'
constant
'
);
fixed_X_spectrum
=
spec_to_network_input
(
fixed_X_spectrum_raw
,
opts
)
# print("Fixed {}".format(fixed_X_spectrum.shape))
fixed_Y_spectrum_raw
=
torch
.
stft
(
input
=
fixed_Y
,
n_fft
=
opts
.
stft_nfft
,
hop_length
=
opts
.
stft_overlap
,
win_length
=
opts
.
stft_window
,
pad_mode
=
'
constant
'
);
fixed_Y_spectrum
=
spec_to_network_input
(
fixed_Y_spectrum_raw
,
opts
)
iter_per_epoch
=
min
(
len
(
iter_X
),
len
(
iter_Y
))
for
iteration
in
range
(
1
,
opts
.
train_iters
+
1
):
if
iteration
%
iter_per_epoch
==
0
:
iter_X
=
iter
(
training_dataloader_X
)
iter_Y
=
iter
(
training_dataloader_Y
)
# images_X, name_X = iter_X.next()
images_X
,
name_X
=
next
(
iter_X
)
labels_X_mapping
=
list
(
map
(
lambda
x
:
int
(
x
.
split
(
'
_
'
)[
5
]),
name_X
))
images_X
,
labels_X
=
to_var
(
images_X
),
to_var
(
torch
.
tensor
(
labels_X_mapping
))
# images_Y, name_Y = iter_Y.next()
images_Y
,
name_Y
=
next
(
iter_Y
)
labels_Y_mapping
=
list
(
map
(
lambda
x
:
int
(
x
.
split
(
'
_
'
)[
5
]),
name_Y
))
images_Y
,
labels_Y
=
to_var
(
images_Y
),
to_var
(
torch
.
tensor
(
labels_Y_mapping
))
# ============================================
# TRAIN THE GENERATOR
# ============================================
images_X_spectrum_raw
=
torch
.
stft
(
input
=
images_X
,
n_fft
=
opts
.
stft_nfft
,
hop_length
=
opts
.
stft_overlap
,
win_length
=
opts
.
stft_window
,
pad_mode
=
'
constant
'
);
images_X_spectrum
=
spec_to_network_input
(
images_X_spectrum_raw
,
opts
)
images_Y_spectrum_raw
=
torch
.
stft
(
input
=
images_Y
,
n_fft
=
opts
.
stft_nfft
,
hop_length
=
opts
.
stft_overlap
,
win_length
=
opts
.
stft_window
,
pad_mode
=
'
constant
'
);
images_Y_spectrum
=
spec_to_network_input
(
images_Y_spectrum_raw
,
opts
)
#########################################
## FILL THIS IN: X--Y ##
#########################################
if
iteration
%
50
==
0
:
print
(
"
Iteration: {}/{}
"
.
format
(
iteration
,
opts
.
train_iters
))
mask_CNN
.
train
()
C_XtoY
.
train
()
fake_Y_spectrum
=
mask_CNN
(
images_X_spectrum
)
# B, 2, 128, 33
# 2. Compute the generator loss based on domain Y
g_y_pix_loss
=
loss_spec
(
fake_Y_spectrum
,
images_Y_spectrum
)
# Try other loss functions
labels_X_estimated
=
C_XtoY
(
fake_Y_spectrum
)
# B, 128
g_y_class_loss
=
loss_class
(
labels_X_estimated
,
labels_X
)
# Try other classification loss functions
g_optimizer
.
zero_grad
()
G_Image_loss
=
opts
.
scaling_for_imaging_loss
*
g_y_pix_loss
G_Class_loss
=
opts
.
scaling_for_classification_loss
*
g_y_class_loss
G_Y_loss
=
G_Image_loss
+
G_Class_loss
G_Y_loss
.
backward
()
g_optimizer
.
step
()
current_lr
=
g_optimizer
.
param_groups
[
0
][
'
lr
'
]
wandb
.
log
({
"
G_Y_loss
"
:
G_Y_loss
,
"
G_Image_loss
"
:
G_Image_loss
,
"
G_class_loss
"
:
G_Class_loss
,
'
lr
'
:
current_lr
})
# Print the log info
if
iteration
%
opts
.
log_step
==
0
:
print
(
'
Iteration [{:5d}/{:5d}] | G_Y_loss: {:6.4f}| G_Image_loss: {:6.4f}| G_Class_loss: {:6.4f}
'
.
format
(
iteration
,
opts
.
train_iters
,
G_Y_loss
.
item
(),
G_Image_loss
.
item
(),
G_Class_loss
.
item
()))
# Save the generated samples
if
(
iteration
%
opts
.
sample_every
==
0
)
and
(
not
opts
.
server
):
# save_samples(iteration, fixed_Y_spectrum, fixed_X_spectrum, mask_CNN, opts)
save_samples_separate
(
iteration
,
fixed_Y_spectrum
,
fixed_X_spectrum
,
mask_CNN
,
opts
,
name_X_fixed
,
name_Y_fixed
,
opts
.
sample_dir
)
# Save the model parameters
if
iteration
%
opts
.
checkpoint_every
==
0
:
checkpoint
(
iteration
,
mask_CNN
,
C_XtoY
,
opts
)
###########################################
################## Testing ################
###########################################
if
iteration
%
opts
.
log_step
==
0
:
error_matrix
,
error_matrix_count
,
error_matrix_info
,
saved_data
=
testing
(
mask_CNN
,
C_XtoY
,
testing_dataloader_X
,
testing_dataloader_Y
,
opts
)
wandb
.
log
({
"
error_matrix
"
:
error_matrix
[
11
:
41
].
mean
(),
'
error_matrix_count
'
:
error_matrix_count
.
mean
()})
wandb
.
finish
()
scipy
.
io
.
savemat
(
opts
.
root_path
+
'
/
'
+
opts
.
dir_comment
+
'
_
'
+
str
(
opts
.
sf
)
+
'
_
'
+
str
(
opts
.
bw
)
+
'
.mat
'
,
dict
(
error_matrix
=
error_matrix
,
error_matrix_count
=
error_matrix_count
,
error_matrix_info
=
error_matrix_info
))
with
open
(
'
test.npy
'
,
'
wb
'
)
as
f
:
np
.
save
(
f
,
saved_data
)
f
.
close
()
def
testing
(
mask_CNN
,
C_XtoY
,
testing_dataloader_X
,
testing_dataloader_Y
,
opts
):
mask_CNN
.
eval
()
C_XtoY
.
eval
()
test_iter_X
=
iter
(
testing_dataloader_X
)
test_iter_Y
=
iter
(
testing_dataloader_Y
)
# iter_per_epoch_test = min(len(test_iter_X), len(test_iter_Y))
iter_per_epoch_test
=
3000
error_matrix
=
np
.
zeros
([
len
(
opts
.
snr_list
),
1
],
dtype
=
float
)
error_matrix_count
=
np
.
zeros
([
len
(
opts
.
snr_list
),
1
],
dtype
=
int
)
error_matrix_info
=
[]
saved_data
=
{}
for
iteration
in
range
(
iter_per_epoch_test
):
# images_X_test, name_X_test = test_iter_X.next()
images_X_test
,
name_X_test
=
next
(
test_iter_X
)
code_X_test_mapping
=
list
(
map
(
lambda
x
:
float
(
x
.
split
(
'
_
'
)[
0
]),
name_X_test
))
snr_X_test_mapping
=
list
(
map
(
lambda
x
:
int
(
x
.
split
(
'
_
'
)[
1
]),
name_X_test
))
instance_X_test_mapping
=
list
(
map
(
lambda
x
:
int
(
x
.
split
(
'
_
'
)[
4
]),
name_X_test
))
labels_X_test_mapping
=
list
(
map
(
lambda
x
:
int
(
x
.
split
(
'
_
'
)[
5
]),
name_X_test
))
images_X_test
,
labels_X_test
=
to_var
(
images_X_test
),
to_var
(
torch
.
tensor
(
labels_X_test_mapping
))
# images_Y_test, labels_Y_test = test_iter_Y.next()
images_Y_test
,
labels_Y_test
=
next
(
test_iter_Y
)
images_Y_test
=
to_var
(
images_Y_test
)
images_X_test_spectrum_raw
=
torch
.
stft
(
input
=
images_X_test
,
n_fft
=
opts
.
stft_nfft
,
hop_length
=
opts
.
stft_overlap
,
win_length
=
opts
.
stft_window
,
pad_mode
=
'
constant
'
);
images_X_test_spectrum
=
spec_to_network_input
(
images_X_test_spectrum_raw
,
opts
)
images_Y_test_spectrum_raw
=
torch
.
stft
(
input
=
images_Y_test
,
n_fft
=
opts
.
stft_nfft
,
hop_length
=
opts
.
stft_overlap
,
win_length
=
opts
.
stft_window
,
pad_mode
=
'
constant
'
);
images_Y_test_spectrum
=
spec_to_network_input
(
images_Y_test_spectrum_raw
,
opts
)
fake_Y_test_spectrum
=
mask_CNN
(
images_X_test_spectrum
)
labels_X_estimated
=
C_XtoY
(
fake_Y_test_spectrum
)
saved_sample
=
to_data
(
labels_X_estimated
)
for
i
,
label
in
enumerate
(
to_data
(
labels_X_test
)):
if
label
not
in
saved_data
.
keys
():
saved_data
[
label
]
=
[]
saved_data
[
label
].
append
(
saved_sample
[
i
])
else
:
saved_data
[
label
].
append
(
saved_sample
[
i
])
_
,
labels_X_test_estimated
=
torch
.
max
(
labels_X_estimated
,
1
)
test_right_case
=
(
labels_X_test_estimated
==
labels_X_test
)
test_right_case
=
to_data
(
test_right_case
)
for
batch_index
in
range
(
opts
.
batch_size
):
try
:
snr_index
=
opts
.
snr_list
.
index
(
snr_X_test_mapping
[
batch_index
])
error_matrix
[
snr_index
]
+=
test_right_case
[
batch_index
]
error_matrix_count
[
snr_index
]
+=
1
error_matrix_info
.
append
([
instance_X_test_mapping
[
batch_index
],
code_X_test_mapping
[
batch_index
],
snr_X_test_mapping
[
batch_index
],
labels_X_test_estimated
[
batch_index
].
cpu
().
data
.
int
(),
labels_X_test
[
batch_index
].
cpu
().
data
.
int
()])
except
:
print
(
"
Something else went wrong
"
)
if
iteration
%
opts
.
log_step
==
0
:
print
(
'
Testing Iteration [{:5d}/{:5d}]
'
.
format
(
iteration
,
iter_per_epoch_test
))
error_matrix
=
np
.
divide
(
error_matrix
,
error_matrix_count
+
0.0001
)
error_matrix_info
=
np
.
array
(
error_matrix_info
)
return
error_matrix
,
error_matrix_count
,
error_matrix_info
,
saved_data
Loading