# Copyright 2020 LMNT, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from math import sqrt from torch.utils.checkpoint import checkpoint from trainer.networks import register_model Linear = nn.Linear ConvTranspose2d = nn.ConvTranspose2d def Conv1d(*args, **kwargs): layer = nn.Conv1d(*args, **kwargs) nn.init.kaiming_normal_(layer.weight) return layer @torch.jit.script def silu(x): return x * torch.sigmoid(x) class DiffusionEmbedding(nn.Module): def __init__(self, max_steps): super().__init__() self.register_buffer('embedding', self._build_embedding(max_steps), persistent=False) self.projection1 = Linear(128, 512) self.projection2 = Linear(512, 512) def forward(self, diffusion_step): if diffusion_step.dtype in [torch.int32, torch.int64]: x = self.embedding[diffusion_step] else: x = self._lerp_embedding(diffusion_step) x = self.projection1(x) x = silu(x) x = self.projection2(x) x = silu(x) return x def _lerp_embedding(self, t): low_idx = torch.floor(t).long() high_idx = torch.ceil(t).long() low = self.embedding[low_idx] high = self.embedding[high_idx] return low + (high - low) * (t - low_idx) def _build_embedding(self, max_steps): steps = torch.arange(max_steps).unsqueeze(1) # [T,1] dims = torch.arange(64).unsqueeze(0) # [1,64] table = steps * 10.0 ** (dims * 4.0 / 63.0) # [T,64] table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) return table class SpectrogramUpsampler(nn.Module): def __init__(self, n_mels): super().__init__() self.conv1 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8]) self.conv2 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8]) def forward(self, x): x = torch.unsqueeze(x, 1) x = self.conv1(x) x = F.leaky_relu(x, 0.4) x = self.conv2(x) x = F.leaky_relu(x, 0.4) x = torch.squeeze(x, 1) return x class ResidualBlock(nn.Module): def __init__(self, n_mels, residual_channels, dilation, uncond=False): ''' :param n_mels: inplanes of conv1x1 for spectrogram conditional :param residual_channels: audio conv :param dilation: audio conv dilation :param uncond: disable spectrogram conditional ''' super().__init__() self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) self.diffusion_projection = Linear(512, residual_channels) if not uncond: # conditional model self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1) else: # unconditional model self.conditioner_projection = None self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) def forward(self, x, diffusion_step, conditioner=None): assert (conditioner is None and self.conditioner_projection is None) or \ (conditioner is not None and self.conditioner_projection is not None) diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) y = x + diffusion_step if self.conditioner_projection is None: # using a unconditional model y = self.dilated_conv(y) else: y = self.dilated_conv(y) conditioner = self.conditioner_projection(conditioner) conditioner = F.interpolate(conditioner, size=y.shape[-1], mode='nearest') y = y + conditioner gate, filter = torch.chunk(y, 2, dim=1) y = torch.sigmoid(gate) * torch.tanh(filter) y = self.output_projection(y) residual, skip = torch.chunk(y, 2, dim=1) return (x + residual) / sqrt(2.0), skip class DiffWave(nn.Module): def __init__(self, residual_layers=30, residual_channels=64, num_timesteps=4000, n_mels=128, dilation_cycle_length=10, unconditional=False): super().__init__() self.input_projection = Conv1d(1, residual_channels, 1) self.diffusion_embedding = DiffusionEmbedding(num_timesteps) if unconditional: # use unconditional model self.spectrogram_upsampler = None else: self.spectrogram_upsampler = SpectrogramUpsampler(n_mels) self.residual_layers = nn.ModuleList([ ResidualBlock(n_mels, residual_channels, 2 ** (i % dilation_cycle_length), uncond=unconditional) for i in range(residual_layers) ]) self.skip_projection = Conv1d(residual_channels, residual_channels, 1) self.output_projection = Conv1d(residual_channels, 2, 1) nn.init.zeros_(self.output_projection.weight) def forward(self, x, timesteps, spectrogram=None): assert (spectrogram is None and self.spectrogram_upsampler is None) or \ (spectrogram is not None and self.spectrogram_upsampler is not None) x = self.input_projection(x) x = F.relu(x) timesteps = checkpoint(self.diffusion_embedding, timesteps) if self.spectrogram_upsampler: # use conditional model spectrogram = checkpoint(self.spectrogram_upsampler, spectrogram) skip = None for layer in self.residual_layers: x, skip_connection = checkpoint(layer, x, timesteps, spectrogram) skip = skip_connection if skip is None else skip_connection + skip x = skip / sqrt(len(self.residual_layers)) x = self.skip_projection(x) x = F.relu(x) x = self.output_projection(x) return x @register_model def register_diffwave(opt_net, opt): return DiffWave(**opt_net['kwargs']) if __name__ == '__main__': model = DiffWave() model(torch.randn(2,1,65536), torch.tensor([500,3999]), torch.randn(2,128,256))