fixed issues that may rise from updating transformers with attention, added nvidia/audio-codec-44khz backend support (by gutting everything necessary because I do NOT want to install more dependencies

This commit is contained in:
mrq 2025-02-04 20:30:07 -06:00
parent 0841f366e8
commit bb2ebe1ca2
10 changed files with 3481 additions and 186 deletions

BIN
test.wav Normal file

Binary file not shown.

View File

@ -216,6 +216,9 @@ class Dataset:
if cfg.sample_rate == 16_000:
return 50
if cfg.audio_backend == "nemo":
return 86.1
# 24Khz Encodec / Vocos and incidentally DAC are all at 75Hz
return 75
@ -815,6 +818,11 @@ class Config(BaseConfig):
audio_extension = ".dec"
sample_rate = 48_000
cfg.model.resp_levels = 8 # ?
elif cfg.audio_backend == "nemo":
audio_extension = ".nem"
sample_rate = 44_100
cfg.model.resp_levels = 8
cfg.model.audio_tokens = 1000
else:
raise Exception(f"Unknown audio backend: {audio_backend}")
@ -827,6 +835,8 @@ class Config(BaseConfig):
audio_extension = ".dac"
elif self.audio_backend == "audiodec":
audio_extension = ".dec"
elif self.audio_backend == "nemo":
audio_extension = ".nem"
return audio_extension
@property

175
vall_e/emb/codecs/dac.py Normal file
View File

@ -0,0 +1,175 @@
import torch
from dac import DACFile
from audiotools import AudioSignal
from dac.utils import load_model as __load_dac_model
"""
Patch decode to skip things related to the metadata (namely the waveform trimming)
So far it seems the raw waveform can just be returned without any post-processing
A smart implementation would just reuse the values from the input prompt
"""
from dac.model.base import CodecMixin
@torch.no_grad()
def CodecMixin_compress(
self,
audio_path_or_signal: Union[str, Path, AudioSignal],
win_duration: float = 1.0,
verbose: bool = False,
normalize_db: float = -16,
n_quantizers: int = None,
) -> DACFile:
"""Processes an audio signal from a file or AudioSignal object into
discrete codes. This function processes the signal in short windows,
using constant GPU memory.
Parameters
----------
audio_path_or_signal : Union[str, Path, AudioSignal]
audio signal to reconstruct
win_duration : float, optional
window duration in seconds, by default 5.0
verbose : bool, optional
by default False
normalize_db : float, optional
normalize db, by default -16
Returns
-------
DACFile
Object containing compressed codes and metadata
required for decompression
"""
audio_signal = audio_path_or_signal
if isinstance(audio_signal, (str, Path)):
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
self.eval()
original_padding = self.padding
original_device = audio_signal.device
audio_signal = audio_signal.clone()
original_sr = audio_signal.sample_rate
resample_fn = audio_signal.resample
loudness_fn = audio_signal.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if audio_signal.signal_duration >= 10 * 60 * 60:
resample_fn = audio_signal.ffmpeg_resample
loudness_fn = audio_signal.ffmpeg_loudness
original_length = audio_signal.signal_length
resample_fn(self.sample_rate)
input_db = loudness_fn()
if normalize_db is not None:
audio_signal.normalize(normalize_db)
audio_signal.ensure_max_of_audio()
nb, nac, nt = audio_signal.audio_data.shape
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
win_duration = (
audio_signal.signal_duration if win_duration is None else win_duration
)
if audio_signal.signal_duration <= win_duration:
# Unchunked compression (used if signal length < win duration)
self.padding = True
n_samples = nt
hop = nt
else:
# Chunked inference
self.padding = False
# Zero-pad signal on either side by the delay
audio_signal.zero_pad(self.delay, self.delay)
n_samples = int(win_duration * self.sample_rate)
# Round n_samples to nearest hop length multiple
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
hop = self.get_output_length(n_samples)
codes = []
range_fn = range if not verbose else tqdm.trange
for i in range_fn(0, nt, hop):
x = audio_signal[..., i : i + n_samples]
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
audio_data = x.audio_data.to(self.device)
audio_data = self.preprocess(audio_data, self.sample_rate)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
codes.append(c.to(original_device))
chunk_length = c.shape[-1]
codes = torch.cat(codes, dim=-1)
dac_file = DACFile(
codes=codes,
chunk_length=chunk_length,
original_length=original_length,
input_db=input_db,
channels=nac,
sample_rate=original_sr,
padding=self.padding,
dac_version="1.0.0",
#dac_version=SUPPORTED_VERSIONS[-1],
)
if n_quantizers is not None:
codes = codes[:, :n_quantizers, :]
self.padding = original_padding
return dac_file
@torch.no_grad()
def CodecMixin_decompress(
self,
obj: Union[str, Path, DACFile],
verbose: bool = False,
) -> AudioSignal:
self.eval()
if isinstance(obj, (str, Path)):
obj = DACFile.load(obj)
original_padding = self.padding
self.padding = obj.padding
range_fn = range if not verbose else tqdm.trange
codes = obj.codes
original_device = codes.device
chunk_length = obj.chunk_length
recons = []
for i in range_fn(0, codes.shape[-1], chunk_length):
c = codes[..., i : i + chunk_length].to(self.device)
z = self.quantizer.from_codes(c)[0]
r = self.decode(z)
recons.append(r.to(original_device))
recons = torch.cat(recons, dim=-1)
recons = AudioSignal(recons, self.sample_rate)
# to-do, original implementation
if not hasattr(obj, "dummy") or not obj.dummy:
resample_fn = recons.resample
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length
)
self.padding = original_padding
return recons
CodecMixin.compress = CodecMixin_compress
CodecMixin.decompress = CodecMixin_decompress

