forked from mrq/DL-Art-School
Set kernel_size in diffusion_vocoder
This commit is contained in:
parent
30cd33fe44
commit
3e073cff85
|
@ -187,7 +187,6 @@ class ResBlock(TimestepBlock):
|
|||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
@ -196,6 +195,7 @@ class ResBlock(TimestepBlock):
|
|||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
|
|
|
@ -61,6 +61,7 @@ class DiffusionVocoder(nn.Module):
|
|||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
kernel_size=5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -82,6 +83,8 @@ class DiffusionVocoder(nn.Module):
|
|||
self.num_heads_upsample = num_heads_upsample
|
||||
self.dims = dims
|
||||
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
|
@ -92,7 +95,7 @@ class DiffusionVocoder(nn.Module):
|
|||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -127,6 +130,7 @@ class DiffusionVocoder(nn.Module):
|
|||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
kernel_size=kernel_size,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
|
@ -154,6 +158,7 @@ class DiffusionVocoder(nn.Module):
|
|||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
kernel_size=kernel_size,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
|
@ -176,6 +181,7 @@ class DiffusionVocoder(nn.Module):
|
|||
dropout,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
kernel_size=kernel_size,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
|
@ -189,6 +195,7 @@ class DiffusionVocoder(nn.Module):
|
|||
dropout,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
kernel_size=kernel_size,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
@ -205,6 +212,7 @@ class DiffusionVocoder(nn.Module):
|
|||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
kernel_size=kernel_size,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
|
@ -228,6 +236,7 @@ class DiffusionVocoder(nn.Module):
|
|||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
kernel_size=kernel_size,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
|
@ -239,7 +248,7 @@ class DiffusionVocoder(nn.Module):
|
|||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
def convert_to_fp16(self):
|
||||
|
|
|
@ -383,8 +383,6 @@ class ExtensibleTrainer(BaseModel):
|
|||
def load(self):
|
||||
for netdict in [self.netsG, self.netsD]:
|
||||
for name, net in netdict.items():
|
||||
if not self.opt['networks'][name]['trainable']:
|
||||
continue
|
||||
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
||||
if load_path is None:
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue
Block a user