Mod trainer to copy config file into experiments root
This commit is contained in:
parent
36ed28913a
commit
e9dc37f19c
|
@ -3,6 +3,8 @@ import math
|
|||
import argparse
|
||||
import random
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
|
@ -13,6 +15,7 @@ from utils import util, options as option
|
|||
from data import create_dataloader, create_dataset
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from time import time
|
||||
from datetime import datetime
|
||||
|
||||
from utils.util import opt_get
|
||||
|
||||
|
@ -32,7 +35,7 @@ def init_dist(backend, **kwargs):
|
|||
|
||||
class Trainer:
|
||||
|
||||
def init(self, opt, launcher, all_networks={}):
|
||||
def init(self, opt_path, opt, launcher):
|
||||
self._profile = False
|
||||
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
|
||||
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
|
||||
|
@ -54,6 +57,7 @@ class Trainer:
|
|||
util.mkdirs(
|
||||
(path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
|
||||
and 'pretrain_model' not in key and 'resume' not in key))
|
||||
shutil.copy(opt_path, os.path.join(opt['path']['experiments_root'], f'{datetime.now().strftime("%d%m%Y_%H%M%S")}_{os.path.basename(opt_path)}'))
|
||||
|
||||
# config loggers. Before it, the log will not work
|
||||
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
||||
|
@ -131,7 +135,7 @@ class Trainer:
|
|||
assert self.train_loader is not None
|
||||
|
||||
#### create model
|
||||
self.model = ExtensibleTrainer(opt, cached_networks=all_networks)
|
||||
self.model = ExtensibleTrainer(opt)
|
||||
|
||||
### Evaluators
|
||||
self.evaluators = []
|
||||
|
@ -280,7 +284,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_noisy_audio_clips_classifier.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_distill.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
@ -306,5 +310,5 @@ if __name__ == '__main__':
|
|||
trainer.world_size = torch.distributed.get_world_size()
|
||||
trainer.rank = torch.distributed.get_rank()
|
||||
|
||||
trainer.init(opt, args.launcher)
|
||||
trainer.init(args.opt, opt, args.launcher)
|
||||
trainer.do_training()
|
||||
|
|
|
@ -21,7 +21,7 @@ logger = logging.getLogger('base')
|
|||
|
||||
|
||||
class ExtensibleTrainer(BaseModel):
|
||||
def __init__(self, opt, cached_networks={}):
|
||||
def __init__(self, opt):
|
||||
super(ExtensibleTrainer, self).__init__(opt)
|
||||
if opt['dist']:
|
||||
self.rank = torch.distributed.get_rank()
|
||||
|
@ -54,10 +54,6 @@ class ExtensibleTrainer(BaseModel):
|
|||
if 'trainable' not in net.keys():
|
||||
net['trainable'] = True
|
||||
|
||||
if name in cached_networks.keys():
|
||||
new_net = cached_networks[name]
|
||||
else:
|
||||
new_net = None
|
||||
if net['type'] == 'generator':
|
||||
if new_net is None:
|
||||
new_net = networks.create_model(opt, net, self.netsG).to(self.device)
|
||||
|
|
Loading…
Reference in New Issue
Block a user