From 4c6bdfc9e2862fb677f41c07c3031aef01d51a13 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 2 Apr 2022 21:55:32 -0600 Subject: [PATCH] get rid of relative position embeddings, which do not work with DDP & checkpointing --- codes/models/audio/tts/autoregressive_codegen.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/codes/models/audio/tts/autoregressive_codegen.py b/codes/models/audio/tts/autoregressive_codegen.py index b3c32ee4..b4003aae 100644 --- a/codes/models/audio/tts/autoregressive_codegen.py +++ b/codes/models/audio/tts/autoregressive_codegen.py @@ -196,17 +196,15 @@ class CheckpointedXTransformerWrapper(nn.Module): class AutoregressiveCodegen(nn.Module): - def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, max_text_tokens=4000, - max_mel_tokens=4000, dropout=.1): + def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1): super().__init__() self.START_TOKEN=8192 self.STOP_TOKEN=8193 - self.max_mel_tokens = max_mel_tokens self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False) self.encoder = CheckpointedXTransformerWrapper( num_tokens=num_text_tokens, - max_seq_len=max_text_tokens, + use_pos_emb=False, attn_layers = Encoder( depth=depth//2, heads=model_dim//64, @@ -221,7 +219,7 @@ class AutoregressiveCodegen(nn.Module): )) self.decoder = CheckpointedXTransformerWrapper( num_tokens=num_mel_tokens, - max_seq_len=max_mel_tokens, + use_pos_emb=False, attn_layers=Decoder( depth=depth, heads=model_dim//64, @@ -268,7 +266,7 @@ class AutoregressiveCodegen(nn.Module): loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes) return loss_mel - def generate(self, conditioning_signal, text_codes, **hf_generate_kwargs): + def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs): if not hasattr(self, 'inference_model'): self.inference_model = InferenceModel(self) @@ -283,7 +281,7 @@ class AutoregressiveCodegen(nn.Module): self.inference_model.store_context(context) gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN, - max_length=self.max_mel_tokens, output_attentions=False, return_dict_in_generate=True, + max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, **hf_generate_kwargs) return gen