More tweaks to diffusion-vocoder

This commit is contained in:
James Betker 2021-10-15 11:51:17 -06:00
parent 3b19581f9a
commit 1d0b44ebc2
2 changed files with 31 additions and 19 deletions

View File

@ -97,12 +97,15 @@ def normalization(channels):
:param channels: number of input channels. :param channels: number of input channels.
:return: an nn.Module for normalization. :return: an nn.Module for normalization.
""" """
groups = 32
if channels <= 16: if channels <= 16:
return GroupNorm32(8, channels) groups = 8
elif channels <= 64: elif channels <= 64:
return GroupNorm32(16, channels) groups = 16
else: while channels % groups != 0:
return GroupNorm32(32, channels) groups = int(groups / 2)
assert groups > 2
return GroupNorm32(groups, channels)
def timestep_embedding(timesteps, dim, max_period=10000): def timestep_embedding(timesteps, dim, max_period=10000):

View File

@ -14,9 +14,15 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
def __init__(self, discrete_codes, channels): def __init__(self, discrete_codes, channels):
super().__init__() super().__init__()
self.emb = nn.Embedding(discrete_codes, channels) self.emb = nn.Embedding(discrete_codes, channels)
self.norm = normalization(channels)
self.act = nn.SiLU()
self.intg = nn.Sequential(nn.Conv1d(channels*2, channels*2, kernel_size=1),
normalization(channels*2),
nn.SiLU(),
nn.Conv1d(channels*2, channels, kernel_size=3, padding=1))
""" """
Embeds the given codes and concatenates them onto x. Return shape: bx2cxS Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape.
:param x: bxcxS waveform latent :param x: bxcxS waveform latent
:param codes: bxN discrete codes, N <= S :param codes: bxN discrete codes, N <= S
@ -27,7 +33,9 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
assert N <= S assert N <= S
emb = self.emb(codes).permute(0,2,1) emb = self.emb(codes).permute(0,2,1)
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest') emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
return torch.cat([x, emb], dim=1) together = torch.cat([self.act(self.norm(x)), emb], dim=1)
together = self.intg(together)
return together + x
class DiffusionVocoderWithRef(nn.Module): class DiffusionVocoderWithRef(nn.Module):
@ -68,11 +76,13 @@ class DiffusionVocoderWithRef(nn.Module):
out_channels=2, # mean and variance out_channels=2, # mean and variance
discrete_codes=8192, discrete_codes=8192,
dropout=0, dropout=0,
# 38400 -> 19200 -> 9600 -> 4800 -> 2400 -> 1200 -> 600 -> 300 -> 150 for ~2secs@22050Hz # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult= (1, 1, 2, 2, 4, 6, 8, 12, 16, 24, 32, 48, 64), channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2), num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2),
spectrogram_conditioning_resolutions=(512,), # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
attention_resolutions=(512,1024,2048,4096), # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
spectrogram_conditioning_resolutions=(1,8,64,512),
attention_resolutions=(512,1024,2048),
conv_resample=True, conv_resample=True,
dims=1, dims=1,
use_fp16=False, use_fp16=False,
@ -136,7 +146,6 @@ class DiffusionVocoderWithRef(nn.Module):
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in spectrogram_conditioning_resolutions: if ds in spectrogram_conditioning_resolutions:
self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch)) self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch))
ch *= 2
for _ in range(num_blocks): for _ in range(num_blocks):
layers = [ layers = [
@ -144,13 +153,13 @@ class DiffusionVocoderWithRef(nn.Module):
ch, ch,
time_embed_dim, time_embed_dim,
dropout, dropout,
out_channels=mult * model_channels, out_channels=int(mult * model_channels),
dims=dims, dims=dims,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size, kernel_size=kernel_size,
) )
] ]
ch = mult * model_channels ch = int(mult * model_channels)
if ds in attention_resolutions: if ds in attention_resolutions:
layers.append( layers.append(
AttentionBlock( AttentionBlock(
@ -223,13 +232,13 @@ class DiffusionVocoderWithRef(nn.Module):
ch + ich, ch + ich,
time_embed_dim, time_embed_dim,
dropout, dropout,
out_channels=model_channels * mult, out_channels=int(model_channels * mult),
dims=dims, dims=dims,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size, kernel_size=kernel_size,
) )
] ]
ch = model_channels * mult ch = int(model_channels * mult)
if ds in attention_resolutions: if ds in attention_resolutions:
layers.append( layers.append(
AttentionBlock( AttentionBlock(
@ -326,9 +335,9 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
# Test for ~4 second audio clip at 22050Hz # Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__': if __name__ == '__main__':
clip = torch.randn(2, 1, 81920) clip = torch.randn(2, 1, 40960)
spec = torch.randint(8192, (2, 160,)) spec = torch.randint(8192, (2, 40,))
cond = torch.randn(2, 4, 80, 600) cond = torch.randn(2, 3, 80, 173)
ts = torch.LongTensor([555, 556]) ts = torch.LongTensor([555, 556])
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False) model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
print(model(clip, ts, spec, cond, 4).shape) print(model(clip, ts, spec, cond, 4).shape)