diff --git a/codes/data/audio/stop_prediction_dataset.py b/codes/data/audio/stop_prediction_dataset.py new file mode 100644 index 00000000..acdc98dd --- /dev/null +++ b/codes/data/audio/stop_prediction_dataset.py @@ -0,0 +1,142 @@ +import os +import pathlib +import random + +import audio2numpy +import numpy as np +import torch +import torch.utils.data +import torch.nn.functional as F +from tqdm import tqdm + +import models.tacotron2.layers as layers +from data.audio.nv_tacotron_dataset import load_mozilla_cv, load_voxpopuli +from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text + +from models.tacotron2.text import text_to_sequence +from utils.util import opt_get + + +def get_similar_files_libritts(filename): + filedir = os.path.dirname(filename) + return list(pathlib.Path(filedir).glob('*.wav')) + + +class StopPredictionDataset(torch.utils.data.Dataset): + """ + 1) loads audio,text pairs + 2) normalizes text and converts them to sequences of one-hot vectors + 3) computes mel-spectrograms from audio files. + """ + def __init__(self, hparams): + self.path = hparams['path'] + if not isinstance(self.path, list): + self.path = [self.path] + + fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj') + if not isinstance(fetcher_mode, list): + fetcher_mode = [fetcher_mode] + assert len(self.path) == len(fetcher_mode) + + self.audiopaths_and_text = [] + for p, fm in zip(self.path, fetcher_mode): + if fm == 'lj' or fm == 'libritts': + fetcher_fn = load_filepaths_and_text + self.get_similar_files = get_similar_files_libritts + elif fm == 'voxpopuli': + fetcher_fn = load_voxpopuli + self.get_similar_files = None # TODO: Fix. + else: + raise NotImplementedError() + self.audiopaths_and_text.extend(fetcher_fn(p)) + self.sampling_rate = hparams.sampling_rate + self.input_sample_rate = opt_get(hparams, ['input_sample_rate'], self.sampling_rate) + self.stft = layers.TacotronSTFT( + hparams.filter_length, hparams.hop_length, hparams.win_length, + hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, + hparams.mel_fmax) + random.seed(hparams.seed) + random.shuffle(self.audiopaths_and_text) + self.max_mel_len = opt_get(hparams, ['max_mel_length'], None) + self.max_text_len = opt_get(hparams, ['max_text_length'], None) + + def get_mel(self, filename): + filename = str(filename) + if filename.endswith('.wav'): + audio, sampling_rate = load_wav_to_torch(filename) + else: + audio, sampling_rate = audio2numpy.audio_from_file(filename) + audio = torch.tensor(audio) + + if sampling_rate != self.input_sample_rate: + if sampling_rate < self.input_sample_rate: + print(f'{filename} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {self.input_sample_rate}. This is not a good idea.') + audio_norm = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='nearest', recompute_scale_factor=False).squeeze() + else: + audio_norm = audio + if audio_norm.std() > 1: + print(f"Something is very wrong with the given audio. std_dev={audio_norm.std()}. file={filename}") + return None + audio_norm.clip_(-1, 1) + audio_norm = audio_norm.unsqueeze(0) + audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) + if self.input_sample_rate != self.sampling_rate: + ratio = self.sampling_rate / self.input_sample_rate + audio_norm = torch.nn.functional.interpolate(audio_norm.unsqueeze(0), scale_factor=ratio, mode='area').squeeze(0) + melspec = self.stft.mel_spectrogram(audio_norm) + melspec = torch.squeeze(melspec, 0) + + return melspec + + def __getitem__(self, index): + path = self.audiopaths_and_text[index][0] + similar_files = self.get_similar_files(path) + mel = self.get_mel(path) + terms = torch.zeros(mel.shape[1]) + terms[-1] = 1 + while mel.shape[-1] < self.max_mel_len: + another_file = random.choice(similar_files) + another_mel = self.get_mel(another_file) + oterms = torch.zeros(another_mel.shape[1]) + oterms[-1] = 1 + mel = torch.cat([mel, another_mel], dim=-1) + terms = torch.cat([terms, oterms], dim=-1) + mel = mel[:, :self.max_mel_len] + terms = terms[:self.max_mel_len] + + + return { + 'padded_mel': mel, + 'termination_mask': terms, + } + + def __len__(self): + return len(self.audiopaths_and_text) + + +if __name__ == '__main__': + params = { + 'mode': 'stop_prediction', + 'path': 'E:\\audio\\LibriTTS\\train-clean-360_list.txt', + 'phase': 'train', + 'n_workers': 0, + 'batch_size': 16, + 'fetcher_mode': 'libritts', + 'max_mel_length': 800, + #'return_wavs': True, + #'input_sample_rate': 22050, + #'sampling_rate': 8000 + } + from data import create_dataset, create_dataloader + + ds, c = create_dataset(params, return_collate=True) + dl = create_dataloader(ds, params, collate_fn=c, shuffle=True) + i = 0 + m = None + for k in range(1000): + for i, b in tqdm(enumerate(dl)): + continue + pm = b['padded_mel'] + pm = torch.nn.functional.pad(pm, (0, 800-pm.shape[-1])) + m = pm if m is None else torch.cat([m, pm], dim=0) + print(m.mean(), m.std()) \ No newline at end of file diff --git a/codes/models/gpt_voice/gpt_audio_segmentor.py b/codes/models/gpt_voice/gpt_audio_segmentor.py index 5ae57e71..10276cf9 100644 --- a/codes/models/gpt_voice/gpt_audio_segmentor.py +++ b/codes/models/gpt_voice/gpt_audio_segmentor.py @@ -49,7 +49,6 @@ class MelEncoder(nn.Module): class GptSegmentor(nn.Module): - MAX_SYMBOLS_PER_PHRASE = 200 MAX_MEL_FRAMES = 2000 // 4 def __init__(self, layers=8, model_dim=512, heads=8): @@ -59,30 +58,28 @@ class GptSegmentor(nn.Module): self.max_mel_frames = self.MAX_MEL_FRAMES self.mel_encoder = MelEncoder(model_dim) self.mel_pos_embedding = nn.Embedding(self.MAX_MEL_FRAMES, model_dim) - self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2+self.MAX_SYMBOLS_PER_PHRASE+self.MAX_MEL_FRAMES, heads=heads, + self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=self.MAX_MEL_FRAMES, heads=heads, attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES) self.final_norm = nn.LayerNorm(model_dim) self.stop_head = nn.Linear(model_dim, 1) - def forward(self, mel_inputs, mel_lengths): - max_len = mel_lengths.max() # This can be done in the dataset layer, but it is easier to do here. - mel_inputs = mel_inputs[:, :, :max_len] - + def forward(self, mel_inputs, termination_points): mel_emb = self.mel_encoder(mel_inputs) - mel_lengths = mel_lengths // 4 # The encoder decimates the mel by a factor of 4. mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) enc = self.gpt(mel_emb) + # The MEL gets decimated to 1/4 the size by the encoder, so we need to do the same to the termination points. + termination_points = F.interpolate(termination_points.unsqueeze(1), size=mel_emb.shape[1], mode='area').squeeze() + termination_points = (termination_points > 0).float() + # Compute loss b, s, _ = enc.shape - mel_pad_mask = ~get_mask_from_lengths(mel_lengths-1, s) - targets = torch.zeros((b,s), device=enc.device).masked_fill_(mel_pad_mask, 1) stop_logits = self.final_norm(enc) stop_logits = self.stop_head(stop_logits) - loss = F.binary_cross_entropy_with_logits(stop_logits.squeeze(-1), targets) + loss = F.binary_cross_entropy_with_logits(stop_logits.squeeze(-1), termination_points) return loss.mean() @@ -95,7 +92,7 @@ def register_gpt_segmentor(opt_net, opt): if __name__ == '__main__': gpt = GptSegmentor() l = gpt(torch.randn(3,80,94), - torch.tensor([18,42,93])) + torch.zeros(3,94)) print(l.shape) #o = gpt.infer(torch.randint(high=24, size=(2,60))) diff --git a/codes/scripts/audio/test_audio_gen.py b/codes/scripts/audio/test_audio_gen.py index 419dc112..0e65f7b1 100644 --- a/codes/scripts/audio/test_audio_gen.py +++ b/codes/scripts/audio/test_audio_gen.py @@ -30,14 +30,14 @@ def forward_pass(model, denoiser, data, output_dir, opt, b): ground_truth_waveforms = denoiser(ground_truth_waveforms) for i in range(pred_waveforms.shape[0]): # Output predicted mels and waveforms. - pred_mel = model.eval_state[opt['eval']['pred_mel']][i] + pred_mel = model.eval_state[opt['eval']['pred_mel']][0][i].unsqueeze(0) pred_mel = ((pred_mel - pred_mel.mean()) / max(abs(pred_mel.min()), pred_mel.max())).unsqueeze(1) torchvision.utils.save_image(pred_mel, osp.join(output_dir, f'{b}_{i}_pred_mel.png')) audio = pred_waveforms[i][0].cpu().numpy() wavfile.write(osp.join(output_dir, f'{b}_{i}.wav'), 22050, audio) if gt: - gt_mel = model.eval_state[opt['eval']['ground_truth_mel']][i] + gt_mel = model.eval_state[opt['eval']['ground_truth_mel']][0][i].unsqueeze(0) gt_mel = ((gt_mel - gt_mel.mean()) / max(abs(gt_mel.min()), gt_mel.max())).unsqueeze(1) torchvision.utils.save_image(gt_mel, osp.join(output_dir, f'{b}_{i}_gt_mel.png')) audio = ground_truth_waveforms[i][0].cpu().numpy() @@ -54,7 +54,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_lrdvae_audio_clips.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_stop_pred_dataset.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt diff --git a/codes/train.py b/codes/train.py index f129db54..7b2da295 100644 --- a/codes/train.py +++ b/codes/train.py @@ -282,7 +282,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_stop_libritts.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()