More tweaks to diffusion-vocoder
This commit is contained in:
parent
3b19581f9a
commit
1d0b44ebc2
|
@ -97,12 +97,15 @@ def normalization(channels):
|
||||||
:param channels: number of input channels.
|
:param channels: number of input channels.
|
||||||
:return: an nn.Module for normalization.
|
:return: an nn.Module for normalization.
|
||||||
"""
|
"""
|
||||||
|
groups = 32
|
||||||
if channels <= 16:
|
if channels <= 16:
|
||||||
return GroupNorm32(8, channels)
|
groups = 8
|
||||||
elif channels <= 64:
|
elif channels <= 64:
|
||||||
return GroupNorm32(16, channels)
|
groups = 16
|
||||||
else:
|
while channels % groups != 0:
|
||||||
return GroupNorm32(32, channels)
|
groups = int(groups / 2)
|
||||||
|
assert groups > 2
|
||||||
|
return GroupNorm32(groups, channels)
|
||||||
|
|
||||||
|
|
||||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||||
|
|
|
@ -14,9 +14,15 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
|
||||||
def __init__(self, discrete_codes, channels):
|
def __init__(self, discrete_codes, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(discrete_codes, channels)
|
self.emb = nn.Embedding(discrete_codes, channels)
|
||||||
|
self.norm = normalization(channels)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.intg = nn.Sequential(nn.Conv1d(channels*2, channels*2, kernel_size=1),
|
||||||
|
normalization(channels*2),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Conv1d(channels*2, channels, kernel_size=3, padding=1))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Embeds the given codes and concatenates them onto x. Return shape: bx2cxS
|
Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape.
|
||||||
|
|
||||||
:param x: bxcxS waveform latent
|
:param x: bxcxS waveform latent
|
||||||
:param codes: bxN discrete codes, N <= S
|
:param codes: bxN discrete codes, N <= S
|
||||||
|
@ -27,7 +33,9 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
|
||||||
assert N <= S
|
assert N <= S
|
||||||
emb = self.emb(codes).permute(0,2,1)
|
emb = self.emb(codes).permute(0,2,1)
|
||||||
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
|
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
|
||||||
return torch.cat([x, emb], dim=1)
|
together = torch.cat([self.act(self.norm(x)), emb], dim=1)
|
||||||
|
together = self.intg(together)
|
||||||
|
return together + x
|
||||||
|
|
||||||
|
|
||||||
class DiffusionVocoderWithRef(nn.Module):
|
class DiffusionVocoderWithRef(nn.Module):
|
||||||
|
@ -68,11 +76,13 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
out_channels=2, # mean and variance
|
out_channels=2, # mean and variance
|
||||||
discrete_codes=8192,
|
discrete_codes=8192,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
# 38400 -> 19200 -> 9600 -> 4800 -> 2400 -> 1200 -> 600 -> 300 -> 150 for ~2secs@22050Hz
|
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
||||||
channel_mult= (1, 1, 2, 2, 4, 6, 8, 12, 16, 24, 32, 48, 64),
|
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
||||||
num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2),
|
||||||
spectrogram_conditioning_resolutions=(512,),
|
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
|
||||||
attention_resolutions=(512,1024,2048,4096),
|
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
|
||||||
|
spectrogram_conditioning_resolutions=(1,8,64,512),
|
||||||
|
attention_resolutions=(512,1024,2048),
|
||||||
conv_resample=True,
|
conv_resample=True,
|
||||||
dims=1,
|
dims=1,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
|
@ -136,7 +146,6 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
||||||
if ds in spectrogram_conditioning_resolutions:
|
if ds in spectrogram_conditioning_resolutions:
|
||||||
self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch))
|
self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch))
|
||||||
ch *= 2
|
|
||||||
|
|
||||||
for _ in range(num_blocks):
|
for _ in range(num_blocks):
|
||||||
layers = [
|
layers = [
|
||||||
|
@ -144,13 +153,13 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
ch,
|
ch,
|
||||||
time_embed_dim,
|
time_embed_dim,
|
||||||
dropout,
|
dropout,
|
||||||
out_channels=mult * model_channels,
|
out_channels=int(mult * model_channels),
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = int(mult * model_channels)
|
||||||
if ds in attention_resolutions:
|
if ds in attention_resolutions:
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
|
@ -223,13 +232,13 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
ch + ich,
|
ch + ich,
|
||||||
time_embed_dim,
|
time_embed_dim,
|
||||||
dropout,
|
dropout,
|
||||||
out_channels=model_channels * mult,
|
out_channels=int(model_channels * mult),
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = model_channels * mult
|
ch = int(model_channels * mult)
|
||||||
if ds in attention_resolutions:
|
if ds in attention_resolutions:
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
|
@ -326,9 +335,9 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
|
||||||
|
|
||||||
# Test for ~4 second audio clip at 22050Hz
|
# Test for ~4 second audio clip at 22050Hz
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 1, 81920)
|
clip = torch.randn(2, 1, 40960)
|
||||||
spec = torch.randint(8192, (2, 160,))
|
spec = torch.randint(8192, (2, 40,))
|
||||||
cond = torch.randn(2, 4, 80, 600)
|
cond = torch.randn(2, 3, 80, 173)
|
||||||
ts = torch.LongTensor([555, 556])
|
ts = torch.LongTensor([555, 556])
|
||||||
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
|
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
|
||||||
print(model(clip, ts, spec, cond, 4).shape)
|
print(model(clip, ts, spec, cond, 4).shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user