This commit is contained in:
James Betker 2022-05-27 11:40:47 -06:00
parent 490d39b967
commit 34ee1d0bc3
3 changed files with 21 additions and 96 deletions

View File

@ -57,14 +57,6 @@ class MusicDiffusionFid(evaluator.Evaluator):
if mode == 'spec_decode':
self.diffusion_fn = self.perform_diffusion_spec_decode
elif 'gap_fill_' in mode:
self.diffusion_fn = self.perform_diffusion_gap_fill
if '_freq' in mode:
self.gap_gen_fn = self.gen_freq_gap
else:
self.gap_gen_fn = self.gen_time_gap
elif 'rerender' in mode:
self.diffusion_fn = self.perform_rerender
elif 'from_codes' == mode:
self.diffusion_fn = self.perform_diffusion_from_codes
self.local_modules['codegen'] = get_music_codegen()
@ -88,75 +80,6 @@ class MusicDiffusionFid(evaluator.Evaluator):
return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate
def gen_freq_gap(self, mel, band_range=(60,100)):
gap_start, gap_end = band_range
mask = torch.ones_like(mel)
mask[:, gap_start:gap_end] = 0
return mel * mask, mask
def gen_time_gap(self, mel):
mask = torch.ones_like(mel)
mask[:, :, 86*4:86*6] = 0
return mel * mask, mask
def perform_diffusion_gap_fill(self, audio, sample_rate=22050):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else:
real_resampled = audio
audio = audio.unsqueeze(0)
# Fetch the MEL and mask out the requested bands.
mel = self.spec_fn({'in': audio})['out']
mel = normalize_mel(mel)
mel, mask = self.gap_gen_fn(mel)
# Repair the MEL with the given model.
spec = self.diffuser.p_sample_loop_with_guidance(self.model, mel, mask, model_kwargs={'truth': mel})
spec = denormalize_mel(spec)
# Re-convert the resulting MEL back into audio using the spectrogram decoder.
output_shape = (1, 16, audio.shape[-1] // 16)
self.spec_decoder = self.spec_decoder.to(audio.device)
# Cool fact: we can re-use the diffuser for the spectrogram diffuser since it has the same parametrization.
gen = self.diffuser.p_sample_loop(self.spec_decoder, output_shape,
model_kwargs={'aligned_conditioning': spec})
gen = pixel_shuffle_1d(gen, 16)
return gen, real_resampled, normalize_mel(spec), mel, sample_rate
def perform_rerender(self, audio, sample_rate=22050):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else:
real_resampled = audio
audio = audio.unsqueeze(0)
# Fetch the MEL and mask out the requested bands.
mel = self.spec_fn({'in': audio})['out']
mel = normalize_mel(mel)
segments = [(0,10),(10,25),(25,45),(45,60),(60,80),(80,100),(100,130),(130,170),(170,210),(210,256)]
shuffle(segments)
spec = mel
for i, segment in enumerate(segments):
mel, mask = self.gen_freq_gap(mel, band_range=segment)
# Repair the MEL with the given model.
spec = self.diffuser.p_sample_loop_with_guidance(self.model, spec, mask, model_kwargs={'truth': spec})
torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, f"{i}_rerender.png")
spec = denormalize_mel(spec)
# Re-convert the resulting MEL back into audio using the spectrogram decoder.
output_shape = (1, 16, audio.shape[-1] // 16)
self.spec_decoder = self.spec_decoder.to(audio.device)
# Cool fact: we can re-use the diffuser for the spectrogram diffuser since it has the same parametrization.
gen = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
model_kwargs={'aligned_conditioning': spec})
gen = pixel_shuffle_1d(gen, 16)
return gen, real_resampled, normalize_mel(spec), mel, sample_rate
def perform_diffusion_from_codes(self, audio, sample_rate=22050):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
@ -164,15 +87,12 @@ class MusicDiffusionFid(evaluator.Evaluator):
real_resampled = audio
audio = audio.unsqueeze(0)
# Fetch the MEL and mask out the requested bands.
mel = self.spec_fn({'in': audio})['out']
codegen = self.local_modules['codegen'].to(mel.device)
codes = codegen.get_codes(mel)
mel_norm = normalize_mel(mel)
precomputed = self.model.timestep_independent(aligned_conditioning=codes, conditioning_input=mel[:,:,:112],
expected_seq_len=mel_norm.shape[-1], return_code_pred=False)
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, noise=torch.zeros_like(mel_norm),
model_kwargs={'precomputed_aligned_embeddings': precomputed})
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape,
model_kwargs={'codes': codes, 'conditioning_input': mel_norm[:,:,:140]})
gen_mel_denorm = denormalize_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16)

View File

@ -6,6 +6,7 @@ import torchaudio
from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat
from trainer.inject import Injector
from utils.music_utils import get_music_codegen
from utils.util import opt_get, load_model_from_config, pad_or_truncate
TACOTRON_MEL_MAX = 2.3143386840820312
@ -326,15 +327,7 @@ class AudioUnshuffleInjector(Injector):
class Mel2vecCodesInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
for_what = opt_get(opt, ['for'], 'music')
from models.audio.mel2vec import ContrastiveTrainingWrapper
self.m2v = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0,
mask_time_prob=0,
mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4,
disable_custom_linear_init=True, do_reconstruction_loss=True)
self.m2v.load_state_dict(torch.load(f"../experiments/m2v_{for_what}.pth", map_location=torch.device('cpu')))
self.m2v = self.m2v.eval()
self.m2v = get_music_codegen()
del self.m2v.m2v.encoder # This is a big memory sink which will not get used.
self.needs_move = True
@ -366,4 +359,14 @@ class ClvpTextInjector(Injector):
if self.needs_move:
self.clvp = self.clvp.to(codes.device)
latents = self.clvp.embed_text(codes)
return {self.output: latents}
return {self.output: latents}
class NormalizeMelInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
def forward(self, state):
mel = state[self.input]
with torch.no_grad():
return {self.output: normalize_mel(mel)}

View File

@ -13,8 +13,10 @@ def get_mel2wav_model():
return model
def get_music_codegen():
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0,
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True)
model.load_state_dict(torch.load("../experiments/m2v_music.pth", map_location=torch.device('cpu')))
model.eval()
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0,
mask_time_prob=0,
mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4,
disable_custom_linear_init=True, do_reconstruction_loss=True)
model.load_state_dict(torch.load(f"../experiments/m2v_music.pth", map_location=torch.device('cpu')))
model = model.eval()
return model