DL-Art-School/codes/trainer/injectors/audio_injectors.py

471 lines
20 KiB
Python
Raw Normal View History

import random
import torch
import torch.nn.functional as F
import torchaudio
from models.audio.music.cheater_gen_ar import ConditioningAR
from trainer.inject import Injector
2022-05-27 17:40:47 +00:00
from utils.music_utils import get_music_codegen
2022-04-01 22:03:07 +00:00
from utils.util import opt_get, load_model_from_config, pad_or_truncate
MEL_MIN = -11.512925148010254
2022-03-16 15:26:55 +00:00
TACOTRON_MEL_MAX = 2.3143386840820312
TORCH_MEL_MAX = 4.82
def normalize_torch_mel(mel):
return 2 * ((mel - MEL_MIN) / (TORCH_MEL_MAX - MEL_MIN)) - 1
def denormalize_torch_mel(norm_mel):
return ((norm_mel+1)/2) * (TORCH_MEL_MAX - MEL_MIN) + MEL_MIN
2022-03-16 15:26:55 +00:00
2022-05-06 06:49:54 +00:00
def normalize_mel(mel):
return 2 * ((mel - MEL_MIN) / (TACOTRON_MEL_MAX - MEL_MIN)) - 1
2022-03-16 15:26:55 +00:00
2022-05-06 06:49:54 +00:00
def denormalize_mel(norm_mel):
return ((norm_mel+1)/2) * (TACOTRON_MEL_MAX - MEL_MIN) + MEL_MIN
class MelSpectrogramInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
from models.audio.tts.tacotron2 import TacotronSTFT
# These are the default tacotron values for the MEL spectrogram.
filter_length = opt_get(opt, ['filter_length'], 1024)
hop_length = opt_get(opt, ['hop_length'], 256)
win_length = opt_get(opt, ['win_length'], 1024)
n_mel_channels = opt_get(opt, ['n_mel_channels'], 80)
mel_fmin = opt_get(opt, ['mel_fmin'], 0)
mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
self.stft = TacotronSTFT(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax)
2022-03-16 15:26:55 +00:00
self.do_normalization = opt_get(opt, ['do_normalization'], None) # This is different from the TorchMelSpectrogramInjector. This just normalizes to the range [-1,1]
def forward(self, state):
inp = state[self.input]
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1)
assert len(inp.shape) == 2
self.stft = self.stft.to(inp.device)
2022-03-16 04:16:59 +00:00
mel = self.stft.mel_spectrogram(inp)
2022-03-16 15:26:55 +00:00
if self.do_normalization:
2022-05-06 06:49:54 +00:00
mel = normalize_mel(mel)
2022-03-16 04:16:59 +00:00
return {self.output: mel}
class TorchMelSpectrogramInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
# These are the default tacotron values for the MEL spectrogram.
self.filter_length = opt_get(opt, ['filter_length'], 1024)
self.hop_length = opt_get(opt, ['hop_length'], 256)
self.win_length = opt_get(opt, ['win_length'], 1024)
self.n_mel_channels = opt_get(opt, ['n_mel_channels'], 80)
self.mel_fmin = opt_get(opt, ['mel_fmin'], 0)
self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
norm = opt_get(opt, ['normalize'], False)
2022-05-06 06:49:54 +00:00
self.true_norm = opt_get(opt, ['true_normalization'], False)
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
win_length=self.win_length, power=2, normalized=norm,
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
f_max=self.mel_fmax, n_mels=self.n_mel_channels,
norm="slaney")
self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None)
if self.mel_norm_file is not None:
self.mel_norms = torch.load(self.mel_norm_file)
else:
self.mel_norms = None
def forward(self, state):
2022-05-23 16:34:53 +00:00
with torch.no_grad():
inp = state[self.input]
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1)
assert len(inp.shape) == 2
self.mel_stft = self.mel_stft.to(inp.device)
mel = self.mel_stft(inp)
# Perform dynamic range compression
mel = torch.log(torch.clamp(mel, min=1e-5))
if self.mel_norms is not None:
self.mel_norms = self.mel_norms.to(mel.device)
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
if self.true_norm:
mel = normalize_torch_mel(mel)
2022-05-23 16:34:53 +00:00
return {self.output: mel}
class RandomAudioCropInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
2022-06-22 01:48:07 +00:00
if 'crop_size' in opt.keys():
self.min_crop_sz = opt['crop_size']
self.max_crop_sz = self.min_crop_sz
else:
2022-06-22 02:09:59 +00:00
self.min_crop_sz = opt['min_crop_size']
self.max_crop_sz = opt['max_crop_size']
2022-03-13 03:42:29 +00:00
self.lengths_key = opt['lengths_key']
2022-06-27 01:46:57 +00:00
self.crop_start_key = opt['crop_start_key']
2022-07-15 03:25:03 +00:00
self.min_buffer = opt_get(opt, ['min_buffer'], 0)
2022-07-14 03:26:25 +00:00
self.rand_buffer_ptr=9999
self.rand_buffer_sz=5000
2022-06-27 01:46:57 +00:00
def forward(self, state):
inp = state[self.input]
2022-07-14 03:26:25 +00:00
if torch.distributed.get_world_size() > 1:
# All processes should agree, otherwise all processes wait to process max_crop_sz (effectively). But agreeing too often
# is expensive, so agree on a "chunk" at a time.
if self.rand_buffer_ptr >= self.rand_buffer_sz:
self.rand_buffer = torch.randint(self.min_crop_sz, self.max_crop_sz, (self.rand_buffer_sz,), dtype=torch.long, device=inp.device)
torch.distributed.broadcast(self.rand_buffer, 0)
self.rand_buffer_ptr = 0
crop_sz = self.rand_buffer[self.rand_buffer_ptr]
self.rand_buffer_ptr += 1
else:
crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz)
2022-06-27 01:46:57 +00:00
if self.lengths_key is not None:
lens = state[self.lengths_key]
len = torch.min(lens)
else:
len = inp.shape[-1]
2022-07-15 03:25:03 +00:00
2022-07-16 19:58:52 +00:00
margin = len - crop_sz - self.min_buffer * 2
2022-03-13 03:42:29 +00:00
if margin < 0:
2022-07-15 03:25:03 +00:00
start = self.min_buffer
2022-06-27 01:46:57 +00:00
else:
2022-07-15 03:25:03 +00:00
start = random.randint(0, margin) + self.min_buffer
res = {self.output: inp[:, :, start:start+crop_sz]}
2022-06-27 01:46:57 +00:00
if self.crop_start_key is not None:
res[self.crop_start_key] = start
return res
class AudioClipInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.clip_size = opt['clip_size']
self.ctc_codes = opt['ctc_codes_key']
self.output_ctc = opt['ctc_out_key']
def forward(self, state):
inp = state[self.input]
ctc = state[self.ctc_codes]
len = inp.shape[-1]
if len > self.clip_size:
proportion_inp_remaining = self.clip_size/len
inp = inp[:, :, :self.clip_size]
ctc = ctc[:,:int(proportion_inp_remaining*ctc.shape[-1])]
return {self.output: inp, self.output_ctc: ctc}
class AudioResampleInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.input_sr = opt['input_sample_rate']
self.output_sr = opt['output_sample_rate']
def forward(self, state):
inp = state[self.input]
return {self.output: torchaudio.functional.resample(inp, self.input_sr, self.output_sr)}
2022-03-06 03:14:36 +00:00
class DiscreteTokenInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
cfg = opt_get(opt, ['dvae_config'], "../experiments/train_diffusion_vocoder_22k_level.yml")
dvae_name = opt_get(opt, ['dvae_name'], 'dvae')
2022-03-25 06:03:18 +00:00
self.dvae = load_model_from_config(cfg, dvae_name, device=f'cuda:{env["device"]}').eval()
2022-03-06 03:14:36 +00:00
def forward(self, state):
inp = state[self.input]
with torch.no_grad():
self.dvae = self.dvae.to(inp.device)
codes = self.dvae.get_codebook_indices(inp)
return {self.output: codes}
class GptVoiceLatentInjector(Injector):
"""
This injector does all the legwork to generate latents out of a UnifiedVoice model, including encoding all audio
inputs into a MEL spectrogram and discretizing the inputs.
"""
def __init__(self, opt, env):
super().__init__(opt, env)
# For discrete tokenization.
cfg = opt_get(opt, ['dvae_config'], "../experiments/train_diffusion_vocoder_22k_level.yml")
dvae_name = opt_get(opt, ['dvae_name'], 'dvae')
self.dvae = load_model_from_config(cfg, dvae_name).cuda().eval()
# The unified_voice model.
cfg = opt_get(opt, ['gpt_config'], "../experiments/train_gpt_tts_unified.yml")
model_name = opt_get(opt, ['gpt_name'], 'gpt')
pretrained_path = opt['gpt_path']
self.gpt = load_model_from_config(cfg, model_name=model_name,
also_load_savepoint=False, load_path=pretrained_path).cuda().eval()
self.needs_move = True
2022-03-06 03:14:36 +00:00
# Mel converter
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
# Aux input keys.
self.conditioning_key = opt['conditioning_clip']
self.text_input_key = opt['text']
self.text_lengths_key = opt['text_lengths']
self.input_lengths_key = opt['input_lengths']
def to_mel(self, t):
return self.mel_inj({'wav': t})['mel']
def forward(self, state):
with torch.no_grad():
mel_inputs = self.to_mel(state[self.input])
2022-04-11 21:18:44 +00:00
state_cond = pad_or_truncate(state[self.conditioning_key], 132300)
2022-04-01 22:03:07 +00:00
mel_conds = []
for k in range(state_cond.shape[1]):
mel_conds.append(self.to_mel(state_cond[:, k]))
mel_conds = torch.stack(mel_conds, dim=1)
2022-03-06 03:14:36 +00:00
if self.needs_move:
self.dvae = self.dvae.to(mel_inputs.device)
self.gpt = self.gpt.to(mel_inputs.device)
2022-03-06 03:14:36 +00:00
codes = self.dvae.get_codebook_indices(mel_inputs)
2022-04-11 21:18:44 +00:00
latents = self.gpt(mel_conds, state[self.text_input_key],
state[self.text_lengths_key], codes, state[self.input_lengths_key],
2022-06-15 15:14:06 +00:00
text_first=True, raw_mels=None, return_attentions=False, return_latent=True)
assert latents.shape[1] == codes.shape[1]
2022-03-06 03:14:36 +00:00
return {self.output: latents}
2022-04-21 03:37:55 +00:00
class ReverseUnivnetInjector(Injector):
"""
This injector specifically builds inputs and labels for a univnet detector.g
"""
def __init__(self, opt, env):
super().__init__(opt, env)
from scripts.audio.gen.speech_synthesis_utils import load_univnet_vocoder
self.univnet = load_univnet_vocoder().cuda()
self.mel_input_key = opt['mel']
self.label_output_key = opt['labels']
2022-04-28 16:09:22 +00:00
self.do_augmentations = opt_get(opt, ['do_aug'], True)
2022-04-21 03:37:55 +00:00
def forward(self, state):
with torch.no_grad():
original_audio = state[self.input]
mel = state[self.mel_input_key]
decoded_mel = self.univnet.inference(mel)[:,:,:original_audio.shape[-1]]
2022-04-28 16:09:22 +00:00
if self.do_augmentations:
original_audio = original_audio + torch.rand_like(original_audio) * random.random() * .005
decoded_mel = decoded_mel + torch.rand_like(decoded_mel) * random.random() * .005
if(random.random() < .5):
original_audio = torchaudio.functional.resample(torchaudio.functional.resample(original_audio, 24000, 10000), 10000, 24000)
if(random.random() < .5):
decoded_mel = torchaudio.functional.resample(torchaudio.functional.resample(decoded_mel, 24000, 10000), 10000, 24000)
if(random.random() < .5):
original_audio = torchaudio.functional.resample(original_audio, 24000, 22000 + random.randint(0,2000))
if(random.random() < .5):
decoded_mel = torchaudio.functional.resample(decoded_mel, 24000, 22000 + random.randint(0,2000))
smallest_dim = min(original_audio.shape[-1], decoded_mel.shape[-1])
original_audio = original_audio[:,:,:smallest_dim]
decoded_mel = decoded_mel[:,:,:smallest_dim]
2022-04-21 03:37:55 +00:00
labels = (torch.rand(mel.shape[0], 1, 1, device=mel.device) > .5)
output = torch.where(labels, original_audio, decoded_mel)
return {self.output: output, self.label_output_key: labels[:,0,0].long()}
class ConditioningLatentDistributionDivergenceInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
if 'gpt_config' in opt.keys():
# The unified_voice model.
cfg = opt_get(opt, ['gpt_config'], "../experiments/train_gpt_tts_unified.yml")
model_name = opt_get(opt, ['gpt_name'], 'gpt')
pretrained_path = opt['gpt_path']
self.latent_producer = load_model_from_config(cfg, model_name=model_name,
also_load_savepoint=False, load_path=pretrained_path).eval()
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
else:
2022-05-27 17:44:27 +00:00
from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat
self.latent_producer = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False,
num_heads=16, layer_drop=0, unconditioned_percentage=0).eval()
self.latent_producer.load_state_dict(torch.load(opt['diffusion_path']))
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_fmax': 12000, 'sampling_rate': 24000, 'n_mel_channels': 100},{})
self.needs_move = True
# Aux input keys.
self.conditioning_key = opt['conditioning_clip']
# Output keys
self.var_loss_key = opt['var_loss']
def to_mel(self, t):
return self.mel_inj({'wav': t})['mel']
def forward(self, state):
with torch.no_grad():
state_preds = state[self.input]
state_cond = pad_or_truncate(state[self.conditioning_key], 132300)
mel_conds = []
for k in range(state_cond.shape[1]):
mel_conds.append(self.to_mel(state_cond[:, k]))
mel_conds = torch.stack(mel_conds, dim=1)
if self.needs_move:
self.latent_producer = self.latent_producer.to(mel_conds.device)
latents = self.latent_producer.get_conditioning_latent(mel_conds)
sp_means, sp_vars = state_preds.mean(dim=0), state_preds.var(dim=0)
tr_means, tr_vars = latents.mean(dim=0), latents.var(dim=0)
mean_loss = F.mse_loss(sp_means, tr_means)
var_loss = F.mse_loss(sp_vars, tr_vars)
return {self.output: mean_loss, self.var_loss_key: var_loss}
2022-05-04 14:03:09 +00:00
2022-05-06 20:33:44 +00:00
class RandomScaleInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.min_samples = opt['min_samples']
def forward(self, state):
inp = state[self.input]
if self.min_samples < inp.shape[-1]:
samples = random.randint(self.min_samples, inp.shape[-1])
start = random.randint(0, inp.shape[-1]-samples)
inp = inp[:, :, start:start+samples]
return {self.output: inp}
2022-05-04 14:03:09 +00:00
def pixel_shuffle_1d(x, upscale_factor):
batch_size, channels, steps = x.size()
channels //= upscale_factor
input_view = x.contiguous().view(batch_size, channels, upscale_factor, steps)
shuffle_out = input_view.permute(0, 1, 3, 2).contiguous()
return shuffle_out.view(batch_size, channels, steps * upscale_factor)
def pixel_unshuffle_1d(x, downscale):
b, c, s = x.size()
x = x.view(b, c, s//downscale, downscale)
x = x.permute(0,1,3,2).contiguous()
x = x.view(b, c*downscale, s//downscale)
return x
class AudioUnshuffleInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.compression = opt['compression']
def forward(self, state):
inp = state[self.input]
return {self.output: pixel_unshuffle_1d(inp, self.compression)}
2022-05-20 17:01:17 +00:00
class Mel2vecCodesInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
2022-05-27 17:40:47 +00:00
self.m2v = get_music_codegen()
2022-06-05 04:23:08 +00:00
del self.m2v.quantizer.encoder # This is a big memory sink which will not get used.
2022-05-20 17:01:17 +00:00
self.needs_move = True
2022-05-29 04:27:45 +00:00
self.inj_vector = opt_get(opt, ['vector'], False)
2022-05-20 17:01:17 +00:00
def forward(self, state):
mels = state[self.input]
with torch.no_grad():
if self.needs_move:
self.m2v = self.m2v.to(mels.device)
2022-05-29 04:27:45 +00:00
codes = self.m2v.get_codes(mels, project=self.inj_vector)
2022-05-20 17:01:17 +00:00
return {self.output: codes}
2022-05-27 15:49:10 +00:00
class ClvpTextInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
2022-05-27 18:28:04 +00:00
from scripts.audio.gen.speech_synthesis_utils import load_clvp
self.clvp = load_clvp()
2022-05-27 15:49:10 +00:00
del self.clvp.speech_transformer # We will only be using the text transformer.
self.needs_move = True
def forward(self, state):
codes = state[self.input]
with torch.no_grad():
if self.needs_move:
self.clvp = self.clvp.to(codes.device)
latents = self.clvp.embed_text(codes)
2022-05-27 17:40:47 +00:00
return {self.output: latents}
class NormalizeMelInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
def forward(self, state):
mel = state[self.input]
with torch.no_grad():
2022-05-27 17:44:27 +00:00
return {self.output: normalize_mel(mel)}
2022-06-14 02:37:35 +00:00
class ChannelClipInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.lo = opt['channel_low']
self.hi = opt['channel_high']
def forward(self, state):
inp = state[self.input]
return {self.output: inp[:,self.lo:self.hi]}
class MusicCheaterLatentInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
from models.audio.music.gpt_music2 import UpperEncoder
self.encoder = UpperEncoder(256, 1024, 256).eval()
self.encoder.load_state_dict(torch.load('../experiments/music_cheater_encoder_256.pth', map_location=torch.device('cpu')))
def forward(self, state):
with torch.no_grad():
mel = state[self.input]
self.encoder = self.encoder.to(mel.device)
proj = self.encoder(mel)
return {self.output: proj}
2022-06-29 05:52:54 +00:00
class KmeansQuantizerInjector(Injector):
2022-06-28 23:07:56 +00:00
def __init__(self, opt, env):
super().__init__(opt, env)
_, self.centroids = torch.load(opt['centroids'])
k, b = self.centroids.shape
2022-06-29 05:52:54 +00:00
self.centroids = self.centroids.permute(1,0)
2022-06-28 23:07:56 +00:00
def forward(self, state):
with torch.no_grad():
x = state[self.input]
self.centroids = self.centroids.to(x.device)
2022-06-29 05:52:54 +00:00
b, c, s = x.shape
x = x.permute(0,2,1).reshape(b*s, c)
distances = x.pow(2).sum(1,keepdim=True) - 2 * x @ self.centroids + self.centroids.pow(2).sum(0, keepdim=True)
2022-06-28 23:07:56 +00:00
distances[distances.isnan()] = 9999999999
2022-06-29 05:52:54 +00:00
distances = distances.reshape(b, s, self.centroids.shape[-1])
labels = distances.argmin(-1)
2022-06-28 23:07:56 +00:00
return {self.output: labels}
class MusicCheaterArInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.cheater_ar = ConditioningAR(1024, layers=24, dropout=0, cond_free_percent=0).eval()
self.cheater_ar.load_state_dict(torch.load('../experiments/music_cheater_ar.pth', map_location=torch.device('cpu')))
self.cond_key = opt['cheater_latent_key']
self.needs_move = True
def forward(self, state):
codes = state[self.input]
cond = state[self.cond_key]
if self.needs_move:
self.cheater_ar = self.cheater_ar.to(codes.device)
self.needs_move = False
with torch.no_grad():
latents = self.cheater_ar(codes, cond, return_latent=True)
2022-07-14 03:26:25 +00:00
return {self.output: latents}