# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # modified grossly to avoid additional dependencies CONSTANT = 1e-5 import librosa import itertools import random from math import ceil from pathlib import Path from abc import ABC, abstractmethod from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import torchaudio from einops import rearrange from nemo.core import ModelPT from nemo.core.classes.module import NeuralModule from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types.neural_type import NeuralType from nemo.collections.common.parts.utils import ClampActivation, HalfSnake, Snake, mask_sequence_tensor from nemo.utils import model_utils from nemo.utils.decorators import experimental from nemo.core.neural_types.elements import ( AudioSignal, EncodedRepresentation, Index, LengthsType, MelSpectrogramType, VoidType, TokenIndex, ) from nemo.core.classes import Loss def instantiate( cfg ): cls = None cfg = dict(cfg) target = cfg.pop("_target_") if target == "nemo.collections.tts.modules.audio_codec_modules.HiFiGANEncoder": cls = HiFiGANEncoder elif target == "nemo.collections.tts.modules.audio_codec_modules.GroupFiniteScalarQuantizer": cls = GroupFiniteScalarQuantizer elif target == "nemo.collections.tts.modules.audio_codec_modules.HiFiGANDecoder": cls = HiFiGANDecoder elif target == "nemo.collections.tts.modules.audio_codec_modules.Discriminator": cls = Discriminator # cheat here cfg['discriminators'] = [ instantiate( c ) for c in cfg['discriminators'] ] # {'discriminators': [{'_target_': 'nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator'}, {'_target_': 'nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT', 'resolutions': [[512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]], 'stft_bands': [[0.0, 0.1], [0.1, 0.25], [0.25, 0.5], [0.5, 0.75], [0.75, 1.0]]}]} elif target == "nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator": cls = MultiPeriodDiscriminator elif target == "nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT": cls = MultiResolutionDiscriminatorSTFT elif target == "nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss": cls = GeneratorSquaredLoss elif target == "nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss": cls = DiscriminatorSquaredLoss else: print( target, cfg ) raise Exception("!") return cls( **cfg ) class GaussianDropout(torch.nn.Module): """ Gaussian dropout using multiplicative gaussian noise. https://keras.io/api/layers/regularization_layers/gaussian_dropout/ Can be an effective alternative bottleneck to VAE or VQ: https://www.deepmind.com/publications/gaussian-dropout-as-an-information-bottleneck-layer Unlike some other implementations, this takes the standard deviation of the noise as input instead of the 'rate' typically defined as: stdev = sqrt(rate / (1 - rate)) """ def __init__(self, stdev=1.0): super(GaussianDropout, self).__init__() self.stdev = stdev def forward(self, inputs): if not self.training: return inputs noise = torch.normal(mean=1.0, std=self.stdev, size=inputs.shape, device=inputs.device) out = noise * inputs return out def get_padding(kernel_size: int, dilation: int = 1) -> int: return (kernel_size * dilation - dilation) // 2 def get_padding_2d(kernel_size: Tuple[int, int], dilation: Tuple[int, int]) -> Tuple[int, int]: paddings = (get_padding(kernel_size[0], dilation[0]), get_padding(kernel_size[1], dilation[1])) return paddings def get_down_sample_padding(kernel_size: int, stride: int) -> int: return (kernel_size - stride + 1) // 2 def get_up_sample_padding(kernel_size: int, stride: int) -> Tuple[int, int]: output_padding = (kernel_size - stride) % 2 padding = (kernel_size - stride + 1) // 2 return padding, output_padding class SSLModel(NeuralModule): def __init__(self, slm_model_name): super().__init__() self.ssl_model = AutoModel.from_pretrained(slm_model_name) def forward(self, *args, **kwargs): return self.ssl_model(*args, **kwargs) class SLMDiscriminator(NeuralModule): """SLM Discriminator, as described in both the StyleTTS2 and Low Frame-Rate Speech Codec papers. Args: slm_model_name: Hugging Face Speech Language Models name. slm_sr: Speech Language Models input sampling rate. input_sr: Audio input sampling rate. slm_hidden: Speech Language Model hidden dim. slm_layers: Speech Language Model number of layers. initial_channel: discriminative head number of channels. use_spectral_norm: If True uses spectral normalization otherwise uses weight norm. """ def __init__( self, slm_model_name="microsoft/wavlm-base-plus", slm_sr=16000, input_sr=22050, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False, ): super().__init__() self.slm_model = SSLModel(slm_model_name) # Freeze slm model self.slm_model.freeze() self.resample = torchaudio.transforms.Resample(input_sr, slm_sr) norm_f = torch.nn.utils.weight_norm if use_spectral_norm == False else torch.nn.utils.spectral_norm self.pre = norm_f(nn.Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)) self.convs = nn.ModuleList( [ norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)), norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)), norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)), ] ) self.conv_post = norm_f(nn.Conv1d(initial_channel * 4, 1, 3, 1, padding=1)) def _forward(self, x): x = self.slm_model(input_values=self.resample(x), output_hidden_states=True).hidden_states x = torch.stack(x, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2) x = self.pre(x) fmap = [] for l in self.convs: x = l(x) x = F.leaky_relu(x, 0.1) fmap.append(x.unsqueeze(-1)) x = self.conv_post(x) x = torch.flatten(x, 1, -1) return x, fmap @property def input_types(self): return { "audio_real": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()), } @property def output_types(self): return { "scores_real": [NeuralType(('B', 'C', 'T_out'), VoidType())], "scores_gen": [NeuralType(('B', 'C', 'T_out'), VoidType())], "fmaps_real": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]], "fmaps_gen": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]], } @typecheck() def forward(self, audio_real, audio_gen): y_d_r, fmap_r = self._forward(audio_real) y_d_g, fmap_g = self._forward(audio_gen) return [y_d_r.unsqueeze(1)], [y_d_g.unsqueeze(1)], [fmap_r], [fmap_g] class CodecActivation(nn.Module): """ Choose between activation based on the input parameter. Args: activation: Name of activation to use. Valid options are "elu" (default), "lrelu", and "snake". channels: Input dimension. """ def __init__(self, activation: str = "elu", channels: int = 1): super().__init__() activation = activation.lower() if activation == "elu": self.activation = nn.ELU() elif activation == "lrelu": self.activation = torch.nn.LeakyReLU() elif activation == "snake": self.activation = Snake(channels) elif activation == "half_snake": self.activation = HalfSnake(channels) else: raise ValueError(f"Unknown activation {activation}") def forward(self, x): return self.activation(x) class Conv1dNorm(NeuralModule): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, padding: Optional[int] = None, ): super().__init__() if not padding: padding = get_padding(kernel_size=kernel_size, dilation=dilation) conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, padding_mode="reflect", ) self.conv = nn.utils.weight_norm(conv) @property def input_types(self): return { "inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { "out": NeuralType(('B', 'C', 'T'), VoidType()), } def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv) @typecheck() def forward(self, inputs, input_len): out = self.conv(inputs) out = mask_sequence_tensor(out, input_len) return out class ConvTranspose1dNorm(NeuralModule): def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1): super().__init__() padding, output_padding = get_up_sample_padding(kernel_size, stride) conv = nn.ConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, padding_mode="zeros", ) self.conv = nn.utils.weight_norm(conv) @property def input_types(self): return { "inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { "out": NeuralType(('B', 'C', 'T'), VoidType()), } def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv) @typecheck() def forward(self, inputs, input_len): out = self.conv(inputs) out = mask_sequence_tensor(out, input_len) return out class Conv2dNorm(NeuralModule): def __init__( self, in_channels: int, out_channels: int, kernel_size: Tuple[int, int], stride: Tuple[int, int] = (1, 1), dilation: Tuple[int, int] = (1, 1), ): super().__init__() assert len(kernel_size) == len(dilation) padding = get_padding_2d(kernel_size, dilation) conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, padding_mode="reflect", ) self.conv = nn.utils.weight_norm(conv) @property def input_types(self): return { "inputs": NeuralType(('B', 'C', 'H', 'T'), VoidType()), } @property def output_types(self): return { "out": NeuralType(('B', 'C', 'H', 'T'), VoidType()), } def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv) @typecheck() def forward(self, inputs): return self.conv(inputs) class PeriodDiscriminator(NeuralModule): """ Period discriminator introduced in HiFi-GAN https://arxiv.org/abs/2010.05646 which attempts to discriminate phase information by looking at equally spaced audio samples. Args: period: Spacing between audio sample inputs. lrelu_slope: Slope to use for activation. Leaky relu with slope of 0.1 or 0.2 is recommended for the stability of the feature matching loss. """ def __init__(self, period, lrelu_slope=0.1): super().__init__() self.period = period self.activation = nn.LeakyReLU(lrelu_slope) self.conv_layers = nn.ModuleList( [ Conv2dNorm(1, 32, kernel_size=(5, 1), stride=(3, 1)), Conv2dNorm(32, 128, kernel_size=(5, 1), stride=(3, 1)), Conv2dNorm(128, 512, kernel_size=(5, 1), stride=(3, 1)), Conv2dNorm(512, 1024, kernel_size=(5, 1), stride=(3, 1)), Conv2dNorm(1024, 1024, kernel_size=(5, 1), stride=(1, 1)), ] ) self.conv_post = Conv2dNorm(1024, 1, kernel_size=(3, 1)) @property def input_types(self): return { "audio": NeuralType(('B', 'T_audio'), AudioSignal()), } @property def output_types(self): return { "score": NeuralType(('B', 'C', 'T_out'), VoidType()), "fmap": [NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())], } @typecheck() def forward(self, audio): batch_size, time = audio.shape out = rearrange(audio, 'B T -> B 1 T') # Pad audio so that it is divisible by the period if time % self.period != 0: n_pad = self.period - (time % self.period) out = F.pad(out, (0, n_pad), "reflect") time = time + n_pad # [batch, 1, (time / period), period] out = out.view(batch_size, 1, time // self.period, self.period) fmap = [] for conv in self.conv_layers: # [batch, filters, (time / period / stride), period] out = conv(inputs=out) out = self.activation(out) fmap.append(out) # [batch, 1, (time / period / strides), period] score = self.conv_post(inputs=out) fmap.append(score) score = rearrange(score, "B 1 T C -> B C T") return score, fmap class MultiPeriodDiscriminator(NeuralModule): """ Wrapper class to aggregate results of multiple period discriminators. The periods are expected to be increasing prime numbers in order to maximize coverage and minimize overlap """ def __init__(self, periods: Iterable[int] = (2, 3, 5, 7, 11), lrelu_slope=0.1): super().__init__() self.discriminators = nn.ModuleList( [PeriodDiscriminator(period=period, lrelu_slope=lrelu_slope) for period in periods] ) @property def input_types(self): return { "audio_real": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()), } @property def output_types(self): return { "scores_real": [NeuralType(('B', 'C', 'T_out'), VoidType())], "scores_gen": [NeuralType(('B', 'C', 'T_out'), VoidType())], "fmaps_real": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]], "fmaps_gen": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]], } @typecheck() def forward(self, audio_real, audio_gen): scores_real = [] scores_gen = [] fmaps_real = [] fmaps_gen = [] for discriminator in self.discriminators: score_real, fmap_real = discriminator(audio=audio_real) score_gen, fmap_gen = discriminator(audio=audio_gen) scores_real.append(score_real) fmaps_real.append(fmap_real) scores_gen.append(score_gen) fmaps_gen.append(fmap_gen) return scores_real, scores_gen, fmaps_real, fmaps_gen class DiscriminatorSTFT(NeuralModule): """ Discriminator network from EnCodec for Complex STFT input, but without dilations. Args: filters: number of filters to use in Conv2d layers lrelu_slope: Slope to use for activations. Leaky relu with slope of 0.1 or 0.2 is recommended for the stability of the feature matching loss """ def __init__(self, filters: int = 32, lrelu_slope: float = 0.1): super().__init__() self.activation = nn.LeakyReLU(lrelu_slope) self.conv_layers = nn.ModuleList( [ Conv2dNorm(2, filters, kernel_size=(3, 9)), Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), Conv2dNorm(filters, filters, kernel_size=(3, 3)), ] ) self.conv_post = Conv2dNorm(filters, 1, kernel_size=(3, 3)) @property def input_types(self): return { "spec": NeuralType(('B', 'C', 'T_spec', 'D'), VoidType()), } @property def output_types(self): return { "scores": NeuralType(('B', 'C', 'T_spec'), VoidType()), "fmap": [NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())], } @typecheck() def forward(self, spec): fmap = [] # [batch, 2, T_spec, fft] out = spec for conv in self.conv_layers: # [batch, filters, T_spec, fft // strides] out = conv(inputs=out) out = self.activation(out) fmap.append(out) # [batch, 1, T_spec, fft // 8] scores = self.conv_post(inputs=out) fmap.append(scores) scores = rearrange(scores, "B 1 T C -> B C T") return scores, fmap class MultiBandDiscriminatorSTFT(NeuralModule): """ Multi-band STFT discriminator proposed in DAC (https://arxiv.org/abs/2306.06546). Computes the complex STFT for a given resolution and splits it into sub-bands, which are given to separate discriminator networks. Args: resolution: STFT resolution, provided as a tuple of 3 integers ordered (num_fft, hop_length, window_length) stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end). The floats are in the range [0, 1] representing the fraction of all stft bands. For example for n_fft=1024, the stft output has 513 dimensions. For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512]. """ def __init__(self, resolution: Tuple[int], stft_bands: Iterable[Tuple[int]]): super().__init__() self.n_fft, self.hop_length, self.win_length = resolution self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) self.discriminators = nn.ModuleList([DiscriminatorSTFT() for _ in stft_bands]) n_stft = self.n_fft // 2 + 1 self.stft_bands = [(int(band[0] * n_stft), int(band[1] * n_stft)) for band in stft_bands] def compute_stft(self, audio): # [B, fft, T_spec] fft = torch.stft( audio, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, normalized=True, center=True, return_complex=True, ) fft = rearrange(fft, "B fft T -> B T fft") # [batch, 2, T_spec, fft] out = torch.stack([fft.real, fft.imag], dim=1) return out @property def input_types(self): return { "audio": NeuralType(('B', 'T_audio'), AudioSignal()), } @property def output_types(self): return { "scores_list": [NeuralType(('B', 'C', 'T_spec'), VoidType())], "fmaps_list": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], } @typecheck() def forward(self, audio): scores_list = [] fmap_list = [] spec = self.compute_stft(audio) for band, disc in zip(self.stft_bands, self.discriminators): spec_band = spec[:, :, :, band[0] : band[1]] score, fmap = disc(spec=spec_band) scores_list.append(score) fmap_list.append(fmap) return scores_list, fmap_list class MultiResolutionDiscriminatorSTFT(NeuralModule): """ Multi-resolution discriminator which creates a multi-band discriminator for each input resolution. Args: resolutions: List of STFT resolutions, each resolution provided as a tuple of 3 integers ordered (num_fft, hop_length, window_length) stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end). The floats are in the range [0, 1] representing the fraction of all stft bands. For example for n_fft=1024, the stft output has 513 dimensions. For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512]. """ def __init__(self, resolutions: Iterable[Tuple[int]], stft_bands: Iterable[Tuple[int]]): super().__init__() self.discriminators = nn.ModuleList( [MultiBandDiscriminatorSTFT(resolution=resolution, stft_bands=stft_bands) for resolution in resolutions] ) @property def input_types(self): return { "audio_real": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()), } @property def output_types(self): return { "scores_real": [NeuralType(('B', 'C', 'T_spec'), VoidType())], "scores_gen": [NeuralType(('B', 'C', 'T_spec'), VoidType())], "fmaps_real": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], "fmaps_gen": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], } @typecheck() def forward(self, audio_real, audio_gen): scores_real = [] scores_gen = [] fmaps_real = [] fmaps_gen = [] for disc in self.discriminators: score_real_i, fmap_real_i = disc(audio=audio_real) scores_real = scores_real + score_real_i fmaps_real = fmaps_real + fmap_real_i score_gen_i, fmap_gen_i = disc(audio=audio_gen) scores_gen = scores_gen + score_gen_i fmaps_gen = fmaps_gen + fmap_gen_i return scores_real, scores_gen, fmaps_real, fmaps_gen class Discriminator(NeuralModule): """ Wrapper class which takes a list of discriminators and aggregates the results across them. """ def __init__(self, discriminators: Iterable[NeuralModule]): super().__init__() self.discriminators = nn.ModuleList(discriminators) @property def input_types(self): return { "audio_real": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()), } @property def output_types(self): return { "scores_real": [NeuralType(('B', 'C', 'T_out'), VoidType())], "scores_gen": [NeuralType(('B', 'C', 'T_out'), VoidType())], "fmaps_real": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]], "fmaps_gen": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]], } @typecheck() def forward(self, audio_real, audio_gen): scores_real = [] scores_gen = [] fmaps_real = [] fmaps_gen = [] for discriminator in self.discriminators: score_real, score_gen, fmap_real, fmap_gen = discriminator(audio_real=audio_real, audio_gen=audio_gen) scores_real += score_real fmaps_real += fmap_real scores_gen += score_gen fmaps_gen += fmap_gen return scores_real, scores_gen, fmaps_real, fmaps_gen class VectorQuantizerBase(NeuralModule, ABC): @property def input_types(self): return { "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "indices": NeuralType(('D', 'B', 'T'), Index()), } @typecheck() @abstractmethod def forward(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: pass @typecheck( input_types={ "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), }, output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, ) @abstractmethod def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: pass @typecheck( input_types={ "indices": NeuralType(('D', 'B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), }, ) @abstractmethod def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: pass class FiniteScalarQuantizer(VectorQuantizerBase): """This quantizer is based on the Finite Scalar Quantization (FSQ) method. It quantizes each element of the input vector independently into a number of levels. Args: num_levels: number of levels for each dimension/element of the input vector eps: small regularization constant for scaling References: Mentzer et al., Finite Scalar Quantization: VQ-VAE Made Simple (https://arxiv.org/abs/2309.15505v1) """ def __init__(self, num_levels: List[int], eps: float = 1e-3): super().__init__() # index base per dimension of the input vector # this is used to convert between per-dimension indices and a codebook token index dim_base_index = torch.cumprod(torch.tensor([1] + num_levels[:-1]), dim=0, dtype=torch.int32) dim_base_index = rearrange(dim_base_index, 'D -> 1 D 1') self.register_buffer('dim_base_index', dim_base_index) # Register the number of levels for each dimension num_levels = torch.tensor(num_levels, dtype=torch.int32) num_levels = rearrange(num_levels, 'D -> 1 D 1') self.register_buffer('num_levels', num_levels) # Regularization self.eps = eps @property def codebook_size(self): """Returns the size of the corresponding codebook.""" return self.num_levels.prod().item() @property def dim(self): """Returns the dimension of the input vector.""" return self.num_levels.numel() @property def codebook_dim(self): """Returns the dimension of the input vector. Keeping for compatiblitiy with the original RVQ implementation. """ return self.dim @property def codes(self): """Returns the codebooks entries. Note that the codebook entries are implicitly defined by the number of levels. """ indices = torch.arange(self.codebook_size) # [D, B, T] indices = rearrange(indices, 'B -> 1 B 1') # [B, D, T] codes = self.decode(indices=indices, input_len=None) # Remove the time dimension codes = codes.squeeze(-1) return codes @property def codebook(self): """Returns the codebooks entries. See self.codes for more details. """ return self.codes @staticmethod def round(inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: """Round the input tensor to nearest integer and use a straight-through estimator for the gradient. """ inputs_rounded = torch.round(inputs) return inputs + (inputs_rounded - inputs).detach() def compress(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: """Apply compression to the input, to limit to values.""" output_scale = (self.num_levels - 1) / 2 # scale down a bit to avoid rounding issues output_scale = output_scale * (1 - self.eps) # offset for even number of levels output_offset = torch.where(self.num_levels % 2 == 0, 0.5, 0) # shift for even number of levels input_shift = (output_offset / output_scale).tan() # compressed output output = output_scale * (inputs + input_shift).tanh() - output_offset return output @typecheck( input_types={ "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), }, output_types={"codes": NeuralType(('B', 'D', 'T'), Index())}, ) def inputs_to_codes(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: # apply compression compressed = self.compress(inputs=inputs, input_len=input_len) # apply rounding to nearest integer codes = self.round(inputs=compressed, input_len=input_len) # normalize to [-1, 1] scale = self.num_levels // 2 codes = codes / scale return codes def codes_to_nonnegative(self, codes: torch.Tensor) -> torch.Tensor: """Convert values centered arouund zero to nonnegative values.""" scale = offset = self.num_levels // 2 return scale * codes + offset def nonnegative_to_codes(self, codes_nonnegative: torch.Tensor) -> torch.Tensor: """Convert nonnegative values to values centered arouund zero.""" scale = offset = self.num_levels // 2 return (codes_nonnegative - offset) / scale def codes_to_indices(self, codes: torch.Tensor) -> torch.Tensor: """Converts a code vector to a single index.""" if codes.size(1) != self.dim: raise RuntimeError( f'Input code dimension {codes.size(1)} not matching the expected dimension {self.dim}, input codes shape {codes.shape}' ) # convert code vectors to nonnegative values indices = self.codes_to_nonnegative(codes) # convert one nonnegative index per dimension to a single index per code vector indices = torch.sum(indices * self.dim_base_index, dim=1) return indices.to(torch.int32) # Implementation of VectorQuantiserBase API @typecheck() def forward( self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: if inputs.size(1) != self.dim: raise RuntimeError( f'Input dimension {inputs.size(1)} not matching the expected dimension {self.dim}, inputs shape {inputs.shape}' ) dequantized = self.inputs_to_codes(inputs=inputs, input_len=input_len) indices = self.codes_to_indices(codes=dequantized) if input_len is not None: # apply masking dequantized = mask_sequence_tensor(dequantized, input_len) indices = mask_sequence_tensor(indices, input_len) # only 1 codebook, but return in [D, B, T] format to match RVQ API indices = indices.unsqueeze(0) return dequantized, indices @typecheck( input_types={ "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType(), optional=True), }, output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, ) def encode(self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None) -> torch.Tensor: """Convert a continuous code vector to a single index.""" _, indices = self(inputs=inputs, input_len=input_len) return indices @typecheck( input_types={ "indices": NeuralType(('D', 'B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType(), optional=True), }, output_types={ "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), }, ) def decode(self, indices: torch.Tensor, input_len: Optional[torch.Tensor] = None) -> torch.Tensor: """Convert a single index to a continuous code vector.""" if indices.size(0) > 1: # codebook dimension used for compatibility with RVQ raise ValueError( f'Expected a single codebook, got {indices.size(0)} codebooks for indices with shape {indices.shape}.' ) indices = rearrange(indices, 'D B T -> B D T') # convert a single index to nonnegative index per-dimension codes_nonnegative = (indices // self.dim_base_index) % self.num_levels # convert nonnegative codes to codes (centered around zero) dequantized = self.nonnegative_to_codes(codes_nonnegative) if input_len is not None: # apply masking dequantized = mask_sequence_tensor(dequantized, input_len) return dequantized class GroupFiniteScalarQuantizer(VectorQuantizerBase): """Split the input vector into groups and apply FSQ on each group separately. This class is for convenience. Since FSQ is applied on each group separately, groups can be defined arbitrarily by splitting the input vector. However, this class makes it easy to construct several groups with the same quantization num_levels. Args: num_groups: number of groups to split the input into, each group will be quantized separately using num_codebooks//num_groups codebooks codebook_dim: embedding dimension, will be split into num_groups **kwargs: parameters of FiniteScalarQuantizer References: Yang et al, HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec, 2023 (http://arxiv.org/abs/2305.02765). """ def __init__(self, num_groups: int, num_levels_per_group: List[int], **kwargs): super().__init__() self.num_groups = num_groups self.codebook_dim_per_group = len(num_levels_per_group) # Initialize FSQ for each group self.fsqs = torch.nn.ModuleList( [FiniteScalarQuantizer(num_levels=num_levels_per_group, **kwargs) for _ in range(self.num_groups)] ) @property def codebook_dim(self): """Input vector dimension.""" return self.codebook_dim_per_group * self.num_groups @property def codebook_size_per_group(self): """Returns the size of the implicit codebook for each group.""" return self.fsqs[0].codebook_size @property def codebook_size(self): """Returns the size of the implicit codebook.""" return self.codebook_size_per_group**self.num_groups @typecheck() def forward(self, inputs, input_len): """Quantize each group separately, then concatenate the results.""" inputs_grouped = inputs.chunk(self.num_groups, dim=1) dequantized, indices = [], [] for in_group, fsq_group in zip(inputs_grouped, self.fsqs): dequantized_group, indices_group = fsq_group(inputs=in_group, input_len=input_len) dequantized.append(dequantized_group) indices.append(indices_group) # concatenate along the feature dimension dequantized = torch.cat(dequantized, dim=1) # concatente along the codebook dimension indices = torch.cat(indices, dim=0) return dequantized, indices @typecheck( input_types={ "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), }, output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, ) def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: """Input is split into groups, each group is encoded separately, then the results are concatenated.""" inputs_grouped = inputs.chunk(self.num_groups, dim=1) indices = [] for in_group, fsq_group in zip(inputs_grouped, self.fsqs): indices_group = fsq_group.encode(inputs=in_group, input_len=input_len) indices.append(indices_group) # concatenate along the codebook dimension indices = torch.cat(indices, dim=0) return indices @typecheck( input_types={ "indices": NeuralType(('D', 'B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), }, ) def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: """Input indices are split into groups, each group is decoded separately, then the results are concatenated.""" indices_grouped = indices.chunk(self.num_groups, dim=0) dequantized = [] for indices_group, fsq_group in zip(indices_grouped, self.fsqs): dequantized_group = fsq_group.decode(indices=indices_group, input_len=input_len) dequantized.append(dequantized_group) # concatenate along the feature dimension dequantized = torch.cat(dequantized, dim=1) return dequantized class ResidualBlock(NeuralModule): """ The residual block structure defined by the HiFi-GAN V1 and V2 configurations. Args: channels: Input dimension. filters: Number of channels in the residual convolutions. kernel_size: Kernel size of the residual convolutions. dilation: Dilation of the residual convolutions. dropout_rate: Dropout to apply to residuals. activation: Activation to apply in between residual convolutions. """ def __init__( self, channels: int, filters: int, kernel_size: int = 3, dilation: int = 1, dropout_rate: float = 0.0, activation: str = "lrelu", ): super(ResidualBlock, self).__init__() self.input_activation = CodecActivation(activation=activation, channels=channels) self.skip_activation = CodecActivation(activation=activation, channels=filters) self.dropout = torch.nn.Dropout(dropout_rate) self.input_conv = Conv1dNorm( in_channels=channels, out_channels=filters, kernel_size=kernel_size, dilation=dilation ) self.skip_conv = Conv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size) def remove_weight_norm(self): self.input_conv.remove_weight_norm() self.skip_conv.remove_weight_norm() @property def input_types(self): return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} @property def output_types(self): return {"out": NeuralType(('B', 'C', 'T'), EncodedRepresentation())} @typecheck() def forward(self, inputs, input_len): conv_input = self.input_activation(inputs) skip_input = self.input_conv(inputs=conv_input, input_len=input_len) skip_input = self.skip_activation(skip_input) res = self.skip_conv(inputs=skip_input, input_len=input_len) res = self.dropout(res) out = inputs + res return out class HiFiGANResBlock(NeuralModule): """ Residual block wrapper for HiFi-GAN which creates a block for multiple dilations. Args: channels: Input dimension. kernel_size: Kernel size of the residual blocks. dilations: List of dilations. One residual block will be created for each dilation in the list. activation: Activation for the residual blocks. """ def __init__(self, channels: int, kernel_size: int, dilations: Iterable[int], activation: str): super().__init__() self.res_blocks = nn.ModuleList( [ ResidualBlock( channels=channels, filters=channels, kernel_size=kernel_size, dilation=dilation, activation=activation, ) for dilation in dilations ] ) def remove_weight_norm(self): for res_block in self.res_blocks: res_block.remove_weight_norm() @property def input_types(self): return { "inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return {"out": NeuralType(('B', 'C', 'T'), VoidType())} @typecheck() def forward(self, inputs, input_len): out = inputs for res_block in self.res_blocks: out = res_block(inputs=out, input_len=input_len) return out class HiFiGANResLayer(NeuralModule): """ Residual block wrapper for HiFi-GAN which creates a block for multiple kernel sizes and dilations. One residual block is created for each combination of kernel size and dilation. Args: channels: Input dimension. kernel_sizes: List of kernel sizes. dilations: List of dilations. activation: Activation for the residual layers. """ def __init__(self, channels: int, kernel_sizes: Iterable[int], dilations: Iterable[int], activation: str): super().__init__() self.res_blocks = nn.ModuleList( [ HiFiGANResBlock(channels=channels, kernel_size=kernel_size, dilations=dilations, activation=activation) for kernel_size in kernel_sizes ] ) def remove_weight_norm(self): for res_block in self.res_blocks: res_block.remove_weight_norm() @property def input_types(self): return { "inputs": NeuralType(('B', 'D', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return {"out": NeuralType(('B', 'D', 'T'), VoidType())} @typecheck() def forward(self, inputs, input_len): residuals = [res_block(inputs=inputs, input_len=input_len) for res_block in self.res_blocks] out = sum(residuals) / len(residuals) return out class HiFiGANEncoder(NeuralModule): """ Audio encoder created by inverting the HiFi-GAN decoder. Args: encoded_dim: Dimension of encoder output. down_sample_rates: Rate to upsample for each decoder block. The product of the downsample rates will determine the output token rate. For example 2 * 2 * 8 * 8 = 256 samples per token. base_channels: Number of filters in the first convolution. The number of channels will be doubled after each downsample layer. in_kernel_size: Kernel size of the input convolution. out_kernel_size: Kernel size of the output convolution. resblock_kernel_sizes: List of kernel sizes to use in each residual block. resblock_dilation_sizes: List of dilations to use in each residual block. activation: Activation to use in residual and downsample layers, defaults to leaky relu. """ def __init__( self, encoded_dim: int, down_sample_rates: Iterable[int] = (2, 2, 8, 8), base_channels: int = 32, in_kernel_size: int = 7, out_kernel_size: int = 7, resblock_kernel_sizes: Iterable[int] = (3, 7, 11), resblock_dilation_sizes: Iterable[int] = (1, 3, 5), activation: str = "lrelu", ): assert in_kernel_size > 0 assert out_kernel_size > 0 super().__init__() self.down_sample_rates = down_sample_rates self.pre_conv = Conv1dNorm(in_channels=1, out_channels=base_channels, kernel_size=in_kernel_size) in_channels = base_channels self.activations = nn.ModuleList([]) self.down_sample_conv_layers = nn.ModuleList([]) self.res_layers = nn.ModuleList([]) for i, down_sample_rate in enumerate(self.down_sample_rates): res_layer = HiFiGANResLayer( channels=in_channels, kernel_sizes=resblock_kernel_sizes, dilations=resblock_dilation_sizes, activation=activation, ) self.res_layers.append(res_layer) act = CodecActivation(activation, channels=in_channels) self.activations.append(act) out_channels = 2 * in_channels kernel_size = 2 * down_sample_rate padding = get_down_sample_padding(kernel_size=kernel_size, stride=down_sample_rate) down_sample_conv = Conv1dNorm( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=down_sample_rate, padding=padding, ) in_channels = out_channels self.down_sample_conv_layers.append(down_sample_conv) self.post_activation = CodecActivation(activation, channels=in_channels) self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=encoded_dim, kernel_size=out_kernel_size) @property def input_types(self): return { "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), "encoded_len": NeuralType(tuple('B'), LengthsType()), } def remove_weight_norm(self): self.pre_conv.remove_weight_norm() self.post_conv.remove_weight_norm() for res_layer in self.res_layers: res_layer.remove_weight_norm() for down_sample_conv in self.down_sample_conv_layers: down_sample_conv.remove_weight_norm() @typecheck() def forward(self, audio, audio_len): encoded_len = audio_len audio = rearrange(audio, "B T -> B 1 T") # [B, C, T_audio] out = self.pre_conv(inputs=audio, input_len=encoded_len) for act, res_layer, down_sample_conv, down_sample_rate in zip( self.activations, self.res_layers, self.down_sample_conv_layers, self.down_sample_rates ): # [B, C, T] out = res_layer(inputs=out, input_len=encoded_len) out = act(out) encoded_len = encoded_len // down_sample_rate # [B, 2 * C, T / down_sample_rate] out = down_sample_conv(inputs=out, input_len=encoded_len) out = self.post_activation(out) # [B, encoded_dim, T_encoded] encoded = self.post_conv(inputs=out, input_len=encoded_len) return encoded, encoded_len class HiFiGANDecoder(NeuralModule): """ Codec decoder using the HiFi-GAN generator architecture. Default parameters match the HiFi-GAN V1 configuration for 22.05khz. Args: input_dim: Input dimension. up_sample_rates: Rate to upsample for each decoder block. The product of the upsample rates should be the same as the overall downsample rate for your encoder. For example, a symmetric encoder/decoder can be created with encoder downsample rates [2, 2, 8, 8] and decoder upsample rates [8, 8, 2, 2]. base_channels: Number of filters in the first convolution. The number of channels will be cut in half after each upsample layer. in_kernel_size: Kernel size of the input convolution. out_kernel_size: Kernel size of the output convolution. resblock_kernel_sizes: List of kernel sizes to use in each residual block. resblock_dilation_sizes: List of dilations to use in each residual block. activation: Activation to use in residual and upsample layers, defaults to leaky relu. output_activation: Activation to apply to output. To produce a valid audio signal, it should output values in the range [-1.0, 1.0]. Supports "tanh" and "clamp". """ def __init__( self, input_dim: int, up_sample_rates: Iterable[int] = (8, 8, 2, 2), base_channels: int = 512, in_kernel_size: int = 7, out_kernel_size: int = 3, resblock_kernel_sizes: Iterable[int] = (3, 7, 11), resblock_dilation_sizes: Iterable[int] = (1, 3, 5), activation: str = "lrelu", output_activation: str = "tanh", ): assert in_kernel_size > 0 assert out_kernel_size > 0 super().__init__() self.up_sample_rates = up_sample_rates self.pre_conv = Conv1dNorm(in_channels=input_dim, out_channels=base_channels, kernel_size=in_kernel_size) in_channels = base_channels self.activations = nn.ModuleList([]) self.up_sample_conv_layers = nn.ModuleList([]) self.res_layers = nn.ModuleList([]) for i, up_sample_rate in enumerate(self.up_sample_rates): out_channels = in_channels // 2 kernel_size = 2 * up_sample_rate act = CodecActivation(activation, channels=in_channels) self.activations.append(act) up_sample_conv = ConvTranspose1dNorm( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=up_sample_rate ) in_channels = out_channels self.up_sample_conv_layers.append(up_sample_conv) res_layer = HiFiGANResLayer( channels=in_channels, kernel_sizes=resblock_kernel_sizes, dilations=resblock_dilation_sizes, activation=activation, ) self.res_layers.append(res_layer) self.post_activation = CodecActivation(activation, channels=in_channels) self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=1, kernel_size=out_kernel_size) if output_activation == "tanh": self.out_activation = nn.Tanh() elif output_activation == "clamp": self.out_activation = ClampActivation() else: raise ValueError(f"Invalid audio output activation {output_activation}") @property def input_types(self): return { "inputs": NeuralType(('B', 'D', 'T_encoded'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } def remove_weight_norm(self): self.pre_conv.remove_weight_norm() for up_sample_conv in self.up_sample_conv_layers: up_sample_conv.remove_weight_norm() for res_layer in self.res_layers: res_layer.remove_weight_norm() @typecheck() def forward(self, inputs, input_len): audio_len = input_len # [B, C, T_encoded] out = self.pre_conv(inputs=inputs, input_len=audio_len) for act, res_layer, up_sample_conv, up_sample_rate in zip( self.activations, self.res_layers, self.up_sample_conv_layers, self.up_sample_rates ): audio_len = audio_len * up_sample_rate out = act(out) # [B, C / 2, T * up_sample_rate] out = up_sample_conv(inputs=out, input_len=audio_len) out = res_layer(inputs=out, input_len=audio_len) out = self.post_activation(out) # [B, 1, T_audio] out = self.post_conv(inputs=out, input_len=audio_len) audio = self.out_activation(out) audio = rearrange(audio, "B 1 T -> B T") return audio, audio_len class MelSpectrogramProcessor(NeuralModule): """ Wrapper interface for computing mel spectrogram for codec training. """ def __init__(self, sample_rate: int, win_length: int, hop_length: int, mel_dim: int = 80, log_guard: float = 1.0): super(MelSpectrogramProcessor, self).__init__() self.mel_dim = mel_dim self.hop_length = hop_length self.preprocessor = AudioToMelSpectrogramPreprocessor( sample_rate=sample_rate, highfreq=None, features=mel_dim, pad_to=1, exact_pad=True, n_window_size=win_length, n_window_stride=hop_length, window_size=False, window_stride=False, n_fft=win_length, mag_power=1.0, log=True, log_zero_guard_type="add", log_zero_guard_value=log_guard, mel_norm=None, normalize=None, preemph=None, dither=0.0, ) @property def input_types(self): return { "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), "spec_len": NeuralType(tuple('B'), LengthsType()), } @typecheck() def forward(self, audio, audio_len): spec, spec_len = self.preprocessor(input_signal=audio, length=audio_len) return spec, spec_len class ResNetEncoder(NeuralModule): """ Residual network which uses HiFi-GAN residual blocks to encode spectrogram features without changing the time dimension. Args: in_channels: input dimension out_channels: output dimension num_layers: number of residual blocks to use hidden_channels: encoder hidden dimension filters: number of filters in residual block layers kernel_size: kernel size in residual block convolutions dropout_rate: Optional dropout rate to apply to residuals. activation: Activation to use, defaults to leaky relu. """ def __init__( self, in_channels: int, out_channels: int, num_layers: int = 6, hidden_channels: int = 256, filters: int = 768, kernel_size: int = 3, dropout_rate: float = 0.1, activation: str = "lrelu", ): super(ResNetEncoder, self).__init__() self.pre_conv = Conv1dNorm(in_channels=in_channels, out_channels=hidden_channels, kernel_size=kernel_size) self.res_layers = nn.ModuleList( [ ResidualBlock( channels=hidden_channels, filters=filters, kernel_size=kernel_size, dropout_rate=dropout_rate, activation=activation, ) for _ in range(num_layers) ] ) self.post_activation = CodecActivation(activation, channels=hidden_channels) self.post_conv = Conv1dNorm(in_channels=hidden_channels, out_channels=out_channels, kernel_size=kernel_size) def remove_weight_norm(self): self.pre_conv.remove_weight_norm() self.post_conv.remove_weight_norm() for res_layer in self.res_layers: res_layer.remove_weight_norm() @property def input_types(self): return { "inputs": NeuralType(('B', 'D', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return {"encoded": NeuralType(('B', 'C', 'T'), EncodedRepresentation())} @typecheck() def forward(self, inputs, input_len): encoded = self.pre_conv(inputs=inputs, input_len=input_len) for res_layer in self.res_layers: encoded = res_layer(inputs=encoded, input_len=input_len) encoded = self.post_activation(encoded) encoded = self.post_conv(inputs=encoded, input_len=input_len) return encoded class FullBandMelEncoder(NeuralModule): """ Encoder which encodes the entire mel spectrogram with a single encoder network. Args: mel_processor: MelSpectrogramProcessor or equivalent class instance for computing the mel spectrogram from input audio. encoder: ResNetEncoder or equivalent class for encoding the mel spectrogram. """ def __init__(self, mel_processor: NeuralModule, encoder: NeuralModule): super(FullBandMelEncoder, self).__init__() self.mel_processor = mel_processor self.encoder = encoder def remove_weight_norm(self): self.encoder.remove_weight_norm() @property def input_types(self): return { "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { "encoded": NeuralType(('B', 'C', 'T_encoded'), EncodedRepresentation()), "encoded_len": NeuralType(tuple('B'), LengthsType()), } @typecheck() def forward(self, audio, audio_len): out, spec_len = self.mel_processor(audio=audio, audio_len=audio_len) encoded = self.encoder(inputs=out, input_len=spec_len) return encoded, spec_len class MultiBandMelEncoder(NeuralModule): """ Encoder which splits mel spectrogram into bands and encodes each using separate residual networks. Args: mel_bands: List of mel spectrogram bands to encode. Each list element is tuple of 2 elements with the start and end index of the mel features to use. mel_processor: MelSpectrogramProcessor or equivalent class instance for computing the mel spectrogram from input audio. encoder_kwargs: Arguments for constructing encoder for each mel band. """ def __init__(self, mel_bands: Iterable[Tuple[int, int]], mel_processor: NeuralModule, **encoder_kwargs): super(MultiBandMelEncoder, self).__init__() self.validate_mel_bands(mel_dim=mel_processor.mel_dim, mel_bands=mel_bands) self.mel_bands = mel_bands self.mel_processor = mel_processor band_dims = [band[1] - band[0] for band in self.mel_bands] self.encoders = nn.ModuleList( [ResNetEncoder(in_channels=band_dim, **encoder_kwargs) for band_dim in band_dims] ) @staticmethod def validate_mel_bands(mel_dim: int, mel_bands: Iterable[Tuple[int, int]]): mel_dims_used = np.zeros([mel_dim], dtype=bool) for band in mel_bands: mel_dims_used[band[0] : band[1]] = True if not all(mel_dims_used): missing_dims = np.where(~mel_dims_used) raise ValueError(f"Mel bands must cover all {mel_dim} dimensions. Missing {missing_dims}.") return def remove_weight_norm(self): for encoder in self.encoders: encoder.remove_weight_norm() @property def input_types(self): return { "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { "encoded": NeuralType(('B', 'C', 'T_encoded'), EncodedRepresentation()), "encoded_len": NeuralType(tuple('B'), LengthsType()), } @typecheck() def forward(self, audio, audio_len): spec, spec_len = self.mel_processor(audio=audio, audio_len=audio_len) outputs = [] for (band_start, band_end), encoder in zip(self.mel_bands, self.encoders): # [B, D_band, T] spec_band = spec[:, band_start:band_end, :] band_out = encoder(inputs=spec_band, input_len=spec_len) outputs.append(band_out) # [B, C, T] encoded = torch.cat(outputs, dim=1) return encoded, spec_len class FilterbankFeatures(nn.Module): """Featurizer that converts wavs to Mel Spectrograms. See AudioToMelSpectrogramPreprocessor for args. """ def __init__( self, sample_rate=16000, n_window_size=320, n_window_stride=160, window="hann", normalize="per_feature", n_fft=None, preemph=0.97, nfilt=64, lowfreq=0, highfreq=None, log=True, log_zero_guard_type="add", log_zero_guard_value=2**-24, dither=CONSTANT, pad_to=16, max_duration=16.7, frame_splicing=1, exact_pad=False, pad_value=0, mag_power=2.0, use_grads=False, rng=None, nb_augmentation_prob=0.0, nb_max_freq=4000, mel_norm="slaney", stft_exact_pad=False, # Deprecated arguments; kept for config compatibility stft_conv=False, # Deprecated arguments; kept for config compatibility ): super().__init__() if exact_pad and n_window_stride % 2 == 1: raise NotImplementedError( f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the " "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size." ) self.log_zero_guard_value = log_zero_guard_value if ( n_window_size is None or n_window_stride is None or not isinstance(n_window_size, int) or not isinstance(n_window_stride, int) or n_window_size <= 0 or n_window_stride <= 0 ): raise ValueError( f"{self} got an invalid value for either n_window_size or " f"n_window_stride. Both must be positive ints." ) self.win_length = n_window_size self.hop_length = n_window_stride self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None self.exact_pad = exact_pad torch_windows = { 'hann': torch.hann_window, 'hamming': torch.hamming_window, 'blackman': torch.blackman_window, 'bartlett': torch.bartlett_window, 'none': None, } window_fn = torch_windows.get(window, None) window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None self.register_buffer("window", window_tensor) self.normalize = normalize self.log = log self.dither = dither self.frame_splicing = frame_splicing self.nfilt = nfilt self.preemph = preemph self.pad_to = pad_to highfreq = highfreq or sample_rate / 2 filterbanks = torch.tensor( librosa.filters.mel( sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq, norm=mel_norm ), dtype=torch.float, ).unsqueeze(0) self.register_buffer("fb", filterbanks) # Calculate maximum sequence length max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float)) max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0 self.max_length = max_length + max_pad self.pad_value = pad_value self.mag_power = mag_power # We want to avoid taking the log of zero # There are two options: either adding or clamping to a small value if log_zero_guard_type not in ["add", "clamp"]: raise ValueError( f"{self} received {log_zero_guard_type} for the " f"log_zero_guard_type parameter. It must be either 'add' or " f"'clamp'." ) self.use_grads = use_grads if not use_grads: self.forward = torch.no_grad()(self.forward) self._rng = random.Random() if rng is None else rng self.nb_augmentation_prob = nb_augmentation_prob if self.nb_augmentation_prob > 0.0: if nb_max_freq >= sample_rate / 2: self.nb_augmentation_prob = 0.0 else: self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft) # log_zero_guard_value is the the small we want to use, we support # an actual number, or "tiny", or "eps" self.log_zero_guard_type = log_zero_guard_type def stft(self, x): return torch.stft( x, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, center=False if self.exact_pad else True, window=self.window.to(dtype=torch.float), return_complex=True, ) def log_zero_guard_value_fn(self, x): if isinstance(self.log_zero_guard_value, str): if self.log_zero_guard_value == "tiny": return torch.finfo(x.dtype).tiny elif self.log_zero_guard_value == "eps": return torch.finfo(x.dtype).eps else: raise ValueError( f"{self} received {self.log_zero_guard_value} for the " f"log_zero_guard_type parameter. It must be either a " f"number, 'tiny', or 'eps'" ) else: return self.log_zero_guard_value def get_seq_len(self, seq_len): # Assuming that center is True is stft_pad_amount = 0 pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2 seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1 return seq_len.to(dtype=torch.long) @property def filter_banks(self): return self.fb def forward(self, x, seq_len, linear_spec=False): seq_len = self.get_seq_len(seq_len) if self.stft_pad_amount is not None: x = torch.nn.functional.pad( x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect" ).squeeze(1) # dither (only in training mode for eval determinism) if self.training and self.dither > 0: x += self.dither * torch.randn_like(x) # do preemphasis if self.preemph is not None: x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1) # disable autocast to get full range of stft values with torch.amp.autocast(x.device.type, enabled=False): x = self.stft(x) # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude # guard is needed for sqrt if grads are passed through guard = 0 if not self.use_grads else CONSTANT x = torch.view_as_real(x) x = torch.sqrt(x.pow(2).sum(-1) + guard) if self.training and self.nb_augmentation_prob > 0.0: for idx in range(x.shape[0]): if self._rng.random() < self.nb_augmentation_prob: x[idx, self._nb_max_fft_bin :, :] = 0.0 # get power spectrum if self.mag_power != 1.0: x = x.pow(self.mag_power) # return plain spectrogram if required if linear_spec: return x, seq_len # dot with filterbank energies x = torch.matmul(self.fb.to(x.dtype), x) # log features if required if self.log: if self.log_zero_guard_type == "add": x = torch.log(x + self.log_zero_guard_value_fn(x)) elif self.log_zero_guard_type == "clamp": x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x))) else: raise ValueError("log_zero_guard_type was not understood") # frame splicing if required if self.frame_splicing > 1: x = splice_frames(x, self.frame_splicing) # normalize if required if self.normalize: x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize) # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency) max_len = x.size(-1) mask = torch.arange(max_len, device=x.device) mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1) x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value) del mask pad_to = self.pad_to if pad_to == "max": x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value) elif pad_to > 0: pad_amt = x.size(-1) % pad_to if pad_amt != 0: x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) return x, seq_len class MaskedLoss(Loss): def __init__(self, loss_fn, loss_scale: float = 1.0): super(MaskedLoss, self).__init__() self.loss_scale = loss_scale self.loss_fn = loss_fn @property def input_types(self): return { "predicted": NeuralType(('B', 'D', 'T'), PredictionsType()), "target": NeuralType(('B', 'D', 'T'), RegressionValuesType()), "target_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { "loss": NeuralType(elements_type=LossType()), } @typecheck() def forward(self, predicted, target, target_len): assert target.shape[2] == predicted.shape[2] # [B, D, T] loss = self.loss_fn(input=predicted, target=target) # [B, T] loss = torch.mean(loss, dim=1) # [B] loss = torch.sum(loss, dim=1) / torch.clamp(target_len, min=1.0) # [1] loss = torch.mean(loss) loss = self.loss_scale * loss return loss class MaskedMAELoss(MaskedLoss): def __init__(self, loss_scale: float = 1.0): loss_fn = torch.nn.L1Loss(reduction='none') super(MaskedMAELoss, self).__init__(loss_fn=loss_fn, loss_scale=loss_scale) class MaskedMSELoss(MaskedLoss): def __init__(self, loss_scale: float = 1.0): loss_fn = torch.nn.MSELoss(reduction='none') super(MaskedMSELoss, self).__init__(loss_fn=loss_fn, loss_scale=loss_scale) class TimeDomainLoss(Loss): def __init__(self): super(TimeDomainLoss, self).__init__() self.loss_fn = MaskedMAELoss() @property def input_types(self): return { "audio_real": NeuralType(('B', 'T'), AudioSignal()), "audio_gen": NeuralType(('B', 'T'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { "loss": NeuralType(elements_type=LossType()), } @typecheck() def forward(self, audio_real, audio_gen, audio_len): audio_real = rearrange(audio_real, "B T -> B 1 T") audio_gen = rearrange(audio_gen, "B T -> B 1 T") loss = self.loss_fn(target=audio_real, predicted=audio_gen, target_len=audio_len) return loss class MultiResolutionMelLoss(Loss): """ Multi-resolution log mel spectrogram loss. Args: sample_rate: Sample rate of audio. resolutions: List of resolutions, each being 3 integers ordered [num_fft, hop_length, window_length] mel_dims: Dimension of mel spectrogram to compute for each resolution. Should be same length as 'resolutions'. log_guard: Value to add to mel spectrogram to avoid taking log of 0. """ def __init__(self, sample_rate: int, resolutions: List[List], mel_dims: List[int], log_guard: float = 1.0): super(MultiResolutionMelLoss, self).__init__() assert len(resolutions) == len(mel_dims) self.l1_loss_fn = MaskedMAELoss() self.l2_loss_fn = MaskedMSELoss() self.mel_features = torch.nn.ModuleList() for mel_dim, (n_fft, hop_len, win_len) in zip(mel_dims, resolutions): mel_feature = FilterbankFeatures( sample_rate=sample_rate, nfilt=mel_dim, n_window_size=win_len, n_window_stride=hop_len, n_fft=n_fft, pad_to=1, mag_power=1.0, log_zero_guard_type="add", log_zero_guard_value=log_guard, mel_norm=None, normalize=None, preemph=None, dither=0.0, use_grads=True, ) self.mel_features.append(mel_feature) @property def input_types(self): return { "audio_real": NeuralType(('B', 'T'), AudioSignal()), "audio_gen": NeuralType(('B', 'T'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return { "l1_loss": NeuralType(elements_type=LossType()), "l2_loss": NeuralType(elements_type=LossType()), } @typecheck() def forward(self, audio_real, audio_gen, audio_len): l1_loss = 0.0 l2_loss = 0.0 for mel_feature in self.mel_features: mel_real, mel_real_len = mel_feature(x=audio_real, seq_len=audio_len) mel_gen, _ = mel_feature(x=audio_gen, seq_len=audio_len) l1_loss += self.l1_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len) l2_loss += self.l2_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len) l1_loss /= len(self.mel_features) l2_loss /= len(self.mel_features) return l1_loss, l2_loss class STFTLoss(Loss): """ Log magnitude STFT loss. Args: resolution: Resolution of spectrogram, a list of 3 numbers ordered [num_fft, hop_length, window_length] log_guard: Value to add to magnitude spectrogram to avoid taking log of 0. sqrt_guard: Value to add to when computing absolute value of STFT to avoid NaN loss. """ def __init__(self, resolution: List[int], log_guard: float = 1.0, sqrt_guard: float = 1e-5): super(STFTLoss, self).__init__() self.loss_fn = MaskedMAELoss() self.n_fft, self.hop_length, self.win_length = resolution self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) self.log_guard = log_guard self.sqrt_guard = sqrt_guard def _compute_spectrogram(self, audio, spec_len): # [B, n_fft, T_spec] spec = torch.stft( audio, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, return_complex=True, ) # [B, n_fft, T_spec, 2] spec = torch.view_as_real(spec) # [B, n_fft, T_spec] spec_mag = torch.sqrt(spec.pow(2).sum(-1) + self.sqrt_guard) spec_log = torch.log(spec_mag + self.log_guard) spec_log = mask_sequence_tensor(spec_log, spec_len) return spec_log @property def input_types(self): return { "audio_real": NeuralType(('B', 'T'), AudioSignal()), "audio_gen": NeuralType(('B', 'T'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return {"loss": NeuralType(elements_type=LossType())} @typecheck() def forward(self, audio_real, audio_gen, audio_len): spec_len = (audio_len // self.hop_length) + 1 spec_real = self._compute_spectrogram(audio=audio_real, spec_len=spec_len) spec_gen = self._compute_spectrogram(audio=audio_gen, spec_len=spec_len) loss = self.loss_fn(predicted=spec_gen, target=spec_real, target_len=spec_len) return loss class MultiResolutionSTFTLoss(Loss): """ Multi-resolution log magnitude STFT loss. Args: resolutions: List of resolutions, each being 3 integers ordered [num_fft, hop_length, window_length] log_guard: Value to add to magnitude spectrogram to avoid taking log of 0. sqrt_guard: Value to add to when computing absolute value of STFT to avoid NaN loss. """ def __init__(self, resolutions: List[List], log_guard: float = 1.0, sqrt_guard: float = 1e-5): super(MultiResolutionSTFTLoss, self).__init__() self.loss_fns = torch.nn.ModuleList( [STFTLoss(resolution=resolution, log_guard=log_guard, sqrt_guard=sqrt_guard) for resolution in resolutions] ) @property def input_types(self): return { "audio_real": NeuralType(('B', 'T'), AudioSignal()), "audio_gen": NeuralType(('B', 'T'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return {"loss": NeuralType(elements_type=LossType())} @typecheck() def forward(self, audio_real, audio_gen, audio_len): loss = 0.0 for loss_fn in self.loss_fns: loss += loss_fn(audio_real=audio_real, audio_gen=audio_gen, audio_len=audio_len) loss /= len(self.loss_fns) return loss class SISDRLoss(Loss): """ SI-SDR loss based off of torchmetrics.functional.audio.sdr.scale_invariant_signal_distortion_ratio with added support for masking. """ def __init__(self, epsilon: float = 1e-8): super(SISDRLoss, self).__init__() self.epsilon = epsilon @property def input_types(self): return { "audio_real": NeuralType(('B', 'T'), AudioSignal()), "audio_gen": NeuralType(('B', 'T'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return {"loss": NeuralType(elements_type=LossType())} @typecheck() def forward(self, audio_real, audio_gen, audio_len): mask = get_mask_from_lengths(x=audio_real, lengths=audio_len) audio_len = rearrange(audio_len, 'B -> B 1') # Shift audio to have zero-mean # [B, 1] target_mean = torch.sum(audio_real, dim=-1, keepdim=True) / audio_len pred_mean = torch.sum(audio_gen, dim=-1, keepdim=True) / audio_len # [B, T] target = audio_real - target_mean target = target * mask pred = audio_gen - pred_mean pred = pred * mask # [B, 1] ref_pred = torch.sum(pred * target, dim=-1, keepdim=True) ref_target = torch.sum(target**2, dim=-1, keepdim=True) alpha = (ref_pred + self.epsilon) / (ref_target + self.epsilon) # [B, T] target_scaled = alpha * target distortion = target_scaled - pred # [B] target_scaled_power = torch.sum(target_scaled**2, dim=-1) distortion_power = torch.sum(distortion**2, dim=-1) ratio = (target_scaled_power + self.epsilon) / (distortion_power + self.epsilon) si_sdr = 10 * torch.log10(ratio) # [1] loss = -torch.mean(si_sdr) return loss class FeatureMatchingLoss(Loss): """ Standard feature matching loss measuring the difference in the internal discriminator layer outputs (usually leaky relu activations) between real and generated audio, scaled down by the total number of discriminators and layers. """ def __init__(self): super(FeatureMatchingLoss, self).__init__() @property def input_types(self): return { "fmaps_real": [[NeuralType(elements_type=VoidType())]], "fmaps_gen": [[NeuralType(elements_type=VoidType())]], } @property def output_types(self): return { "loss": NeuralType(elements_type=LossType()), } @typecheck() def forward(self, fmaps_real, fmaps_gen): loss = 0.0 for fmap_real, fmap_gen in zip(fmaps_real, fmaps_gen): # [B, ..., time] for feat_real, feat_gen in zip(fmap_real, fmap_gen): # [B, ...] diff = torch.abs(feat_real - feat_gen) feat_loss = torch.mean(diff) / len(fmap_real) loss += feat_loss loss /= len(fmaps_real) return loss class RelativeFeatureMatchingLoss(Loss): """ Relative feature matching loss as described in https://arxiv.org/pdf/2210.13438.pdf. This is similar to standard feature matching loss, but it scales the loss by the absolute value of each feature averaged across time. This might be slightly different from the paper which says the "mean is computed over all dimensions", which could imply taking the average across both time and features. Args: div_guard: Value to add when dividing by mean to avoid large/NaN values. """ def __init__(self, div_guard=1e-3): super(RelativeFeatureMatchingLoss, self).__init__() self.div_guard = div_guard @property def input_types(self): return { "fmaps_real": [[NeuralType(elements_type=VoidType())]], "fmaps_gen": [[NeuralType(elements_type=VoidType())]], } @property def output_types(self): return { "loss": NeuralType(elements_type=LossType()), } @typecheck() def forward(self, fmaps_real, fmaps_gen): loss = 0.0 for fmap_real, fmap_gen in zip(fmaps_real, fmaps_gen): # [B, ..., time] for feat_real, feat_gen in zip(fmap_real, fmap_gen): # [B, ...] feat_mean = torch.mean(torch.abs(feat_real), dim=-1) diff = torch.mean(torch.abs(feat_real - feat_gen), dim=-1) feat_loss = diff / (feat_mean + self.div_guard) # [1] feat_loss = torch.mean(feat_loss) / len(fmap_real) loss += feat_loss loss /= len(fmaps_real) return loss class GeneratorHingedLoss(Loss): @property def input_types(self): return { "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], } @property def output_types(self): return {"loss": NeuralType(elements_type=LossType())} @typecheck() def forward(self, disc_scores_gen): loss = 0.0 for disc_score_gen in disc_scores_gen: loss += torch.mean(F.relu(1 - disc_score_gen)) loss /= len(disc_scores_gen) return loss class GeneratorSquaredLoss(Loss): @property def input_types(self): return { "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], } @property def output_types(self): return {"loss": NeuralType(elements_type=LossType())} @typecheck() def forward(self, disc_scores_gen): loss = 0.0 for disc_score_gen in disc_scores_gen: loss += torch.mean((1 - disc_score_gen) ** 2) loss /= len(disc_scores_gen) return loss class DiscriminatorHingedLoss(Loss): @property def input_types(self): return { "disc_scores_real": [NeuralType(('B', 'C', 'T'), VoidType())], "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], } @property def output_types(self): return {"loss": NeuralType(elements_type=LossType())} @typecheck() def forward(self, disc_scores_real, disc_scores_gen): loss = 0.0 for disc_score_real, disc_score_gen in zip(disc_scores_real, disc_scores_gen): loss_real = torch.mean(F.relu(1 - disc_score_real)) loss_gen = torch.mean(F.relu(1 + disc_score_gen)) loss += (loss_real + loss_gen) / 2 loss /= len(disc_scores_real) return loss class DiscriminatorSquaredLoss(Loss): @property def input_types(self): return { "disc_scores_real": [NeuralType(('B', 'C', 'T'), VoidType())], "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], } @property def output_types(self): return {"loss": NeuralType(elements_type=LossType())} @typecheck() def forward(self, disc_scores_real, disc_scores_gen): loss = 0.0 for disc_score_real, disc_score_gen in zip(disc_scores_real, disc_scores_gen): loss_real = torch.mean((1 - disc_score_real) ** 2) loss_gen = torch.mean(disc_score_gen**2) loss += (loss_real + loss_gen) / 2 loss /= len(disc_scores_real) return loss @experimental class AudioCodecModel(ModelPT): def __init__(self, cfg): cfg = model_utils.convert_model_config_to_dict_config(cfg) cfg = model_utils.maybe_update_config_version(cfg) self.world_size = 1 super().__init__(cfg=cfg) # Expected sample rate for the input audio self.sample_rate = cfg.sample_rate # Number of samples in each audio frame that is encoded self.samples_per_frame = cfg.samples_per_frame # Discriminator updates self.disc_updates_per_period = cfg.get("disc_updates_per_period", 1) self.disc_update_period = cfg.get("disc_update_period", 1) if self.disc_updates_per_period > self.disc_update_period: raise ValueError( f'Number of discriminator updates ({self.disc_updates_per_period}) per period must be less or equal to the configured period ({self.disc_update_period})' ) # Encoder setup self.audio_encoder = instantiate(cfg.audio_encoder) # Optionally, add gaussian noise to encoder output as an information bottleneck encoder_noise_stdev = cfg.get("encoder_noise_stdev", 0.0) if encoder_noise_stdev: self.encoder_noise = GaussianDropout(stdev=encoder_noise_stdev) else: self.encoder_noise = None if "vector_quantizer" in cfg: self.vector_quantizer = instantiate(cfg.vector_quantizer) vq_output_types = list(self.vector_quantizer.output_types.keys()) if len(vq_output_types) == 3 and vq_output_types[-1] == 'commit_loss': self.vector_quantizer_has_commit_loss = True else: self.vector_quantizer_has_commit_loss = False else: self.vector_quantizer = None # Decoder setup self.audio_decoder = instantiate(cfg.audio_decoder) # Discriminator setup self.discriminator = instantiate(cfg.discriminator) # Mel loss setup loss_resolutions = cfg.loss_resolutions mel_loss_dims = cfg.get("mel_loss_dims") mel_loss_log_guard = cfg.get("mel_loss_log_guard", 1.0) self.mel_loss_l1_scale = cfg.get("mel_loss_l1_scale", 1.0) self.mel_loss_l2_scale = cfg.get("mel_loss_l2_scale", 1.0) self.mel_loss_fn = MultiResolutionMelLoss( sample_rate=self.sample_rate, mel_dims=mel_loss_dims, resolutions=loss_resolutions, log_guard=mel_loss_log_guard, ) # STFT loss setup stft_loss_log_guard = cfg.get("stft_loss_log_guard", 1.0) self.stft_loss_scale = cfg.get("stft_loss_scale", 0.0) self.stft_loss_fn = MultiResolutionSTFTLoss( resolutions=loss_resolutions, log_guard=stft_loss_log_guard, ) # Time domain loss setup self.time_domain_loss_scale = cfg.get("time_domain_loss_scale", 1.0) self.si_sdr_loss_scale = cfg.get("si_sdr_loss_scale", 0.0) self.time_domain_loss_fn = TimeDomainLoss() self.si_sdr_loss_fn = SISDRLoss() # Discriminator loss setup self.gen_loss_scale = cfg.get("gen_loss_scale", 1.0) self.feature_loss_scale = cfg.get("feature_loss_scale", 1.0) self.gen_loss_fn = instantiate(cfg.generator_loss) self.disc_loss_fn = instantiate(cfg.discriminator_loss) feature_loss_type = cfg.get("feature_loss_type", "relative") if feature_loss_type == "relative": self.feature_loss_fn = RelativeFeatureMatchingLoss() elif feature_loss_type == "absolute": self.feature_loss_fn = FeatureMatchingLoss() else: raise ValueError(f'Unknown feature loss type {feature_loss_type}.') # Codebook loss setup if self.vector_quantizer: self.commit_loss_scale = cfg.get("commit_loss_scale", 1.0) else: self.commit_loss_scale = 0.0 if self.commit_loss_scale > 0 and not self.vector_quantizer_has_commit_loss: raise ValueError('Commit loss is enabled but the quantizer does not support it.') @typecheck( input_types={ "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), "encoded_len": NeuralType(tuple('B'), LengthsType()), }, ) def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Apply encoder on the input audio signal. Input will be padded with zeros so the last frame has full `self.samples_per_frame` samples. Args: audio: input time-domain signal audio_len: valid length for each example in the batch Returns: Encoder output `encoded` and its length in number of frames `encoded_len` """ audio, audio_len = self.pad_audio(audio, audio_len) encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) return encoded, encoded_len @typecheck( input_types={ "inputs": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), }, ) def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Apply decoder on the input. Note that the input is a non-quantized encoder output or a dequantized representation. Args: inputs: encoded signal input_len: valid length for each example in the batch Returns: Decoded output `audio` in the time domain and its length in number of samples `audio_len`. Note that `audio_len` will be a multiple of `self.samples_per_frame`. """ audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len) return audio, audio_len @typecheck( input_types={ "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), "encoded_len": NeuralType(tuple('B'), LengthsType()), }, output_types={"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex())}, ) def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor: """Quantize the continuous encoded representation into a discrete representation for each frame. Args: encoded: encoded signal representation encoded_len: valid length of the encoded representation in frames Returns: A tensor of tokens for each codebook for each frame. """ if not self.vector_quantizer: raise ValueError("Cannot quantize without quantizer") # vector quantizer is returning [C, B, T], where C is the number of codebooks tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len) # use batch first for the output tokens = rearrange(tokens, 'C B T -> B C T') return tokens @typecheck( input_types={ "tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()), "tokens_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ "dequantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), }, ) def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Tensor: """Convert the discrete tokens into a continuous encoded representation. Args: tokens: discrete tokens for each codebook for each time frame tokens_len: valid length of each example in the batch Returns: Continuous encoded representation of the discrete input representation. """ if not self.vector_quantizer: raise ValueError("Cannot dequantize without quantizer") # vector quantizer is using [C, B, T], where C is the number of codebooks tokens = rearrange(tokens, 'B C T -> C B T') dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len) return dequantized @typecheck( input_types={ "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ "tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()), "tokens_len": NeuralType(tuple('B'), LengthsType()), }, ) def encode(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Convert input time-domain audio signal into a discrete representation (tokens). Args: audio: input time-domain signal, shape `(batch, number of samples)` audio_len: valid length for each example in the batch, shape `(batch size,)` Returns: Tokens for each codebook for each frame, shape `(batch, number of codebooks, number of frames)`, and the corresponding valid lengths, shape `(batch,)` """ # Apply encoder to obtain a continuous vector for each frame encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len) # Apply quantizer to obtain discrete representation per frame tokens = self.quantize(encoded=encoded, encoded_len=encoded_len) return tokens, encoded_len @typecheck( input_types={ "tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()), "tokens_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), }, ) def decode(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Convert discrete tokens into a continuous time-domain signal. Args: tokens: discrete tokens for each codebook for each time frame, shape `(batch, number of codebooks, number of frames)` tokens_len: valid lengths, shape `(batch,)` Returns: Decoded output `audio` in the time domain and its length in number of samples `audio_len`. Note that `audio_len` will be a multiple of `self.samples_per_frame`. """ # Convert a discrete representation to a dequantized vector for each frame dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len) # Apply decoder to obtain time-domain audio for each frame audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len) return audio, audio_len @typecheck( input_types={ "audio": NeuralType(('B', 'T_audio'), AudioSignal()), "audio_len": NeuralType(tuple('B'), LengthsType()), }, output_types={ "output_audio": NeuralType(('B', 'T_audio'), EncodedRepresentation()), "output_audio_len": NeuralType(tuple('B'), LengthsType()), }, ) def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Apply encoder, quantizer, decoder on the input time-domain signal. Args: audio: input time-domain signal audio_len: valid length for each example in the batch Returns: Reconstructed time-domain signal `output_audio` and its length in number of samples `output_audio_len`. """ encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len) if self.vector_quantizer: # quantize to discrete tokens tokens = self.quantize(encoded=encoded, encoded_len=encoded_len) # decode tokens to audio output_audio, output_audio_len = self.decode(tokens=tokens, tokens_len=encoded_len) else: # no quantization, directly decode to audio output_audio, output_audio_len = self.decode_audio(inputs=encoded, input_len=encoded_len) return output_audio, output_audio_len def pad_audio(self, audio, audio_len): """Zero pad the end of the audio so that we do not have a partial end frame. The output will be zero-padded to have an integer number of frames of length `self.samples_per_frame`. Args: audio: input time-domain signal audio_len: valid length for each example in the batch Returns: Padded time-domain signal `padded_audio` and its length `padded_len`. """ padded_len = self.samples_per_frame * torch.ceil(audio_len / self.samples_per_frame).int() max_len = padded_len.max().item() num_padding = max_len - audio.shape[1] padded_audio = F.pad(audio, (0, num_padding)) return padded_audio, padded_len def _process_batch(self, batch): # [B, T_audio] audio = batch.get("audio") # [B] audio_len = batch.get("audio_lens") audio, audio_len = self.pad_audio(audio, audio_len) # [B, D, T_encoded] encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) if self.encoder_noise is not None: encoded = self.encoder_noise(encoded) if self.vector_quantizer: if self.vector_quantizer_has_commit_loss: encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len) else: encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len) commit_loss = 0.0 else: commit_loss = 0.0 # [B, T] audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len) return audio, audio_len, audio_gen, commit_loss @property def disc_update_prob(self) -> float: """Probability of updating the discriminator.""" return self.disc_updates_per_period / self.disc_update_period def should_update_disc(self, batch_idx) -> bool: """Decide whether to update the descriminator based on the batch index and configured discriminator update period. """ disc_update_step = batch_idx % self.disc_update_period return disc_update_step < self.disc_updates_per_period def setup_training_data(self): ... def setup_validation_data(self): ... @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: models = [] model = PretrainedModelInfo( pretrained_model_name="audio_codec_16khz_small", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/audio_codec_16khz_small/versions/v1/files/audio_codec_16khz_small.nemo", description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/audio_codec_16khz_small", ) models.append(model) model = PretrainedModelInfo( pretrained_model_name="mel_codec_22khz_medium", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_medium/versions/v1/files/mel_codec_22khz_medium.nemo", description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_22khz_medium", ) models.append(model) model = PretrainedModelInfo( pretrained_model_name="mel_codec_44khz_medium", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_medium/versions/v1/files/mel_codec_44khz_medium.nemo", description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_44khz_medium", ) models.append(model) model = PretrainedModelInfo( pretrained_model_name="mel_codec_22khz_fullband_medium", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_fullband_medium/versions/v1/files/mel_codec_22khz_fullband_medium.nemo", description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_22khz_fullband_medium", ) models.append(model) model = PretrainedModelInfo( pretrained_model_name="mel_codec_44khz_fullband_medium", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_fullband_medium/versions/v1/files/mel_codec_44khz_fullband_medium.nemo", description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_44khz_fullband_medium", ) models.append(model) return models