Mod trainer to copy config file into experiments root

This commit is contained in:
James Betker 2021-10-30 17:00:24 -06:00
parent 36ed28913a
commit e9dc37f19c
2 changed files with 9 additions and 9 deletions

View File

@ -3,6 +3,8 @@ import math
import argparse
import random
import logging
import shutil
from tqdm import tqdm
import torch
@ -13,6 +15,7 @@ from utils import util, options as option
from data import create_dataloader, create_dataset
from trainer.ExtensibleTrainer import ExtensibleTrainer
from time import time
from datetime import datetime
from utils.util import opt_get
@ -32,7 +35,7 @@ def init_dist(backend, **kwargs):
class Trainer:
def init(self, opt, launcher, all_networks={}):
def init(self, opt_path, opt, launcher):
self._profile = False
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
@ -54,6 +57,7 @@ class Trainer:
util.mkdirs(
(path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
and 'pretrain_model' not in key and 'resume' not in key))
shutil.copy(opt_path, os.path.join(opt['path']['experiments_root'], f'{datetime.now().strftime("%d%m%Y_%H%M%S")}_{os.path.basename(opt_path)}'))
# config loggers. Before it, the log will not work
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
@ -131,7 +135,7 @@ class Trainer:
assert self.train_loader is not None
#### create model
self.model = ExtensibleTrainer(opt, cached_networks=all_networks)
self.model = ExtensibleTrainer(opt)
### Evaluators
self.evaluators = []
@ -280,7 +284,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_noisy_audio_clips_classifier.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_distill.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
@ -306,5 +310,5 @@ if __name__ == '__main__':
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
trainer.init(opt, args.launcher)
trainer.init(args.opt, opt, args.launcher)
trainer.do_training()

View File

@ -21,7 +21,7 @@ logger = logging.getLogger('base')
class ExtensibleTrainer(BaseModel):
def __init__(self, opt, cached_networks={}):
def __init__(self, opt):
super(ExtensibleTrainer, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
@ -54,10 +54,6 @@ class ExtensibleTrainer(BaseModel):
if 'trainable' not in net.keys():
net['trainable'] = True
if name in cached_networks.keys():
new_net = cached_networks[name]
else:
new_net = None
if net['type'] == 'generator':
if new_net is None:
new_net = networks.create_model(opt, net, self.netsG).to(self.device)