forked from mrq/DL-Art-School
few things with gap filling
This commit is contained in:
parent
b83b53cf84
commit
d8925ccde5
codes
models/audio/music
trainer
|
@ -186,15 +186,14 @@ class MusicGenerator(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def timestep_independent(self, truth, expected_seq_len, return_code_pred):
|
def timestep_independent(self, truth, expected_seq_len, return_code_pred):
|
||||||
code_emb = self.conditioner(truth)
|
truth_emb = self.conditioner(truth)
|
||||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
if self.training and self.unconditioned_percentage > 0:
|
if self.training and self.unconditioned_percentage > 0:
|
||||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
unconditioned_batches = torch.rand((truth_emb.shape[0], 1, 1),
|
||||||
device=code_emb.device) < self.unconditioned_percentage
|
device=truth_emb.device) < self.unconditioned_percentage
|
||||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1),
|
truth_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1),
|
||||||
code_emb)
|
truth_emb)
|
||||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
return truth_emb
|
||||||
return expanded_code_emb
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False):
|
def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False):
|
||||||
|
@ -212,20 +211,21 @@ class MusicGenerator(nn.Module):
|
||||||
|
|
||||||
unused_params = []
|
unused_params = []
|
||||||
if conditioning_free:
|
if conditioning_free:
|
||||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
truth_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||||
unused_params.extend(list(self.conditioner.parameters()))
|
unused_params.extend(list(self.conditioner.parameters()))
|
||||||
else:
|
else:
|
||||||
if precomputed_aligned_embeddings is not None:
|
if precomputed_aligned_embeddings is not None:
|
||||||
code_emb = precomputed_aligned_embeddings
|
truth_emb = precomputed_aligned_embeddings
|
||||||
else:
|
else:
|
||||||
|
if self.training:
|
||||||
truth = self.do_masking(truth)
|
truth = self.do_masking(truth)
|
||||||
code_emb = self.timestep_independent(truth, x.shape[-1], True)
|
truth_emb = self.timestep_independent(truth, x.shape[-1], True)
|
||||||
unused_params.append(self.unconditioned_embedding)
|
unused_params.append(self.unconditioned_embedding)
|
||||||
|
|
||||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
truth_emb = self.conditioning_timestep_integrator(truth_emb, time_emb)
|
||||||
x = self.inp_block(x)
|
x = self.inp_block(x)
|
||||||
x = torch.cat([x, code_emb], dim=1)
|
x = torch.cat([x, truth_emb], dim=1)
|
||||||
x = self.integrating_conv(x)
|
x = self.integrating_conv(x)
|
||||||
for i, lyr in enumerate(self.layers):
|
for i, lyr in enumerate(self.layers):
|
||||||
# Do layer drop where applicable. Do not drop first and last layers.
|
# Do layer drop where applicable. Do not drop first and last layers.
|
||||||
|
|
|
@ -95,7 +95,7 @@ class ResBlock(TimestepBlock):
|
||||||
h = self.out_layers(h)
|
h = self.out_layers(h)
|
||||||
return self.skip_connection(x) + h
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
class DiffusionTts(nn.Module):
|
class DiffusionWaveformGen(nn.Module):
|
||||||
"""
|
"""
|
||||||
The full UNet model with attention and timestep embedding.
|
The full UNet model with attention and timestep embedding.
|
||||||
|
|
||||||
|
@ -465,7 +465,7 @@ class DiffusionTts(nn.Module):
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_unet_diffusion_waveform_gen(opt_net, opt):
|
def register_unet_diffusion_waveform_gen(opt_net, opt):
|
||||||
return DiffusionTts(**opt_net['kwargs'])
|
return DiffusionWaveformGen(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -473,7 +473,7 @@ if __name__ == '__main__':
|
||||||
aligned_latent = torch.randn(2,388,1024)
|
aligned_latent = torch.randn(2,388,1024)
|
||||||
aligned_sequence = torch.randn(2,120,220)
|
aligned_sequence = torch.randn(2,120,220)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = DiffusionTts(128,
|
model = DiffusionWaveformGen(128,
|
||||||
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
||||||
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
||||||
token_conditioning_resolutions=[1,4,16,64],
|
token_conditioning_resolutions=[1,4,16,64],
|
||||||
|
|
|
@ -17,7 +17,7 @@ from models.clip.mel_text_clip import MelTextCLIP
|
||||||
from models.audio.tts.tacotron2 import text_to_sequence
|
from models.audio.tts.tacotron2 import text_to_sequence
|
||||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||||
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
||||||
from trainer.injectors.audio_injectors import denormalize_tacotron_mel
|
from trainer.injectors.audio_injectors import denormalize_mel
|
||||||
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,7 +161,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
model_kwargs={'aligned_conditioning': mel_codes,
|
model_kwargs={'aligned_conditioning': mel_codes,
|
||||||
'conditioning_input': univnet_mel})
|
'conditioning_input': univnet_mel})
|
||||||
# denormalize mel
|
# denormalize mel
|
||||||
gen_mel = denormalize_tacotron_mel(gen_mel)
|
gen_mel = denormalize_mel(gen_mel)
|
||||||
|
|
||||||
gen_wav = self.local_modules['vocoder'].inference(gen_mel)
|
gen_wav = self.local_modules['vocoder'].inference(gen_mel)
|
||||||
real_dec = self.local_modules['vocoder'].inference(univnet_mel)
|
real_dec = self.local_modules['vocoder'].inference(univnet_mel)
|
||||||
|
|
|
@ -16,13 +16,15 @@ import trainer.eval.evaluator as evaluator
|
||||||
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
|
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
|
||||||
from data.audio.unsupervised_audio_dataset import load_audio
|
from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||||
|
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
|
||||||
from models.clip.mel_text_clip import MelTextCLIP
|
from models.clip.mel_text_clip import MelTextCLIP
|
||||||
from models.audio.tts.tacotron2 import text_to_sequence
|
from models.audio.tts.tacotron2 import text_to_sequence
|
||||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||||
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||||
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
||||||
from trainer.injectors.audio_injectors import denormalize_tacotron_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d
|
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
|
||||||
|
normalize_mel
|
||||||
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,15 +52,28 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
conditioning_free=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k)
|
conditioning_free=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k)
|
||||||
self.dev = self.env['device']
|
self.dev = self.env['device']
|
||||||
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
||||||
self.local_modules = {}
|
|
||||||
if mode == 'standard':
|
self.spec_decoder = DiffusionWaveformGen(model_channels=256, in_channels=16, in_mel_channels=256, out_channels=32,
|
||||||
self.diffusion_fn = self.perform_diffusion_standard
|
channel_mult=[1,2,3,4], num_res_blocks=[3,3,3,3], token_conditioning_resolutions=[1,4],
|
||||||
|
num_heads=8,
|
||||||
|
dropout=0, kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, unconditioned_percentage=0)
|
||||||
|
self.spec_decoder.load_state_dict(torch.load('../experiments/music_waveform_gen.pth', map_location=torch.device('cpu')))
|
||||||
|
self.local_modules = {'spec_decoder': self.spec_decoder}
|
||||||
|
|
||||||
|
if mode == 'spec_decode':
|
||||||
|
self.diffusion_fn = self.perform_diffusion_spec_decode
|
||||||
|
elif 'gap_fill_' in mode:
|
||||||
|
self.diffusion_fn = self.perform_diffusion_gap_fill
|
||||||
|
if '_freq' in mode:
|
||||||
|
self.gap_gen_fn = self.gen_freq_gap
|
||||||
|
else:
|
||||||
|
self.gap_gen_fn = self.gen_time_gap
|
||||||
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})
|
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})
|
||||||
|
|
||||||
def load_data(self, path):
|
def load_data(self, path):
|
||||||
return list(glob(f'{path}/*.wav'))
|
return list(glob(f'{path}/*.wav'))
|
||||||
|
|
||||||
def perform_diffusion_standard(self, audio, sample_rate=22050):
|
def perform_diffusion_spec_decode(self, audio, sample_rate=22050):
|
||||||
if sample_rate != sample_rate:
|
if sample_rate != sample_rate:
|
||||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
|
@ -69,7 +84,47 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
gen = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
|
gen = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
|
||||||
model_kwargs={'aligned_conditioning': mel})
|
model_kwargs={'aligned_conditioning': mel})
|
||||||
gen = pixel_shuffle_1d(gen, 16)
|
gen = pixel_shuffle_1d(gen, 16)
|
||||||
real_resampled = real_resampled + torch.FloatTensor(real_resampled.shape).uniform_(0.0, 1e-5).to(real_resampled.device)
|
|
||||||
|
return gen, real_resampled, sample_rate
|
||||||
|
|
||||||
|
def gen_freq_gap(self, mel, band_range=(130,150)):
|
||||||
|
gap_start, gap_end = band_range
|
||||||
|
mel[:, gap_start:gap_end] = 0
|
||||||
|
return mel
|
||||||
|
|
||||||
|
def gen_time_gap(self, mel):
|
||||||
|
mel[:, :, 22050*5:22050*6] = 0
|
||||||
|
return mel
|
||||||
|
|
||||||
|
def perform_diffusion_gap_fill(self, audio, sample_rate=22050, band_range=(130,150)):
|
||||||
|
if sample_rate != sample_rate:
|
||||||
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
real_resampled = audio
|
||||||
|
audio = audio.unsqueeze(0)
|
||||||
|
|
||||||
|
# Fetch the MEL and mask out the requested bands.
|
||||||
|
mel = self.spec_fn({'in': audio})['out']
|
||||||
|
mel = normalize_mel(mel)
|
||||||
|
mel = self.gap_gen_fn(mel)
|
||||||
|
output_shape = (1, mel.shape[1], mel.shape[2])
|
||||||
|
|
||||||
|
# Repair the MEL with the given model.
|
||||||
|
spec = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
|
||||||
|
model_kwargs={'truth': mel})
|
||||||
|
import torchvision
|
||||||
|
torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, 'gen.png')
|
||||||
|
torchvision.utils.save_image((mel.unsqueeze(1) + 1) / 2, 'mel.png')
|
||||||
|
spec = denormalize_mel(spec)
|
||||||
|
|
||||||
|
# Re-convert the resulting MEL back into audio using the spectrogram decoder.
|
||||||
|
output_shape = (1, 16, audio.shape[-1] // 16)
|
||||||
|
self.spec_decoder = self.spec_decoder.to(audio.device)
|
||||||
|
# Cool fact: we can re-use the diffuser for the spectrogram diffuser since it has the same parametrization.
|
||||||
|
gen = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
|
||||||
|
model_kwargs={'aligned_conditioning': spec})
|
||||||
|
gen = pixel_shuffle_1d(gen, 16)
|
||||||
|
|
||||||
return gen, real_resampled, sample_rate
|
return gen, real_resampled, sample_rate
|
||||||
|
|
||||||
def load_projector(self):
|
def load_projector(self):
|
||||||
|
@ -148,12 +203,12 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen3.yml', 'generator',
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator',
|
||||||
also_load_savepoint=False,
|
also_load_savepoint=False,
|
||||||
load_path='X:\\dlas\\experiments\\train_music_waveform_gen3_r1\\models\\10000_generator_ema.pth').cuda()
|
load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\5000_generator.pth').cuda()
|
||||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50,
|
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
|
||||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'standard'}
|
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'}
|
||||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}}
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}}
|
||||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||||
print(eval.perform_eval())
|
print(eval.perform_eval())
|
||||||
|
|
|
@ -282,6 +282,20 @@ class ConditioningLatentDistributionDivergenceInjector(Injector):
|
||||||
return {self.output: mean_loss, self.var_loss_key: var_loss}
|
return {self.output: mean_loss, self.var_loss_key: var_loss}
|
||||||
|
|
||||||
|
|
||||||
|
class RandomScaleInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.min_samples = opt['min_samples']
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
inp = state[self.input]
|
||||||
|
if self.min_samples < inp.shape[-1]:
|
||||||
|
samples = random.randint(self.min_samples, inp.shape[-1])
|
||||||
|
start = random.randint(0, inp.shape[-1]-samples)
|
||||||
|
inp = inp[:, :, start:start+samples]
|
||||||
|
return {self.output: inp}
|
||||||
|
|
||||||
|
|
||||||
def pixel_shuffle_1d(x, upscale_factor):
|
def pixel_shuffle_1d(x, upscale_factor):
|
||||||
batch_size, channels, steps = x.size()
|
batch_size, channels, steps = x.size()
|
||||||
channels //= upscale_factor
|
channels //= upscale_factor
|
||||||
|
|
Loading…
Reference in New Issue
Block a user