forked from mrq/DL-Art-School
few things with gap filling
This commit is contained in:
parent
b83b53cf84
commit
d8925ccde5
|
@ -186,15 +186,14 @@ class MusicGenerator(nn.Module):
|
|||
|
||||
|
||||
def timestep_independent(self, truth, expected_seq_len, return_code_pred):
|
||||
code_emb = self.conditioner(truth)
|
||||
truth_emb = self.conditioner(truth)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1),
|
||||
code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
||||
return expanded_code_emb
|
||||
unconditioned_batches = torch.rand((truth_emb.shape[0], 1, 1),
|
||||
device=truth_emb.device) < self.unconditioned_percentage
|
||||
truth_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1),
|
||||
truth_emb)
|
||||
return truth_emb
|
||||
|
||||
|
||||
def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False):
|
||||
|
@ -212,20 +211,21 @@ class MusicGenerator(nn.Module):
|
|||
|
||||
unused_params = []
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
truth_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(list(self.conditioner.parameters()))
|
||||
else:
|
||||
if precomputed_aligned_embeddings is not None:
|
||||
code_emb = precomputed_aligned_embeddings
|
||||
truth_emb = precomputed_aligned_embeddings
|
||||
else:
|
||||
truth = self.do_masking(truth)
|
||||
code_emb = self.timestep_independent(truth, x.shape[-1], True)
|
||||
if self.training:
|
||||
truth = self.do_masking(truth)
|
||||
truth_emb = self.timestep_independent(truth, x.shape[-1], True)
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||
truth_emb = self.conditioning_timestep_integrator(truth_emb, time_emb)
|
||||
x = self.inp_block(x)
|
||||
x = torch.cat([x, code_emb], dim=1)
|
||||
x = torch.cat([x, truth_emb], dim=1)
|
||||
x = self.integrating_conv(x)
|
||||
for i, lyr in enumerate(self.layers):
|
||||
# Do layer drop where applicable. Do not drop first and last layers.
|
||||
|
|
|
@ -95,7 +95,7 @@ class ResBlock(TimestepBlock):
|
|||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
class DiffusionTts(nn.Module):
|
||||
class DiffusionWaveformGen(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
||||
|
@ -465,7 +465,7 @@ class DiffusionTts(nn.Module):
|
|||
|
||||
@register_model
|
||||
def register_unet_diffusion_waveform_gen(opt_net, opt):
|
||||
return DiffusionTts(**opt_net['kwargs'])
|
||||
return DiffusionWaveformGen(**opt_net['kwargs'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -473,17 +473,17 @@ if __name__ == '__main__':
|
|||
aligned_latent = torch.randn(2,388,1024)
|
||||
aligned_sequence = torch.randn(2,120,220)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = DiffusionTts(128,
|
||||
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
||||
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
||||
token_conditioning_resolutions=[1,4,16,64],
|
||||
attention_resolutions=[],
|
||||
num_heads=8,
|
||||
kernel_size=3,
|
||||
scale_factor=2,
|
||||
time_embed_dim_multiplier=4,
|
||||
super_sampling=False,
|
||||
efficient_convs=False)
|
||||
model = DiffusionWaveformGen(128,
|
||||
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
||||
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
||||
token_conditioning_resolutions=[1,4,16,64],
|
||||
attention_resolutions=[],
|
||||
num_heads=8,
|
||||
kernel_size=3,
|
||||
scale_factor=2,
|
||||
time_embed_dim_multiplier=4,
|
||||
super_sampling=False,
|
||||
efficient_convs=False)
|
||||
# Test with latent aligned conditioning
|
||||
o = model(clip, ts, aligned_latent)
|
||||
# Test with sequence aligned conditioning
|
||||
|
|
|
@ -17,7 +17,7 @@ from models.clip.mel_text_clip import MelTextCLIP
|
|||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
||||
from trainer.injectors.audio_injectors import denormalize_tacotron_mel
|
||||
from trainer.injectors.audio_injectors import denormalize_mel
|
||||
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
||||
|
||||
|
||||
|
@ -161,7 +161,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
model_kwargs={'aligned_conditioning': mel_codes,
|
||||
'conditioning_input': univnet_mel})
|
||||
# denormalize mel
|
||||
gen_mel = denormalize_tacotron_mel(gen_mel)
|
||||
gen_mel = denormalize_mel(gen_mel)
|
||||
|
||||
gen_wav = self.local_modules['vocoder'].inference(gen_mel)
|
||||
real_dec = self.local_modules['vocoder'].inference(univnet_mel)
|
||||
|
|
|
@ -16,13 +16,15 @@ import trainer.eval.evaluator as evaluator
|
|||
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
|
||||
from models.clip.mel_text_clip import MelTextCLIP
|
||||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
||||
from trainer.injectors.audio_injectors import denormalize_tacotron_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d
|
||||
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
|
||||
normalize_mel
|
||||
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
||||
|
||||
|
||||
|
@ -50,15 +52,28 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
conditioning_free=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k)
|
||||
self.dev = self.env['device']
|
||||
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
||||
self.local_modules = {}
|
||||
if mode == 'standard':
|
||||
self.diffusion_fn = self.perform_diffusion_standard
|
||||
|
||||
self.spec_decoder = DiffusionWaveformGen(model_channels=256, in_channels=16, in_mel_channels=256, out_channels=32,
|
||||
channel_mult=[1,2,3,4], num_res_blocks=[3,3,3,3], token_conditioning_resolutions=[1,4],
|
||||
num_heads=8,
|
||||
dropout=0, kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, unconditioned_percentage=0)
|
||||
self.spec_decoder.load_state_dict(torch.load('../experiments/music_waveform_gen.pth', map_location=torch.device('cpu')))
|
||||
self.local_modules = {'spec_decoder': self.spec_decoder}
|
||||
|
||||
if mode == 'spec_decode':
|
||||
self.diffusion_fn = self.perform_diffusion_spec_decode
|
||||
elif 'gap_fill_' in mode:
|
||||
self.diffusion_fn = self.perform_diffusion_gap_fill
|
||||
if '_freq' in mode:
|
||||
self.gap_gen_fn = self.gen_freq_gap
|
||||
else:
|
||||
self.gap_gen_fn = self.gen_time_gap
|
||||
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})
|
||||
|
||||
def load_data(self, path):
|
||||
return list(glob(f'{path}/*.wav'))
|
||||
|
||||
def perform_diffusion_standard(self, audio, sample_rate=22050):
|
||||
def perform_diffusion_spec_decode(self, audio, sample_rate=22050):
|
||||
if sample_rate != sample_rate:
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
else:
|
||||
|
@ -69,7 +84,47 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
gen = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
|
||||
model_kwargs={'aligned_conditioning': mel})
|
||||
gen = pixel_shuffle_1d(gen, 16)
|
||||
real_resampled = real_resampled + torch.FloatTensor(real_resampled.shape).uniform_(0.0, 1e-5).to(real_resampled.device)
|
||||
|
||||
return gen, real_resampled, sample_rate
|
||||
|
||||
def gen_freq_gap(self, mel, band_range=(130,150)):
|
||||
gap_start, gap_end = band_range
|
||||
mel[:, gap_start:gap_end] = 0
|
||||
return mel
|
||||
|
||||
def gen_time_gap(self, mel):
|
||||
mel[:, :, 22050*5:22050*6] = 0
|
||||
return mel
|
||||
|
||||
def perform_diffusion_gap_fill(self, audio, sample_rate=22050, band_range=(130,150)):
|
||||
if sample_rate != sample_rate:
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
else:
|
||||
real_resampled = audio
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
# Fetch the MEL and mask out the requested bands.
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
mel = normalize_mel(mel)
|
||||
mel = self.gap_gen_fn(mel)
|
||||
output_shape = (1, mel.shape[1], mel.shape[2])
|
||||
|
||||
# Repair the MEL with the given model.
|
||||
spec = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
|
||||
model_kwargs={'truth': mel})
|
||||
import torchvision
|
||||
torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, 'gen.png')
|
||||
torchvision.utils.save_image((mel.unsqueeze(1) + 1) / 2, 'mel.png')
|
||||
spec = denormalize_mel(spec)
|
||||
|
||||
# Re-convert the resulting MEL back into audio using the spectrogram decoder.
|
||||
output_shape = (1, 16, audio.shape[-1] // 16)
|
||||
self.spec_decoder = self.spec_decoder.to(audio.device)
|
||||
# Cool fact: we can re-use the diffuser for the spectrogram diffuser since it has the same parametrization.
|
||||
gen = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
|
||||
model_kwargs={'aligned_conditioning': spec})
|
||||
gen = pixel_shuffle_1d(gen, 16)
|
||||
|
||||
return gen, real_resampled, sample_rate
|
||||
|
||||
def load_projector(self):
|
||||
|
@ -148,12 +203,12 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen3.yml', 'generator',
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_waveform_gen3_r1\\models\\10000_generator_ema.pth').cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50,
|
||||
load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\5000_generator.pth').cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'standard'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}}
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
|
@ -282,6 +282,20 @@ class ConditioningLatentDistributionDivergenceInjector(Injector):
|
|||
return {self.output: mean_loss, self.var_loss_key: var_loss}
|
||||
|
||||
|
||||
class RandomScaleInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.min_samples = opt['min_samples']
|
||||
|
||||
def forward(self, state):
|
||||
inp = state[self.input]
|
||||
if self.min_samples < inp.shape[-1]:
|
||||
samples = random.randint(self.min_samples, inp.shape[-1])
|
||||
start = random.randint(0, inp.shape[-1]-samples)
|
||||
inp = inp[:, :, start:start+samples]
|
||||
return {self.output: inp}
|
||||
|
||||
|
||||
def pixel_shuffle_1d(x, upscale_factor):
|
||||
batch_size, channels, steps = x.size()
|
||||
channels //= upscale_factor
|
||||
|
|
Loading…
Reference in New Issue
Block a user