forked from mrq/DL-Art-School
Support inference across batches, support inference on cpu, checkpoint
This is a checkpoint of a set of long tests with reduced-complexity networks. Some takeaways: 1) A full GAN using the resnet discriminator does appear to converge, but the quality is capped. 2) Likewise, a combination GAN/feature loss does not converge. The feature loss is optimized but the model appears unable to fight the discriminator, so the G-loss steadily increases. Going forwards, I want to try some bigger models. In particular, I want to change the generator to increase complexity and capacity. I also want to add skip connections between the disc and generator.
This commit is contained in:
parent
9c7debe75c
commit
44b89330c2
|
@ -3,6 +3,7 @@
|
|||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/codes" isTestSource="false" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/codes/temp" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/datasets" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/experiments" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/results" />
|
||||
|
|
|
@ -21,7 +21,8 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
|||
num_workers=num_workers, sampler=sampler, drop_last=True,
|
||||
pin_memory=False)
|
||||
else:
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0,
|
||||
batch_size = dataset_opt['batch_size'] or 1
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
|
||||
pin_memory=False)
|
||||
|
||||
|
||||
|
|
|
@ -158,10 +158,7 @@ class SRGANModel(BaseModel):
|
|||
|
||||
self.fake_H = []
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||
if step > self.D_init_iters:
|
||||
fake_H = self.netG(var_L)
|
||||
else:
|
||||
fake_H = pix
|
||||
fake_H = self.netG(var_L)
|
||||
self.fake_H.append(fake_H.detach())
|
||||
|
||||
l_g_total = 0
|
||||
|
|
|
@ -26,7 +26,7 @@ class SRModel(BaseModel):
|
|||
self.netG = networks.define_G(opt).to(self.device)
|
||||
if opt['dist']:
|
||||
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
||||
else:
|
||||
elif opt['gpu_ids'] is not None:
|
||||
self.netG = DataParallel(self.netG)
|
||||
# print network
|
||||
self.print_network()
|
||||
|
|
|
@ -10,9 +10,10 @@ def parse(opt_path, is_train=True):
|
|||
with open(opt_path, mode='r') as f:
|
||||
opt = yaml.load(f, Loader=Loader)
|
||||
# export CUDA_VISIBLE_DEVICES
|
||||
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||||
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
||||
if 'gpu_ids' in opt.keys():
|
||||
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||||
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
||||
|
||||
opt['is_train'] = is_train
|
||||
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
||||
|
|
|
@ -4,23 +4,23 @@ model: sr
|
|||
distortion: sr
|
||||
scale: 4
|
||||
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
||||
gpu_ids: [0]
|
||||
#gpu_ids: [0]
|
||||
|
||||
datasets:
|
||||
test_1: # the 1st test dataset
|
||||
name: set5
|
||||
mode: LQ
|
||||
dataroot_LQ: ../datasets/upsample_tests
|
||||
batch_size: 1
|
||||
dataroot_LQ: E:\4k6k\datasets\adrianna\full_extract
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: RRDBNet
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 64
|
||||
nf: 48
|
||||
nb: 23
|
||||
upscale: 4
|
||||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ../experiments/ESRGANx4_blacked_for_adrianna/models/19500_G.pth
|
||||
pretrain_model_G: ../experiments/rrdb_blacked_gan_g.pth
|
|
@ -5,7 +5,7 @@ model: srgan
|
|||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [0]
|
||||
amp_opt_level: O0
|
||||
amp_opt_level: O1
|
||||
|
||||
#### datasets
|
||||
datasets:
|
||||
|
@ -17,7 +17,7 @@ datasets:
|
|||
doCrop: false
|
||||
use_shuffle: true
|
||||
n_workers: 12 # per GPU
|
||||
batch_size: 64
|
||||
batch_size: 40
|
||||
target_size: 256
|
||||
color: RGB
|
||||
val:
|
||||
|
@ -40,18 +40,18 @@ network_D:
|
|||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_G.pth
|
||||
pretrain_model_D: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_D.pth
|
||||
pretrain_model_G: ../experiments/rrdb_blacked_gan_g.pth
|
||||
pretrain_model_D: ~
|
||||
strict_load: true
|
||||
resume_state: ~
|
||||
resume_state: ../experiments/blacked_fix_and_upconv/training_state/9500.state
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 5e-5
|
||||
lr_G: !!float 1e-5
|
||||
weight_decay_G: 0
|
||||
beta1_G: 0.9
|
||||
beta2_G: 0.99
|
||||
lr_D: !!float 8e-5
|
||||
lr_D: !!float 4e-5
|
||||
weight_decay_D: 0
|
||||
beta1_D: 0.9
|
||||
beta2_D: 0.99
|
||||
|
@ -61,20 +61,20 @@ train:
|
|||
warmup_iter: -1 # no warm up
|
||||
lr_steps: [5000, 20000, 40000, 60000]
|
||||
lr_gamma: 0.5
|
||||
mega_batch_factor: 8
|
||||
mega_batch_factor: 4
|
||||
|
||||
pixel_criterion: l1
|
||||
pixel_weight: !!float 1e-2
|
||||
feature_criterion: l1
|
||||
feature_weight: 0
|
||||
feature_weight_decay: .9
|
||||
feature_weight_decay_steps: 500
|
||||
feature_weight_minimum: .1
|
||||
feature_weight_decay_steps: 501
|
||||
feature_weight_minimum: 0
|
||||
gan_type: gan # gan | ragan
|
||||
gan_weight: 1
|
||||
|
||||
D_update_ratio: 1
|
||||
D_init_iters: -1
|
||||
D_init_iters: 997
|
||||
|
||||
manual_seed: 10
|
||||
val_freq: !!float 5e2
|
||||
|
|
|
@ -15,7 +15,7 @@ if __name__ == "__main__":
|
|||
#### options
|
||||
want_just_images = True
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_vrp.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_adrianna_full.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user