2019-08-23 13:42:47 +00:00
import os
import math
import argparse
import random
import logging
2021-10-30 23:00:24 +00:00
import shutil
2020-04-22 06:38:53 +00:00
from tqdm import tqdm
2019-08-23 13:42:47 +00:00
import torch
from data . data_sampler import DistIterSampler
2020-12-31 03:18:58 +00:00
from trainer . eval . evaluator import create_evaluator
2019-08-23 13:42:47 +00:00
2020-10-14 02:56:39 +00:00
from utils import util , options as option
2022-01-06 19:38:20 +00:00
from data import create_dataloader , create_dataset , get_dataset_debugger
2020-12-18 16:18:34 +00:00
from trainer . ExtensibleTrainer import ExtensibleTrainer
2020-06-18 17:28:55 +00:00
from time import time
2021-10-30 23:00:24 +00:00
from datetime import datetime
2019-08-23 13:42:47 +00:00
2022-01-24 22:08:13 +00:00
from utils . util import opt_get , map_cuda_to_correct_device
2021-06-03 03:47:32 +00:00
2020-10-28 02:58:03 +00:00
def init_dist ( backend , * * kwargs ) :
# These packages have globals that screw with Windows, so only import them if needed.
import torch . distributed as dist
import torch . multiprocessing as mp
""" initialization for distributed training """
if mp . get_start_method ( allow_none = True ) != ' spawn ' :
mp . set_start_method ( ' spawn ' )
rank = int ( os . environ [ ' RANK ' ] )
num_gpus = torch . cuda . device_count ( )
torch . cuda . set_device ( rank % num_gpus )
dist . init_process_group ( backend = backend , * * kwargs )
2020-10-22 22:15:24 +00:00
class Trainer :
2021-10-30 23:00:24 +00:00
def init ( self , opt_path , opt , launcher ) :
2020-10-22 22:15:24 +00:00
self . _profile = False
2021-06-03 03:47:32 +00:00
self . val_compute_psnr = opt_get ( opt , [ ' eval ' , ' compute_psnr ' ] , False )
self . val_compute_fea = opt_get ( opt , [ ' eval ' , ' compute_fea ' ] , False )
2020-10-22 22:15:24 +00:00
#### loading resume state if exists
if opt [ ' path ' ] . get ( ' resume_state ' , None ) :
# distributed resuming: all load into default GPU
2022-01-24 22:08:13 +00:00
resume_state = torch . load ( opt [ ' path ' ] [ ' resume_state ' ] , map_location = map_cuda_to_correct_device )
2020-10-22 22:15:24 +00:00
else :
resume_state = None
#### mkdir and loggers
if self . rank < = 0 : # normal training (self.rank -1) OR distributed training (self.rank 0)
if resume_state is None :
util . mkdir_and_rename (
opt [ ' path ' ] [ ' experiments_root ' ] ) # rename experiment folder if exists
util . mkdirs (
( path for key , path in opt [ ' path ' ] . items ( ) if not key == ' experiments_root ' and path is not None
and ' pretrain_model ' not in key and ' resume ' not in key ) )
2021-10-30 23:00:24 +00:00
shutil . copy ( opt_path , os . path . join ( opt [ ' path ' ] [ ' experiments_root ' ] , f ' { datetime . now ( ) . strftime ( " %d % m % Y_ % H % M % S " ) } _ { os . path . basename ( opt_path ) } ' ) )
2020-10-22 22:15:24 +00:00
# config loggers. Before it, the log will not work
util . setup_logger ( ' base ' , opt [ ' path ' ] [ ' log ' ] , ' train_ ' + opt [ ' name ' ] , level = logging . INFO ,
screen = True , tofile = True )
self . logger = logging . getLogger ( ' base ' )
self . logger . info ( option . dict2str ( opt ) )
# tensorboard logger
if opt [ ' use_tb_logger ' ] and ' debug ' not in opt [ ' name ' ] :
self . tb_logger_path = os . path . join ( opt [ ' path ' ] [ ' experiments_root ' ] , ' tb_logger ' )
2021-10-29 04:32:42 +00:00
from torch . utils . tensorboard import SummaryWriter
2020-10-22 22:15:24 +00:00
self . tb_logger = SummaryWriter ( log_dir = self . tb_logger_path )
else :
util . setup_logger ( ' base ' , opt [ ' path ' ] [ ' log ' ] , ' train ' , level = logging . INFO , screen = True )
self . logger = logging . getLogger ( ' base ' )
2021-10-29 04:32:42 +00:00
if resume_state is not None :
option . check_resume ( opt , resume_state [ ' iter ' ] ) # check resume options
2020-10-22 22:15:24 +00:00
# convert to NoneDict, which returns None for missing keys
opt = option . dict_to_nonedict ( opt )
self . opt = opt
2020-11-12 22:45:25 +00:00
#### wandb init
2021-06-06 22:52:07 +00:00
if opt [ ' wandb ' ] and self . rank < = 0 :
2020-11-12 22:45:25 +00:00
import wandb
os . makedirs ( os . path . join ( opt [ ' path ' ] [ ' log ' ] , ' wandb ' ) , exist_ok = True )
2021-11-22 23:40:05 +00:00
project_name = opt_get ( opt , [ ' wandb_project_name ' ] , opt [ ' name ' ] )
2021-11-23 00:31:29 +00:00
run_name = opt_get ( opt , [ ' wandb_run_name ' ] , None )
wandb . init ( project = project_name , dir = opt [ ' path ' ] [ ' log ' ] , config = opt , name = run_name )
2020-11-12 22:45:25 +00:00
2020-10-22 22:15:24 +00:00
#### random seed
seed = opt [ ' train ' ] [ ' manual_seed ' ]
if seed is None :
seed = random . randint ( 1 , 10000 )
if self . rank < = 0 :
self . logger . info ( ' Random seed: {} ' . format ( seed ) )
2021-01-02 22:10:06 +00:00
seed + = self . rank # Different multiprocessing instances should behave differently.
2020-10-22 22:15:24 +00:00
util . set_random_seed ( seed )
2021-09-24 23:01:36 +00:00
torch . backends . cudnn . benchmark = opt_get ( opt , [ ' cuda_benchmarking_enabled ' ] , True )
2020-10-22 22:15:24 +00:00
# torch.backends.cudnn.deterministic = True
2021-06-29 19:41:55 +00:00
if opt_get ( opt , [ ' anomaly_detection ' ] , False ) :
torch . autograd . set_detect_anomaly ( True )
2020-10-22 22:15:24 +00:00
# Save the compiled opt dict to the global loaded_options variable.
util . loaded_options = opt
#### create train and val dataloader
dataset_ratio = 1 # enlarge the size of each epoch
for phase , dataset_opt in opt [ ' datasets ' ] . items ( ) :
if phase == ' train ' :
2021-07-06 17:11:35 +00:00
self . train_set , collate_fn = create_dataset ( dataset_opt , return_collate = True )
2022-01-06 19:38:20 +00:00
self . dataset_debugger = get_dataset_debugger ( dataset_opt )
if self . dataset_debugger is not None and resume_state is not None :
self . dataset_debugger . load_state ( opt_get ( resume_state , [ ' dataset_debugger_state ' ] , { } ) )
2020-10-22 22:15:24 +00:00
train_size = int ( math . ceil ( len ( self . train_set ) / dataset_opt [ ' batch_size ' ] ) )
total_iters = int ( opt [ ' train ' ] [ ' niter ' ] )
self . total_epochs = int ( math . ceil ( total_iters / train_size ) )
if opt [ ' dist ' ] :
2020-10-28 02:58:03 +00:00
self . train_sampler = DistIterSampler ( self . train_set , self . world_size , self . rank , dataset_ratio )
2020-10-22 22:15:24 +00:00
self . total_epochs = int ( math . ceil ( total_iters / ( train_size * dataset_ratio ) ) )
2021-08-19 22:45:34 +00:00
shuffle = False
2020-10-22 22:15:24 +00:00
else :
2020-10-27 21:24:05 +00:00
self . train_sampler = None
2021-08-19 22:45:34 +00:00
shuffle = True
self . train_loader = create_dataloader ( self . train_set , dataset_opt , opt , self . train_sampler , collate_fn = collate_fn , shuffle = shuffle )
2020-10-22 22:15:24 +00:00
if self . rank < = 0 :
self . logger . info ( ' Number of train images: {:,d} , iters: {:,d} ' . format (
len ( self . train_set ) , train_size ) )
self . logger . info ( ' Total epochs needed: {:d} for iters {:,d} ' . format (
self . total_epochs , total_iters ) )
elif phase == ' val ' :
2021-07-06 17:11:35 +00:00
self . val_set , collate_fn = create_dataset ( dataset_opt , return_collate = True )
self . val_loader = create_dataloader ( self . val_set , dataset_opt , opt , None , collate_fn = collate_fn )
2020-10-22 22:15:24 +00:00
if self . rank < = 0 :
self . logger . info ( ' Number of val images in [ {:s} ]: {:d} ' . format (
dataset_opt [ ' name ' ] , len ( self . val_set ) ) )
2020-10-22 19:27:32 +00:00
else :
2020-10-22 22:15:24 +00:00
raise NotImplementedError ( ' Phase [ {:s} ] is not recognized. ' . format ( phase ) )
assert self . train_loader is not None
#### create model
2021-10-30 23:00:24 +00:00
self . model = ExtensibleTrainer ( opt )
2020-10-22 22:15:24 +00:00
2020-11-13 18:03:54 +00:00
### Evaluators
self . evaluators = [ ]
2021-06-03 03:47:32 +00:00
if ' eval ' in opt . keys ( ) and ' evaluators ' in opt [ ' eval ' ] . keys ( ) :
2021-08-09 20:58:35 +00:00
# In "pure" mode, we propagate through the normal training steps, but use validation data instead and average
# the total loss. A validation dataloader is required.
if opt_get ( opt , [ ' eval ' , ' pure ' ] , False ) :
assert hasattr ( self , ' val_loader ' )
2020-11-13 18:03:54 +00:00
for ev_key , ev_opt in opt [ ' eval ' ] [ ' evaluators ' ] . items ( ) :
self . evaluators . append ( create_evaluator ( self . model . networks [ ev_opt [ ' for ' ] ] ,
ev_opt , self . model . env ) )
2020-10-22 22:15:24 +00:00
#### resume training
if resume_state :
self . logger . info ( ' Resuming training from epoch: {} , iter: {} . ' . format (
resume_state [ ' epoch ' ] , resume_state [ ' iter ' ] ) )
self . start_epoch = resume_state [ ' epoch ' ]
self . current_step = resume_state [ ' iter ' ]
self . model . resume_training ( resume_state , ' amp_opt_level ' in opt . keys ( ) ) # handle optimizers and schedulers
2020-10-22 19:27:32 +00:00
else :
2020-10-22 22:15:24 +00:00
self . current_step = - 1 if ' start_step ' not in opt . keys ( ) else opt [ ' start_step ' ]
self . start_epoch = 0
if ' force_start_step ' in opt . keys ( ) :
self . current_step = opt [ ' force_start_step ' ]
2021-01-03 05:24:12 +00:00
opt [ ' current_step ' ] = self . current_step
2020-10-22 22:15:24 +00:00
def do_step ( self , train_data ) :
if self . _profile :
print ( " Data fetch: %f " % ( time ( ) - _t ) )
_t = time ( )
opt = self . opt
self . current_step + = 1
#### update learning rate
self . model . update_learning_rate ( self . current_step , warmup_iter = opt [ ' train ' ] [ ' warmup_iter ' ] )
#### training
if self . _profile :
print ( " Update LR: %f " % ( time ( ) - _t ) )
_t = time ( )
2020-11-12 22:45:25 +00:00
self . model . feed_data ( train_data , self . current_step )
2020-10-22 22:15:24 +00:00
self . model . optimize_parameters ( self . current_step )
if self . _profile :
print ( " Model feed + step: %f " % ( time ( ) - _t ) )
_t = time ( )
#### log
2022-01-06 19:38:20 +00:00
if self . dataset_debugger is not None :
self . dataset_debugger . update ( train_data )
2020-10-22 22:15:24 +00:00
if self . current_step % opt [ ' logger ' ] [ ' print_freq ' ] == 0 and self . rank < = 0 :
logs = self . model . get_current_log ( self . current_step )
2022-01-06 19:38:20 +00:00
if self . dataset_debugger is not None :
logs . update ( self . dataset_debugger . get_debugging_map ( ) )
2020-10-22 22:49:34 +00:00
message = ' [epoch: {:3d} , iter: {:8,d} , lr:( ' . format ( self . epoch , self . current_step )
2020-10-22 22:15:24 +00:00
for v in self . model . get_current_learning_rate ( ) :
message + = ' {:.3e} , ' . format ( v )
message + = ' )] '
for k , v in logs . items ( ) :
if ' histogram ' in k :
self . tb_logger . add_histogram ( k , v , self . current_step )
elif isinstance ( v , dict ) :
self . tb_logger . add_scalars ( k , v , self . current_step )
else :
message + = ' {:s} : {:.4e} ' . format ( k , v )
2019-08-23 13:42:47 +00:00
# tensorboard logger
2020-10-22 22:15:24 +00:00
if opt [ ' use_tb_logger ' ] and ' debug ' not in opt [ ' name ' ] :
self . tb_logger . add_scalar ( k , v , self . current_step )
2021-06-06 22:52:07 +00:00
if opt [ ' wandb ' ] and self . rank < = 0 :
2020-11-12 04:48:56 +00:00
import wandb
2022-01-28 02:58:58 +00:00
wandb . log ( logs , step = int ( self . current_step * opt_get ( opt , [ ' wandb_step_factor ' ] , 1 ) ) )
2020-10-22 22:15:24 +00:00
self . logger . info ( message )
#### save models and training states
if self . current_step % opt [ ' logger ' ] [ ' save_checkpoint_freq ' ] == 0 :
2022-01-25 01:12:08 +00:00
self . model . consolidate_state ( )
2020-10-22 22:15:24 +00:00
if self . rank < = 0 :
self . logger . info ( ' Saving models and training states. ' )
self . model . save ( self . current_step )
2022-01-06 19:38:20 +00:00
state = { ' epoch ' : self . epoch , ' iter ' : self . current_step }
if self . dataset_debugger is not None :
state [ ' dataset_debugger_state ' ] = self . dataset_debugger . get_state ( )
self . model . save_training_state ( state )
2020-10-22 22:15:24 +00:00
if ' alt_path ' in opt [ ' path ' ] . keys ( ) :
import shutil
print ( " Synchronizing tb_logger to alt_path.. " )
alt_tblogger = os . path . join ( opt [ ' path ' ] [ ' alt_path ' ] , " tb_logger " )
shutil . rmtree ( alt_tblogger , ignore_errors = True )
shutil . copytree ( self . tb_logger_path , alt_tblogger )
#### validation
2021-08-09 20:58:35 +00:00
if opt_get ( opt , [ ' eval ' , ' pure ' ] , False ) and self . current_step % opt [ ' train ' ] [ ' val_freq ' ] == 0 :
metrics = [ ]
for val_data in tqdm ( self . val_loader ) :
2021-08-09 22:02:01 +00:00
self . model . feed_data ( val_data , self . current_step , perform_micro_batching = False )
2021-08-09 20:58:35 +00:00
metrics . append ( self . model . test ( ) )
reduced_metrics = { }
for metric in metrics :
for k , v in metric . as_dict ( ) . items ( ) :
if isinstance ( v , torch . Tensor ) and len ( v . shape ) == 0 :
if k in reduced_metrics . keys ( ) :
reduced_metrics [ k ] . append ( v )
else :
reduced_metrics [ k ] = [ v ]
if self . rank < = 0 :
for k , v in reduced_metrics . items ( ) :
val = torch . stack ( v ) . mean ( ) . item ( )
2021-08-15 15:09:51 +00:00
self . tb_logger . add_scalar ( f ' val_ { k } ' , val , self . current_step )
2021-08-09 20:58:35 +00:00
print ( f " >>Eval { k } : { val } " )
if opt [ ' wandb ' ] :
import wandb
2021-08-10 03:31:12 +00:00
wandb . log ( { f ' eval_ { k } ' : torch . stack ( v ) . mean ( ) . item ( ) for k , v in reduced_metrics . items ( ) } )
2020-10-22 22:15:24 +00:00
2021-06-14 15:50:04 +00:00
if len ( self . evaluators ) != 0 and self . current_step % opt [ ' train ' ] [ ' val_freq ' ] == 0 :
2020-11-13 18:03:54 +00:00
eval_dict = { }
for eval in self . evaluators :
2021-06-14 15:51:44 +00:00
if eval . uses_all_ddp or self . rank < = 0 :
2021-06-14 15:50:04 +00:00
eval_dict . update ( eval . perform_eval ( ) )
2020-12-04 23:39:21 +00:00
if self . rank < = 0 :
print ( " Evaluator results: " , eval_dict )
for ek , ev in eval_dict . items ( ) :
self . tb_logger . add_scalar ( ek , ev , self . current_step )
2021-06-05 05:23:20 +00:00
if opt [ ' wandb ' ] :
2021-06-05 05:27:15 +00:00
import wandb
2021-06-05 05:23:20 +00:00
wandb . log ( eval_dict )
2022-01-09 03:31:19 +00:00
# Should not be necessary, but make absolutely sure that there is no grad leakage from validation runs.
for net in self . model . networks . values ( ) :
net . zero_grad ( )
2020-10-22 22:15:24 +00:00
def do_training ( self ) :
self . logger . info ( ' Start training from epoch: {:d} , iter: {:d} ' . format ( self . start_epoch , self . current_step ) )
for epoch in range ( self . start_epoch , self . total_epochs + 1 ) :
2020-10-22 22:49:34 +00:00
self . epoch = epoch
2020-10-22 22:15:24 +00:00
if opt [ ' dist ' ] :
self . train_sampler . set_epoch ( epoch )
2021-12-28 23:18:12 +00:00
tq_ldr = tqdm ( self . train_loader ) if self . rank < = 0 else self . train_loader
2020-10-22 22:15:24 +00:00
_t = time ( )
for train_data in tq_ldr :
self . do_step ( train_data )
def create_training_generator ( self , index ) :
self . logger . info ( ' Start training from epoch: {:d} , iter: {:d} ' . format ( self . start_epoch , self . current_step ) )
for epoch in range ( self . start_epoch , self . total_epochs + 1 ) :
2020-10-22 22:49:34 +00:00
self . epoch = epoch
2020-10-22 22:15:24 +00:00
if self . opt [ ' dist ' ] :
self . train_sampler . set_epoch ( epoch )
tq_ldr = tqdm ( self . train_loader , position = index )
2019-08-23 13:42:47 +00:00
2020-10-22 22:15:24 +00:00
_t = time ( )
for train_data in tq_ldr :
yield self . model
self . do_step ( train_data )
2019-08-23 13:42:47 +00:00
if __name__ == ' __main__ ' :
2020-10-22 19:27:32 +00:00
parser = argparse . ArgumentParser ( )
2022-01-26 00:57:16 +00:00
parser . add_argument ( ' -opt ' , type = str , help = ' Path to option YAML file. ' , default = ' ../experiments/train_diffusion_tts_experimental_fp16/train_diffusion_tts.yml ' )
2020-10-22 19:27:32 +00:00
parser . add_argument ( ' --launcher ' , choices = [ ' none ' , ' pytorch ' ] , default = ' none ' , help = ' job launcher ' )
2020-10-27 21:24:05 +00:00
parser . add_argument ( ' --local_rank ' , type = int , default = 0 )
2020-10-22 19:27:32 +00:00
args = parser . parse_args ( )
opt = option . parse ( args . opt , is_train = True )
2020-12-31 17:31:40 +00:00
if args . launcher != ' none ' :
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
if ' gpu_ids ' in opt . keys ( ) :
gpu_list = ' , ' . join ( str ( x ) for x in opt [ ' gpu_ids ' ] )
os . environ [ ' CUDA_VISIBLE_DEVICES ' ] = gpu_list
print ( ' export CUDA_VISIBLE_DEVICES= ' + gpu_list )
2020-10-22 22:15:24 +00:00
trainer = Trainer ( )
2020-10-28 02:58:03 +00:00
2021-01-01 18:59:00 +00:00
#### distributed training settings
2020-10-28 02:58:03 +00:00
if args . launcher == ' none ' : # disabled distributed training
opt [ ' dist ' ] = False
trainer . rank = - 1
2020-11-12 22:43:01 +00:00
if len ( opt [ ' gpu_ids ' ] ) == 1 :
torch . cuda . set_device ( opt [ ' gpu_ids ' ] [ 0 ] )
2020-10-28 02:58:03 +00:00
print ( ' Disabled distributed training. ' )
else :
opt [ ' dist ' ] = True
2021-01-01 18:59:00 +00:00
init_dist ( ' nccl ' )
2020-10-28 02:58:03 +00:00
trainer . world_size = torch . distributed . get_world_size ( )
trainer . rank = torch . distributed . get_rank ( )
2021-10-30 23:00:24 +00:00
trainer . init ( args . opt , opt , args . launcher )
2020-10-22 22:15:24 +00:00
trainer . do_training ( )