adf update

This commit is contained in:
James Betker 2022-05-27 09:25:53 -06:00
parent 8587a18717
commit 3db862dd32
6 changed files with 47 additions and 24 deletions

View File

@ -40,7 +40,7 @@ class AttentionBlock(TimestepBlock):
def __init__(self, dim, heads, dropout):
super().__init__()
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout, zero_init_output=False)
self.ff = FeedForward(dim, mult=2, dropout=dropout, zero_init_output=True)
self.ff = FeedForward(dim, mult=1, dropout=dropout, zero_init_output=True)
self.rms_scale_norm = RMSScaleShiftNorm(dim)
def forward(self, x, timestep_emb, rotary_emb):

View File

@ -261,13 +261,13 @@ class DiffusionTts(nn.Module):
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
)
def forward(self, x, timesteps, tokens, conditioning_input=None):
def forward(self, x, timesteps, codes, conditioning_input=None):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param tokens: an aligned text input.
:param codes: an aligned text input.
:return: an [N x C x ...] Tensor of outputs.
"""
with autocast(x.device.type):
@ -276,16 +276,16 @@ class DiffusionTts(nn.Module):
if cm != 0:
pc = (cm-x.shape[-1])/x.shape[-1]
x = F.pad(x, (0,cm-x.shape[-1]))
tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1])))
codes = F.pad(codes, (0, int(pc * codes.shape[-1])))
hs = []
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
# Mask out guidance tokens for un-guided diffusion.
if self.training and self.nil_guidance_fwd_proportion > 0:
token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
tokens = torch.where(token_mask, self.mask_token_id, tokens)
code_emb = self.code_embedding(tokens).permute(0,2,1)
token_mask = torch.rand(codes.shape, device=codes.device) < self.nil_guidance_fwd_proportion
codes = torch.where(token_mask, self.mask_token_id, codes)
code_emb = self.code_embedding(codes).permute(0, 2, 1)
cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1)
cond_emb = self.conditioning_encoder(cond_emb)[:, 0]
code_emb = self.codes_encoder(code_emb.permute(0,2,1), norm_scale_shift_inp=cond_emb).permute(0,2,1)

View File

@ -332,7 +332,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_code2mel.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)

View File

@ -13,6 +13,7 @@ import trainer.networks as networks
from trainer.base_model import BaseModel
from trainer.batch_size_optimizer import create_batch_size_optimizer
from trainer.inject import create_injector
from trainer.injectors.audio_injectors import normalize_mel
from trainer.steps import ConfigurableStep
from trainer.experiments.experiments import get_experiment_for_name
import torchvision.utils as utils
@ -354,9 +355,9 @@ class ExtensibleTrainer(BaseModel):
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and it % self.opt['logger']['visual_debug_rate'] == 0:
def fix_image(img):
if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
if img.min() < -2:
img = normalize_mel(img)
img = img.unsqueeze(dim=1)
# Normalize so spectrogram is easier to view.
img = (img - img.mean()) / img.std()
if img.shape[1] > 3:
img = img[:, :3, :, :]
if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False):

View File

@ -2,6 +2,7 @@ import os
import os.path as osp
import torch
import torchaudio
import torchvision.utils
from pytorch_fid.fid_score import calculate_frechet_distance
from torch import distributed
from tqdm import tqdm
@ -17,7 +18,7 @@ from models.clip.mel_text_clip import MelTextCLIP
from models.audio.tts.tacotron2 import text_to_sequence
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
from trainer.injectors.audio_injectors import denormalize_mel
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
@ -57,6 +58,9 @@ class AudioDiffusionFid(evaluator.Evaluator):
elif mode == 'vocoder':
self.local_modules['dvae'] = load_speech_dvae().cpu()
self.diffusion_fn = self.perform_diffusion_vocoder
elif mode == 'ctc_to_mel':
self.diffusion_fn = self.perform_diffusion_ctc
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
elif 'tts9_mel' in mode:
mel_means, self.mel_max, self.mel_min, mel_stds, mel_vars = torch.load('../experiments/univnet_mel_norms.pth')
self.bpe_tokenizer = VoiceBpeTokenizer('../experiments/bpe_lowercase_asr_256.json')
@ -167,6 +171,22 @@ class AudioDiffusionFid(evaluator.Evaluator):
real_dec = self.local_modules['vocoder'].inference(univnet_mel)
return gen_wav.float(), real_dec, SAMPLE_RATE
def perform_diffusion_ctc(self, audio, codes, text):
SAMPLE_RATE = 24000
real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0)
univnet_mel = wav_to_univnet_mel(real_resampled, do_normalization=True)
output_shape = univnet_mel.shape
cond_mel = TorchMelSpectrogramInjector({'n_mel_channels': 100, 'mel_fmax': 11000, 'filter_length': 8000, 'normalize': True,
'true_normalization': True, 'in': 'in', 'out': 'out'}, {})({'in': audio})['out']
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape, model_kwargs={'codes': codes.unsqueeze(0),
'conditioning_input': cond_mel})
gen_mel_denorm = denormalize_mel(gen_mel)
gen_wav = self.local_modules['vocoder'].inference(gen_mel_denorm)
real_dec = self.local_modules['vocoder'].inference(denormalize_mel(univnet_mel))
return gen_wav.float(), real_dec, gen_mel, univnet_mel, SAMPLE_RATE
def load_projector(self):
"""
Builds the CLIP model used to project speech into a latent. This model has fixed parameters and a fixed loading
@ -237,12 +257,14 @@ class AudioDiffusionFid(evaluator.Evaluator):
path, text, codes = self.data[i + self.env['rank']]
audio = load_audio(path, 22050).to(self.dev)
codes = codes.to(self.dev)
sample, ref, sample_rate = self.diffusion_fn(audio, codes, text)
sample, ref, gen_mel, ref_mel, sample_rate = self.diffusion_fn(audio, codes, text)
gen_projections.append(self.project(projector, sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
real_projections.append(self.project(projector, ref, sample_rate).cpu())
intelligibility_losses.append(self.intelligibility_loss(w2v, sample, ref, sample_rate, text))
torchvision.utils.save_image((gen_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f'{self.env["rank"]}_{i}_mel.png'))
torchvision.utils.save_image((ref_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f'{self.env["rank"]}_{i}_mel_target.png'))
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.squeeze(0).cpu(), sample_rate)
gen_projections = torch.stack(gen_projections, dim=0)
@ -283,12 +305,12 @@ if __name__ == '__main__':
if __name__ == '__main__':
# 34k; no conditioning_free: {'frechet_distance': tensor(1.4559, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(151.9112, device='cuda:0')}
# 34k; conditioning_free: {'frechet_distance': tensor(1.4059, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(118.3377, device='cuda:0')}
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts_mel_flat_autoregressive_inputs.yml', 'generator',
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_speech_diffusion_from_ctc_und10\\train.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\tts_flat_autoregressive_inputs_r2_initial\\models\\2000_generator.pth').cuda()
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
'conditioning_free': True, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel_autoin'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 563, 'device': 'cuda', 'opt': {}}
load_path='X:\\dlas\\experiments\\train_speech_diffusion_from_ctc_und10\\models\\43000_generator_ema.pth').cuda()
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-oco-realtext.tsv', 'diffusion_steps': 100,
'conditioning_free': False, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'ctc_to_mel'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 223, 'device': 'cuda', 'opt': {}}
eval = AudioDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval())

View File

@ -246,13 +246,13 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_flat.yml', 'generator',
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen.yml', 'generator',
also_load_savepoint=False,
#load_path='X:\\dlas\\experiments\\train_music_diffusion_flat\\models\\33000_generator_ema.pth'
load_path='X:\\dlas\\experiments\\train_music_waveform_gen_reformed_mel\\models\\57500_generator_ema.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
'conditioning_free': False, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 25, 'device': 'cuda', 'opt': {}}
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500,
'conditioning_free': True, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 26, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval())