get rid of relative position embeddings, which do not work with DDP & checkpointing

This commit is contained in:
James Betker 2022-04-02 21:55:32 -06:00
parent b6d62aca5d
commit 4c6bdfc9e2

View File

@ -196,17 +196,15 @@ class CheckpointedXTransformerWrapper(nn.Module):
class AutoregressiveCodegen(nn.Module):
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, max_text_tokens=4000,
max_mel_tokens=4000, dropout=.1):
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
super().__init__()
self.START_TOKEN=8192
self.STOP_TOKEN=8193
self.max_mel_tokens = max_mel_tokens
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
self.encoder = CheckpointedXTransformerWrapper(
num_tokens=num_text_tokens,
max_seq_len=max_text_tokens,
use_pos_emb=False,
attn_layers = Encoder(
depth=depth//2,
heads=model_dim//64,
@ -221,7 +219,7 @@ class AutoregressiveCodegen(nn.Module):
))
self.decoder = CheckpointedXTransformerWrapper(
num_tokens=num_mel_tokens,
max_seq_len=max_mel_tokens,
use_pos_emb=False,
attn_layers=Decoder(
depth=depth,
heads=model_dim//64,
@ -268,7 +266,7 @@ class AutoregressiveCodegen(nn.Module):
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
return loss_mel
def generate(self, conditioning_signal, text_codes, **hf_generate_kwargs):
def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs):
if not hasattr(self, 'inference_model'):
self.inference_model = InferenceModel(self)
@ -283,7 +281,7 @@ class AutoregressiveCodegen(nn.Module):
self.inference_model.store_context(context)
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
max_length=self.max_mel_tokens, output_attentions=False, return_dict_in_generate=True,
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
**hf_generate_kwargs)
return gen