diff --git a/codes/train.py b/codes/train.py index 5a819376..260b3874 100644 --- a/codes/train.py +++ b/codes/train.py @@ -294,7 +294,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_bigboi_512.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_discriminator_diffimage.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 2737142c..7d3d1983 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -96,6 +96,16 @@ class ExtensibleTrainer(BaseModel): for s in self.steps: def_opt.extend(s.get_optimizers_with_default_scheduler()) self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt) + + # Set the starting step count for the scheduler. + start_step = 0 + if 'force_start_step' in opt.keys(): + start_step = opt['force_start_step'] + elif 'start_step' in opt.keys(): + start_step = opt['start_step'] + if start_step != 0: + for sched in self.schedulers: + sched.last_epoch = start_step else: self.schedulers = [] diff --git a/codes/trainer/lr_scheduler.py b/codes/trainer/lr_scheduler.py index 307cec27..bc9f96f0 100644 --- a/codes/trainer/lr_scheduler.py +++ b/codes/trainer/lr_scheduler.py @@ -159,11 +159,11 @@ if __name__ == "__main__": restart_weights = [1] ## four - T_period = [80000, 80000, 80000, 80000] - restarts = [100000, 200000] - restart_weights = [.5, .25] + T_period = [25000, 25000] + restarts = [252000] + restart_weights = [.5] - scheduler = CosineAnnealingLR_Restart(optimizer, T_period, warmup=100000, eta_min=.01, restarts=restarts, + scheduler = CosineAnnealingLR_Restart(optimizer, T_period, warmup=227000, eta_min=.01, restarts=restarts, weights=restart_weights) ############################## diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index fb9d7033..07c24e2b 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -104,7 +104,7 @@ class ConfigurableStep(Module): elif self.step_opt['optimizer'] == 'lars': from trainer.optimizers.larc import LARC from trainer.optimizers.sgd import SGDNoBiasMomentum - optSGD = SGDNoBiasMomentum(optim_params, lr=opt_config['lr'], momentum=opt_config['momentum'], + optSGD = SGDNoBiasMomentum(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'], weight_decay=opt_config['weight_decay']) opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient']) opt._config = opt_config # This is a bit seedy, but we will need these configs later.