|
|
|
@ -3,10 +3,11 @@ import math
|
|
|
|
|
import argparse
|
|
|
|
|
import random
|
|
|
|
|
import logging
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
|
#import torch.distributed as dist
|
|
|
|
|
#import torch.multiprocessing as mp
|
|
|
|
|
from data.data_sampler import DistIterSampler
|
|
|
|
|
|
|
|
|
|
import options.options as option
|
|
|
|
@ -28,7 +29,7 @@ def init_dist(backend='nccl', **kwargs):
|
|
|
|
|
def main():
|
|
|
|
|
#### options
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
|
|
|
|
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_ESRGAN_blacked.yml')
|
|
|
|
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
|
|
|
|
help='job launcher')
|
|
|
|
|
parser.add_argument('--local_rank', type=int, default=0)
|
|
|
|
@ -138,7 +139,7 @@ def main():
|
|
|
|
|
current_step = resume_state['iter']
|
|
|
|
|
model.resume_training(resume_state) # handle optimizers and schedulers
|
|
|
|
|
else:
|
|
|
|
|
current_step = 0
|
|
|
|
|
current_step = -1
|
|
|
|
|
start_epoch = 0
|
|
|
|
|
|
|
|
|
|
#### training
|
|
|
|
@ -146,7 +147,8 @@ def main():
|
|
|
|
|
for epoch in range(start_epoch, total_epochs + 1):
|
|
|
|
|
if opt['dist']:
|
|
|
|
|
train_sampler.set_epoch(epoch)
|
|
|
|
|
for _, train_data in enumerate(train_loader):
|
|
|
|
|
tq_ldr = tqdm(train_loader)
|
|
|
|
|
for _, train_data in enumerate(tq_ldr):
|
|
|
|
|
current_step += 1
|
|
|
|
|
if current_step > total_iters:
|
|
|
|
|
break
|
|
|
|
|