reverse univnet classifier

This commit is contained in:
James Betker 2022-04-20 21:37:55 -06:00
parent 9df85c902e
commit f02b01bd9d
4 changed files with 39 additions and 12 deletions

View File

@ -135,11 +135,12 @@ class AudioMiniEncoder(nn.Module):
class AudioMiniEncoderWithClassifierHead(nn.Module):
def __init__(self, classes, **kwargs):
def __init__(self, classes, distribute_zero_label=True, **kwargs):
super().__init__()
self.enc = AudioMiniEncoder(**kwargs)
self.head = nn.Linear(self.enc.dim, classes)
self.num_classes = classes
self.distribute_zero_label = distribute_zero_label
def forward(self, x, labels=None):
h = self.enc(x)
@ -147,13 +148,16 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
if labels is None:
return logits
else:
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
zeros_indices = (labels == 0).unsqueeze(-1)
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
zero_extra_mass[:, 0] = -.2
zero_extra_mass = zero_extra_mass * zeros_indices
oh_labels = oh_labels + zero_extra_mass
if self.distribute_zero_label:
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
zeros_indices = (labels == 0).unsqueeze(-1)
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
zero_extra_mass[:, 0] = -.2
zero_extra_mass = zero_extra_mass * zeros_indices
oh_labels = oh_labels + zero_extra_mass
else:
oh_labels = labels
loss = nn.functional.cross_entropy(logits, oh_labels)
return loss

View File

@ -327,7 +327,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_reverse_classifier.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)

View File

@ -52,7 +52,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
self.local_modules = {}
if mode == 'standard':
self.diffusion_fn = self.perform_diffusion_standard
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 120, 'mel_fmax': 11000, 'in': 'in', 'out': 'out'}, {})
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 128, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})
def load_data(self, path):
return list(glob(f'{path}/*.wav'))
@ -141,8 +141,8 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_waveform_gen\\models\\36000_generator_ema.pth').cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50,
load_path='X:\\dlas\\experiments\\train_music_waveform_gen_r3\\models\\11200_generator_ema.pth').cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500,
'conditioning_free': False, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'standard'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}}

View File

@ -190,3 +190,26 @@ class GptVoiceLatentInjector(Injector):
clip_inputs=False)
assert latents.shape[1] == codes.shape[1]
return {self.output: latents}
class ReverseUnivnetInjector(Injector):
"""
This injector specifically builds inputs and labels for a univnet detector.g
"""
def __init__(self, opt, env):
super().__init__(opt, env)
from scripts.audio.gen.speech_synthesis_utils import load_univnet_vocoder
self.univnet = load_univnet_vocoder().cuda()
self.mel_input_key = opt['mel']
self.label_output_key = opt['labels']
def forward(self, state):
with torch.no_grad():
original_audio = state[self.input]
mel = state[self.mel_input_key]
decoded_mel = self.univnet.inference(mel)[:,:,:original_audio.shape[-1]]
labels = (torch.rand(mel.shape[0], 1, 1, device=mel.device) > .5)
output = torch.where(labels, original_audio, decoded_mel)
return {self.output: output, self.label_output_key: labels[:,0,0].long()}