some dumb stuff

This commit is contained in:
James Betker 2022-04-07 11:32:34 -06:00
parent e6387c7613
commit 6fc4f49e86
3 changed files with 8 additions and 6 deletions

View File

@ -43,7 +43,8 @@ class Wav2VecWrapper(nn.Module):
"""
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1, ramp_dropout_max=.5):
remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1,
ramp_dropout_max=.5, layer_drop_pct=.1):
super().__init__()
self.provide_attention_mask = provide_attention_mask
@ -55,6 +56,7 @@ class Wav2VecWrapper(nn.Module):
self.w2v.config.pad_token_id = 0
self.w2v.config.ctc_loss_reduction = 'sum'
self.w2v.config.apply_spec_augment = spec_augment
self.w2v.config.layerdrop = layer_drop_pct
self.remove_feature_extractor = remove_feature_extractor
# This is a provision for distilling by ramping up dropout.

View File

@ -183,7 +183,6 @@ class AudioDiffusionFid(evaluator.Evaluator):
gen_loss, real_loss = results
return gen_loss - real_loss
def compute_frechet_distance(self, proj1, proj2):
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
proj1 = proj1.cpu().numpy()
@ -264,13 +263,14 @@ if __name__ == '__main__':
if __name__ == '__main__':
from utils.util import load_model_from_config
# 34k; no conditioning_free: {'frechet_distance': tensor(1.4559, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(151.9112, device='cuda:0')}
# 34k; conditioning_free: {'frechet_distance': tensor(1.4059, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(118.3377, device='cuda:0')}
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts_mel_flat.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_diffusion_tts_mel_flat0\\models\\12000_generator_ema.pth').cuda()
load_path='X:\\dlas\\experiments\\train_diffusion_tts_mel_flat0\\models\\34000_generator_ema.pth').cuda()
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
'conditioning_free': False, 'conditioning_free_k': 1,
'conditioning_free': True, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 558, 'device': 'cuda', 'opt': {}}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 560, 'device': 'cuda', 'opt': {}}
eval = AudioDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval())