more work
This commit is contained in:
parent
f541610256
commit
6c8032b4be
|
@ -117,8 +117,8 @@ class MusicGenerator(nn.Module):
|
||||||
layer_drop=.1,
|
layer_drop=.1,
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
# Masking parameters.
|
# Masking parameters.
|
||||||
time_mask_percent_max=.4,
|
frequency_mask_percent_max=0,
|
||||||
frequency_mask_percent_max=.4,
|
time_mask_percent_max=0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ class MusicGenerator(nn.Module):
|
||||||
def do_masking(self, truth):
|
def do_masking(self, truth):
|
||||||
b, c, s = truth.shape
|
b, c, s = truth.shape
|
||||||
mask = torch.ones_like(truth)
|
mask = torch.ones_like(truth)
|
||||||
if random.random() < .5:
|
if self.frequency_mask_percent_mask > 0:
|
||||||
# Frequency mask
|
# Frequency mask
|
||||||
cs = random.randint(0, c-10)
|
cs = random.randint(0, c-10)
|
||||||
ce = min(c-1, cs+random.randint(1, int(self.frequency_mask_percent_mask*c)))
|
ce = min(c-1, cs+random.randint(1, int(self.frequency_mask_percent_mask*c)))
|
||||||
|
|
|
@ -219,6 +219,12 @@ class ContrastiveAudio(nn.Module):
|
||||||
def update_for_step(self, step, __):
|
def update_for_step(self, step, __):
|
||||||
self.to_latent2.weight.data = self.to_latent2.weight.data * .99 + self.to_latent.weight.data * .01
|
self.to_latent2.weight.data = self.to_latent2.weight.data * .99 + self.to_latent.weight.data * .01
|
||||||
|
|
||||||
|
def project(self, mel):
|
||||||
|
h1 = self.emb(mel).permute(0, 2, 1)
|
||||||
|
h1 = self.transformer(h1)
|
||||||
|
h1 = self.to_latent(h1)
|
||||||
|
return h1
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
mel_input1,
|
mel_input1,
|
||||||
|
|
|
@ -17,6 +17,7 @@ from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
|
||||||
from data.audio.unsupervised_audio_dataset import load_audio
|
from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||||
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
|
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
|
||||||
|
from models.clip.contrastive_audio import ContrastiveAudio
|
||||||
from models.clip.mel_text_clip import MelTextCLIP
|
from models.clip.mel_text_clip import MelTextCLIP
|
||||||
from models.audio.tts.tacotron2 import text_to_sequence
|
from models.audio.tts.tacotron2 import text_to_sequence
|
||||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||||
|
@ -58,7 +59,9 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
dropout=0, kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, unconditioned_percentage=0)
|
dropout=0, kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, unconditioned_percentage=0)
|
||||||
self.spec_decoder.load_state_dict(torch.load('../experiments/music_waveform_gen.pth', map_location=torch.device('cpu')))
|
self.spec_decoder.load_state_dict(torch.load('../experiments/music_waveform_gen.pth', map_location=torch.device('cpu')))
|
||||||
self.local_modules = {'spec_decoder': self.spec_decoder}
|
self.projector = ContrastiveAudio(model_dim=512, transformer_heads=8, dropout=0, encoder_depth=8, mel_channels=256)
|
||||||
|
#self.projector.load_state_dict(torch.load('../experiments/music_eval_projector.pth', map_location=torch.device('cpu')))
|
||||||
|
self.local_modules = {'spec_decoder': self.spec_decoder, 'projector': self.projector}
|
||||||
|
|
||||||
if mode == 'spec_decode':
|
if mode == 'spec_decode':
|
||||||
self.diffusion_fn = self.perform_diffusion_spec_decode
|
self.diffusion_fn = self.perform_diffusion_spec_decode
|
||||||
|
@ -127,20 +130,11 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
return gen, real_resampled, sample_rate
|
return gen, real_resampled, sample_rate
|
||||||
|
|
||||||
def load_projector(self):
|
def project(self, sample, sample_rate):
|
||||||
# TODO: implement for music.
|
|
||||||
model = MelTextCLIP(dim_text=512, dim_latent=512, dim_speech=512, num_text_tokens=148, text_enc_depth=8,
|
|
||||||
text_seq_len=400, text_heads=8, speech_enc_depth=10, speech_heads=8, speech_seq_len=1000,
|
|
||||||
text_mask_percentage=.15, voice_mask_percentage=.15)
|
|
||||||
weights = torch.load('../experiments/clip_text_to_voice_for_speech_fid.pth')
|
|
||||||
model.load_state_dict(weights)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def project(self, projector, sample, sample_rate):
|
|
||||||
# TODO: implement for music.
|
|
||||||
sample = torchaudio.functional.resample(sample, sample_rate, 22050)
|
sample = torchaudio.functional.resample(sample, sample_rate, 22050)
|
||||||
mel = wav_to_mel(sample)
|
mel = self.spec_fn({'in': sample})['out']
|
||||||
return projector.get_speech_projection(mel).squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim]
|
projection = self.projector.project(mel)
|
||||||
|
return projection.squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim]
|
||||||
|
|
||||||
def compute_frechet_distance(self, proj1, proj2):
|
def compute_frechet_distance(self, proj1, proj2):
|
||||||
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
|
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
|
||||||
|
@ -156,41 +150,35 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"]))
|
save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"]))
|
||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
|
||||||
#projector = self.load_projector().to(self.env['device'])
|
self.projector = self.projector.to(self.dev)
|
||||||
#projector.eval()
|
self.projector.eval()
|
||||||
|
|
||||||
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
||||||
rng_state = torch.get_rng_state()
|
rng_state = torch.get_rng_state()
|
||||||
torch.manual_seed(5)
|
torch.manual_seed(5)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
frechet_distance = 0
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gen_projections = []
|
gen_projections = []
|
||||||
real_projections = []
|
real_projections = []
|
||||||
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
||||||
path = self.data[i + self.env['rank']]
|
path = self.data[i + self.env['rank']]
|
||||||
audio = load_audio(path, 22050).to(self.dev)
|
audio = load_audio(path, 22050).to(self.dev)
|
||||||
mel = self.spec_fn({'in': audio})['out']
|
audio = audio[:, :22050*5]
|
||||||
mel_norm = (mel + mel.min().abs())
|
|
||||||
mel_norm = mel_norm / mel_norm.max(dim=-1, keepdim=True).values
|
|
||||||
torchvision.utils.save_image(mel_norm.unsqueeze(1), 'mel.png')
|
|
||||||
sample, ref, sample_rate = self.diffusion_fn(audio)
|
sample, ref, sample_rate = self.diffusion_fn(audio)
|
||||||
|
|
||||||
#gen_projections.append(self.project(projector, sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
|
gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
|
||||||
#real_projections.append(self.project(projector, ref, sample_rate).cpu())
|
real_projections.append(self.project(ref, sample_rate).cpu())
|
||||||
|
|
||||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
|
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
|
||||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate)
|
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate)
|
||||||
#gen_projections = torch.stack(gen_projections, dim=0)
|
gen_projections = torch.stack(gen_projections, dim=0)
|
||||||
#real_projections = torch.stack(real_projections, dim=0)
|
real_projections = torch.stack(real_projections, dim=0)
|
||||||
#frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
|
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
|
||||||
|
|
||||||
#if distributed.is_initialized() and distributed.get_world_size() > 1:
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||||
# distributed.all_reduce(frechet_distance)
|
distributed.all_reduce(frechet_distance)
|
||||||
# frechet_distance = frechet_distance / distributed.get_world_size()
|
frechet_distance = frechet_distance / distributed.get_world_size()\
|
||||||
# distributed.all_reduce(intelligibility_loss)
|
|
||||||
# intelligibility_loss = intelligibility_loss / distributed.get_world_size()
|
|
||||||
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
torch.set_rng_state(rng_state)
|
torch.set_rng_state(rng_state)
|
||||||
|
@ -205,8 +193,8 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator',
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator',
|
||||||
also_load_savepoint=False,
|
also_load_savepoint=False,
|
||||||
load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\5000_generator.pth').cuda()
|
load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\14000_generator.pth').cuda()
|
||||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
|
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500,
|
||||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'}
|
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'}
|
||||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}}
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user