From 32529720579e74ed7092ec9aee232a8d5e300cb9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 12 Feb 2022 19:59:54 -0700 Subject: [PATCH] ctc_code_gen mods --- codes/models/gpt_voice/ctc_code_generator.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/codes/models/gpt_voice/ctc_code_generator.py b/codes/models/gpt_voice/ctc_code_generator.py index d45e993f..95136c67 100644 --- a/codes/models/gpt_voice/ctc_code_generator.py +++ b/codes/models/gpt_voice/ctc_code_generator.py @@ -45,11 +45,11 @@ class CheckpointedTransformerWrapper(nn.Module): class CtcCodeGenerator(nn.Module): - def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, mask_prob=.1): + def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, mask_probability=.1): super().__init__() self.max_pad = max_pad self.max_repeat = max_repeat - self.mask_probability = mask_prob + self.mask_probability = mask_probability self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True) self.initial_embedding = nn.Embedding(ctc_codes, model_dim) self.combiner = nn.Linear(model_dim*2, model_dim) @@ -131,8 +131,10 @@ class CtcCodeGenerator(nn.Module): cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1) h = torch.cat([cond, self.initial_embedding(codes)], dim=-1) h = self.combiner(h) - logits = self.transformer(h) - generate = torch.argmax(logits, dim=-1) + with torch.autocast(codes.device.type): + logits = self.transformer(h) + ctc_pred = self.ctc_head(logits) + generate = torch.argmax(ctc_pred, dim=-1) # De-compress the codes from the generated output pads = generate % self.max_pad