few things with gap filling

This commit is contained in:
James Betker 2022-05-06 14:33:44 -06:00
parent b83b53cf84
commit d8925ccde5
5 changed files with 108 additions and 39 deletions

View File

@ -186,15 +186,14 @@ class MusicGenerator(nn.Module):
def timestep_independent(self, truth, expected_seq_len, return_code_pred): def timestep_independent(self, truth, expected_seq_len, return_code_pred):
code_emb = self.conditioner(truth) truth_emb = self.conditioner(truth)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0: if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), unconditioned_batches = torch.rand((truth_emb.shape[0], 1, 1),
device=code_emb.device) < self.unconditioned_percentage device=truth_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1), truth_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1),
code_emb) truth_emb)
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest') return truth_emb
return expanded_code_emb
def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False): def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False):
@ -212,20 +211,21 @@ class MusicGenerator(nn.Module):
unused_params = [] unused_params = []
if conditioning_free: if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) truth_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.conditioner.parameters())) unused_params.extend(list(self.conditioner.parameters()))
else: else:
if precomputed_aligned_embeddings is not None: if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings truth_emb = precomputed_aligned_embeddings
else: else:
truth = self.do_masking(truth) if self.training:
code_emb = self.timestep_independent(truth, x.shape[-1], True) truth = self.do_masking(truth)
truth_emb = self.timestep_independent(truth, x.shape[-1], True)
unused_params.append(self.unconditioned_embedding) unused_params.append(self.unconditioned_embedding)
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) truth_emb = self.conditioning_timestep_integrator(truth_emb, time_emb)
x = self.inp_block(x) x = self.inp_block(x)
x = torch.cat([x, code_emb], dim=1) x = torch.cat([x, truth_emb], dim=1)
x = self.integrating_conv(x) x = self.integrating_conv(x)
for i, lyr in enumerate(self.layers): for i, lyr in enumerate(self.layers):
# Do layer drop where applicable. Do not drop first and last layers. # Do layer drop where applicable. Do not drop first and last layers.

View File

@ -95,7 +95,7 @@ class ResBlock(TimestepBlock):
h = self.out_layers(h) h = self.out_layers(h)
return self.skip_connection(x) + h return self.skip_connection(x) + h
class DiffusionTts(nn.Module): class DiffusionWaveformGen(nn.Module):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
@ -465,7 +465,7 @@ class DiffusionTts(nn.Module):
@register_model @register_model
def register_unet_diffusion_waveform_gen(opt_net, opt): def register_unet_diffusion_waveform_gen(opt_net, opt):
return DiffusionTts(**opt_net['kwargs']) return DiffusionWaveformGen(**opt_net['kwargs'])
if __name__ == '__main__': if __name__ == '__main__':
@ -473,17 +473,17 @@ if __name__ == '__main__':
aligned_latent = torch.randn(2,388,1024) aligned_latent = torch.randn(2,388,1024)
aligned_sequence = torch.randn(2,120,220) aligned_sequence = torch.randn(2,120,220)
ts = torch.LongTensor([600, 600]) ts = torch.LongTensor([600, 600])
model = DiffusionTts(128, model = DiffusionWaveformGen(128,
channel_mult=[1,1.5,2, 3, 4, 6, 8], channel_mult=[1,1.5,2, 3, 4, 6, 8],
num_res_blocks=[2, 2, 2, 2, 2, 2, 1], num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
token_conditioning_resolutions=[1,4,16,64], token_conditioning_resolutions=[1,4,16,64],
attention_resolutions=[], attention_resolutions=[],
num_heads=8, num_heads=8,
kernel_size=3, kernel_size=3,
scale_factor=2, scale_factor=2,
time_embed_dim_multiplier=4, time_embed_dim_multiplier=4,
super_sampling=False, super_sampling=False,
efficient_convs=False) efficient_convs=False)
# Test with latent aligned conditioning # Test with latent aligned conditioning
o = model(clip, ts, aligned_latent) o = model(clip, ts, aligned_latent)
# Test with sequence aligned conditioning # Test with sequence aligned conditioning

View File

@ -17,7 +17,7 @@ from models.clip.mel_text_clip import MelTextCLIP
from models.audio.tts.tacotron2 import text_to_sequence from models.audio.tts.tacotron2 import text_to_sequence
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \ from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
from trainer.injectors.audio_injectors import denormalize_tacotron_mel from trainer.injectors.audio_injectors import denormalize_mel
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
@ -161,7 +161,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
model_kwargs={'aligned_conditioning': mel_codes, model_kwargs={'aligned_conditioning': mel_codes,
'conditioning_input': univnet_mel}) 'conditioning_input': univnet_mel})
# denormalize mel # denormalize mel
gen_mel = denormalize_tacotron_mel(gen_mel) gen_mel = denormalize_mel(gen_mel)
gen_wav = self.local_modules['vocoder'].inference(gen_mel) gen_wav = self.local_modules['vocoder'].inference(gen_mel)
real_dec = self.local_modules['vocoder'].inference(univnet_mel) real_dec = self.local_modules['vocoder'].inference(univnet_mel)

View File

