diff --git a/codes/train.py b/codes/train.py index d2e2b973..d66e222e 100644 --- a/codes/train.py +++ b/codes/train.py @@ -3,6 +3,8 @@ import math import argparse import random import logging +import shutil + from tqdm import tqdm import torch @@ -13,6 +15,7 @@ from utils import util, options as option from data import create_dataloader, create_dataset from trainer.ExtensibleTrainer import ExtensibleTrainer from time import time +from datetime import datetime from utils.util import opt_get @@ -32,7 +35,7 @@ def init_dist(backend, **kwargs): class Trainer: - def init(self, opt, launcher, all_networks={}): + def init(self, opt_path, opt, launcher): self._profile = False self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False) self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False) @@ -54,6 +57,7 @@ class Trainer: util.mkdirs( (path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None and 'pretrain_model' not in key and 'resume' not in key)) + shutil.copy(opt_path, os.path.join(opt['path']['experiments_root'], f'{datetime.now().strftime("%d%m%Y_%H%M%S")}_{os.path.basename(opt_path)}')) # config loggers. Before it, the log will not work util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, @@ -131,7 +135,7 @@ class Trainer: assert self.train_loader is not None #### create model - self.model = ExtensibleTrainer(opt, cached_networks=all_networks) + self.model = ExtensibleTrainer(opt) ### Evaluators self.evaluators = [] @@ -280,7 +284,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_noisy_audio_clips_classifier.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_distill.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() @@ -306,5 +310,5 @@ if __name__ == '__main__': trainer.world_size = torch.distributed.get_world_size() trainer.rank = torch.distributed.get_rank() - trainer.init(opt, args.launcher) + trainer.init(args.opt, opt, args.launcher) trainer.do_training() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index edabeab0..e2c01ea6 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -21,7 +21,7 @@ logger = logging.getLogger('base') class ExtensibleTrainer(BaseModel): - def __init__(self, opt, cached_networks={}): + def __init__(self, opt): super(ExtensibleTrainer, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() @@ -54,10 +54,6 @@ class ExtensibleTrainer(BaseModel): if 'trainable' not in net.keys(): net['trainable'] = True - if name in cached_networks.keys(): - new_net = cached_networks[name] - else: - new_net = None if net['type'] == 'generator': if new_net is None: new_net = networks.create_model(opt, net, self.netsG).to(self.device)