Support inference across batches, support inference on cpu, checkpoint

This is a checkpoint of a set of long tests with reduced-complexity networks. Some takeaways:
1) A full GAN using the resnet discriminator does appear to converge, but the quality is capped.
2) Likewise, a combination GAN/feature loss does not converge. The feature loss is optimized but
    the model appears unable to fight the discriminator, so the G-loss steadily increases.

Going forwards, I want to try some bigger models. In particular, I want to change the generator
to increase complexity and capacity. I also want to add skip connections between the
disc and generator.
This commit is contained in:
James Betker 2020-05-04 08:48:25 -06:00
parent 9c7debe75c
commit 44b89330c2
8 changed files with 26 additions and 26 deletions

View File

@ -3,6 +3,7 @@
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$"> <content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/codes" isTestSource="false" /> <sourceFolder url="file://$MODULE_DIR$/codes" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/codes/temp" />
<excludeFolder url="file://$MODULE_DIR$/datasets" /> <excludeFolder url="file://$MODULE_DIR$/datasets" />
<excludeFolder url="file://$MODULE_DIR$/experiments" /> <excludeFolder url="file://$MODULE_DIR$/experiments" />
<excludeFolder url="file://$MODULE_DIR$/results" /> <excludeFolder url="file://$MODULE_DIR$/results" />

View File

@ -21,7 +21,8 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
num_workers=num_workers, sampler=sampler, drop_last=True, num_workers=num_workers, sampler=sampler, drop_last=True,
pin_memory=False) pin_memory=False)
else: else:
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, batch_size = dataset_opt['batch_size'] or 1
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
pin_memory=False) pin_memory=False)

View File

@ -158,10 +158,7 @@ class SRGANModel(BaseModel):
self.fake_H = [] self.fake_H = []
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix): for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
if step > self.D_init_iters: fake_H = self.netG(var_L)
fake_H = self.netG(var_L)
else:
fake_H = pix
self.fake_H.append(fake_H.detach()) self.fake_H.append(fake_H.detach())
l_g_total = 0 l_g_total = 0

View File

@ -26,7 +26,7 @@ class SRModel(BaseModel):
self.netG = networks.define_G(opt).to(self.device) self.netG = networks.define_G(opt).to(self.device)
if opt['dist']: if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
else: elif opt['gpu_ids'] is not None:
self.netG = DataParallel(self.netG) self.netG = DataParallel(self.netG)
# print network # print network
self.print_network() self.print_network()

View File

@ -10,9 +10,10 @@ def parse(opt_path, is_train=True):
with open(opt_path, mode='r') as f: with open(opt_path, mode='r') as f:
opt = yaml.load(f, Loader=Loader) opt = yaml.load(f, Loader=Loader)
# export CUDA_VISIBLE_DEVICES # export CUDA_VISIBLE_DEVICES
gpu_list = ','.join(str(x) for x in opt['gpu_ids']) if 'gpu_ids' in opt.keys():
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
print('export CUDA_VISIBLE_DEVICES=' + gpu_list) os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
opt['is_train'] = is_train opt['is_train'] = is_train
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample': if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':

View File

@ -4,23 +4,23 @@ model: sr
distortion: sr distortion: sr
scale: 4 scale: 4
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
gpu_ids: [0] #gpu_ids: [0]
datasets: datasets:
test_1: # the 1st test dataset test_1: # the 1st test dataset
name: set5 name: set5
mode: LQ mode: LQ
dataroot_LQ: ../datasets/upsample_tests batch_size: 1
dataroot_LQ: E:\4k6k\datasets\adrianna\full_extract
#### network structures #### network structures
network_G: network_G:
which_model_G: RRDBNet which_model_G: RRDBNet
in_nc: 3 in_nc: 3
out_nc: 3 out_nc: 3
nf: 64 nf: 48
nb: 23 nb: 23
upscale: 4
#### path #### path
path: path:
pretrain_model_G: ../experiments/ESRGANx4_blacked_for_adrianna/models/19500_G.pth pretrain_model_G: ../experiments/rrdb_blacked_gan_g.pth

View File

@ -5,7 +5,7 @@ model: srgan
distortion: sr distortion: sr
scale: 4 scale: 4
gpu_ids: [0] gpu_ids: [0]
amp_opt_level: O0 amp_opt_level: O1
#### datasets #### datasets
datasets: datasets:
@ -17,7 +17,7 @@ datasets:
doCrop: false doCrop: false
use_shuffle: true use_shuffle: true
n_workers: 12 # per GPU n_workers: 12 # per GPU
batch_size: 64 batch_size: 40
target_size: 256 target_size: 256
color: RGB color: RGB
val: val:
@ -40,18 +40,18 @@ network_D:
#### path #### path
path: path:
pretrain_model_G: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_G.pth pretrain_model_G: ../experiments/rrdb_blacked_gan_g.pth
pretrain_model_D: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_D.pth pretrain_model_D: ~
strict_load: true strict_load: true
resume_state: ~ resume_state: ../experiments/blacked_fix_and_upconv/training_state/9500.state
#### training settings: learning rate scheme, loss #### training settings: learning rate scheme, loss
train: train:
lr_G: !!float 5e-5 lr_G: !!float 1e-5
weight_decay_G: 0 weight_decay_G: 0
beta1_G: 0.9 beta1_G: 0.9
beta2_G: 0.99 beta2_G: 0.99
lr_D: !!float 8e-5 lr_D: !!float 4e-5
weight_decay_D: 0 weight_decay_D: 0
beta1_D: 0.9 beta1_D: 0.9
beta2_D: 0.99 beta2_D: 0.99
@ -61,20 +61,20 @@ train:
warmup_iter: -1 # no warm up warmup_iter: -1 # no warm up
lr_steps: [5000, 20000, 40000, 60000] lr_steps: [5000, 20000, 40000, 60000]
lr_gamma: 0.5 lr_gamma: 0.5
mega_batch_factor: 8 mega_batch_factor: 4
pixel_criterion: l1 pixel_criterion: l1
pixel_weight: !!float 1e-2 pixel_weight: !!float 1e-2
feature_criterion: l1 feature_criterion: l1
feature_weight: 0 feature_weight: 0
feature_weight_decay: .9 feature_weight_decay: .9
feature_weight_decay_steps: 500 feature_weight_decay_steps: 501
feature_weight_minimum: .1 feature_weight_minimum: 0
gan_type: gan # gan | ragan gan_type: gan # gan | ragan
gan_weight: 1 gan_weight: 1
D_update_ratio: 1 D_update_ratio: 1
D_init_iters: -1 D_init_iters: 997
manual_seed: 10 manual_seed: 10
val_freq: !!float 5e2 val_freq: !!float 5e2

View File

@ -15,7 +15,7 @@ if __name__ == "__main__":
#### options #### options
want_just_images = True want_just_images = True
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_vrp.yml') parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_adrianna_full.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)