From ae5f934ea13652859c92ffb00ac4f6491e7878a3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 2 May 2022 00:05:04 -0600 Subject: [PATCH] diffwave --- codes/models/audio/music/diffwave.py | 177 +++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 codes/models/audio/music/diffwave.py diff --git a/codes/models/audio/music/diffwave.py b/codes/models/audio/music/diffwave.py new file mode 100644 index 00000000..348721bb --- /dev/null +++ b/codes/models/audio/music/diffwave.py @@ -0,0 +1,177 @@ +# Copyright 2020 LMNT, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from math import sqrt + +from trainer.networks import register_model + +Linear = nn.Linear +ConvTranspose2d = nn.ConvTranspose2d + + +def Conv1d(*args, **kwargs): + layer = nn.Conv1d(*args, **kwargs) + nn.init.kaiming_normal_(layer.weight) + return layer + + +@torch.jit.script +def silu(x): + return x * torch.sigmoid(x) + + +class DiffusionEmbedding(nn.Module): + def __init__(self, max_steps): + super().__init__() + self.register_buffer('embedding', self._build_embedding(max_steps), persistent=False) + self.projection1 = Linear(128, 512) + self.projection2 = Linear(512, 512) + + def forward(self, diffusion_step): + if diffusion_step.dtype in [torch.int32, torch.int64]: + x = self.embedding[diffusion_step] + else: + x = self._lerp_embedding(diffusion_step) + x = self.projection1(x) + x = silu(x) + x = self.projection2(x) + x = silu(x) + return x + + def _lerp_embedding(self, t): + low_idx = torch.floor(t).long() + high_idx = torch.ceil(t).long() + low = self.embedding[low_idx] + high = self.embedding[high_idx] + return low + (high - low) * (t - low_idx) + + def _build_embedding(self, max_steps): + steps = torch.arange(max_steps).unsqueeze(1) # [T,1] + dims = torch.arange(64).unsqueeze(0) # [1,64] + table = steps * 10.0 ** (dims * 4.0 / 63.0) # [T,64] + table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) + return table + + +class SpectrogramUpsampler(nn.Module): + def __init__(self, n_mels): + super().__init__() + self.conv1 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8]) + self.conv2 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8]) + + def forward(self, x): + x = torch.unsqueeze(x, 1) + x = self.conv1(x) + x = F.leaky_relu(x, 0.4) + x = self.conv2(x) + x = F.leaky_relu(x, 0.4) + x = torch.squeeze(x, 1) + return x + + +class ResidualBlock(nn.Module): + def __init__(self, n_mels, residual_channels, dilation, uncond=False): + ''' + :param n_mels: inplanes of conv1x1 for spectrogram conditional + :param residual_channels: audio conv + :param dilation: audio conv dilation + :param uncond: disable spectrogram conditional + ''' + super().__init__() + self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) + self.diffusion_projection = Linear(512, residual_channels) + if not uncond: # conditional model + self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1) + else: # unconditional model + self.conditioner_projection = None + + self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, diffusion_step, conditioner=None): + assert (conditioner is None and self.conditioner_projection is None) or \ + (conditioner is not None and self.conditioner_projection is not None) + + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + y = x + diffusion_step + if self.conditioner_projection is None: # using a unconditional model + y = self.dilated_conv(y) + else: + y = self.dilated_conv(y) + conditioner = self.conditioner_projection(conditioner) + conditioner = F.interpolate(conditioner, size=y.shape[-1], mode='nearest') + y = y + conditioner + + gate, filter = torch.chunk(y, 2, dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) + + y = self.output_projection(y) + residual, skip = torch.chunk(y, 2, dim=1) + return (x + residual) / sqrt(2.0), skip + + +class DiffWave(nn.Module): + def __init__(self, residual_layers=30, residual_channels=64, num_timesteps=4000, n_mels=128, + dilation_cycle_length=10, unconditional=False): + super().__init__() + self.input_projection = Conv1d(1, residual_channels, 1) + self.diffusion_embedding = DiffusionEmbedding(num_timesteps) + if unconditional: # use unconditional model + self.spectrogram_upsampler = None + else: + self.spectrogram_upsampler = SpectrogramUpsampler(n_mels) + + self.residual_layers = nn.ModuleList([ + ResidualBlock(n_mels, residual_channels, 2 ** (i % dilation_cycle_length), uncond=unconditional) + for i in range(residual_layers) + ]) + self.skip_projection = Conv1d(residual_channels, residual_channels, 1) + self.output_projection = Conv1d(residual_channels, 1, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, x, timesteps, spectrogram=None): + assert (spectrogram is None and self.spectrogram_upsampler is None) or \ + (spectrogram is not None and self.spectrogram_upsampler is not None) + x = self.input_projection(x) + x = F.relu(x) + + timesteps = self.diffusion_embedding(timesteps) + if self.spectrogram_upsampler: # use conditional model + spectrogram = self.spectrogram_upsampler(spectrogram) + + skip = None + for layer in self.residual_layers: + x, skip_connection = layer(x, timesteps, spectrogram) + skip = skip_connection if skip is None else skip_connection + skip + + x = skip / sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) + return x + + +@register_model +def register_diffwave(opt_net, opt): + return DiffWave(**opt_net['kwargs']) + + +if __name__ == '__main__': + model = DiffWave() + model(torch.randn(2,1,20000), torch.tensor([500,3999]), torch.randn(2,128,78)) \ No newline at end of file