forked from mrq/DL-Art-School
Mod trainer to copy config file into experiments root
This commit is contained in:
parent
36ed28913a
commit
e9dc37f19c
|
@ -3,6 +3,8 @@ import math
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
|
import shutil
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -13,6 +15,7 @@ from utils import util, options as option
|
||||||
from data import create_dataloader, create_dataset
|
from data import create_dataloader, create_dataset
|
||||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from time import time
|
from time import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
@ -32,7 +35,7 @@ def init_dist(backend, **kwargs):
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
|
||||||
def init(self, opt, launcher, all_networks={}):
|
def init(self, opt_path, opt, launcher):
|
||||||
self._profile = False
|
self._profile = False
|
||||||
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
|
self.val_compute_psnr = opt_get(opt, ['eval', 'compute_psnr'], False)
|
||||||
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
|
self.val_compute_fea = opt_get(opt, ['eval', 'compute_fea'], False)
|
||||||
|
@ -54,6 +57,7 @@ class Trainer:
|
||||||
util.mkdirs(
|
util.mkdirs(
|
||||||
(path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
|
(path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
|
||||||
and 'pretrain_model' not in key and 'resume' not in key))
|
and 'pretrain_model' not in key and 'resume' not in key))
|
||||||
|
shutil.copy(opt_path, os.path.join(opt['path']['experiments_root'], f'{datetime.now().strftime("%d%m%Y_%H%M%S")}_{os.path.basename(opt_path)}'))
|
||||||
|
|
||||||
# config loggers. Before it, the log will not work
|
# config loggers. Before it, the log will not work
|
||||||
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
||||||
|
@ -131,7 +135,7 @@ class Trainer:
|
||||||
assert self.train_loader is not None
|
assert self.train_loader is not None
|
||||||
|
|
||||||
#### create model
|
#### create model
|
||||||
self.model = ExtensibleTrainer(opt, cached_networks=all_networks)
|
self.model = ExtensibleTrainer(opt)
|
||||||
|
|
||||||
### Evaluators
|
### Evaluators
|
||||||
self.evaluators = []
|
self.evaluators = []
|
||||||
|
@ -280,7 +284,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_noisy_audio_clips_classifier.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_distill.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -306,5 +310,5 @@ if __name__ == '__main__':
|
||||||
trainer.world_size = torch.distributed.get_world_size()
|
trainer.world_size = torch.distributed.get_world_size()
|
||||||
trainer.rank = torch.distributed.get_rank()
|
trainer.rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
trainer.init(opt, args.launcher)
|
trainer.init(args.opt, opt, args.launcher)
|
||||||
trainer.do_training()
|
trainer.do_training()
|
||||||
|
|
|
@ -21,7 +21,7 @@ logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
class ExtensibleTrainer(BaseModel):
|
class ExtensibleTrainer(BaseModel):
|
||||||
def __init__(self, opt, cached_networks={}):
|
def __init__(self, opt):
|
||||||
super(ExtensibleTrainer, self).__init__(opt)
|
super(ExtensibleTrainer, self).__init__(opt)
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
self.rank = torch.distributed.get_rank()
|
self.rank = torch.distributed.get_rank()
|
||||||
|
@ -54,10 +54,6 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if 'trainable' not in net.keys():
|
if 'trainable' not in net.keys():
|
||||||
net['trainable'] = True
|
net['trainable'] = True
|
||||||
|
|
||||||
if name in cached_networks.keys():
|
|
||||||
new_net = cached_networks[name]
|
|
||||||
else:
|
|
||||||
new_net = None
|
|
||||||
if net['type'] == 'generator':
|
if net['type'] == 'generator':
|
||||||
if new_net is None:
|
if new_net is None:
|
||||||
new_net = networks.create_model(opt, net, self.netsG).to(self.device)
|
new_net = networks.create_model(opt, net, self.netsG).to(self.device)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user