reverse univnet classifier
This commit is contained in:
parent
9df85c902e
commit
f02b01bd9d
|
@ -135,11 +135,12 @@ class AudioMiniEncoder(nn.Module):
|
|||
|
||||
|
||||
class AudioMiniEncoderWithClassifierHead(nn.Module):
|
||||
def __init__(self, classes, **kwargs):
|
||||
def __init__(self, classes, distribute_zero_label=True, **kwargs):
|
||||
super().__init__()
|
||||
self.enc = AudioMiniEncoder(**kwargs)
|
||||
self.head = nn.Linear(self.enc.dim, classes)
|
||||
self.num_classes = classes
|
||||
self.distribute_zero_label = distribute_zero_label
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
h = self.enc(x)
|
||||
|
@ -147,13 +148,16 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
|
|||
if labels is None:
|
||||
return logits
|
||||
else:
|
||||
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
|
||||
zeros_indices = (labels == 0).unsqueeze(-1)
|
||||
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
|
||||
zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
|
||||
zero_extra_mass[:, 0] = -.2
|
||||
zero_extra_mass = zero_extra_mass * zeros_indices
|
||||
oh_labels = oh_labels + zero_extra_mass
|
||||
if self.distribute_zero_label:
|
||||
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
|
||||
zeros_indices = (labels == 0).unsqueeze(-1)
|
||||
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
|
||||
zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
|
||||
zero_extra_mass[:, 0] = -.2
|
||||
zero_extra_mass = zero_extra_mass * zeros_indices
|
||||
oh_labels = oh_labels + zero_extra_mass
|
||||
else:
|
||||
oh_labels = labels
|
||||
loss = nn.functional.cross_entropy(logits, oh_labels)
|
||||
return loss
|
||||
|
||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_reverse_classifier.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -52,7 +52,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
self.local_modules = {}
|
||||
if mode == 'standard':
|
||||
self.diffusion_fn = self.perform_diffusion_standard
|
||||
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 120, 'mel_fmax': 11000, 'in': 'in', 'out': 'out'}, {})
|
||||
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 128, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})
|
||||
|
||||
def load_data(self, path):
|
||||
return list(glob(f'{path}/*.wav'))
|
||||
|
@ -141,8 +141,8 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
if __name__ == '__main__':
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_waveform_gen\\models\\36000_generator_ema.pth').cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50,
|
||||
load_path='X:\\dlas\\experiments\\train_music_waveform_gen_r3\\models\\11200_generator_ema.pth').cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'standard'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 1, 'device': 'cuda', 'opt': {}}
|
||||
|
|
|
@ -190,3 +190,26 @@ class GptVoiceLatentInjector(Injector):
|
|||
clip_inputs=False)
|
||||
assert latents.shape[1] == codes.shape[1]
|
||||
return {self.output: latents}
|
||||
|
||||
|
||||
class ReverseUnivnetInjector(Injector):
|
||||
"""
|
||||
This injector specifically builds inputs and labels for a univnet detector.g
|
||||
"""
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
from scripts.audio.gen.speech_synthesis_utils import load_univnet_vocoder
|
||||
self.univnet = load_univnet_vocoder().cuda()
|
||||
self.mel_input_key = opt['mel']
|
||||
self.label_output_key = opt['labels']
|
||||
|
||||
def forward(self, state):
|
||||
with torch.no_grad():
|
||||
original_audio = state[self.input]
|
||||
mel = state[self.mel_input_key]
|
||||
decoded_mel = self.univnet.inference(mel)[:,:,:original_audio.shape[-1]]
|
||||
|
||||
labels = (torch.rand(mel.shape[0], 1, 1, device=mel.device) > .5)
|
||||
output = torch.where(labels, original_audio, decoded_mel)
|
||||
|
||||
return {self.output: output, self.label_output_key: labels[:,0,0].long()}
|
Loading…
Reference in New Issue
Block a user