forked from mrq/DL-Art-School
some dumb stuff
This commit is contained in:
parent
e6387c7613
commit
6fc4f49e86
|
@ -43,7 +43,8 @@ class Wav2VecWrapper(nn.Module):
|
||||||
"""
|
"""
|
||||||
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
|
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
|
||||||
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
|
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
|
||||||
remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1, ramp_dropout_max=.5):
|
remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1,
|
||||||
|
ramp_dropout_max=.5, layer_drop_pct=.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.provide_attention_mask = provide_attention_mask
|
self.provide_attention_mask = provide_attention_mask
|
||||||
|
|
||||||
|
@ -55,6 +56,7 @@ class Wav2VecWrapper(nn.Module):
|
||||||
self.w2v.config.pad_token_id = 0
|
self.w2v.config.pad_token_id = 0
|
||||||
self.w2v.config.ctc_loss_reduction = 'sum'
|
self.w2v.config.ctc_loss_reduction = 'sum'
|
||||||
self.w2v.config.apply_spec_augment = spec_augment
|
self.w2v.config.apply_spec_augment = spec_augment
|
||||||
|
self.w2v.config.layerdrop = layer_drop_pct
|
||||||
self.remove_feature_extractor = remove_feature_extractor
|
self.remove_feature_extractor = remove_feature_extractor
|
||||||
|
|
||||||
# This is a provision for distilling by ramping up dropout.
|
# This is a provision for distilling by ramping up dropout.
|
||||||
|
|
|
@ -183,7 +183,6 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
gen_loss, real_loss = results
|
gen_loss, real_loss = results
|
||||||
return gen_loss - real_loss
|
return gen_loss - real_loss
|
||||||
|
|
||||||
|
|
||||||
def compute_frechet_distance(self, proj1, proj2):
|
def compute_frechet_distance(self, proj1, proj2):
|
||||||
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
|
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
|
||||||
proj1 = proj1.cpu().numpy()
|
proj1 = proj1.cpu().numpy()
|
||||||
|
@ -264,13 +263,14 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from utils.util import load_model_from_config
|
from utils.util import load_model_from_config
|
||||||
|
# 34k; no conditioning_free: {'frechet_distance': tensor(1.4559, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(151.9112, device='cuda:0')}
|
||||||
|
# 34k; conditioning_free: {'frechet_distance': tensor(1.4059, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(118.3377, device='cuda:0')}
|
||||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts_mel_flat.yml', 'generator',
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts_mel_flat.yml', 'generator',
|
||||||
also_load_savepoint=False,
|
also_load_savepoint=False,
|
||||||
load_path='X:\\dlas\\experiments\\train_diffusion_tts_mel_flat0\\models\\12000_generator_ema.pth').cuda()
|
load_path='X:\\dlas\\experiments\\train_diffusion_tts_mel_flat0\\models\\34000_generator_ema.pth').cuda()
|
||||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
|
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
|
||||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 558, 'device': 'cuda', 'opt': {}}
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 560, 'device': 'cuda', 'opt': {}}
|
||||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||||
print(eval.perform_eval())
|
print(eval.perform_eval())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user