diff --git a/codes/train.py b/codes/train.py index 5db47536..980d00c2 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_xlbatch_ragan.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -121,7 +121,7 @@ def main(): # torch.backends.cudnn.deterministic = True #### create train and val dataloader - dataset_ratio = 200 # enlarge the size of each epoch + dataset_ratio = 1 # enlarge the size of each epoch for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt)