2771 lines
100 KiB
Python
2771 lines
100 KiB
Python
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# modified grossly to avoid additional dependencies
|
|
|
|
CONSTANT = 1e-5
|
|
|
|
import librosa
|
|
import itertools
|
|
import random
|
|
|
|
from math import ceil
|
|
from pathlib import Path
|
|
from abc import ABC, abstractmethod
|
|
from typing import Iterable, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import torchaudio
|
|
|
|
from einops import rearrange
|
|
|
|
from nemo.core import ModelPT
|
|
from nemo.core.classes.module import NeuralModule
|
|
from nemo.core.classes.common import PretrainedModelInfo, typecheck
|
|
from nemo.core.neural_types.neural_type import NeuralType
|
|
|
|
from nemo.collections.common.parts.utils import ClampActivation, HalfSnake, Snake, mask_sequence_tensor
|
|
|
|
from nemo.utils import model_utils
|
|
from nemo.utils.decorators import experimental
|
|
|
|
from nemo.core.neural_types.elements import (
|
|
AudioSignal,
|
|
EncodedRepresentation,
|
|
Index,
|
|
LengthsType,
|
|
MelSpectrogramType,
|
|
VoidType,
|
|
TokenIndex,
|
|
)
|
|
|
|
from nemo.core.classes import Loss
|
|
|
|
def instantiate( cfg ):
|
|
cls = None
|
|
cfg = dict(cfg)
|
|
target = cfg.pop("_target_")
|
|
|
|
if target == "nemo.collections.tts.modules.audio_codec_modules.HiFiGANEncoder":
|
|
cls = HiFiGANEncoder
|
|
elif target == "nemo.collections.tts.modules.audio_codec_modules.GroupFiniteScalarQuantizer":
|
|
cls = GroupFiniteScalarQuantizer
|
|
elif target == "nemo.collections.tts.modules.audio_codec_modules.HiFiGANDecoder":
|
|
cls = HiFiGANDecoder
|
|
elif target == "nemo.collections.tts.modules.audio_codec_modules.Discriminator":
|
|
cls = Discriminator
|
|
# cheat here
|
|
cfg['discriminators'] = [ instantiate( c ) for c in cfg['discriminators'] ]
|
|
# {'discriminators': [{'_target_': 'nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator'}, {'_target_': 'nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT', 'resolutions': [[512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]], 'stft_bands': [[0.0, 0.1], [0.1, 0.25], [0.25, 0.5], [0.5, 0.75], [0.75, 1.0]]}]}
|
|
elif target == "nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator":
|
|
cls = MultiPeriodDiscriminator
|
|
elif target == "nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT":
|
|
cls = MultiResolutionDiscriminatorSTFT
|
|
elif target == "nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss":
|
|
cls = GeneratorSquaredLoss
|
|
elif target == "nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss":
|
|
cls = DiscriminatorSquaredLoss
|
|
else:
|
|
print( target, cfg )
|
|
raise Exception("!")
|
|
|
|
return cls( **cfg )
|
|
|
|
class GaussianDropout(torch.nn.Module):
|
|
"""
|
|
Gaussian dropout using multiplicative gaussian noise.
|
|
|
|
https://keras.io/api/layers/regularization_layers/gaussian_dropout/
|
|
|
|
Can be an effective alternative bottleneck to VAE or VQ:
|
|
|
|
https://www.deepmind.com/publications/gaussian-dropout-as-an-information-bottleneck-layer
|
|
|
|
Unlike some other implementations, this takes the standard deviation of the noise as input
|
|
instead of the 'rate' typically defined as: stdev = sqrt(rate / (1 - rate))
|
|
"""
|
|
|
|
def __init__(self, stdev=1.0):
|
|
super(GaussianDropout, self).__init__()
|
|
self.stdev = stdev
|
|
|
|
def forward(self, inputs):
|
|
if not self.training:
|
|
return inputs
|
|
|
|
noise = torch.normal(mean=1.0, std=self.stdev, size=inputs.shape, device=inputs.device)
|
|
out = noise * inputs
|
|
return out
|
|
|
|
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
|
return (kernel_size * dilation - dilation) // 2
|
|
|
|
|
|
def get_padding_2d(kernel_size: Tuple[int, int], dilation: Tuple[int, int]) -> Tuple[int, int]:
|
|
paddings = (get_padding(kernel_size[0], dilation[0]), get_padding(kernel_size[1], dilation[1]))
|
|
return paddings
|
|
|
|
|
|
def get_down_sample_padding(kernel_size: int, stride: int) -> int:
|
|
return (kernel_size - stride + 1) // 2
|
|
|
|
|
|
def get_up_sample_padding(kernel_size: int, stride: int) -> Tuple[int, int]:
|
|
output_padding = (kernel_size - stride) % 2
|
|
padding = (kernel_size - stride + 1) // 2
|
|
return padding, output_padding
|
|
|
|
|
|
class SSLModel(NeuralModule):
|
|
def __init__(self, slm_model_name):
|
|
super().__init__()
|
|
self.ssl_model = AutoModel.from_pretrained(slm_model_name)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.ssl_model(*args, **kwargs)
|
|
|
|
|
|
class SLMDiscriminator(NeuralModule):
|
|
"""SLM Discriminator, as described in both the StyleTTS2 and Low Frame-Rate Speech Codec papers.
|
|
|
|
Args:
|
|
slm_model_name: Hugging Face Speech Language Models name.
|
|
slm_sr: Speech Language Models input sampling rate.
|
|
input_sr: Audio input sampling rate.
|
|
slm_hidden: Speech Language Model hidden dim.
|
|
slm_layers: Speech Language Model number of layers.
|
|
initial_channel: discriminative head number of channels.
|
|
use_spectral_norm: If True uses spectral normalization otherwise uses weight norm.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
slm_model_name="microsoft/wavlm-base-plus",
|
|
slm_sr=16000,
|
|
input_sr=22050,
|
|
slm_hidden=768,
|
|
slm_layers=13,
|
|
initial_channel=64,
|
|
use_spectral_norm=False,
|
|
):
|
|
super().__init__()
|
|
|
|
self.slm_model = SSLModel(slm_model_name)
|
|
|
|
# Freeze slm model
|
|
self.slm_model.freeze()
|
|
|
|
self.resample = torchaudio.transforms.Resample(input_sr, slm_sr)
|
|
|
|
norm_f = torch.nn.utils.weight_norm if use_spectral_norm == False else torch.nn.utils.spectral_norm
|
|
self.pre = norm_f(nn.Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0))
|
|
|
|
self.convs = nn.ModuleList(
|
|
[
|
|
norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)),
|
|
norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)),
|
|
norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)),
|
|
]
|
|
)
|
|
|
|
self.conv_post = norm_f(nn.Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
|
|
|
|
def _forward(self, x):
|
|
x = self.slm_model(input_values=self.resample(x), output_hidden_states=True).hidden_states
|
|
x = torch.stack(x, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
|
|
|
|
x = self.pre(x)
|
|
fmap = []
|
|
for l in self.convs:
|
|
x = l(x)
|
|
x = F.leaky_relu(x, 0.1)
|
|
fmap.append(x.unsqueeze(-1))
|
|
|
|
x = self.conv_post(x)
|
|
x = torch.flatten(x, 1, -1)
|
|
|
|
return x, fmap
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio_real": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"scores_real": [NeuralType(('B', 'C', 'T_out'), VoidType())],
|
|
"scores_gen": [NeuralType(('B', 'C', 'T_out'), VoidType())],
|
|
"fmaps_real": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
|
|
"fmaps_gen": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, audio_real, audio_gen):
|
|
|
|
y_d_r, fmap_r = self._forward(audio_real)
|
|
y_d_g, fmap_g = self._forward(audio_gen)
|
|
|
|
return [y_d_r.unsqueeze(1)], [y_d_g.unsqueeze(1)], [fmap_r], [fmap_g]
|
|
|
|
|
|
class CodecActivation(nn.Module):
|
|
"""
|
|
Choose between activation based on the input parameter.
|
|
|
|
Args:
|
|
activation: Name of activation to use. Valid options are "elu" (default), "lrelu", and "snake".
|
|
channels: Input dimension.
|
|
"""
|
|
|
|
def __init__(self, activation: str = "elu", channels: int = 1):
|
|
super().__init__()
|
|
activation = activation.lower()
|
|
if activation == "elu":
|
|
self.activation = nn.ELU()
|
|
elif activation == "lrelu":
|
|
self.activation = torch.nn.LeakyReLU()
|
|
elif activation == "snake":
|
|
self.activation = Snake(channels)
|
|
elif activation == "half_snake":
|
|
self.activation = HalfSnake(channels)
|
|
else:
|
|
raise ValueError(f"Unknown activation {activation}")
|
|
|
|
def forward(self, x):
|
|
return self.activation(x)
|
|
|
|
|
|
class Conv1dNorm(NeuralModule):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: int,
|
|
stride: int = 1,
|
|
dilation: int = 1,
|
|
padding: Optional[int] = None,
|
|
):
|
|
super().__init__()
|
|
if not padding:
|
|
padding = get_padding(kernel_size=kernel_size, dilation=dilation)
|
|
conv = nn.Conv1d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
padding_mode="reflect",
|
|
)
|
|
self.conv = nn.utils.weight_norm(conv)
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"inputs": NeuralType(('B', 'C', 'T'), VoidType()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"out": NeuralType(('B', 'C', 'T'), VoidType()),
|
|
}
|
|
|
|
def remove_weight_norm(self):
|
|
nn.utils.remove_weight_norm(self.conv)
|
|
|
|
@typecheck()
|
|
def forward(self, inputs, input_len):
|
|
out = self.conv(inputs)
|
|
out = mask_sequence_tensor(out, input_len)
|
|
return out
|
|
|
|
|
|
class ConvTranspose1dNorm(NeuralModule):
|
|
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1):
|
|
super().__init__()
|
|
padding, output_padding = get_up_sample_padding(kernel_size, stride)
|
|
conv = nn.ConvTranspose1d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
output_padding=output_padding,
|
|
padding_mode="zeros",
|
|
)
|
|
self.conv = nn.utils.weight_norm(conv)
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"inputs": NeuralType(('B', 'C', 'T'), VoidType()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"out": NeuralType(('B', 'C', 'T'), VoidType()),
|
|
}
|
|
|
|
def remove_weight_norm(self):
|
|
nn.utils.remove_weight_norm(self.conv)
|
|
|
|
@typecheck()
|
|
def forward(self, inputs, input_len):
|
|
out = self.conv(inputs)
|
|
out = mask_sequence_tensor(out, input_len)
|
|
return out
|
|
|
|
|
|
class Conv2dNorm(NeuralModule):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: Tuple[int, int],
|
|
stride: Tuple[int, int] = (1, 1),
|
|
dilation: Tuple[int, int] = (1, 1),
|
|
):
|
|
super().__init__()
|
|
assert len(kernel_size) == len(dilation)
|
|
padding = get_padding_2d(kernel_size, dilation)
|
|
conv = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
dilation=dilation,
|
|
padding=padding,
|
|
padding_mode="reflect",
|
|
)
|
|
self.conv = nn.utils.weight_norm(conv)
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"inputs": NeuralType(('B', 'C', 'H', 'T'), VoidType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"out": NeuralType(('B', 'C', 'H', 'T'), VoidType()),
|
|
}
|
|
|
|
def remove_weight_norm(self):
|
|
nn.utils.remove_weight_norm(self.conv)
|
|
|
|
@typecheck()
|
|
def forward(self, inputs):
|
|
return self.conv(inputs)
|
|
|
|
|
|
class PeriodDiscriminator(NeuralModule):
|
|
"""
|
|
Period discriminator introduced in HiFi-GAN https://arxiv.org/abs/2010.05646 which attempts to
|
|
discriminate phase information by looking at equally spaced audio samples.
|
|
|
|
Args:
|
|
period: Spacing between audio sample inputs.
|
|
lrelu_slope: Slope to use for activation. Leaky relu with slope of 0.1 or 0.2 is recommended for the
|
|
stability of the feature matching loss.
|
|
"""
|
|
|
|
def __init__(self, period, lrelu_slope=0.1):
|
|
super().__init__()
|
|
self.period = period
|
|
self.activation = nn.LeakyReLU(lrelu_slope)
|
|
self.conv_layers = nn.ModuleList(
|
|
[
|
|
Conv2dNorm(1, 32, kernel_size=(5, 1), stride=(3, 1)),
|
|
Conv2dNorm(32, 128, kernel_size=(5, 1), stride=(3, 1)),
|
|
Conv2dNorm(128, 512, kernel_size=(5, 1), stride=(3, 1)),
|
|
Conv2dNorm(512, 1024, kernel_size=(5, 1), stride=(3, 1)),
|
|
Conv2dNorm(1024, 1024, kernel_size=(5, 1), stride=(1, 1)),
|
|
]
|
|
)
|
|
self.conv_post = Conv2dNorm(1024, 1, kernel_size=(3, 1))
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"score": NeuralType(('B', 'C', 'T_out'), VoidType()),
|
|
"fmap": [NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())],
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, audio):
|
|
|
|
batch_size, time = audio.shape
|
|
out = rearrange(audio, 'B T -> B 1 T')
|
|
# Pad audio so that it is divisible by the period
|
|
if time % self.period != 0:
|
|
n_pad = self.period - (time % self.period)
|
|
out = F.pad(out, (0, n_pad), "reflect")
|
|
time = time + n_pad
|
|
# [batch, 1, (time / period), period]
|
|
out = out.view(batch_size, 1, time // self.period, self.period)
|
|
|
|
fmap = []
|
|
for conv in self.conv_layers:
|
|
# [batch, filters, (time / period / stride), period]
|
|
out = conv(inputs=out)
|
|
out = self.activation(out)
|
|
fmap.append(out)
|
|
# [batch, 1, (time / period / strides), period]
|
|
score = self.conv_post(inputs=out)
|
|
fmap.append(score)
|
|
score = rearrange(score, "B 1 T C -> B C T")
|
|
|
|
return score, fmap
|
|
|
|
|
|
class MultiPeriodDiscriminator(NeuralModule):
|
|
"""
|
|
Wrapper class to aggregate results of multiple period discriminators.
|
|
|
|
The periods are expected to be increasing prime numbers in order to maximize coverage and minimize overlap
|
|
"""
|
|
|
|
def __init__(self, periods: Iterable[int] = (2, 3, 5, 7, 11), lrelu_slope=0.1):
|
|
super().__init__()
|
|
self.discriminators = nn.ModuleList(
|
|
[PeriodDiscriminator(period=period, lrelu_slope=lrelu_slope) for period in periods]
|
|
)
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio_real": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"scores_real": [NeuralType(('B', 'C', 'T_out'), VoidType())],
|
|
"scores_gen": [NeuralType(('B', 'C', 'T_out'), VoidType())],
|
|
"fmaps_real": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
|
|
"fmaps_gen": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, audio_real, audio_gen):
|
|
scores_real = []
|
|
scores_gen = []
|
|
fmaps_real = []
|
|
fmaps_gen = []
|
|
for discriminator in self.discriminators:
|
|
score_real, fmap_real = discriminator(audio=audio_real)
|
|
score_gen, fmap_gen = discriminator(audio=audio_gen)
|
|
scores_real.append(score_real)
|
|
fmaps_real.append(fmap_real)
|
|
scores_gen.append(score_gen)
|
|
fmaps_gen.append(fmap_gen)
|
|
|
|
return scores_real, scores_gen, fmaps_real, fmaps_gen
|
|
|
|
|
|
class DiscriminatorSTFT(NeuralModule):
|
|
"""
|
|
Discriminator network from EnCodec for Complex STFT input, but without dilations.
|
|
|
|
Args:
|
|
filters: number of filters to use in Conv2d layers
|
|
lrelu_slope: Slope to use for activations. Leaky relu with slope of 0.1 or 0.2 is recommended for the
|
|
stability of the feature matching loss
|
|
"""
|
|
|
|
def __init__(self, filters: int = 32, lrelu_slope: float = 0.1):
|
|
super().__init__()
|
|
|
|
self.activation = nn.LeakyReLU(lrelu_slope)
|
|
self.conv_layers = nn.ModuleList(
|
|
[
|
|
Conv2dNorm(2, filters, kernel_size=(3, 9)),
|
|
Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)),
|
|
Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)),
|
|
Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)),
|
|
Conv2dNorm(filters, filters, kernel_size=(3, 3)),
|
|
]
|
|
)
|
|
self.conv_post = Conv2dNorm(filters, 1, kernel_size=(3, 3))
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"spec": NeuralType(('B', 'C', 'T_spec', 'D'), VoidType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"scores": NeuralType(('B', 'C', 'T_spec'), VoidType()),
|
|
"fmap": [NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())],
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, spec):
|
|
fmap = []
|
|
|
|
# [batch, 2, T_spec, fft]
|
|
out = spec
|
|
for conv in self.conv_layers:
|
|
# [batch, filters, T_spec, fft // strides]
|
|
out = conv(inputs=out)
|
|
out = self.activation(out)
|
|
fmap.append(out)
|
|
# [batch, 1, T_spec, fft // 8]
|
|
scores = self.conv_post(inputs=out)
|
|
fmap.append(scores)
|
|
scores = rearrange(scores, "B 1 T C -> B C T")
|
|
|
|
return scores, fmap
|
|
|
|
|
|
class MultiBandDiscriminatorSTFT(NeuralModule):
|
|
"""
|
|
Multi-band STFT discriminator proposed in DAC (https://arxiv.org/abs/2306.06546).
|
|
|
|
Computes the complex STFT for a given resolution and splits it into sub-bands,
|
|
which are given to separate discriminator networks.
|
|
|
|
Args:
|
|
resolution: STFT resolution, provided as a tuple of 3 integers ordered (num_fft, hop_length, window_length)
|
|
stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end).
|
|
The floats are in the range [0, 1] representing the fraction of all stft bands.
|
|
For example for n_fft=1024, the stft output has 513 dimensions.
|
|
For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512].
|
|
"""
|
|
|
|
def __init__(self, resolution: Tuple[int], stft_bands: Iterable[Tuple[int]]):
|
|
super().__init__()
|
|
|
|
self.n_fft, self.hop_length, self.win_length = resolution
|
|
self.register_buffer("window", torch.hann_window(self.win_length, periodic=False))
|
|
self.discriminators = nn.ModuleList([DiscriminatorSTFT() for _ in stft_bands])
|
|
n_stft = self.n_fft // 2 + 1
|
|
self.stft_bands = [(int(band[0] * n_stft), int(band[1] * n_stft)) for band in stft_bands]
|
|
|
|
def compute_stft(self, audio):
|
|
# [B, fft, T_spec]
|
|
fft = torch.stft(
|
|
audio,
|
|
n_fft=self.n_fft,
|
|
hop_length=self.hop_length,
|
|
win_length=self.win_length,
|
|
window=self.window,
|
|
normalized=True,
|
|
center=True,
|
|
return_complex=True,
|
|
)
|
|
fft = rearrange(fft, "B fft T -> B T fft")
|
|
# [batch, 2, T_spec, fft]
|
|
out = torch.stack([fft.real, fft.imag], dim=1)
|
|
return out
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"scores_list": [NeuralType(('B', 'C', 'T_spec'), VoidType())],
|
|
"fmaps_list": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]],
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, audio):
|
|
scores_list = []
|
|
fmap_list = []
|
|
spec = self.compute_stft(audio)
|
|
for band, disc in zip(self.stft_bands, self.discriminators):
|
|
spec_band = spec[:, :, :, band[0] : band[1]]
|
|
score, fmap = disc(spec=spec_band)
|
|
scores_list.append(score)
|
|
fmap_list.append(fmap)
|
|
|
|
return scores_list, fmap_list
|
|
|
|
|
|
class MultiResolutionDiscriminatorSTFT(NeuralModule):
|
|
"""
|
|
Multi-resolution discriminator which creates a multi-band discriminator for each input resolution.
|
|
|
|
Args:
|
|
resolutions: List of STFT resolutions, each resolution provided as a tuple of 3 integers ordered
|
|
(num_fft, hop_length, window_length)
|
|
stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end).
|
|
The floats are in the range [0, 1] representing the fraction of all stft bands.
|
|
For example for n_fft=1024, the stft output has 513 dimensions.
|
|
For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512].
|
|
"""
|
|
|
|
def __init__(self, resolutions: Iterable[Tuple[int]], stft_bands: Iterable[Tuple[int]]):
|
|
super().__init__()
|
|
self.discriminators = nn.ModuleList(
|
|
[MultiBandDiscriminatorSTFT(resolution=resolution, stft_bands=stft_bands) for resolution in resolutions]
|
|
)
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio_real": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"scores_real": [NeuralType(('B', 'C', 'T_spec'), VoidType())],
|
|
"scores_gen": [NeuralType(('B', 'C', 'T_spec'), VoidType())],
|
|
"fmaps_real": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]],
|
|
"fmaps_gen": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]],
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, audio_real, audio_gen):
|
|
scores_real = []
|
|
scores_gen = []
|
|
fmaps_real = []
|
|
fmaps_gen = []
|
|
|
|
for disc in self.discriminators:
|
|
score_real_i, fmap_real_i = disc(audio=audio_real)
|
|
scores_real = scores_real + score_real_i
|
|
fmaps_real = fmaps_real + fmap_real_i
|
|
|
|
score_gen_i, fmap_gen_i = disc(audio=audio_gen)
|
|
scores_gen = scores_gen + score_gen_i
|
|
fmaps_gen = fmaps_gen + fmap_gen_i
|
|
|
|
return scores_real, scores_gen, fmaps_real, fmaps_gen
|
|
|
|
|
|
class Discriminator(NeuralModule):
|
|
"""
|
|
Wrapper class which takes a list of discriminators and aggregates the results across them.
|
|
"""
|
|
|
|
def __init__(self, discriminators: Iterable[NeuralModule]):
|
|
super().__init__()
|
|
self.discriminators = nn.ModuleList(discriminators)
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio_real": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"scores_real": [NeuralType(('B', 'C', 'T_out'), VoidType())],
|
|
"scores_gen": [NeuralType(('B', 'C', 'T_out'), VoidType())],
|
|
"fmaps_real": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
|
|
"fmaps_gen": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, audio_real, audio_gen):
|
|
scores_real = []
|
|
scores_gen = []
|
|
fmaps_real = []
|
|
fmaps_gen = []
|
|
for discriminator in self.discriminators:
|
|
score_real, score_gen, fmap_real, fmap_gen = discriminator(audio_real=audio_real, audio_gen=audio_gen)
|
|
scores_real += score_real
|
|
fmaps_real += fmap_real
|
|
scores_gen += score_gen
|
|
fmaps_gen += fmap_gen
|
|
|
|
return scores_real, scores_gen, fmaps_real, fmaps_gen
|
|
|
|
|
|
class VectorQuantizerBase(NeuralModule, ABC):
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
|
|
"indices": NeuralType(('D', 'B', 'T'), Index()),
|
|
}
|
|
|
|
@typecheck()
|
|
@abstractmethod
|
|
def forward(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
pass
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={"indices": NeuralType(('D', 'B', 'T'), Index())},
|
|
)
|
|
@abstractmethod
|
|
def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"indices": NeuralType(('D', 'B', 'T'), Index()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={
|
|
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
|
|
},
|
|
)
|
|
@abstractmethod
|
|
def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
|
|
class FiniteScalarQuantizer(VectorQuantizerBase):
|
|
"""This quantizer is based on the Finite Scalar Quantization (FSQ) method.
|
|
It quantizes each element of the input vector independently into a number of levels.
|
|
|
|
Args:
|
|
num_levels: number of levels for each dimension/element of the input vector
|
|
eps: small regularization constant for scaling
|
|
|
|
References:
|
|
Mentzer et al., Finite Scalar Quantization: VQ-VAE Made Simple (https://arxiv.org/abs/2309.15505v1)
|
|
"""
|
|
|
|
def __init__(self, num_levels: List[int], eps: float = 1e-3):
|
|
super().__init__()
|
|
|
|
# index base per dimension of the input vector
|
|
# this is used to convert between per-dimension indices and a codebook token index
|
|
dim_base_index = torch.cumprod(torch.tensor([1] + num_levels[:-1]), dim=0, dtype=torch.int32)
|
|
dim_base_index = rearrange(dim_base_index, 'D -> 1 D 1')
|
|
self.register_buffer('dim_base_index', dim_base_index)
|
|
|
|
# Register the number of levels for each dimension
|
|
num_levels = torch.tensor(num_levels, dtype=torch.int32)
|
|
num_levels = rearrange(num_levels, 'D -> 1 D 1')
|
|
self.register_buffer('num_levels', num_levels)
|
|
|
|
# Regularization
|
|
self.eps = eps
|
|
|
|
@property
|
|
def codebook_size(self):
|
|
"""Returns the size of the corresponding codebook."""
|
|
return self.num_levels.prod().item()
|
|
|
|
@property
|
|
def dim(self):
|
|
"""Returns the dimension of the input vector."""
|
|
return self.num_levels.numel()
|
|
|
|
@property
|
|
def codebook_dim(self):
|
|
"""Returns the dimension of the input vector.
|
|
Keeping for compatiblitiy with the original RVQ implementation.
|
|
"""
|
|
return self.dim
|
|
|
|
@property
|
|
def codes(self):
|
|
"""Returns the codebooks entries.
|
|
|
|
Note that the codebook entries are implicitly defined by the number of levels.
|
|
"""
|
|
indices = torch.arange(self.codebook_size)
|
|
# [D, B, T]
|
|
indices = rearrange(indices, 'B -> 1 B 1')
|
|
# [B, D, T]
|
|
codes = self.decode(indices=indices, input_len=None)
|
|
# Remove the time dimension
|
|
codes = codes.squeeze(-1)
|
|
return codes
|
|
|
|
@property
|
|
def codebook(self):
|
|
"""Returns the codebooks entries.
|
|
See self.codes for more details.
|
|
"""
|
|
return self.codes
|
|
|
|
@staticmethod
|
|
def round(inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor:
|
|
"""Round the input tensor to nearest integer
|
|
and use a straight-through estimator for the gradient.
|
|
"""
|
|
inputs_rounded = torch.round(inputs)
|
|
return inputs + (inputs_rounded - inputs).detach()
|
|
|
|
def compress(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor:
|
|
"""Apply compression to the input, to limit to values."""
|
|
output_scale = (self.num_levels - 1) / 2
|
|
# scale down a bit to avoid rounding issues
|
|
output_scale = output_scale * (1 - self.eps)
|
|
# offset for even number of levels
|
|
output_offset = torch.where(self.num_levels % 2 == 0, 0.5, 0)
|
|
# shift for even number of levels
|
|
input_shift = (output_offset / output_scale).tan()
|
|
# compressed output
|
|
output = output_scale * (inputs + input_shift).tanh() - output_offset
|
|
return output
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={"codes": NeuralType(('B', 'D', 'T'), Index())},
|
|
)
|
|
def inputs_to_codes(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor:
|
|
# apply compression
|
|
compressed = self.compress(inputs=inputs, input_len=input_len)
|
|
# apply rounding to nearest integer
|
|
codes = self.round(inputs=compressed, input_len=input_len)
|
|
# normalize to [-1, 1]
|
|
scale = self.num_levels // 2
|
|
codes = codes / scale
|
|
return codes
|
|
|
|
def codes_to_nonnegative(self, codes: torch.Tensor) -> torch.Tensor:
|
|
"""Convert values centered arouund zero to nonnegative values."""
|
|
scale = offset = self.num_levels // 2
|
|
return scale * codes + offset
|
|
|
|
def nonnegative_to_codes(self, codes_nonnegative: torch.Tensor) -> torch.Tensor:
|
|
"""Convert nonnegative values to values centered arouund zero."""
|
|
scale = offset = self.num_levels // 2
|
|
return (codes_nonnegative - offset) / scale
|
|
|
|
def codes_to_indices(self, codes: torch.Tensor) -> torch.Tensor:
|
|
"""Converts a code vector to a single index."""
|
|
if codes.size(1) != self.dim:
|
|
raise RuntimeError(
|
|
f'Input code dimension {codes.size(1)} not matching the expected dimension {self.dim}, input codes shape {codes.shape}'
|
|
)
|
|
# convert code vectors to nonnegative values
|
|
indices = self.codes_to_nonnegative(codes)
|
|
# convert one nonnegative index per dimension to a single index per code vector
|
|
indices = torch.sum(indices * self.dim_base_index, dim=1)
|
|
return indices.to(torch.int32)
|
|
|
|
# Implementation of VectorQuantiserBase API
|
|
@typecheck()
|
|
def forward(
|
|
self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
if inputs.size(1) != self.dim:
|
|
raise RuntimeError(
|
|
f'Input dimension {inputs.size(1)} not matching the expected dimension {self.dim}, inputs shape {inputs.shape}'
|
|
)
|
|
|
|
dequantized = self.inputs_to_codes(inputs=inputs, input_len=input_len)
|
|
indices = self.codes_to_indices(codes=dequantized)
|
|
|
|
if input_len is not None:
|
|
# apply masking
|
|
dequantized = mask_sequence_tensor(dequantized, input_len)
|
|
indices = mask_sequence_tensor(indices, input_len)
|
|
|
|
# only 1 codebook, but return in [D, B, T] format to match RVQ API
|
|
indices = indices.unsqueeze(0)
|
|
return dequantized, indices
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType(), optional=True),
|
|
},
|
|
output_types={"indices": NeuralType(('D', 'B', 'T'), Index())},
|
|
)
|
|
def encode(self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
"""Convert a continuous code vector to a single index."""
|
|
_, indices = self(inputs=inputs, input_len=input_len)
|
|
return indices
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"indices": NeuralType(('D', 'B', 'T'), Index()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType(), optional=True),
|
|
},
|
|
output_types={
|
|
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
|
|
},
|
|
)
|
|
def decode(self, indices: torch.Tensor, input_len: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
"""Convert a single index to a continuous code vector."""
|
|
if indices.size(0) > 1:
|
|
# codebook dimension used for compatibility with RVQ
|
|
raise ValueError(
|
|
f'Expected a single codebook, got {indices.size(0)} codebooks for indices with shape {indices.shape}.'
|
|
)
|
|
|
|
indices = rearrange(indices, 'D B T -> B D T')
|
|
# convert a single index to nonnegative index per-dimension
|
|
codes_nonnegative = (indices // self.dim_base_index) % self.num_levels
|
|
# convert nonnegative codes to codes (centered around zero)
|
|
dequantized = self.nonnegative_to_codes(codes_nonnegative)
|
|
|
|
if input_len is not None:
|
|
# apply masking
|
|
dequantized = mask_sequence_tensor(dequantized, input_len)
|
|
return dequantized
|
|
|
|
|
|
class GroupFiniteScalarQuantizer(VectorQuantizerBase):
|
|
"""Split the input vector into groups and apply FSQ on each group separately.
|
|
This class is for convenience. Since FSQ is applied on each group separately,
|
|
groups can be defined arbitrarily by splitting the input vector. However, this
|
|
class makes it easy to construct several groups with the same quantization num_levels.
|
|
|
|
Args:
|
|
num_groups: number of groups to split the input into, each group will be quantized separately using num_codebooks//num_groups codebooks
|
|
codebook_dim: embedding dimension, will be split into num_groups
|
|
**kwargs: parameters of FiniteScalarQuantizer
|
|
|
|
References:
|
|
Yang et al, HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec, 2023 (http://arxiv.org/abs/2305.02765).
|
|
"""
|
|
|
|
def __init__(self, num_groups: int, num_levels_per_group: List[int], **kwargs):
|
|
super().__init__()
|
|
|
|
self.num_groups = num_groups
|
|
self.codebook_dim_per_group = len(num_levels_per_group)
|
|
|
|
# Initialize FSQ for each group
|
|
self.fsqs = torch.nn.ModuleList(
|
|
[FiniteScalarQuantizer(num_levels=num_levels_per_group, **kwargs) for _ in range(self.num_groups)]
|
|
)
|
|
|
|
@property
|
|
def codebook_dim(self):
|
|
"""Input vector dimension."""
|
|
return self.codebook_dim_per_group * self.num_groups
|
|
|
|
@property
|
|
def codebook_size_per_group(self):
|
|
"""Returns the size of the implicit codebook for each group."""
|
|
return self.fsqs[0].codebook_size
|
|
|
|
@property
|
|
def codebook_size(self):
|
|
"""Returns the size of the implicit codebook."""
|
|
return self.codebook_size_per_group**self.num_groups
|
|
|
|
@typecheck()
|
|
def forward(self, inputs, input_len):
|
|
"""Quantize each group separately, then concatenate the results."""
|
|
inputs_grouped = inputs.chunk(self.num_groups, dim=1)
|
|
|
|
dequantized, indices = [], []
|
|
|
|
for in_group, fsq_group in zip(inputs_grouped, self.fsqs):
|
|
dequantized_group, indices_group = fsq_group(inputs=in_group, input_len=input_len)
|
|
dequantized.append(dequantized_group)
|
|
indices.append(indices_group)
|
|
|
|
# concatenate along the feature dimension
|
|
dequantized = torch.cat(dequantized, dim=1)
|
|
|
|
# concatente along the codebook dimension
|
|
indices = torch.cat(indices, dim=0)
|
|
|
|
return dequantized, indices
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={"indices": NeuralType(('D', 'B', 'T'), Index())},
|
|
)
|
|
def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor:
|
|
"""Input is split into groups, each group is encoded separately, then the results are concatenated."""
|
|
inputs_grouped = inputs.chunk(self.num_groups, dim=1)
|
|
indices = []
|
|
|
|
for in_group, fsq_group in zip(inputs_grouped, self.fsqs):
|
|
indices_group = fsq_group.encode(inputs=in_group, input_len=input_len)
|
|
indices.append(indices_group)
|
|
|
|
# concatenate along the codebook dimension
|
|
indices = torch.cat(indices, dim=0)
|
|
|
|
return indices
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"indices": NeuralType(('D', 'B', 'T'), Index()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={
|
|
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
|
|
},
|
|
)
|
|
def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor:
|
|
"""Input indices are split into groups, each group is decoded separately, then the results are concatenated."""
|
|
indices_grouped = indices.chunk(self.num_groups, dim=0)
|
|
dequantized = []
|
|
|
|
for indices_group, fsq_group in zip(indices_grouped, self.fsqs):
|
|
dequantized_group = fsq_group.decode(indices=indices_group, input_len=input_len)
|
|
dequantized.append(dequantized_group)
|
|
|
|
# concatenate along the feature dimension
|
|
dequantized = torch.cat(dequantized, dim=1)
|
|
|
|
return dequantized
|
|
|
|
|
|
class ResidualBlock(NeuralModule):
|
|
"""
|
|
The residual block structure defined by the HiFi-GAN V1 and V2 configurations.
|
|
|
|
Args:
|
|
channels: Input dimension.
|
|
filters: Number of channels in the residual convolutions.
|
|
kernel_size: Kernel size of the residual convolutions.
|
|
dilation: Dilation of the residual convolutions.
|
|
dropout_rate: Dropout to apply to residuals.
|
|
activation: Activation to apply in between residual convolutions.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels: int,
|
|
filters: int,
|
|
kernel_size: int = 3,
|
|
dilation: int = 1,
|
|
dropout_rate: float = 0.0,
|
|
activation: str = "lrelu",
|
|
):
|
|
super(ResidualBlock, self).__init__()
|
|
|
|
self.input_activation = CodecActivation(activation=activation, channels=channels)
|
|
self.skip_activation = CodecActivation(activation=activation, channels=filters)
|
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
|
self.input_conv = Conv1dNorm(
|
|
in_channels=channels, out_channels=filters, kernel_size=kernel_size, dilation=dilation
|
|
)
|
|
self.skip_conv = Conv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size)
|
|
|
|
def remove_weight_norm(self):
|
|
self.input_conv.remove_weight_norm()
|
|
self.skip_conv.remove_weight_norm()
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {"out": NeuralType(('B', 'C', 'T'), EncodedRepresentation())}
|
|
|
|
@typecheck()
|
|
def forward(self, inputs, input_len):
|
|
conv_input = self.input_activation(inputs)
|
|
skip_input = self.input_conv(inputs=conv_input, input_len=input_len)
|
|
skip_input = self.skip_activation(skip_input)
|
|
res = self.skip_conv(inputs=skip_input, input_len=input_len)
|
|
res = self.dropout(res)
|
|
out = inputs + res
|
|
return out
|
|
|
|
|
|
class HiFiGANResBlock(NeuralModule):
|
|
"""
|
|
Residual block wrapper for HiFi-GAN which creates a block for multiple dilations.
|
|
|
|
Args:
|
|
channels: Input dimension.
|
|
kernel_size: Kernel size of the residual blocks.
|
|
dilations: List of dilations. One residual block will be created for each dilation in the list.
|
|
activation: Activation for the residual blocks.
|
|
"""
|
|
|
|
def __init__(self, channels: int, kernel_size: int, dilations: Iterable[int], activation: str):
|
|
super().__init__()
|
|
|
|
self.res_blocks = nn.ModuleList(
|
|
[
|
|
ResidualBlock(
|
|
channels=channels,
|
|
filters=channels,
|
|
kernel_size=kernel_size,
|
|
dilation=dilation,
|
|
activation=activation,
|
|
)
|
|
for dilation in dilations
|
|
]
|
|
)
|
|
|
|
def remove_weight_norm(self):
|
|
for res_block in self.res_blocks:
|
|
res_block.remove_weight_norm()
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"inputs": NeuralType(('B', 'C', 'T'), VoidType()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {"out": NeuralType(('B', 'C', 'T'), VoidType())}
|
|
|
|
@typecheck()
|
|
def forward(self, inputs, input_len):
|
|
out = inputs
|
|
for res_block in self.res_blocks:
|
|
out = res_block(inputs=out, input_len=input_len)
|
|
return out
|
|
|
|
|
|
class HiFiGANResLayer(NeuralModule):
|
|
"""
|
|
Residual block wrapper for HiFi-GAN which creates a block for multiple kernel sizes and dilations.
|
|
One residual block is created for each combination of kernel size and dilation.
|
|
|
|
Args:
|
|
channels: Input dimension.
|
|
kernel_sizes: List of kernel sizes.
|
|
dilations: List of dilations.
|
|
activation: Activation for the residual layers.
|
|
|
|
"""
|
|
|
|
def __init__(self, channels: int, kernel_sizes: Iterable[int], dilations: Iterable[int], activation: str):
|
|
super().__init__()
|
|
|
|
self.res_blocks = nn.ModuleList(
|
|
[
|
|
HiFiGANResBlock(channels=channels, kernel_size=kernel_size, dilations=dilations, activation=activation)
|
|
for kernel_size in kernel_sizes
|
|
]
|
|
)
|
|
|
|
def remove_weight_norm(self):
|
|
for res_block in self.res_blocks:
|
|
res_block.remove_weight_norm()
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"inputs": NeuralType(('B', 'D', 'T'), VoidType()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {"out": NeuralType(('B', 'D', 'T'), VoidType())}
|
|
|
|
@typecheck()
|
|
def forward(self, inputs, input_len):
|
|
residuals = [res_block(inputs=inputs, input_len=input_len) for res_block in self.res_blocks]
|
|
out = sum(residuals) / len(residuals)
|
|
return out
|
|
|
|
|
|
class HiFiGANEncoder(NeuralModule):
|
|
"""
|
|
Audio encoder created by inverting the HiFi-GAN decoder.
|
|
|
|
Args:
|
|
encoded_dim: Dimension of encoder output.
|
|
down_sample_rates: Rate to upsample for each decoder block. The product of the downsample rates will
|
|
determine the output token rate. For example 2 * 2 * 8 * 8 = 256 samples per token.
|
|
base_channels: Number of filters in the first convolution. The number of channels will be doubled after each
|
|
downsample layer.
|
|
in_kernel_size: Kernel size of the input convolution.
|
|
out_kernel_size: Kernel size of the output convolution.
|
|
resblock_kernel_sizes: List of kernel sizes to use in each residual block.
|
|
resblock_dilation_sizes: List of dilations to use in each residual block.
|
|
activation: Activation to use in residual and downsample layers, defaults to leaky relu.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
encoded_dim: int,
|
|
down_sample_rates: Iterable[int] = (2, 2, 8, 8),
|
|
base_channels: int = 32,
|
|
in_kernel_size: int = 7,
|
|
out_kernel_size: int = 7,
|
|
resblock_kernel_sizes: Iterable[int] = (3, 7, 11),
|
|
resblock_dilation_sizes: Iterable[int] = (1, 3, 5),
|
|
activation: str = "lrelu",
|
|
):
|
|
assert in_kernel_size > 0
|
|
assert out_kernel_size > 0
|
|
|
|
super().__init__()
|
|
|
|
self.down_sample_rates = down_sample_rates
|
|
self.pre_conv = Conv1dNorm(in_channels=1, out_channels=base_channels, kernel_size=in_kernel_size)
|
|
|
|
in_channels = base_channels
|
|
self.activations = nn.ModuleList([])
|
|
self.down_sample_conv_layers = nn.ModuleList([])
|
|
self.res_layers = nn.ModuleList([])
|
|
for i, down_sample_rate in enumerate(self.down_sample_rates):
|
|
res_layer = HiFiGANResLayer(
|
|
channels=in_channels,
|
|
kernel_sizes=resblock_kernel_sizes,
|
|
dilations=resblock_dilation_sizes,
|
|
activation=activation,
|
|
)
|
|
self.res_layers.append(res_layer)
|
|
|
|
act = CodecActivation(activation, channels=in_channels)
|
|
self.activations.append(act)
|
|
|
|
out_channels = 2 * in_channels
|
|
kernel_size = 2 * down_sample_rate
|
|
|
|
padding = get_down_sample_padding(kernel_size=kernel_size, stride=down_sample_rate)
|
|
down_sample_conv = Conv1dNorm(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=down_sample_rate,
|
|
padding=padding,
|
|
)
|
|
in_channels = out_channels
|
|
self.down_sample_conv_layers.append(down_sample_conv)
|
|
|
|
self.post_activation = CodecActivation(activation, channels=in_channels)
|
|
self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=encoded_dim, kernel_size=out_kernel_size)
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
|
|
"encoded_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
def remove_weight_norm(self):
|
|
self.pre_conv.remove_weight_norm()
|
|
self.post_conv.remove_weight_norm()
|
|
for res_layer in self.res_layers:
|
|
res_layer.remove_weight_norm()
|
|
for down_sample_conv in self.down_sample_conv_layers:
|
|
down_sample_conv.remove_weight_norm()
|
|
|
|
@typecheck()
|
|
def forward(self, audio, audio_len):
|
|
encoded_len = audio_len
|
|
audio = rearrange(audio, "B T -> B 1 T")
|
|
# [B, C, T_audio]
|
|
out = self.pre_conv(inputs=audio, input_len=encoded_len)
|
|
for act, res_layer, down_sample_conv, down_sample_rate in zip(
|
|
self.activations, self.res_layers, self.down_sample_conv_layers, self.down_sample_rates
|
|
):
|
|
# [B, C, T]
|
|
out = res_layer(inputs=out, input_len=encoded_len)
|
|
out = act(out)
|
|
|
|
encoded_len = encoded_len // down_sample_rate
|
|
# [B, 2 * C, T / down_sample_rate]
|
|
out = down_sample_conv(inputs=out, input_len=encoded_len)
|
|
|
|
out = self.post_activation(out)
|
|
# [B, encoded_dim, T_encoded]
|
|
encoded = self.post_conv(inputs=out, input_len=encoded_len)
|
|
return encoded, encoded_len
|
|
|
|
|
|
class HiFiGANDecoder(NeuralModule):
|
|
"""
|
|
Codec decoder using the HiFi-GAN generator architecture.
|
|
|
|
Default parameters match the HiFi-GAN V1 configuration for 22.05khz.
|
|
|
|
Args:
|
|
input_dim: Input dimension.
|
|
up_sample_rates: Rate to upsample for each decoder block. The product of the upsample rates should be the same
|
|
as the overall downsample rate for your encoder. For example, a symmetric encoder/decoder can be created
|
|
with encoder downsample rates [2, 2, 8, 8] and decoder upsample rates [8, 8, 2, 2].
|
|
base_channels: Number of filters in the first convolution. The number of channels will be cut in
|
|
half after each upsample layer.
|
|
in_kernel_size: Kernel size of the input convolution.
|
|
out_kernel_size: Kernel size of the output convolution.
|
|
resblock_kernel_sizes: List of kernel sizes to use in each residual block.
|
|
resblock_dilation_sizes: List of dilations to use in each residual block.
|
|
activation: Activation to use in residual and upsample layers, defaults to leaky relu.
|
|
output_activation: Activation to apply to output. To produce a valid audio signal, it should output values in
|
|
the range [-1.0, 1.0]. Supports "tanh" and "clamp".
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
up_sample_rates: Iterable[int] = (8, 8, 2, 2),
|
|
base_channels: int = 512,
|
|
in_kernel_size: int = 7,
|
|
out_kernel_size: int = 3,
|
|
resblock_kernel_sizes: Iterable[int] = (3, 7, 11),
|
|
resblock_dilation_sizes: Iterable[int] = (1, 3, 5),
|
|
activation: str = "lrelu",
|
|
output_activation: str = "tanh",
|
|
):
|
|
assert in_kernel_size > 0
|
|
assert out_kernel_size > 0
|
|
|
|
super().__init__()
|
|
|
|
self.up_sample_rates = up_sample_rates
|
|
self.pre_conv = Conv1dNorm(in_channels=input_dim, out_channels=base_channels, kernel_size=in_kernel_size)
|
|
|
|
in_channels = base_channels
|
|
self.activations = nn.ModuleList([])
|
|
self.up_sample_conv_layers = nn.ModuleList([])
|
|
self.res_layers = nn.ModuleList([])
|
|
for i, up_sample_rate in enumerate(self.up_sample_rates):
|
|
out_channels = in_channels // 2
|
|
kernel_size = 2 * up_sample_rate
|
|
|
|
act = CodecActivation(activation, channels=in_channels)
|
|
self.activations.append(act)
|
|
|
|
up_sample_conv = ConvTranspose1dNorm(
|
|
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=up_sample_rate
|
|
)
|
|
in_channels = out_channels
|
|
self.up_sample_conv_layers.append(up_sample_conv)
|
|
|
|
res_layer = HiFiGANResLayer(
|
|
channels=in_channels,
|
|
kernel_sizes=resblock_kernel_sizes,
|
|
dilations=resblock_dilation_sizes,
|
|
activation=activation,
|
|
)
|
|
self.res_layers.append(res_layer)
|
|
|
|
self.post_activation = CodecActivation(activation, channels=in_channels)
|
|
self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=1, kernel_size=out_kernel_size)
|
|
if output_activation == "tanh":
|
|
self.out_activation = nn.Tanh()
|
|
elif output_activation == "clamp":
|
|
self.out_activation = ClampActivation()
|
|
else:
|
|
raise ValueError(f"Invalid audio output activation {output_activation}")
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"inputs": NeuralType(('B', 'D', 'T_encoded'), VoidType()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
def remove_weight_norm(self):
|
|
self.pre_conv.remove_weight_norm()
|
|
for up_sample_conv in self.up_sample_conv_layers:
|
|
up_sample_conv.remove_weight_norm()
|
|
for res_layer in self.res_layers:
|
|
res_layer.remove_weight_norm()
|
|
|
|
@typecheck()
|
|
def forward(self, inputs, input_len):
|
|
audio_len = input_len
|
|
# [B, C, T_encoded]
|
|
out = self.pre_conv(inputs=inputs, input_len=audio_len)
|
|
for act, res_layer, up_sample_conv, up_sample_rate in zip(
|
|
self.activations, self.res_layers, self.up_sample_conv_layers, self.up_sample_rates
|
|
):
|
|
audio_len = audio_len * up_sample_rate
|
|
out = act(out)
|
|
# [B, C / 2, T * up_sample_rate]
|
|
out = up_sample_conv(inputs=out, input_len=audio_len)
|
|
out = res_layer(inputs=out, input_len=audio_len)
|
|
|
|
out = self.post_activation(out)
|
|
# [B, 1, T_audio]
|
|
out = self.post_conv(inputs=out, input_len=audio_len)
|
|
audio = self.out_activation(out)
|
|
audio = rearrange(audio, "B 1 T -> B T")
|
|
return audio, audio_len
|
|
|
|
|
|
class MelSpectrogramProcessor(NeuralModule):
|
|
"""
|
|
Wrapper interface for computing mel spectrogram for codec training.
|
|
"""
|
|
|
|
def __init__(self, sample_rate: int, win_length: int, hop_length: int, mel_dim: int = 80, log_guard: float = 1.0):
|
|
super(MelSpectrogramProcessor, self).__init__()
|
|
self.mel_dim = mel_dim
|
|
self.hop_length = hop_length
|
|
self.preprocessor = AudioToMelSpectrogramPreprocessor(
|
|
sample_rate=sample_rate,
|
|
highfreq=None,
|
|
features=mel_dim,
|
|
pad_to=1,
|
|
exact_pad=True,
|
|
n_window_size=win_length,
|
|
n_window_stride=hop_length,
|
|
window_size=False,
|
|
window_stride=False,
|
|
n_fft=win_length,
|
|
mag_power=1.0,
|
|
log=True,
|
|
log_zero_guard_type="add",
|
|
log_zero_guard_value=log_guard,
|
|
mel_norm=None,
|
|
normalize=None,
|
|
preemph=None,
|
|
dither=0.0,
|
|
)
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
|
|
"spec_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, audio, audio_len):
|
|
spec, spec_len = self.preprocessor(input_signal=audio, length=audio_len)
|
|
return spec, spec_len
|
|
|
|
|
|
class ResNetEncoder(NeuralModule):
|
|
"""
|
|
Residual network which uses HiFi-GAN residual blocks to encode spectrogram features without changing
|
|
the time dimension.
|
|
|
|
Args:
|
|
in_channels: input dimension
|
|
out_channels: output dimension
|
|
num_layers: number of residual blocks to use
|
|
hidden_channels: encoder hidden dimension
|
|
filters: number of filters in residual block layers
|
|
kernel_size: kernel size in residual block convolutions
|
|
dropout_rate: Optional dropout rate to apply to residuals.
|
|
activation: Activation to use, defaults to leaky relu.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
num_layers: int = 6,
|
|
hidden_channels: int = 256,
|
|
filters: int = 768,
|
|
kernel_size: int = 3,
|
|
dropout_rate: float = 0.1,
|
|
activation: str = "lrelu",
|
|
):
|
|
super(ResNetEncoder, self).__init__()
|
|
|
|
self.pre_conv = Conv1dNorm(in_channels=in_channels, out_channels=hidden_channels, kernel_size=kernel_size)
|
|
self.res_layers = nn.ModuleList(
|
|
[
|
|
ResidualBlock(
|
|
channels=hidden_channels,
|
|
filters=filters,
|
|
kernel_size=kernel_size,
|
|
dropout_rate=dropout_rate,
|
|
activation=activation,
|
|
)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
self.post_activation = CodecActivation(activation, channels=hidden_channels)
|
|
self.post_conv = Conv1dNorm(in_channels=hidden_channels, out_channels=out_channels, kernel_size=kernel_size)
|
|
|
|
def remove_weight_norm(self):
|
|
self.pre_conv.remove_weight_norm()
|
|
self.post_conv.remove_weight_norm()
|
|
for res_layer in self.res_layers:
|
|
res_layer.remove_weight_norm()
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"inputs": NeuralType(('B', 'D', 'T'), VoidType()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {"encoded": NeuralType(('B', 'C', 'T'), EncodedRepresentation())}
|
|
|
|
@typecheck()
|
|
def forward(self, inputs, input_len):
|
|
encoded = self.pre_conv(inputs=inputs, input_len=input_len)
|
|
for res_layer in self.res_layers:
|
|
encoded = res_layer(inputs=encoded, input_len=input_len)
|
|
encoded = self.post_activation(encoded)
|
|
encoded = self.post_conv(inputs=encoded, input_len=input_len)
|
|
return encoded
|
|
|
|
|
|
class FullBandMelEncoder(NeuralModule):
|
|
"""
|
|
Encoder which encodes the entire mel spectrogram with a single encoder network.
|
|
|
|
Args:
|
|
mel_processor: MelSpectrogramProcessor or equivalent class instance for computing the mel spectrogram from
|
|
input audio.
|
|
encoder: ResNetEncoder or equivalent class for encoding the mel spectrogram.
|
|
"""
|
|
|
|
def __init__(self, mel_processor: NeuralModule, encoder: NeuralModule):
|
|
super(FullBandMelEncoder, self).__init__()
|
|
self.mel_processor = mel_processor
|
|
self.encoder = encoder
|
|
|
|
def remove_weight_norm(self):
|
|
self.encoder.remove_weight_norm()
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"encoded": NeuralType(('B', 'C', 'T_encoded'), EncodedRepresentation()),
|
|
"encoded_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, audio, audio_len):
|
|
out, spec_len = self.mel_processor(audio=audio, audio_len=audio_len)
|
|
encoded = self.encoder(inputs=out, input_len=spec_len)
|
|
return encoded, spec_len
|
|
|
|
|
|
class MultiBandMelEncoder(NeuralModule):
|
|
"""
|
|
Encoder which splits mel spectrogram into bands and encodes each using separate residual networks.
|
|
|
|
Args:
|
|
mel_bands: List of mel spectrogram bands to encode.
|
|
Each list element is tuple of 2 elements with the start and end index of the mel features to use.
|
|
mel_processor: MelSpectrogramProcessor or equivalent class instance for computing the mel spectrogram from
|
|
input audio.
|
|
encoder_kwargs: Arguments for constructing encoder for each mel band.
|
|
"""
|
|
|
|
def __init__(self, mel_bands: Iterable[Tuple[int, int]], mel_processor: NeuralModule, **encoder_kwargs):
|
|
super(MultiBandMelEncoder, self).__init__()
|
|
self.validate_mel_bands(mel_dim=mel_processor.mel_dim, mel_bands=mel_bands)
|
|
self.mel_bands = mel_bands
|
|
self.mel_processor = mel_processor
|
|
band_dims = [band[1] - band[0] for band in self.mel_bands]
|
|
self.encoders = nn.ModuleList(
|
|
[ResNetEncoder(in_channels=band_dim, **encoder_kwargs) for band_dim in band_dims]
|
|
)
|
|
|
|
@staticmethod
|
|
def validate_mel_bands(mel_dim: int, mel_bands: Iterable[Tuple[int, int]]):
|
|
mel_dims_used = np.zeros([mel_dim], dtype=bool)
|
|
for band in mel_bands:
|
|
mel_dims_used[band[0] : band[1]] = True
|
|
|
|
if not all(mel_dims_used):
|
|
missing_dims = np.where(~mel_dims_used)
|
|
raise ValueError(f"Mel bands must cover all {mel_dim} dimensions. Missing {missing_dims}.")
|
|
|
|
return
|
|
|
|
def remove_weight_norm(self):
|
|
for encoder in self.encoders:
|
|
encoder.remove_weight_norm()
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"encoded": NeuralType(('B', 'C', 'T_encoded'), EncodedRepresentation()),
|
|
"encoded_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, audio, audio_len):
|
|
spec, spec_len = self.mel_processor(audio=audio, audio_len=audio_len)
|
|
outputs = []
|
|
for (band_start, band_end), encoder in zip(self.mel_bands, self.encoders):
|
|
# [B, D_band, T]
|
|
spec_band = spec[:, band_start:band_end, :]
|
|
band_out = encoder(inputs=spec_band, input_len=spec_len)
|
|
outputs.append(band_out)
|
|
# [B, C, T]
|
|
encoded = torch.cat(outputs, dim=1)
|
|
return encoded, spec_len
|
|
|
|
class FilterbankFeatures(nn.Module):
|
|
"""Featurizer that converts wavs to Mel Spectrograms.
|
|
See AudioToMelSpectrogramPreprocessor for args.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
sample_rate=16000,
|
|
n_window_size=320,
|
|
n_window_stride=160,
|
|
window="hann",
|
|
normalize="per_feature",
|
|
n_fft=None,
|
|
preemph=0.97,
|
|
nfilt=64,
|
|
lowfreq=0,
|
|
highfreq=None,
|
|
log=True,
|
|
log_zero_guard_type="add",
|
|
log_zero_guard_value=2**-24,
|
|
dither=CONSTANT,
|
|
pad_to=16,
|
|
max_duration=16.7,
|
|
frame_splicing=1,
|
|
exact_pad=False,
|
|
pad_value=0,
|
|
mag_power=2.0,
|
|
use_grads=False,
|
|
rng=None,
|
|
nb_augmentation_prob=0.0,
|
|
nb_max_freq=4000,
|
|
mel_norm="slaney",
|
|
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
|
|
stft_conv=False, # Deprecated arguments; kept for config compatibility
|
|
):
|
|
super().__init__()
|
|
if exact_pad and n_window_stride % 2 == 1:
|
|
raise NotImplementedError(
|
|
f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
|
|
"returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
|
|
)
|
|
self.log_zero_guard_value = log_zero_guard_value
|
|
if (
|
|
n_window_size is None
|
|
or n_window_stride is None
|
|
or not isinstance(n_window_size, int)
|
|
or not isinstance(n_window_stride, int)
|
|
or n_window_size <= 0
|
|
or n_window_stride <= 0
|
|
):
|
|
raise ValueError(
|
|
f"{self} got an invalid value for either n_window_size or "
|
|
f"n_window_stride. Both must be positive ints."
|
|
)
|
|
|
|
self.win_length = n_window_size
|
|
self.hop_length = n_window_stride
|
|
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
|
|
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
|
|
self.exact_pad = exact_pad
|
|
|
|
torch_windows = {
|
|
'hann': torch.hann_window,
|
|
'hamming': torch.hamming_window,
|
|
'blackman': torch.blackman_window,
|
|
'bartlett': torch.bartlett_window,
|
|
'none': None,
|
|
}
|
|
window_fn = torch_windows.get(window, None)
|
|
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
|
|
self.register_buffer("window", window_tensor)
|
|
|
|
self.normalize = normalize
|
|
self.log = log
|
|
self.dither = dither
|
|
self.frame_splicing = frame_splicing
|
|
self.nfilt = nfilt
|
|
self.preemph = preemph
|
|
self.pad_to = pad_to
|
|
highfreq = highfreq or sample_rate / 2
|
|
|
|
filterbanks = torch.tensor(
|
|
librosa.filters.mel(
|
|
sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq, norm=mel_norm
|
|
),
|
|
dtype=torch.float,
|
|
).unsqueeze(0)
|
|
self.register_buffer("fb", filterbanks)
|
|
|
|
# Calculate maximum sequence length
|
|
max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
|
|
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
|
|
self.max_length = max_length + max_pad
|
|
self.pad_value = pad_value
|
|
self.mag_power = mag_power
|
|
|
|
# We want to avoid taking the log of zero
|
|
# There are two options: either adding or clamping to a small value
|
|
if log_zero_guard_type not in ["add", "clamp"]:
|
|
raise ValueError(
|
|
f"{self} received {log_zero_guard_type} for the "
|
|
f"log_zero_guard_type parameter. It must be either 'add' or "
|
|
f"'clamp'."
|
|
)
|
|
|
|
self.use_grads = use_grads
|
|
if not use_grads:
|
|
self.forward = torch.no_grad()(self.forward)
|
|
self._rng = random.Random() if rng is None else rng
|
|
self.nb_augmentation_prob = nb_augmentation_prob
|
|
if self.nb_augmentation_prob > 0.0:
|
|
if nb_max_freq >= sample_rate / 2:
|
|
self.nb_augmentation_prob = 0.0
|
|
else:
|
|
self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)
|
|
|
|
# log_zero_guard_value is the the small we want to use, we support
|
|
# an actual number, or "tiny", or "eps"
|
|
self.log_zero_guard_type = log_zero_guard_type
|
|
|
|
def stft(self, x):
|
|
return torch.stft(
|
|
x,
|
|
n_fft=self.n_fft,
|
|
hop_length=self.hop_length,
|
|
win_length=self.win_length,
|
|
center=False if self.exact_pad else True,
|
|
window=self.window.to(dtype=torch.float),
|
|
return_complex=True,
|
|
)
|
|
|
|
def log_zero_guard_value_fn(self, x):
|
|
if isinstance(self.log_zero_guard_value, str):
|
|
if self.log_zero_guard_value == "tiny":
|
|
return torch.finfo(x.dtype).tiny
|
|
elif self.log_zero_guard_value == "eps":
|
|
return torch.finfo(x.dtype).eps
|
|
else:
|
|
raise ValueError(
|
|
f"{self} received {self.log_zero_guard_value} for the "
|
|
f"log_zero_guard_type parameter. It must be either a "
|
|
f"number, 'tiny', or 'eps'"
|
|
)
|
|
else:
|
|
return self.log_zero_guard_value
|
|
|
|
def get_seq_len(self, seq_len):
|
|
# Assuming that center is True is stft_pad_amount = 0
|
|
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
|
|
seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1
|
|
return seq_len.to(dtype=torch.long)
|
|
|
|
@property
|
|
def filter_banks(self):
|
|
return self.fb
|
|
|
|
def forward(self, x, seq_len, linear_spec=False):
|
|
seq_len = self.get_seq_len(seq_len)
|
|
|
|
if self.stft_pad_amount is not None:
|
|
x = torch.nn.functional.pad(
|
|
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
|
|
).squeeze(1)
|
|
|
|
# dither (only in training mode for eval determinism)
|
|
if self.training and self.dither > 0:
|
|
x += self.dither * torch.randn_like(x)
|
|
|
|
# do preemphasis
|
|
if self.preemph is not None:
|
|
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
|
|
|
|
# disable autocast to get full range of stft values
|
|
with torch.amp.autocast(x.device.type, enabled=False):
|
|
x = self.stft(x)
|
|
|
|
# torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
|
|
# guard is needed for sqrt if grads are passed through
|
|
guard = 0 if not self.use_grads else CONSTANT
|
|
x = torch.view_as_real(x)
|
|
x = torch.sqrt(x.pow(2).sum(-1) + guard)
|
|
|
|
if self.training and self.nb_augmentation_prob > 0.0:
|
|
for idx in range(x.shape[0]):
|
|
if self._rng.random() < self.nb_augmentation_prob:
|
|
x[idx, self._nb_max_fft_bin :, :] = 0.0
|
|
|
|
# get power spectrum
|
|
if self.mag_power != 1.0:
|
|
x = x.pow(self.mag_power)
|
|
|
|
# return plain spectrogram if required
|
|
if linear_spec:
|
|
return x, seq_len
|
|
|
|
# dot with filterbank energies
|
|
x = torch.matmul(self.fb.to(x.dtype), x)
|
|
# log features if required
|
|
if self.log:
|
|
if self.log_zero_guard_type == "add":
|
|
x = torch.log(x + self.log_zero_guard_value_fn(x))
|
|
elif self.log_zero_guard_type == "clamp":
|
|
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
|
|
else:
|
|
raise ValueError("log_zero_guard_type was not understood")
|
|
|
|
# frame splicing if required
|
|
if self.frame_splicing > 1:
|
|
x = splice_frames(x, self.frame_splicing)
|
|
|
|
# normalize if required
|
|
if self.normalize:
|
|
x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize)
|
|
|
|
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
|
|
max_len = x.size(-1)
|
|
mask = torch.arange(max_len, device=x.device)
|
|
mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
|
|
x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
|
|
del mask
|
|
pad_to = self.pad_to
|
|
if pad_to == "max":
|
|
x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
|
|
elif pad_to > 0:
|
|
pad_amt = x.size(-1) % pad_to
|
|
if pad_amt != 0:
|
|
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
|
|
return x, seq_len
|
|
|
|
class MaskedLoss(Loss):
|
|
def __init__(self, loss_fn, loss_scale: float = 1.0):
|
|
super(MaskedLoss, self).__init__()
|
|
self.loss_scale = loss_scale
|
|
self.loss_fn = loss_fn
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"predicted": NeuralType(('B', 'D', 'T'), PredictionsType()),
|
|
"target": NeuralType(('B', 'D', 'T'), RegressionValuesType()),
|
|
"target_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"loss": NeuralType(elements_type=LossType()),
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, predicted, target, target_len):
|
|
assert target.shape[2] == predicted.shape[2]
|
|
|
|
# [B, D, T]
|
|
loss = self.loss_fn(input=predicted, target=target)
|
|
# [B, T]
|
|
loss = torch.mean(loss, dim=1)
|
|
# [B]
|
|
loss = torch.sum(loss, dim=1) / torch.clamp(target_len, min=1.0)
|
|
|
|
# [1]
|
|
loss = torch.mean(loss)
|
|
loss = self.loss_scale * loss
|
|
|
|
return loss
|
|
|
|
|
|
class MaskedMAELoss(MaskedLoss):
|
|
def __init__(self, loss_scale: float = 1.0):
|
|
loss_fn = torch.nn.L1Loss(reduction='none')
|
|
super(MaskedMAELoss, self).__init__(loss_fn=loss_fn, loss_scale=loss_scale)
|
|
|
|
|
|
class MaskedMSELoss(MaskedLoss):
|
|
def __init__(self, loss_scale: float = 1.0):
|
|
loss_fn = torch.nn.MSELoss(reduction='none')
|
|
super(MaskedMSELoss, self).__init__(loss_fn=loss_fn, loss_scale=loss_scale)
|
|
|
|
|
|
class TimeDomainLoss(Loss):
|
|
def __init__(self):
|
|
super(TimeDomainLoss, self).__init__()
|
|
self.loss_fn = MaskedMAELoss()
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio_real": NeuralType(('B', 'T'), AudioSignal()),
|
|
"audio_gen": NeuralType(('B', 'T'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"loss": NeuralType(elements_type=LossType()),
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, audio_real, audio_gen, audio_len):
|
|
audio_real = rearrange(audio_real, "B T -> B 1 T")
|
|
audio_gen = rearrange(audio_gen, "B T -> B 1 T")
|
|
loss = self.loss_fn(target=audio_real, predicted=audio_gen, target_len=audio_len)
|
|
return loss
|
|
|
|
|
|
class MultiResolutionMelLoss(Loss):
|
|
"""
|
|
Multi-resolution log mel spectrogram loss.
|
|
|
|
Args:
|
|
sample_rate: Sample rate of audio.
|
|
resolutions: List of resolutions, each being 3 integers ordered [num_fft, hop_length, window_length]
|
|
mel_dims: Dimension of mel spectrogram to compute for each resolution. Should be same length as 'resolutions'.
|
|
log_guard: Value to add to mel spectrogram to avoid taking log of 0.
|
|
"""
|
|
|
|
def __init__(self, sample_rate: int, resolutions: List[List], mel_dims: List[int], log_guard: float = 1.0):
|
|
super(MultiResolutionMelLoss, self).__init__()
|
|
assert len(resolutions) == len(mel_dims)
|
|
|
|
self.l1_loss_fn = MaskedMAELoss()
|
|
self.l2_loss_fn = MaskedMSELoss()
|
|
|
|
self.mel_features = torch.nn.ModuleList()
|
|
for mel_dim, (n_fft, hop_len, win_len) in zip(mel_dims, resolutions):
|
|
mel_feature = FilterbankFeatures(
|
|
sample_rate=sample_rate,
|
|
nfilt=mel_dim,
|
|
n_window_size=win_len,
|
|
n_window_stride=hop_len,
|
|
n_fft=n_fft,
|
|
pad_to=1,
|
|
mag_power=1.0,
|
|
log_zero_guard_type="add",
|
|
log_zero_guard_value=log_guard,
|
|
mel_norm=None,
|
|
normalize=None,
|
|
preemph=None,
|
|
dither=0.0,
|
|
use_grads=True,
|
|
)
|
|
self.mel_features.append(mel_feature)
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio_real": NeuralType(('B', 'T'), AudioSignal()),
|
|
"audio_gen": NeuralType(('B', 'T'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"l1_loss": NeuralType(elements_type=LossType()),
|
|
"l2_loss": NeuralType(elements_type=LossType()),
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, audio_real, audio_gen, audio_len):
|
|
l1_loss = 0.0
|
|
l2_loss = 0.0
|
|
for mel_feature in self.mel_features:
|
|
mel_real, mel_real_len = mel_feature(x=audio_real, seq_len=audio_len)
|
|
mel_gen, _ = mel_feature(x=audio_gen, seq_len=audio_len)
|
|
l1_loss += self.l1_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len)
|
|
l2_loss += self.l2_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len)
|
|
|
|
l1_loss /= len(self.mel_features)
|
|
l2_loss /= len(self.mel_features)
|
|
|
|
return l1_loss, l2_loss
|
|
|
|
|
|
class STFTLoss(Loss):
|
|
"""
|
|
Log magnitude STFT loss.
|
|
|
|
Args:
|
|
resolution: Resolution of spectrogram, a list of 3 numbers ordered [num_fft, hop_length, window_length]
|
|
log_guard: Value to add to magnitude spectrogram to avoid taking log of 0.
|
|
sqrt_guard: Value to add to when computing absolute value of STFT to avoid NaN loss.
|
|
"""
|
|
|
|
def __init__(self, resolution: List[int], log_guard: float = 1.0, sqrt_guard: float = 1e-5):
|
|
super(STFTLoss, self).__init__()
|
|
self.loss_fn = MaskedMAELoss()
|
|
self.n_fft, self.hop_length, self.win_length = resolution
|
|
self.register_buffer("window", torch.hann_window(self.win_length, periodic=False))
|
|
self.log_guard = log_guard
|
|
self.sqrt_guard = sqrt_guard
|
|
|
|
def _compute_spectrogram(self, audio, spec_len):
|
|
# [B, n_fft, T_spec]
|
|
spec = torch.stft(
|
|
audio,
|
|
n_fft=self.n_fft,
|
|
hop_length=self.hop_length,
|
|
win_length=self.win_length,
|
|
window=self.window,
|
|
return_complex=True,
|
|
)
|
|
# [B, n_fft, T_spec, 2]
|
|
spec = torch.view_as_real(spec)
|
|
# [B, n_fft, T_spec]
|
|
spec_mag = torch.sqrt(spec.pow(2).sum(-1) + self.sqrt_guard)
|
|
spec_log = torch.log(spec_mag + self.log_guard)
|
|
spec_log = mask_sequence_tensor(spec_log, spec_len)
|
|
return spec_log
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio_real": NeuralType(('B', 'T'), AudioSignal()),
|
|
"audio_gen": NeuralType(('B', 'T'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {"loss": NeuralType(elements_type=LossType())}
|
|
|
|
@typecheck()
|
|
def forward(self, audio_real, audio_gen, audio_len):
|
|
spec_len = (audio_len // self.hop_length) + 1
|
|
spec_real = self._compute_spectrogram(audio=audio_real, spec_len=spec_len)
|
|
spec_gen = self._compute_spectrogram(audio=audio_gen, spec_len=spec_len)
|
|
loss = self.loss_fn(predicted=spec_gen, target=spec_real, target_len=spec_len)
|
|
return loss
|
|
|
|
|
|
class MultiResolutionSTFTLoss(Loss):
|
|
"""
|
|
Multi-resolution log magnitude STFT loss.
|
|
|
|
Args:
|
|
resolutions: List of resolutions, each being 3 integers ordered [num_fft, hop_length, window_length]
|
|
log_guard: Value to add to magnitude spectrogram to avoid taking log of 0.
|
|
sqrt_guard: Value to add to when computing absolute value of STFT to avoid NaN loss.
|
|
"""
|
|
|
|
def __init__(self, resolutions: List[List], log_guard: float = 1.0, sqrt_guard: float = 1e-5):
|
|
super(MultiResolutionSTFTLoss, self).__init__()
|
|
self.loss_fns = torch.nn.ModuleList(
|
|
[STFTLoss(resolution=resolution, log_guard=log_guard, sqrt_guard=sqrt_guard) for resolution in resolutions]
|
|
)
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio_real": NeuralType(('B', 'T'), AudioSignal()),
|
|
"audio_gen": NeuralType(('B', 'T'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {"loss": NeuralType(elements_type=LossType())}
|
|
|
|
@typecheck()
|
|
def forward(self, audio_real, audio_gen, audio_len):
|
|
loss = 0.0
|
|
for loss_fn in self.loss_fns:
|
|
loss += loss_fn(audio_real=audio_real, audio_gen=audio_gen, audio_len=audio_len)
|
|
loss /= len(self.loss_fns)
|
|
return loss
|
|
|
|
|
|
class SISDRLoss(Loss):
|
|
"""
|
|
SI-SDR loss based off of torchmetrics.functional.audio.sdr.scale_invariant_signal_distortion_ratio
|
|
with added support for masking.
|
|
"""
|
|
|
|
def __init__(self, epsilon: float = 1e-8):
|
|
super(SISDRLoss, self).__init__()
|
|
self.epsilon = epsilon
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"audio_real": NeuralType(('B', 'T'), AudioSignal()),
|
|
"audio_gen": NeuralType(('B', 'T'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {"loss": NeuralType(elements_type=LossType())}
|
|
|
|
@typecheck()
|
|
def forward(self, audio_real, audio_gen, audio_len):
|
|
mask = get_mask_from_lengths(x=audio_real, lengths=audio_len)
|
|
audio_len = rearrange(audio_len, 'B -> B 1')
|
|
|
|
# Shift audio to have zero-mean
|
|
# [B, 1]
|
|
target_mean = torch.sum(audio_real, dim=-1, keepdim=True) / audio_len
|
|
pred_mean = torch.sum(audio_gen, dim=-1, keepdim=True) / audio_len
|
|
|
|
# [B, T]
|
|
target = audio_real - target_mean
|
|
target = target * mask
|
|
pred = audio_gen - pred_mean
|
|
pred = pred * mask
|
|
|
|
# [B, 1]
|
|
ref_pred = torch.sum(pred * target, dim=-1, keepdim=True)
|
|
ref_target = torch.sum(target**2, dim=-1, keepdim=True)
|
|
alpha = (ref_pred + self.epsilon) / (ref_target + self.epsilon)
|
|
|
|
# [B, T]
|
|
target_scaled = alpha * target
|
|
distortion = target_scaled - pred
|
|
|
|
# [B]
|
|
target_scaled_power = torch.sum(target_scaled**2, dim=-1)
|
|
distortion_power = torch.sum(distortion**2, dim=-1)
|
|
|
|
ratio = (target_scaled_power + self.epsilon) / (distortion_power + self.epsilon)
|
|
si_sdr = 10 * torch.log10(ratio)
|
|
|
|
# [1]
|
|
loss = -torch.mean(si_sdr)
|
|
return loss
|
|
|
|
|
|
class FeatureMatchingLoss(Loss):
|
|
"""
|
|
Standard feature matching loss measuring the difference in the internal discriminator layer outputs
|
|
(usually leaky relu activations) between real and generated audio, scaled down by the total number of
|
|
discriminators and layers.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(FeatureMatchingLoss, self).__init__()
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"fmaps_real": [[NeuralType(elements_type=VoidType())]],
|
|
"fmaps_gen": [[NeuralType(elements_type=VoidType())]],
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"loss": NeuralType(elements_type=LossType()),
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, fmaps_real, fmaps_gen):
|
|
loss = 0.0
|
|
for fmap_real, fmap_gen in zip(fmaps_real, fmaps_gen):
|
|
# [B, ..., time]
|
|
for feat_real, feat_gen in zip(fmap_real, fmap_gen):
|
|
# [B, ...]
|
|
diff = torch.abs(feat_real - feat_gen)
|
|
feat_loss = torch.mean(diff) / len(fmap_real)
|
|
loss += feat_loss
|
|
|
|
loss /= len(fmaps_real)
|
|
|
|
return loss
|
|
|
|
|
|
class RelativeFeatureMatchingLoss(Loss):
|
|
"""
|
|
Relative feature matching loss as described in https://arxiv.org/pdf/2210.13438.pdf.
|
|
|
|
This is similar to standard feature matching loss, but it scales the loss by the absolute value of
|
|
each feature averaged across time. This might be slightly different from the paper which says the
|
|
"mean is computed over all dimensions", which could imply taking the average across both time and
|
|
features.
|
|
|
|
Args:
|
|
div_guard: Value to add when dividing by mean to avoid large/NaN values.
|
|
"""
|
|
|
|
def __init__(self, div_guard=1e-3):
|
|
super(RelativeFeatureMatchingLoss, self).__init__()
|
|
self.div_guard = div_guard
|
|
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"fmaps_real": [[NeuralType(elements_type=VoidType())]],
|
|
"fmaps_gen": [[NeuralType(elements_type=VoidType())]],
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {
|
|
"loss": NeuralType(elements_type=LossType()),
|
|
}
|
|
|
|
@typecheck()
|
|
def forward(self, fmaps_real, fmaps_gen):
|
|
loss = 0.0
|
|
for fmap_real, fmap_gen in zip(fmaps_real, fmaps_gen):
|
|
# [B, ..., time]
|
|
for feat_real, feat_gen in zip(fmap_real, fmap_gen):
|
|
# [B, ...]
|
|
feat_mean = torch.mean(torch.abs(feat_real), dim=-1)
|
|
diff = torch.mean(torch.abs(feat_real - feat_gen), dim=-1)
|
|
feat_loss = diff / (feat_mean + self.div_guard)
|
|
# [1]
|
|
feat_loss = torch.mean(feat_loss) / len(fmap_real)
|
|
loss += feat_loss
|
|
|
|
loss /= len(fmaps_real)
|
|
|
|
return loss
|
|
|
|
|
|
class GeneratorHingedLoss(Loss):
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())],
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {"loss": NeuralType(elements_type=LossType())}
|
|
|
|
@typecheck()
|
|
def forward(self, disc_scores_gen):
|
|
loss = 0.0
|
|
for disc_score_gen in disc_scores_gen:
|
|
loss += torch.mean(F.relu(1 - disc_score_gen))
|
|
|
|
loss /= len(disc_scores_gen)
|
|
|
|
return loss
|
|
|
|
|
|
class GeneratorSquaredLoss(Loss):
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())],
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {"loss": NeuralType(elements_type=LossType())}
|
|
|
|
@typecheck()
|
|
def forward(self, disc_scores_gen):
|
|
loss = 0.0
|
|
for disc_score_gen in disc_scores_gen:
|
|
loss += torch.mean((1 - disc_score_gen) ** 2)
|
|
|
|
loss /= len(disc_scores_gen)
|
|
|
|
return loss
|
|
|
|
|
|
class DiscriminatorHingedLoss(Loss):
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"disc_scores_real": [NeuralType(('B', 'C', 'T'), VoidType())],
|
|
"disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())],
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {"loss": NeuralType(elements_type=LossType())}
|
|
|
|
@typecheck()
|
|
def forward(self, disc_scores_real, disc_scores_gen):
|
|
loss = 0.0
|
|
for disc_score_real, disc_score_gen in zip(disc_scores_real, disc_scores_gen):
|
|
loss_real = torch.mean(F.relu(1 - disc_score_real))
|
|
loss_gen = torch.mean(F.relu(1 + disc_score_gen))
|
|
loss += (loss_real + loss_gen) / 2
|
|
|
|
loss /= len(disc_scores_real)
|
|
|
|
return loss
|
|
|
|
|
|
class DiscriminatorSquaredLoss(Loss):
|
|
@property
|
|
def input_types(self):
|
|
return {
|
|
"disc_scores_real": [NeuralType(('B', 'C', 'T'), VoidType())],
|
|
"disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())],
|
|
}
|
|
|
|
@property
|
|
def output_types(self):
|
|
return {"loss": NeuralType(elements_type=LossType())}
|
|
|
|
@typecheck()
|
|
def forward(self, disc_scores_real, disc_scores_gen):
|
|
loss = 0.0
|
|
for disc_score_real, disc_score_gen in zip(disc_scores_real, disc_scores_gen):
|
|
loss_real = torch.mean((1 - disc_score_real) ** 2)
|
|
loss_gen = torch.mean(disc_score_gen**2)
|
|
loss += (loss_real + loss_gen) / 2
|
|
|
|
loss /= len(disc_scores_real)
|
|
|
|
return loss
|
|
|
|
@experimental
|
|
class AudioCodecModel(ModelPT):
|
|
def __init__(self, cfg):
|
|
cfg = model_utils.convert_model_config_to_dict_config(cfg)
|
|
cfg = model_utils.maybe_update_config_version(cfg)
|
|
self.world_size = 1
|
|
super().__init__(cfg=cfg)
|
|
|
|
# Expected sample rate for the input audio
|
|
self.sample_rate = cfg.sample_rate
|
|
|
|
# Number of samples in each audio frame that is encoded
|
|
self.samples_per_frame = cfg.samples_per_frame
|
|
|
|
# Discriminator updates
|
|
self.disc_updates_per_period = cfg.get("disc_updates_per_period", 1)
|
|
self.disc_update_period = cfg.get("disc_update_period", 1)
|
|
if self.disc_updates_per_period > self.disc_update_period:
|
|
raise ValueError(
|
|
f'Number of discriminator updates ({self.disc_updates_per_period}) per period must be less or equal to the configured period ({self.disc_update_period})'
|
|
)
|
|
|
|
# Encoder setup
|
|
self.audio_encoder = instantiate(cfg.audio_encoder)
|
|
|
|
# Optionally, add gaussian noise to encoder output as an information bottleneck
|
|
encoder_noise_stdev = cfg.get("encoder_noise_stdev", 0.0)
|
|
if encoder_noise_stdev:
|
|
self.encoder_noise = GaussianDropout(stdev=encoder_noise_stdev)
|
|
else:
|
|
self.encoder_noise = None
|
|
|
|
if "vector_quantizer" in cfg:
|
|
self.vector_quantizer = instantiate(cfg.vector_quantizer)
|
|
|
|
vq_output_types = list(self.vector_quantizer.output_types.keys())
|
|
|
|
if len(vq_output_types) == 3 and vq_output_types[-1] == 'commit_loss':
|
|
self.vector_quantizer_has_commit_loss = True
|
|
else:
|
|
self.vector_quantizer_has_commit_loss = False
|
|
|
|
else:
|
|
self.vector_quantizer = None
|
|
|
|
# Decoder setup
|
|
self.audio_decoder = instantiate(cfg.audio_decoder)
|
|
|
|
# Discriminator setup
|
|
self.discriminator = instantiate(cfg.discriminator)
|
|
|
|
# Mel loss setup
|
|
loss_resolutions = cfg.loss_resolutions
|
|
mel_loss_dims = cfg.get("mel_loss_dims")
|
|
mel_loss_log_guard = cfg.get("mel_loss_log_guard", 1.0)
|
|
self.mel_loss_l1_scale = cfg.get("mel_loss_l1_scale", 1.0)
|
|
self.mel_loss_l2_scale = cfg.get("mel_loss_l2_scale", 1.0)
|
|
self.mel_loss_fn = MultiResolutionMelLoss(
|
|
sample_rate=self.sample_rate,
|
|
mel_dims=mel_loss_dims,
|
|
resolutions=loss_resolutions,
|
|
log_guard=mel_loss_log_guard,
|
|
)
|
|
|
|
# STFT loss setup
|
|
stft_loss_log_guard = cfg.get("stft_loss_log_guard", 1.0)
|
|
self.stft_loss_scale = cfg.get("stft_loss_scale", 0.0)
|
|
self.stft_loss_fn = MultiResolutionSTFTLoss(
|
|
resolutions=loss_resolutions,
|
|
log_guard=stft_loss_log_guard,
|
|
)
|
|
|
|
# Time domain loss setup
|
|
self.time_domain_loss_scale = cfg.get("time_domain_loss_scale", 1.0)
|
|
self.si_sdr_loss_scale = cfg.get("si_sdr_loss_scale", 0.0)
|
|
self.time_domain_loss_fn = TimeDomainLoss()
|
|
self.si_sdr_loss_fn = SISDRLoss()
|
|
|
|
# Discriminator loss setup
|
|
self.gen_loss_scale = cfg.get("gen_loss_scale", 1.0)
|
|
self.feature_loss_scale = cfg.get("feature_loss_scale", 1.0)
|
|
self.gen_loss_fn = instantiate(cfg.generator_loss)
|
|
self.disc_loss_fn = instantiate(cfg.discriminator_loss)
|
|
|
|
feature_loss_type = cfg.get("feature_loss_type", "relative")
|
|
if feature_loss_type == "relative":
|
|
self.feature_loss_fn = RelativeFeatureMatchingLoss()
|
|
elif feature_loss_type == "absolute":
|
|
self.feature_loss_fn = FeatureMatchingLoss()
|
|
else:
|
|
raise ValueError(f'Unknown feature loss type {feature_loss_type}.')
|
|
|
|
# Codebook loss setup
|
|
if self.vector_quantizer:
|
|
self.commit_loss_scale = cfg.get("commit_loss_scale", 1.0)
|
|
else:
|
|
self.commit_loss_scale = 0.0
|
|
|
|
if self.commit_loss_scale > 0 and not self.vector_quantizer_has_commit_loss:
|
|
raise ValueError('Commit loss is enabled but the quantizer does not support it.')
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={
|
|
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
|
|
"encoded_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
)
|
|
def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Apply encoder on the input audio signal. Input will be padded with zeros so
|
|
the last frame has full `self.samples_per_frame` samples.
|
|
|
|
Args:
|
|
audio: input time-domain signal
|
|
audio_len: valid length for each example in the batch
|
|
|
|
Returns:
|
|
Encoder output `encoded` and its length in number of frames `encoded_len`
|
|
"""
|
|
audio, audio_len = self.pad_audio(audio, audio_len)
|
|
encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len)
|
|
return encoded, encoded_len
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"inputs": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
|
|
"input_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
)
|
|
def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Apply decoder on the input. Note that the input is a non-quantized encoder output or a dequantized representation.
|
|
|
|
Args:
|
|
inputs: encoded signal
|
|
input_len: valid length for each example in the batch
|
|
|
|
Returns:
|
|
Decoded output `audio` in the time domain and its length in number of samples `audio_len`.
|
|
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
|
|
"""
|
|
audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len)
|
|
return audio, audio_len
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
|
|
"encoded_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex())},
|
|
)
|
|
def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor:
|
|
"""Quantize the continuous encoded representation into a discrete
|
|
representation for each frame.
|
|
|
|
Args:
|
|
encoded: encoded signal representation
|
|
encoded_len: valid length of the encoded representation in frames
|
|
|
|
Returns:
|
|
A tensor of tokens for each codebook for each frame.
|
|
"""
|
|
if not self.vector_quantizer:
|
|
raise ValueError("Cannot quantize without quantizer")
|
|
|
|
# vector quantizer is returning [C, B, T], where C is the number of codebooks
|
|
tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len)
|
|
# use batch first for the output
|
|
tokens = rearrange(tokens, 'C B T -> B C T')
|
|
return tokens
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
|
|
"tokens_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={
|
|
"dequantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
|
|
},
|
|
)
|
|
def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Tensor:
|
|
"""Convert the discrete tokens into a continuous encoded representation.
|
|
|
|
Args:
|
|
tokens: discrete tokens for each codebook for each time frame
|
|
tokens_len: valid length of each example in the batch
|
|
|
|
Returns:
|
|
Continuous encoded representation of the discrete input representation.
|
|
"""
|
|
if not self.vector_quantizer:
|
|
raise ValueError("Cannot dequantize without quantizer")
|
|
|
|
# vector quantizer is using [C, B, T], where C is the number of codebooks
|
|
tokens = rearrange(tokens, 'B C T -> C B T')
|
|
dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len)
|
|
return dequantized
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={
|
|
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
|
|
"tokens_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
)
|
|
def encode(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Convert input time-domain audio signal into a discrete representation (tokens).
|
|
|
|
Args:
|
|
audio: input time-domain signal, shape `(batch, number of samples)`
|
|
audio_len: valid length for each example in the batch, shape `(batch size,)`
|
|
|
|
Returns:
|
|
Tokens for each codebook for each frame, shape `(batch, number of codebooks, number of frames)`,
|
|
and the corresponding valid lengths, shape `(batch,)`
|
|
"""
|
|
# Apply encoder to obtain a continuous vector for each frame
|
|
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)
|
|
# Apply quantizer to obtain discrete representation per frame
|
|
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
|
|
return tokens, encoded_len
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
|
|
"tokens_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
)
|
|
def decode(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Convert discrete tokens into a continuous time-domain signal.
|
|
|
|
Args:
|
|
tokens: discrete tokens for each codebook for each time frame, shape `(batch, number of codebooks, number of frames)`
|
|
tokens_len: valid lengths, shape `(batch,)`
|
|
|
|
Returns:
|
|
Decoded output `audio` in the time domain and its length in number of samples `audio_len`.
|
|
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
|
|
"""
|
|
# Convert a discrete representation to a dequantized vector for each frame
|
|
dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len)
|
|
# Apply decoder to obtain time-domain audio for each frame
|
|
audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len)
|
|
|
|
return audio, audio_len
|
|
|
|
@typecheck(
|
|
input_types={
|
|
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
|
|
"audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
output_types={
|
|
"output_audio": NeuralType(('B', 'T_audio'), EncodedRepresentation()),
|
|
"output_audio_len": NeuralType(tuple('B'), LengthsType()),
|
|
},
|
|
)
|
|
def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Apply encoder, quantizer, decoder on the input time-domain signal.
|
|
|
|
Args:
|
|
audio: input time-domain signal
|
|
audio_len: valid length for each example in the batch
|
|
|
|
Returns:
|
|
Reconstructed time-domain signal `output_audio` and its length in number of samples `output_audio_len`.
|
|
"""
|
|
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)
|
|
|
|
if self.vector_quantizer:
|
|
# quantize to discrete tokens
|
|
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
|
|
# decode tokens to audio
|
|
output_audio, output_audio_len = self.decode(tokens=tokens, tokens_len=encoded_len)
|
|
else:
|
|
# no quantization, directly decode to audio
|
|
output_audio, output_audio_len = self.decode_audio(inputs=encoded, input_len=encoded_len)
|
|
|
|
return output_audio, output_audio_len
|
|
|
|
def pad_audio(self, audio, audio_len):
|
|
"""Zero pad the end of the audio so that we do not have a partial end frame.
|
|
The output will be zero-padded to have an integer number of frames of
|
|
length `self.samples_per_frame`.
|
|
|
|
Args:
|
|
audio: input time-domain signal
|
|
audio_len: valid length for each example in the batch
|
|
|
|
Returns:
|
|
Padded time-domain signal `padded_audio` and its length `padded_len`.
|
|
"""
|
|
padded_len = self.samples_per_frame * torch.ceil(audio_len / self.samples_per_frame).int()
|
|
max_len = padded_len.max().item()
|
|
num_padding = max_len - audio.shape[1]
|
|
padded_audio = F.pad(audio, (0, num_padding))
|
|
return padded_audio, padded_len
|
|
|
|
def _process_batch(self, batch):
|
|
# [B, T_audio]
|
|
audio = batch.get("audio")
|
|
# [B]
|
|
audio_len = batch.get("audio_lens")
|
|
audio, audio_len = self.pad_audio(audio, audio_len)
|
|
|
|
# [B, D, T_encoded]
|
|
encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len)
|
|
|
|
if self.encoder_noise is not None:
|
|
encoded = self.encoder_noise(encoded)
|
|
|
|
if self.vector_quantizer:
|
|
if self.vector_quantizer_has_commit_loss:
|
|
encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len)
|
|
else:
|
|
encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len)
|
|
commit_loss = 0.0
|
|
else:
|
|
commit_loss = 0.0
|
|
|
|
# [B, T]
|
|
audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len)
|
|
|
|
return audio, audio_len, audio_gen, commit_loss
|
|
|
|
@property
|
|
def disc_update_prob(self) -> float:
|
|
"""Probability of updating the discriminator."""
|
|
return self.disc_updates_per_period / self.disc_update_period
|
|
|
|
def should_update_disc(self, batch_idx) -> bool:
|
|
"""Decide whether to update the descriminator based
|
|
on the batch index and configured discriminator update period.
|
|
"""
|
|
disc_update_step = batch_idx % self.disc_update_period
|
|
return disc_update_step < self.disc_updates_per_period
|
|
|
|
def setup_training_data(self):
|
|
...
|
|
|
|
def setup_validation_data(self):
|
|
...
|
|
|
|
@classmethod
|
|
def list_available_models(cls) -> List[PretrainedModelInfo]:
|
|
models = []
|
|
|
|
model = PretrainedModelInfo(
|
|
pretrained_model_name="audio_codec_16khz_small",
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/audio_codec_16khz_small/versions/v1/files/audio_codec_16khz_small.nemo",
|
|
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/audio_codec_16khz_small",
|
|
)
|
|
models.append(model)
|
|
|
|
model = PretrainedModelInfo(
|
|
pretrained_model_name="mel_codec_22khz_medium",
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_medium/versions/v1/files/mel_codec_22khz_medium.nemo",
|
|
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_22khz_medium",
|
|
)
|
|
models.append(model)
|
|
|
|
model = PretrainedModelInfo(
|
|
pretrained_model_name="mel_codec_44khz_medium",
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_medium/versions/v1/files/mel_codec_44khz_medium.nemo",
|
|
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_44khz_medium",
|
|
)
|
|
models.append(model)
|
|
|
|
model = PretrainedModelInfo(
|
|
pretrained_model_name="mel_codec_22khz_fullband_medium",
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_fullband_medium/versions/v1/files/mel_codec_22khz_fullband_medium.nemo",
|
|
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_22khz_fullband_medium",
|
|
)
|
|
models.append(model)
|
|
|
|
model = PretrainedModelInfo(
|
|
pretrained_model_name="mel_codec_44khz_fullband_medium",
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_fullband_medium/versions/v1/files/mel_codec_44khz_fullband_medium.nemo",
|
|
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_44khz_fullband_medium",
|
|
)
|
|
models.append(model)
|
|
|
|
return models |