Set kernel_size in diffusion_vocoder

This commit is contained in:
James Betker 2021-09-01 08:33:46 -06:00
parent 30cd33fe44
commit 3e073cff85
3 changed files with 12 additions and 5 deletions

View File

@ -187,7 +187,6 @@ class ResBlock(TimestepBlock):
up=False, up=False,
down=False, down=False,
kernel_size=3, kernel_size=3,
padding=1,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -196,6 +195,7 @@ class ResBlock(TimestepBlock):
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm self.use_scale_shift_norm = use_scale_shift_norm
padding = 1 if kernel_size == 3 else 2
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
normalization(channels), normalization(channels),

View File

@ -61,6 +61,7 @@ class DiffusionVocoder(nn.Module):
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
use_new_attention_order=False, use_new_attention_order=False,
kernel_size=5,
): ):
super().__init__() super().__init__()
@ -82,6 +83,8 @@ class DiffusionVocoder(nn.Module):
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
self.dims = dims self.dims = dims
padding = 1 if kernel_size == 3 else 2
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim), linear(model_channels, time_embed_dim),
@ -92,7 +95,7 @@ class DiffusionVocoder(nn.Module):
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
TimestepEmbedSequential( TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1) conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
) )
] ]
) )
@ -127,6 +130,7 @@ class DiffusionVocoder(nn.Module):
out_channels=mult * model_channels, out_channels=mult * model_channels,
dims=dims, dims=dims,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
) )
] ]
ch = mult * model_channels ch = mult * model_channels
@ -154,6 +158,7 @@ class DiffusionVocoder(nn.Module):
dims=dims, dims=dims,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
down=True, down=True,
kernel_size=kernel_size,
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(
@ -176,6 +181,7 @@ class DiffusionVocoder(nn.Module):
dropout, dropout,
dims=dims, dims=dims,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
), ),
AttentionBlock( AttentionBlock(
ch, ch,
@ -189,6 +195,7 @@ class DiffusionVocoder(nn.Module):
dropout, dropout,
dims=dims, dims=dims,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
), ),
) )
self._feature_size += ch self._feature_size += ch
@ -205,6 +212,7 @@ class DiffusionVocoder(nn.Module):
out_channels=model_channels * mult, out_channels=model_channels * mult,
dims=dims, dims=dims,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
) )
] ]
ch = model_channels * mult ch = model_channels * mult
@ -228,6 +236,7 @@ class DiffusionVocoder(nn.Module):
dims=dims, dims=dims,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
up=True, up=True,
kernel_size=kernel_size,
) )
if resblock_updown if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
@ -239,7 +248,7 @@ class DiffusionVocoder(nn.Module):
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
) )
def convert_to_fp16(self): def convert_to_fp16(self):

View File

@ -383,8 +383,6 @@ class ExtensibleTrainer(BaseModel):
def load(self): def load(self):
for netdict in [self.netsG, self.netsD]: for netdict in [self.netsG, self.netsD]:
for name, net in netdict.items(): for name, net in netdict.items():
if not self.opt['networks'][name]['trainable']:
continue
load_path = self.opt['path']['pretrain_model_%s' % (name,)] load_path = self.opt['path']['pretrain_model_%s' % (name,)]
if load_path is None: if load_path is None:
return return