misc
This commit is contained in:
parent
6706591d3d
commit
65ffe38fce
|
@ -339,13 +339,17 @@ class GptAsrHf2(nn.Module):
|
||||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
return loss_text.mean(), text_logits
|
return loss_text.mean(), text_logits
|
||||||
|
|
||||||
def inference(self, mel_inputs, do_sample=False, temperature=1.0, num_beams=8):
|
def inference(self, mel_inputs, wav_lengths, do_sample=False, temperature=1.0, num_beams=8):
|
||||||
"""
|
"""
|
||||||
Performs inference by transcribing mel_inputs into text. Returns the text tokens.
|
Performs inference by transcribing mel_inputs into text. Returns the text tokens.
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
|
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
|
||||||
|
|
||||||
|
# TODO: get rid of this..
|
||||||
|
max_mel_len = wav_lengths.max() // self.mel_compression
|
||||||
|
mel_inputs = mel_inputs[:, :, :max_mel_len]
|
||||||
|
|
||||||
mel_emb = self.mel_encoder(mel_inputs)
|
mel_emb = self.mel_encoder(mel_inputs)
|
||||||
assert mel_emb.shape[-1] <= self.max_mel_frames
|
assert mel_emb.shape[-1] <= self.max_mel_frames
|
||||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||||
|
|
|
@ -78,7 +78,7 @@ if __name__ == '__main__':
|
||||||
'simmons': 'Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav',
|
'simmons': 'Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav',
|
||||||
'news_girl': 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav',
|
'news_girl': 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav',
|
||||||
'dan_carlin': 'Y:\\clips\\books1\5_dchha06 Shield of the West\\00476.wav',
|
'dan_carlin': 'Y:\\clips\\books1\5_dchha06 Shield of the West\\00476.wav',
|
||||||
'libri_test': 'Z:\\bigasr_dataset\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav'
|
'libri_test': 'Y:\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav'
|
||||||
}
|
}
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
|
@ -296,7 +296,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_hf2_lg_distill.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -24,7 +24,7 @@ def create_loss(opt_loss, env):
|
||||||
elif 'lightweight_gan_divergence' == type:
|
elif 'lightweight_gan_divergence' == type:
|
||||||
from models.lightweight_gan import LightweightGanDivergenceLoss
|
from models.lightweight_gan import LightweightGanDivergenceLoss
|
||||||
return LightweightGanDivergenceLoss(opt_loss, env)
|
return LightweightGanDivergenceLoss(opt_loss, env)
|
||||||
elif type == 'crossentropy':
|
elif type == 'crossentropy' or type == 'cross_entropy':
|
||||||
return CrossEntropy(opt_loss, env)
|
return CrossEntropy(opt_loss, env)
|
||||||
elif type == 'distillation':
|
elif type == 'distillation':
|
||||||
return Distillation(opt_loss, env)
|
return Distillation(opt_loss, env)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user