adf update
This commit is contained in:
parent
8587a18717
commit
3db862dd32
|
@ -40,7 +40,7 @@ class AttentionBlock(TimestepBlock):
|
|||
def __init__(self, dim, heads, dropout):
|
||||
super().__init__()
|
||||
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout, zero_init_output=False)
|
||||
self.ff = FeedForward(dim, mult=2, dropout=dropout, zero_init_output=True)
|
||||
self.ff = FeedForward(dim, mult=1, dropout=dropout, zero_init_output=True)
|
||||
self.rms_scale_norm = RMSScaleShiftNorm(dim)
|
||||
|
||||
def forward(self, x, timestep_emb, rotary_emb):
|
||||
|
|
|
@ -261,13 +261,13 @@ class DiffusionTts(nn.Module):
|
|||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, tokens, conditioning_input=None):
|
||||
def forward(self, x, timesteps, codes, conditioning_input=None):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param tokens: an aligned text input.
|
||||
:param codes: an aligned text input.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
with autocast(x.device.type):
|
||||
|
@ -276,16 +276,16 @@ class DiffusionTts(nn.Module):
|
|||
if cm != 0:
|
||||
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||
tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1])))
|
||||
codes = F.pad(codes, (0, int(pc * codes.shape[-1])))
|
||||
|
||||
hs = []
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
# Mask out guidance tokens for un-guided diffusion.
|
||||
if self.training and self.nil_guidance_fwd_proportion > 0:
|
||||
token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
|
||||
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
||||
code_emb = self.code_embedding(tokens).permute(0,2,1)
|
||||
token_mask = torch.rand(codes.shape, device=codes.device) < self.nil_guidance_fwd_proportion
|
||||
codes = torch.where(token_mask, self.mask_token_id, codes)
|
||||
code_emb = self.code_embedding(codes).permute(0, 2, 1)
|
||||
cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1)
|
||||
cond_emb = self.conditioning_encoder(cond_emb)[:, 0]
|
||||
code_emb = self.codes_encoder(code_emb.permute(0,2,1), norm_scale_shift_inp=cond_emb).permute(0,2,1)
|
||||
|
|
|
@ -332,7 +332,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_code2mel.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -13,6 +13,7 @@ import trainer.networks as networks
|
|||
from trainer.base_model import BaseModel
|
||||
from trainer.batch_size_optimizer import create_batch_size_optimizer
|
||||
from trainer.inject import create_injector
|
||||
from trainer.injectors.audio_injectors import normalize_mel
|
||||
from trainer.steps import ConfigurableStep
|
||||
from trainer.experiments.experiments import get_experiment_for_name
|
||||
import torchvision.utils as utils
|
||||
|
@ -354,9 +355,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and it % self.opt['logger']['visual_debug_rate'] == 0:
|
||||
def fix_image(img):
|
||||
if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
|
||||
if img.min() < -2:
|
||||
img = normalize_mel(img)
|
||||
img = img.unsqueeze(dim=1)
|
||||
# Normalize so spectrogram is easier to view.
|
||||
img = (img - img.mean()) / img.std()
|
||||
if img.shape[1] > 3:
|
||||
img = img[:, :3, :, :]
|
||||
if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False):
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
import os.path as osp
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchvision.utils
|
||||
from pytorch_fid.fid_score import calculate_frechet_distance
|
||||
from torch import distributed
|
||||
from tqdm import tqdm
|
||||
|
@ -17,7 +18,7 @@ from models.clip.mel_text_clip import MelTextCLIP
|
|||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
||||
from trainer.injectors.audio_injectors import denormalize_mel
|
||||
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector
|
||||
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
||||
|
||||
|
||||
|
@ -57,6 +58,9 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
elif mode == 'vocoder':
|
||||
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||
self.diffusion_fn = self.perform_diffusion_vocoder
|
||||
elif mode == 'ctc_to_mel':
|
||||
self.diffusion_fn = self.perform_diffusion_ctc
|
||||
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
|
||||
elif 'tts9_mel' in mode:
|
||||
mel_means, self.mel_max, self.mel_min, mel_stds, mel_vars = torch.load('../experiments/univnet_mel_norms.pth')
|
||||
self.bpe_tokenizer = VoiceBpeTokenizer('../experiments/bpe_lowercase_asr_256.json')
|
||||
|
@ -167,6 +171,22 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
real_dec = self.local_modules['vocoder'].inference(univnet_mel)
|
||||
return gen_wav.float(), real_dec, SAMPLE_RATE
|
||||
|
||||
def perform_diffusion_ctc(self, audio, codes, text):
|
||||
SAMPLE_RATE = 24000
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0)
|
||||
univnet_mel = wav_to_univnet_mel(real_resampled, do_normalization=True)
|
||||
output_shape = univnet_mel.shape
|
||||
cond_mel = TorchMelSpectrogramInjector({'n_mel_channels': 100, 'mel_fmax': 11000, 'filter_length': 8000, 'normalize': True,
|
||||
'true_normalization': True, 'in': 'in', 'out': 'out'}, {})({'in': audio})['out']
|
||||
|
||||
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape, model_kwargs={'codes': codes.unsqueeze(0),
|
||||
'conditioning_input': cond_mel})
|
||||
gen_mel_denorm = denormalize_mel(gen_mel)
|
||||
|
||||
gen_wav = self.local_modules['vocoder'].inference(gen_mel_denorm)
|
||||
real_dec = self.local_modules['vocoder'].inference(denormalize_mel(univnet_mel))
|
||||
return gen_wav.float(), real_dec, gen_mel, univnet_mel, SAMPLE_RATE
|
||||
|
||||
def load_projector(self):
|
||||
"""
|
||||
Builds the CLIP model used to project speech into a latent. This model has fixed parameters and a fixed loading
|
||||
|
@ -237,12 +257,14 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
path, text, codes = self.data[i + self.env['rank']]
|
||||
audio = load_audio(path, 22050).to(self.dev)
|
||||
codes = codes.to(self.dev)
|
||||
sample, ref, sample_rate = self.diffusion_fn(audio, codes, text)
|
||||
sample, ref, gen_mel, ref_mel, sample_rate = self.diffusion_fn(audio, codes, text)
|
||||
|
||||
gen_projections.append(self.project(projector, sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
|
||||
real_projections.append(self.project(projector, ref, sample_rate).cpu())
|
||||
intelligibility_losses.append(self.intelligibility_loss(w2v, sample, ref, sample_rate, text))
|
||||
|
||||
torchvision.utils.save_image((gen_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f'{self.env["rank"]}_{i}_mel.png'))
|
||||
torchvision.utils.save_image((ref_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f'{self.env["rank"]}_{i}_mel_target.png'))
|
||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
|
||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.squeeze(0).cpu(), sample_rate)
|
||||
gen_projections = torch.stack(gen_projections, dim=0)
|
||||
|
@ -283,12 +305,12 @@ if __name__ == '__main__':
|
|||
if __name__ == '__main__':
|
||||
# 34k; no conditioning_free: {'frechet_distance': tensor(1.4559, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(151.9112, device='cuda:0')}
|
||||
# 34k; conditioning_free: {'frechet_distance': tensor(1.4059, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(118.3377, device='cuda:0')}
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts_mel_flat_autoregressive_inputs.yml', 'generator',
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_speech_diffusion_from_ctc_und10\\train.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\tts_flat_autoregressive_inputs_r2_initial\\models\\2000_generator.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel_autoin'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 563, 'device': 'cuda', 'opt': {}}
|
||||
load_path='X:\\dlas\\experiments\\train_speech_diffusion_from_ctc_und10\\models\\43000_generator_ema.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-oco-realtext.tsv', 'diffusion_steps': 100,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'ctc_to_mel'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 223, 'device': 'cuda', 'opt': {}}
|
||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
|
@ -246,13 +246,13 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_flat.yml', 'generator',
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
#load_path='X:\\dlas\\experiments\\train_music_diffusion_flat\\models\\33000_generator_ema.pth'
|
||||
load_path='X:\\dlas\\experiments\\train_music_waveform_gen_reformed_mel\\models\\57500_generator_ema.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 25, 'device': 'cuda', 'opt': {}}
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500,
|
||||
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 26, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
Loading…
Reference in New Issue
Block a user