@ -16,13 +16,15 @@ import trainer.eval.evaluator as evaluator
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
from data.audio.unsupervised_audio_dataset import load_audio from data.audio.unsupervised_audio_dataset import load_audio
from data.audio.voice_tokenizer import VoiceBpeTokenizer from data.audio.voice_tokenizer import VoiceBpeTokenizer
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
from models.clip.mel_text_clip import MelTextCLIP from models.clip.mel_text_clip import MelTextCLIP
from models.audio.tts.tacotron2 import text_to_sequence from models.audio.tts.tacotron2 import text_to_sequence
from models.diffusion.gaussian_diffusion import get_named_beta_schedule from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.respace import space_timesteps, SpacedDiffusion from models.diffusion.respace import space_timesteps, SpacedDiffusion
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \ from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
from trainer.injectors.audio_injectors import denormalize_tacotron_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
normalize_mel
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
@ -50,15 +52,28 @@ class MusicDiffusionFid(evaluator.Evaluator):
conditioning_free=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k) conditioning_free=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k)
self.dev = self.env['device'] self.dev = self.env['device']
mode = opt_get(opt_eval, ['diffusion_type'], 'tts') mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
self.local_modules = {}
if mode == 'standard': self.spec_decoder = DiffusionWaveformGen(model_channels=256, in_channels=16, in_mel_channels=256, out_channels=32,
self.diffusion_fn = self.perform_diffusion_standard channel_mult=[1,2,3,4], num_res_blocks=[3,3,3,3], token_conditioning_resolutions=[1,4],
num_heads=8,
dropout=0, kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, unconditioned_percentage=0)
self.spec_decoder.load_state_dict(torch.load('../experiments/music_waveform_gen.pth', map_location=torch.device('cpu')))
self.local_modules = {'spec_decoder': self.spec_decoder}
if mode == 'spec_decode':
self.diffusion_fn = self.perform_diffusion_spec_decode
elif 'gap_fill_' in mode:
self.diffusion_fn = self.perform_diffusion_gap_fill
if '_freq' in mode:
self.gap_gen_fn = self.gen_freq_gap
else:
self.gap_gen_fn = self.gen_time_gap
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {}) self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})
def load_data(self, path): def load_data(self, path):
return list(glob(f'{path}/*.wav')) return list(glob(f'{path}/*.wav'))
def perform_diffusion_standard(self, audio, sample_rate=22050): def perform_diffusion_spec_decode(self, audio, sample_rate=22050):
if sample_rate != sample_rate: if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else: else:
@ -69,7 +84,47 @@ class MusicDiffusionFid(evaluator.Evaluator):
gen = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device), gen = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
model_kwargs={'aligned_conditioning': mel}) model_kwargs={'aligned_conditioning': mel})
gen = pixel_shuffle_1d(gen, 16) gen = pixel_shuffle_1d(gen, 16)
real_resampled = real_resampled + torch.FloatTensor(real_resampled.shape).uniform_(0.0, 1e-5).to(real_resampled.device)
return gen, real_resampled, sample_rate
def gen_freq_gap(self, mel, band_range=(130,150)):
gap_start, gap_end = band_range
mel[:, gap_start:gap_end] = 0
return mel
def gen_time_gap(self, mel):
mel[:, :, 22050*5:22050*6] = 0
return mel
def perform_diffusion_gap_fill(self, audio, sample_rate=22050, band_range=(130,150)):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else:
real_resampled = audio
audio = audio.unsqueeze(0)
# Fetch the MEL and mask out the requested bands.
mel = self.spec_fn({'in': audio})['out']
mel = normalize_mel(mel)
mel = self.gap_gen_fn(mel)
output_shape = (1, mel.shape[1], mel.shape[2])
# Repair the MEL with the given model.
spec = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
model_kwargs={'truth': mel})
import torchvision
torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, 'gen.png')
torchvision.utils.save_image((mel.unsqueeze(1) + 1) / 2, 'mel.png')
spec = denormalize_mel(spec)
# Re-convert the resulting MEL back into audio using the spectrogram decoder.
output_shape = (1, 16, audio.shape[-1] // 16)
self.spec_decoder = self.spec_decoder.to(audio.device)
# Cool fact: we can re-use the diffuser for the spectrogram diffuser since it has the same parametrization.
gen = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
model_kwargs={'aligned_conditioning': spec})
gen = pixel_shuffle_1d(gen, 16)
return gen, real_resampled, sample_rate return gen, real_resampled, sample_rate
def load_projector(self): def load_projector(self):
@ -148,12 +203,12 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__': if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen3.yml', 'generator', diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator',
also_load_savepoint=False, also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_waveform_gen3_r1\\models\\10000_generator_ema.pth').cuda() load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\5000_generator.pth').cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50, opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
'conditioning_free': False, 'conditioning_free_k': 1, 'conditioning_free': False, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'standard'} 'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env) eval = MusicDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval()) print(eval.perform_eval())

View File

@ -282,6 +282,20 @@ class ConditioningLatentDistributionDivergenceInjector(Injector):
return {self.output: mean_loss, self.var_loss_key: var_loss} return {self.output: mean_loss, self.var_loss_key: var_loss}
class RandomScaleInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.min_samples = opt['min_samples']
def forward(self, state):
inp = state[self.input]
if self.min_samples < inp.shape[-1]:
samples = random.randint(self.min_samples, inp.shape[-1])
start = random.randint(0, inp.shape[-1]-samples)
inp = inp[:, :, start:start+samples]
return {self.output: inp}
def pixel_shuffle_1d(x, upscale_factor): def pixel_shuffle_1d(x, upscale_factor):
batch_size, channels, steps = x.size() batch_size, channels, steps = x.size()
channels //= upscale_factor channels //= upscale_factor