forked from mrq/DL-Art-School
Merge remote-tracking branch 'origin/gan_lab' into gan_lab
This commit is contained in:
commit
a63bf2ea2f
|
@ -71,6 +71,8 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||||
self.recurrent = recurrent
|
self.recurrent = recurrent
|
||||||
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||||
if recurrent:
|
if recurrent:
|
||||||
|
self.recurrent_nf = recurrent_nf
|
||||||
|
self.recurrent_stride = recurrent_stride
|
||||||
self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False)
|
self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False)
|
||||||
self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
|
self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False)
|
||||||
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||||
|
@ -86,7 +88,12 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||||
fea = self.initial_conv(x)
|
fea = self.initial_conv(x)
|
||||||
if self.recurrent:
|
if self.recurrent:
|
||||||
if recurrent is None:
|
if recurrent is None:
|
||||||
recurrent = torch.zeros_like(fea)
|
if self.recurrent_nf == 3:
|
||||||
|
recurrent = torch.zeros_like(x)
|
||||||
|
if self.recurrent_stride != 1:
|
||||||
|
recurrent = torch.nn.functional.interpolate(recurrent, scale_factor=self.recurrent_stride, mode='nearest')
|
||||||
|
else:
|
||||||
|
recurrent = torch.zeros_like(fea)
|
||||||
rec = self.recurrent_process(recurrent)
|
rec = self.recurrent_process(recurrent)
|
||||||
fea, recstd = self.recurrent_join(fea, rec)
|
fea, recstd = self.recurrent_join(fea, rec)
|
||||||
self.ref_join_std = recstd.item()
|
self.ref_join_std = recstd.item()
|
||||||
|
|
|
@ -170,6 +170,8 @@ def main():
|
||||||
else:
|
else:
|
||||||
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
|
if 'force_start_step' in opt.keys():
|
||||||
|
current_step = opt['force_start_step']
|
||||||
|
|
||||||
#### training
|
#### training
|
||||||
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user