View File

@ -0,0 +1,2 @@
from encodec import EncodecModel
from encodec.utils import convert_audio

View File

@ -0,0 +1,459 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2020 Jungil Kong
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
# ResBlock1, ResBlock2, Generator, DiscriminatorP, DiscriminatorS, MultiScaleDiscriminator,
# MultiPeriodDiscriminator, init_weights, get_padding
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from nemo.core.classes.common import typecheck
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType, VoidType
from nemo.core.neural_types.neural_type import NeuralType
LRELU_SLOPE = 0.1
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class ResBlock1(torch.nn.Module):
__constants__ = ['lrelu_slope']
def __init__(self, channels, kernel_size, dilation):
super().__init__()
self.lrelu_slope = LRELU_SLOPE
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, self.lrelu_slope)
xt = c1(xt)
xt = F.leaky_relu(xt, self.lrelu_slope)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
__constants__ = ['lrelu_slope']
def __init__(self, channels, kernel_size, dilation):
super().__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
self.lrelu_slope = LRELU_SLOPE
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, self.lrelu_slope)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Generator(NeuralModule):
__constants__ = ['lrelu_slope', 'num_kernels', 'num_upsamples']
def __init__(
self,
resblock,
upsample_rates,
upsample_kernel_sizes,
upsample_initial_channel,
resblock_kernel_sizes,
resblock_dilation_sizes,
initial_input_size=80,
apply_weight_init_conv_pre=False,
):
super().__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = weight_norm(Conv1d(initial_input_size, upsample_initial_channel, 7, 1, padding=3))
self.lrelu_slope = LRELU_SLOPE
resblock = ResBlock1 if resblock == 1 else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2 ** i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
resblock_list = nn.ModuleList()
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
resblock_list.append(resblock(ch, k, d))
self.resblocks.append(resblock_list)
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
if apply_weight_init_conv_pre:
self.conv_pre.apply(init_weights)
@property
def input_types(self):
return {
"x": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
}
@property
def output_types(self):
return {
"audio": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@typecheck()
def forward(self, x):
x = self.conv_pre(x)
for upsample_layer, resblock_group in zip(self.ups, self.resblocks):
x = F.leaky_relu(x, self.lrelu_slope)
x = upsample_layer(x)
xs = torch.zeros(x.shape, dtype=x.dtype, device=x.device)
for resblock in resblock_group:
xs += resblock(x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for group in self.resblocks:
for block in group:
block.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class DiscriminatorP(NeuralModule):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, debug=False):
super().__init__()
self.period = period
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
conv_ch = [32, 128, 512, 1024] if not debug else [8, 12, 32, 64]
self.convs = nn.ModuleList(
[
norm_f(Conv2d(1, conv_ch[0], (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(conv_ch[0], conv_ch[1], (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(conv_ch[1], conv_ch[2], (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(conv_ch[2], conv_ch[3], (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(conv_ch[3], conv_ch[3], (kernel_size, 1), 1, padding=(2, 0))),
]
)
self.conv_post = norm_f(Conv2d(conv_ch[3], 1, (3, 1), 1, padding=(1, 0)))
@property
def input_types(self):
return {
"x": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"decision": NeuralType(('B', 'T'), VoidType()),
"feature_maps": [NeuralType(("B", "C", "H", "W"), VoidType())],
}
@typecheck()
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(NeuralModule):
def __init__(self, debug=False):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(2, debug=debug),
DiscriminatorP(3, debug=debug),
DiscriminatorP(5, debug=debug),
DiscriminatorP(7, debug=debug),
DiscriminatorP(11, debug=debug),
]
)
@property
def input_types(self):
return {
"y": NeuralType(('B', 'S', 'T'), AudioSignal()),
"y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"real_scores": [NeuralType(('B', 'T'), VoidType())],
"fake_scores": [NeuralType(('B', 'T'), VoidType())],
"real_feature_maps": [[NeuralType(("B", "C", "H", "W"), VoidType())]],
"fake_feature_maps": [[NeuralType(("B", "C", "H", "W"), VoidType())]],
}
@typecheck()
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorS(NeuralModule):
def __init__(self, use_spectral_norm=False, debug=False):
super().__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
conv_ch = [128, 256, 512, 1024] if not debug else [16, 32, 32, 64]
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, conv_ch[0], 15, 1, padding=7)),
norm_f(Conv1d(conv_ch[0], conv_ch[0], 41, 2, groups=4, padding=20)),
norm_f(Conv1d(conv_ch[0], conv_ch[1], 41, 2, groups=16, padding=20)),
norm_f(Conv1d(conv_ch[1], conv_ch[2], 41, 4, groups=16, padding=20)),
norm_f(Conv1d(conv_ch[2], conv_ch[3], 41, 4, groups=16, padding=20)),
norm_f(Conv1d(conv_ch[3], conv_ch[3], 41, 1, groups=16, padding=20)),
norm_f(Conv1d(conv_ch[3], conv_ch[3], 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(conv_ch[3], 1, 3, 1, padding=1))
@property
def input_types(self):
return {
"x": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"decision": NeuralType(('B', 'T'), VoidType()),
"feature_maps": [NeuralType(("B", "C", "T"), VoidType())],
}
@typecheck()
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(NeuralModule):
def __init__(self, debug=False):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorS(use_spectral_norm=True, debug=debug),
DiscriminatorS(debug=debug),
DiscriminatorS(debug=debug),
]
)
self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
@property
def input_types(self):
return {
"y": NeuralType(('B', 'S', 'T'), AudioSignal()),
"y_hat": NeuralType(('B', 'S', 'T'), AudioSignal()),
}
@property
def output_types(self):
return {
"real_scores": [NeuralType(('B', 'T'), VoidType())],
"fake_scores": [NeuralType(('B', 'T'), VoidType())],
"real_feature_maps": [[NeuralType(("B", "C", "T"), VoidType())]],
"fake_feature_maps": [[NeuralType(("B", "C", "T"), VoidType())]],
}
@typecheck()
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
y = self.meanpools[i - 1](y)
y_hat = self.meanpools[i - 1](y_hat)
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs

2771
vall_e/emb/codecs/nemo.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
from vocos import Vocos

View File

@ -19,205 +19,27 @@ from torch import Tensor
from tqdm import tqdm
try:
from encodec import EncodecModel
from encodec.utils import convert_audio
from .codecs.encodec import *
except Exception as e:
cfg.inference.use_encodec = False
_logger.warning(str(e))
try:
from vocos import Vocos
from .codecs.vocos import *
except Exception as e:
cfg.inference.use_vocos = False
_logger.warning(str(e))
try:
from dac import DACFile
from audiotools import AudioSignal
from dac.utils import load_model as __load_dac_model
"""
Patch decode to skip things related to the metadata (namely the waveform trimming)
So far it seems the raw waveform can just be returned without any post-processing
A smart implementation would just reuse the values from the input prompt
"""
from dac.model.base import CodecMixin
@torch.no_grad()
def CodecMixin_compress(
self,
audio_path_or_signal: Union[str, Path, AudioSignal],
win_duration: float = 1.0,
verbose: bool = False,
normalize_db: float = -16,
n_quantizers: int = None,
) -> DACFile:
"""Processes an audio signal from a file or AudioSignal object into
discrete codes. This function processes the signal in short windows,
using constant GPU memory.
Parameters
----------
audio_path_or_signal : Union[str, Path, AudioSignal]
audio signal to reconstruct
win_duration : float, optional
window duration in seconds, by default 5.0
verbose : bool, optional
by default False
normalize_db : float, optional
normalize db, by default -16
Returns
-------
DACFile
Object containing compressed codes and metadata
required for decompression
"""
audio_signal = audio_path_or_signal
if isinstance(audio_signal, (str, Path)):
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
self.eval()
original_padding = self.padding
original_device = audio_signal.device
audio_signal = audio_signal.clone()
original_sr = audio_signal.sample_rate
resample_fn = audio_signal.resample
loudness_fn = audio_signal.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if audio_signal.signal_duration >= 10 * 60 * 60:
resample_fn = audio_signal.ffmpeg_resample
loudness_fn = audio_signal.ffmpeg_loudness
original_length = audio_signal.signal_length
resample_fn(self.sample_rate)
input_db = loudness_fn()
if normalize_db is not None:
audio_signal.normalize(normalize_db)
audio_signal.ensure_max_of_audio()
nb, nac, nt = audio_signal.audio_data.shape
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
win_duration = (
audio_signal.signal_duration if win_duration is None else win_duration
)
if audio_signal.signal_duration <= win_duration:
# Unchunked compression (used if signal length < win duration)
self.padding = True
n_samples = nt
hop = nt
else:
# Chunked inference
self.padding = False
# Zero-pad signal on either side by the delay
audio_signal.zero_pad(self.delay, self.delay)
n_samples = int(win_duration * self.sample_rate)
# Round n_samples to nearest hop length multiple
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
hop = self.get_output_length(n_samples)
codes = []
range_fn = range if not verbose else tqdm.trange
for i in range_fn(0, nt, hop):
x = audio_signal[..., i : i + n_samples]
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
audio_data = x.audio_data.to(self.device)
audio_data = self.preprocess(audio_data, self.sample_rate)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
codes.append(c.to(original_device))
chunk_length = c.shape[-1]
codes = torch.cat(codes, dim=-1)
dac_file = DACFile(
codes=codes,
chunk_length=chunk_length,
original_length=original_length,
input_db=input_db,
channels=nac,
sample_rate=original_sr,
padding=self.padding,
dac_version="1.0.0",
#dac_version=SUPPORTED_VERSIONS[-1],
)
if n_quantizers is not None:
codes = codes[:, :n_quantizers, :]
self.padding = original_padding
return dac_file
@torch.no_grad()
def CodecMixin_decompress(
self,
obj: Union[str, Path, DACFile],
verbose: bool = False,
) -> AudioSignal:
self.eval()
if isinstance(obj, (str, Path)):
obj = DACFile.load(obj)
original_padding = self.padding
self.padding = obj.padding
range_fn = range if not verbose else tqdm.trange
codes = obj.codes
original_device = codes.device
chunk_length = obj.chunk_length
recons = []
for i in range_fn(0, codes.shape[-1], chunk_length):
c = codes[..., i : i + chunk_length].to(self.device)
z = self.quantizer.from_codes(c)[0]
r = self.decode(z)
recons.append(r.to(original_device))
recons = torch.cat(recons, dim=-1)
recons = AudioSignal(recons, self.sample_rate)
# to-do, original implementation
if not hasattr(obj, "dummy") or not obj.dummy:
resample_fn = recons.resample
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length
)
self.padding = original_padding
return recons
CodecMixin.compress = CodecMixin_compress
CodecMixin.decompress = CodecMixin_decompress
from .codecs.dac import *
except Exception as e:
cfg.inference.use_dac = False
_logger.warning(str(e))
# uses https://github.com/facebookresearch/AudioDec/
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
# I was not happy with testing, it sounded rather mediocre.
"""
try:
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
except Exception as e:
cfg.inference.use_audiodec = False
from .codecs.nemo import *
cfg.inference.use_nemo = False
_logger.warning(str(e))
"""
@cache
def _load_encodec_model(device="cuda", levels=0):
@ -306,7 +128,7 @@ def _load_audiodec_model(device="cuda", model_name=None):
model_name = "libritts_v1" if cfg.sample_rate == 24_000 else "vctk_v1"
sample_rate, encoder_checkpoint, decoder_checkpoint = _audiodec_assign_model(model_name)
model = AudioDec(tx_device=device , rx_device=device )
model = AudioDec(tx_device=device, rx_device=device)
model.load_transmitter(encoder_checkpoint)
model.load_receiver(encoder_checkpoint, decoder_checkpoint)
@ -316,11 +138,27 @@ def _load_audiodec_model(device="cuda", model_name=None):
return model
@cache
def _load_nemo_model(device="cuda", model_name=None):
if not model_name:
model_name = "nvidia/audio-codec-44khz"
model = AudioCodecModel.from_pretrained(model_name).to(device).eval()
model.backend = "nemo"
model.sample_rate = 44_100
#model.device = device
return model
@cache
def _load_model(device="cuda", backend=None):
if not backend:
backend = cfg.audio_backend
if backend == "nemo":
return _load_nemo_model(device)
if backend == "audiodec":
return _load_audiodec_model(device)
if backend == "dac":
@ -354,6 +192,15 @@ def decode(codes: Tensor, device="cuda", metadata=None, window_duration=None):
# load the model
model = _load_model(device)
# NeMo uses a different pathway
if model.backend == "nemo":
# ugh
codes = rearrange( codes, "b q t -> b t q")
codes = codes.to( device=device )
l = torch.tensor([codes.shape[-1]], device=device, dtype=torch.int32)
wav, _ = model.decode(tokens=codes, tokens_len=l)
return wav, model.sample_rate
# AudioDec uses a different pathway
if model.backend == "audiodec":
codes = codes.to( device=device )[0]
@ -483,6 +330,23 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device=
@torch.inference_mode()
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadata=True, window_duration=None):
# NeMo uses a different pathway
if cfg.audio_backend == "nemo":
model = _load_nemo_model( device )
# reshape (channel, samples) => (batch, channel, samples)
if wav.dim() < 3:
wav = wav.unsqueeze(0)
# skip unnecessary resample
if sr != model.sample_rate or wav.shape[1] != 1:
wav = convert_audio(wav, sr, model.sample_rate, 1)
wav = wav.to(device)[0, :, :]
l = torch.tensor([wav[0].shape[0]]).to(device)
codes, _ = model.encode(audio=wav, audio_len=l)
# ( batch, level, frame )
return codes[0]
# DAC uses a different pathway
if cfg.audio_backend == "dac":
model = _load_dac_model( device )

View File

@ -32,6 +32,8 @@ class LlamaAttention_Adapted(LlamaAttention):
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
elif self.mode == "cudnn":
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
elif self.mode == "sdpa":
self.mode = torch.nn.attention.SDPBackend.MATH
super().__init__(*args, **kwargs)
@ -393,6 +395,11 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.weigh_by_timestep( hidden_states, timesteps )
# ugh
if isinstance( is_causal, list ) and len(is_causal) == 1:
is_causal = is_causal[0]
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,

View File

@ -671,9 +671,12 @@ class Base(nn.Module):
#gradient_checkpointing=self.gradient_checkpointing,
))
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
"""
# replace with desired attention
if attention_backend not in HF_ATTENTIONS:
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
"""
else:
self.model = MixtralModel_Adapted(MixtralConfig(
vocab_size =n_resp_tokens,
@ -694,8 +697,11 @@ class Base(nn.Module):
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
"""
if attention_backend not in HF_ATTENTIONS:
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
"""
if self.layerskip:
self.model.layer_dropout_p = layerskip_p_max