Skip to content
Snippets Groups Projects

Main

Closed liyifa11 requested to merge main into dev
1 file
+ 10
26
Compare changes
  • Side-by-side
  • Inline
"""Main script for project."""
from __future__ import print_function
import os
from utils import generate_dataset, create_dir, set_gpu, print_opts
import config
import datasets.data_loader as data_loader
import end2end, end2end_debug
from prune import go_prune
from utils import create_dir, generate_dataset, print_opts, set_gpu
import end2end, end2end_improve
import os
@@ -15,19 +11,15 @@ def main(opts):
"""Loads the data, creates checkpoint and sample directories, and starts the training loop.
"""
[files_train, files_test
] = generate_dataset(
opts.root_path, opts.data_dir, opts.ratio_bt_train_and_test,
opts.code_list, opts.snr_list, opts.bw_list, opts.sf_list,
opts.instance_list, opts.sorting_type
)
] = generate_dataset(opts.root_path, opts.data_dir, opts.ratio_bt_train_and_test,
opts.code_list, opts.snr_list, opts.bw_list, opts.sf_list,
opts.instance_list, opts.sorting_type)
# Create train and test dataloaders for images from the two domains X and Y
training_dataloader_X, testing_dataloader_X = data_loader.lora_loader(
opts, files_train, files_test, False
)
opts, files_train, files_test, False)
training_dataloader_Y, testing_dataloader_Y = data_loader.lora_loader(
opts, files_train, files_test, True
)
opts, files_train, files_test, True)
# Create checkpoint and sample directories
create_dir(opts.checkpoint_dir)
@@ -38,19 +30,11 @@ def main(opts):
# Start training
set_gpu(opts.free_gpu_id)
# if prune mode
if opts.prune:
go_prune(opts, testing_dataloader_X, testing_dataloader_Y)
return
# select the model
if opts.network == 'end2end':
end2end.training_loop(
training_dataloader_X, training_dataloader_Y, testing_dataloader_X,
testing_dataloader_Y, opts
)
elif opts.network == 'end2end_debug':
end2end.training_loop(training_dataloader_X, training_dataloader_Y, testing_dataloader_X,
testing_dataloader_Y, opts)
elif opts.network == 'end2end_improve':
end2end_debug.training_loop(training_dataloader_X, training_dataloader_Y, testing_dataloader_X,
testing_dataloader_Y, opts)
Loading