Set kernel_size in diffusion_vocoder
This commit is contained in:
parent
30cd33fe44
commit
3e073cff85
|
@ -187,7 +187,6 @@ class ResBlock(TimestepBlock):
|
||||||
up=False,
|
up=False,
|
||||||
down=False,
|
down=False,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
@ -196,6 +195,7 @@ class ResBlock(TimestepBlock):
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
self.use_scale_shift_norm = use_scale_shift_norm
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
normalization(channels),
|
normalization(channels),
|
||||||
|
|
|
@ -61,6 +61,7 @@ class DiffusionVocoder(nn.Module):
|
||||||
use_scale_shift_norm=False,
|
use_scale_shift_norm=False,
|
||||||
resblock_updown=False,
|
resblock_updown=False,
|
||||||
use_new_attention_order=False,
|
use_new_attention_order=False,
|
||||||
|
kernel_size=5,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -82,6 +83,8 @@ class DiffusionVocoder(nn.Module):
|
||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
|
|
||||||
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * 4
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
linear(model_channels, time_embed_dim),
|
linear(model_channels, time_embed_dim),
|
||||||
|
@ -92,7 +95,7 @@ class DiffusionVocoder(nn.Module):
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -127,6 +130,7 @@ class DiffusionVocoder(nn.Module):
|
||||||
out_channels=mult * model_channels,
|
out_channels=mult * model_channels,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
kernel_size=kernel_size,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = mult * model_channels
|
||||||
|
@ -154,6 +158,7 @@ class DiffusionVocoder(nn.Module):
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
down=True,
|
down=True,
|
||||||
|
kernel_size=kernel_size,
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Downsample(
|
else Downsample(
|
||||||
|
@ -176,6 +181,7 @@ class DiffusionVocoder(nn.Module):
|
||||||
dropout,
|
dropout,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
kernel_size=kernel_size,
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
ch,
|
ch,
|
||||||
|
@ -189,6 +195,7 @@ class DiffusionVocoder(nn.Module):
|
||||||
dropout,
|
dropout,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
kernel_size=kernel_size,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
@ -205,6 +212,7 @@ class DiffusionVocoder(nn.Module):
|
||||||
out_channels=model_channels * mult,
|
out_channels=model_channels * mult,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
kernel_size=kernel_size,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = model_channels * mult
|
ch = model_channels * mult
|
||||||
|
@ -228,6 +236,7 @@ class DiffusionVocoder(nn.Module):
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
up=True,
|
up=True,
|
||||||
|
kernel_size=kernel_size,
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||||
|
@ -239,7 +248,7 @@ class DiffusionVocoder(nn.Module):
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(ch),
|
normalization(ch),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def convert_to_fp16(self):
|
def convert_to_fp16(self):
|
||||||
|
|
|
@ -383,8 +383,6 @@ class ExtensibleTrainer(BaseModel):
|
||||||
def load(self):
|
def load(self):
|
||||||
for netdict in [self.netsG, self.netsD]:
|
for netdict in [self.netsG, self.netsD]:
|
||||||
for name, net in netdict.items():
|
for name, net in netdict.items():
|
||||||
if not self.opt['networks'][name]['trainable']:
|
|
||||||
continue
|
|
||||||
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
||||||
if load_path is None:
|
if load_path is None:
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue
Block a user