forked from mrq/DL-Art-School
some dumb stuff
This commit is contained in:
parent
e6387c7613
commit
6fc4f49e86
|
@ -43,7 +43,8 @@ class Wav2VecWrapper(nn.Module):
|
|||
"""
|
||||
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
|
||||
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
|
||||
remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1, ramp_dropout_max=.5):
|
||||
remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1,
|
||||
ramp_dropout_max=.5, layer_drop_pct=.1):
|
||||
super().__init__()
|
||||
self.provide_attention_mask = provide_attention_mask
|
||||
|
||||
|
@ -55,6 +56,7 @@ class Wav2VecWrapper(nn.Module):
|
|||
self.w2v.config.pad_token_id = 0
|
||||
self.w2v.config.ctc_loss_reduction = 'sum'
|
||||
self.w2v.config.apply_spec_augment = spec_augment
|
||||
self.w2v.config.layerdrop = layer_drop_pct
|
||||
self.remove_feature_extractor = remove_feature_extractor
|
||||
|
||||
# This is a provision for distilling by ramping up dropout.
|
||||
|
|
|
@ -183,7 +183,6 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
gen_loss, real_loss = results
|
||||
return gen_loss - real_loss
|
||||
|
||||
|
||||
def compute_frechet_distance(self, proj1, proj2):
|
||||
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
|
||||
proj1 = proj1.cpu().numpy()
|
||||
|
@ -264,13 +263,14 @@ if __name__ == '__main__':
|
|||
|
||||
if __name__ == '__main__':
|
||||
from utils.util import load_model_from_config
|
||||
|
||||
# 34k; no conditioning_free: {'frechet_distance': tensor(1.4559, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(151.9112, device='cuda:0')}
|
||||
# 34k; conditioning_free: {'frechet_distance': tensor(1.4059, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(118.3377, device='cuda:0')}
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts_mel_flat.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_diffusion_tts_mel_flat0\\models\\12000_generator_ema.pth').cuda()
|
||||
load_path='X:\\dlas\\experiments\\train_diffusion_tts_mel_flat0\\models\\34000_generator_ema.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 558, 'device': 'cuda', 'opt': {}}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 560, 'device': 'cuda', 'opt': {}}
|
||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
Loading…
Reference in New Issue
Block a user