Make ExtensibleTrainer set the starting step for the LR scheduler

This commit is contained in:
James Betker 2021-01-02 22:22:34 -07:00
parent bdbab65082
commit edf9c38198
4 changed files with 16 additions and 6 deletions

View File

@ -294,7 +294,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_bigboi_512.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_discriminator_diffimage.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -96,6 +96,16 @@ class ExtensibleTrainer(BaseModel):
for s in self.steps:
def_opt.extend(s.get_optimizers_with_default_scheduler())
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
# Set the starting step count for the scheduler.
start_step = 0
if 'force_start_step' in opt.keys():
start_step = opt['force_start_step']
elif 'start_step' in opt.keys():
start_step = opt['start_step']
if start_step != 0:
for sched in self.schedulers:
sched.last_epoch = start_step
else:
self.schedulers = []

View File

@ -159,11 +159,11 @@ if __name__ == "__main__":
restart_weights = [1]
## four
T_period = [80000, 80000, 80000, 80000]
restarts = [100000, 200000]
restart_weights = [.5, .25]
T_period = [25000, 25000]
restarts = [252000]
restart_weights = [.5]
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, warmup=100000, eta_min=.01, restarts=restarts,
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, warmup=227000, eta_min=.01, restarts=restarts,
weights=restart_weights)
##############################

View File

@ -104,7 +104,7 @@ class ConfigurableStep(Module):
elif self.step_opt['optimizer'] == 'lars':
from trainer.optimizers.larc import LARC
from trainer.optimizers.sgd import SGDNoBiasMomentum
optSGD = SGDNoBiasMomentum(optim_params, lr=opt_config['lr'], momentum=opt_config['momentum'],
optSGD = SGDNoBiasMomentum(list(optim_params.values()), lr=opt_config['lr'], momentum=opt_config['momentum'],
weight_decay=opt_config['weight_decay'])
opt = LARC(optSGD, trust_coefficient=opt_config['lars_coefficient'])
opt._config = opt_config # This is a bit seedy, but we will need these configs later.