diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index 2ecc26da..9acad597 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -339,13 +339,17 @@ class GptAsrHf2(nn.Module): loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean(), text_logits - def inference(self, mel_inputs, do_sample=False, temperature=1.0, num_beams=8): + def inference(self, mel_inputs, wav_lengths, do_sample=False, temperature=1.0, num_beams=8): """ Performs inference by transcribing mel_inputs into text. Returns the text tokens. """ if not hasattr(self, 'inference_model'): self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head) + # TODO: get rid of this.. + max_mel_len = wav_lengths.max() // self.mel_compression + mel_inputs = mel_inputs[:, :, :max_mel_len] + mel_emb = self.mel_encoder(mel_inputs) assert mel_emb.shape[-1] <= self.max_mel_frames mel_emb = mel_emb.permute(0,2,1).contiguous() diff --git a/codes/scripts/audio/gen/use_gpt_tts.py b/codes/scripts/audio/gen/use_gpt_tts.py index 58eaa63b..e871231b 100644 --- a/codes/scripts/audio/gen/use_gpt_tts.py +++ b/codes/scripts/audio/gen/use_gpt_tts.py @@ -78,7 +78,7 @@ if __name__ == '__main__': 'simmons': 'Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', 'news_girl': 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav', 'dan_carlin': 'Y:\\clips\\books1\5_dchha06 Shield of the West\\00476.wav', - 'libri_test': 'Z:\\bigasr_dataset\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav' + 'libri_test': 'Y:\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav' } parser = argparse.ArgumentParser() diff --git a/codes/train.py b/codes/train.py index a3bd9b35..dfe5f5c3 100644 --- a/codes/train.py +++ b/codes/train.py @@ -296,7 +296,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_hf2_lg_distill.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/losses.py b/codes/trainer/losses.py index 980262c0..6ea27b64 100644 --- a/codes/trainer/losses.py +++ b/codes/trainer/losses.py @@ -24,7 +24,7 @@ def create_loss(opt_loss, env): elif 'lightweight_gan_divergence' == type: from models.lightweight_gan import LightweightGanDivergenceLoss return LightweightGanDivergenceLoss(opt_loss, env) - elif type == 'crossentropy': + elif type == 'crossentropy' or type == 'cross_entropy': return CrossEntropy(opt_loss, env) elif type == 'distillation': return Distillation(opt_loss, env)