diff --git a/codes/models/audio/tts/mini_encoder.py b/codes/models/audio/tts/mini_encoder.py index d8e3f90a..23283173 100644 --- a/codes/models/audio/tts/mini_encoder.py +++ b/codes/models/audio/tts/mini_encoder.py @@ -135,11 +135,12 @@ class AudioMiniEncoder(nn.Module): class AudioMiniEncoderWithClassifierHead(nn.Module): - def __init__(self, classes, **kwargs): + def __init__(self, classes, distribute_zero_label=True, **kwargs): super().__init__() self.enc = AudioMiniEncoder(**kwargs) self.head = nn.Linear(self.enc.dim, classes) self.num_classes = classes + self.distribute_zero_label = distribute_zero_label def forward(self, x, labels=None): h = self.enc(x) @@ -147,13 +148,16 @@ class AudioMiniEncoderWithClassifierHead(nn.Module): if labels is None: return logits else: - oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes) - zeros_indices = (labels == 0).unsqueeze(-1) - # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise. - zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1)) - zero_extra_mass[:, 0] = -.2 - zero_extra_mass = zero_extra_mass * zeros_indices - oh_labels = oh_labels + zero_extra_mass + if self.distribute_zero_label: + oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes) + zeros_indices = (labels == 0).unsqueeze(-1) + # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise. + zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1)) + zero_extra_mass[:, 0] = -.2 + zero_extra_mass = zero_extra_mass * zeros_indices + oh_labels = oh_labels + zero_extra_mass + else: + oh_labels = labels loss = nn.functional.cross_entropy(logits, oh_labels) return loss diff --git a/codes/train.py b/codes/train.py index 70578721..e2601790 100644 --- a/codes/train.py +++ b/codes/train.py @@ -327,7 +327,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_reverse_classifier.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 56b77a33..dfa6a035 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -52,7 +52,7 @@ class MusicDiffusionFid(evaluator.Evaluator): self.local_modules = {} if mode == 'standard': self.diffusion_fn = self.perform_diffusion_standard - self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 120, 'mel_fmax': 11000, 'in': 'in', 'out': 'out'}, {}) + self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 128, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {}) def load_data(self, path): return list(glob(f'{path}/*.wav')) @@ -141,8 +141,8 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_waveform_gen\\models\\36000_generator_ema.pth').cuda() - opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50, + load_path='X:\\dlas\\experiments\\train_music_waveform_gen_r3\\models\\11200_generator_ema.pth').cuda() + opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500, 'conditioning_free': False, 'conditioning_free_k': 1, 'diffusion_schedule': 'linear', 'diffusion_type': 'standard'} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}} diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index b7d563aa..6393f52e 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -190,3 +190,26 @@ class GptVoiceLatentInjector(Injector): clip_inputs=False) assert latents.shape[1] == codes.shape[1] return {self.output: latents} + + +class ReverseUnivnetInjector(Injector): + """ + This injector specifically builds inputs and labels for a univnet detector.g + """ + def __init__(self, opt, env): + super().__init__(opt, env) + from scripts.audio.gen.speech_synthesis_utils import load_univnet_vocoder + self.univnet = load_univnet_vocoder().cuda() + self.mel_input_key = opt['mel'] + self.label_output_key = opt['labels'] + + def forward(self, state): + with torch.no_grad(): + original_audio = state[self.input] + mel = state[self.mel_input_key] + decoded_mel = self.univnet.inference(mel)[:,:,:original_audio.shape[-1]] + + labels = (torch.rand(mel.shape[0], 1, 1, device=mel.device) > .5) + output = torch.where(labels, original_audio, decoded_mel) + + return {self.output: output, self.label_output_key: labels[:,0,0].long()} \ No newline at end of file