Support inference across batches, support inference on cpu, checkpoint

This is a checkpoint of a set of long tests with reduced-complexity networks. Some takeaways:
1) A full GAN using the resnet discriminator does appear to converge, but the quality is capped.
2) Likewise, a combination GAN/feature loss does not converge. The feature loss is optimized but
    the model appears unable to fight the discriminator, so the G-loss steadily increases.

Going forwards, I want to try some bigger models. In particular, I want to change the generator
to increase complexity and capacity. I also want to add skip connections between the
disc and generator.
This commit is contained in:
James Betker 2020-05-04 08:48:25 -06:00
parent 9c7debe75c
commit 44b89330c2
8 changed files with 26 additions and 26 deletions

View File

@ -3,6 +3,7 @@
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/codes" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/codes/temp" />
<excludeFolder url="file://$MODULE_DIR$/datasets" />
<excludeFolder url="file://$MODULE_DIR$/experiments" />
<excludeFolder url="file://$MODULE_DIR$/results" />

View File

@ -21,7 +21,8 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
num_workers=num_workers, sampler=sampler, drop_last=True,
pin_memory=False)
else:
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0,
batch_size = dataset_opt['batch_size'] or 1
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
pin_memory=False)

View File

@ -158,10 +158,7 @@ class SRGANModel(BaseModel):
self.fake_H = []
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
if step > self.D_init_iters:
fake_H = self.netG(var_L)
else:
fake_H = pix
fake_H = self.netG(var_L)
self.fake_H.append(fake_H.detach())
l_g_total = 0

View File

@ -26,7 +26,7 @@ class SRModel(BaseModel):
self.netG = networks.define_G(opt).to(self.device)
if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
else:
elif opt['gpu_ids'] is not None:
self.netG = DataParallel(self.netG)
# print network
self.print_network()

View File

@ -10,9 +10,10 @@ def parse(opt_path, is_train=True):
with open(opt_path, mode='r') as f:
opt = yaml.load(f, Loader=Loader)
# export CUDA_VISIBLE_DEVICES
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
if 'gpu_ids' in opt.keys():
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
opt['is_train'] = is_train
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':

View File

@ -4,23 +4,23 @@ model: sr
distortion: sr
scale: 4
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
gpu_ids: [0]
#gpu_ids: [0]
datasets:
test_1: # the 1st test dataset
name: set5
mode: LQ
dataroot_LQ: ../datasets/upsample_tests
batch_size: 1
dataroot_LQ: E:\4k6k\datasets\adrianna\full_extract
#### network structures
network_G:
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nf: 48
nb: 23
upscale: 4
#### path
path:
pretrain_model_G: ../experiments/ESRGANx4_blacked_for_adrianna/models/19500_G.pth
pretrain_model_G: ../experiments/rrdb_blacked_gan_g.pth

View File

@ -5,7 +5,7 @@ model: srgan
distortion: sr
scale: 4
gpu_ids: [0]
amp_opt_level: O0
amp_opt_level: O1
#### datasets
datasets:
@ -17,7 +17,7 @@ datasets:
doCrop: false
use_shuffle: true
n_workers: 12 # per GPU
batch_size: 64
batch_size: 40
target_size: 256
color: RGB
val:
@ -40,18 +40,18 @@ network_D:
#### path
path:
pretrain_model_G: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_G.pth
pretrain_model_D: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_D.pth
pretrain_model_G: ../experiments/rrdb_blacked_gan_g.pth
pretrain_model_D: ~
strict_load: true
resume_state: ~
resume_state: ../experiments/blacked_fix_and_upconv/training_state/9500.state
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 5e-5
lr_G: !!float 1e-5
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 8e-5
lr_D: !!float 4e-5
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
@ -61,20 +61,20 @@ train:
warmup_iter: -1 # no warm up
lr_steps: [5000, 20000, 40000, 60000]
lr_gamma: 0.5
mega_batch_factor: 8
mega_batch_factor: 4
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: 0
feature_weight_decay: .9
feature_weight_decay_steps: 500
feature_weight_minimum: .1
feature_weight_decay_steps: 501
feature_weight_minimum: 0
gan_type: gan # gan | ragan
gan_weight: 1
D_update_ratio: 1
D_init_iters: -1
D_init_iters: 997
manual_seed: 10
val_freq: !!float 5e2

View File

@ -15,7 +15,7 @@ if __name__ == "__main__":
#### options
want_just_images = True
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_vrp.yml')
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_adrianna_full.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)