forked from mrq/DL-Art-School
get rid of relative position embeddings, which do not work with DDP & checkpointing
This commit is contained in:
parent
b6d62aca5d
commit
4c6bdfc9e2
|
@ -196,17 +196,15 @@ class CheckpointedXTransformerWrapper(nn.Module):
|
|||
|
||||
|
||||
class AutoregressiveCodegen(nn.Module):
|
||||
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, max_text_tokens=4000,
|
||||
max_mel_tokens=4000, dropout=.1):
|
||||
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
|
||||
super().__init__()
|
||||
|
||||
self.START_TOKEN=8192
|
||||
self.STOP_TOKEN=8193
|
||||
self.max_mel_tokens = max_mel_tokens
|
||||
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
||||
self.encoder = CheckpointedXTransformerWrapper(
|
||||
num_tokens=num_text_tokens,
|
||||
max_seq_len=max_text_tokens,
|
||||
use_pos_emb=False,
|
||||
attn_layers = Encoder(
|
||||
depth=depth//2,
|
||||
heads=model_dim//64,
|
||||
|
@ -221,7 +219,7 @@ class AutoregressiveCodegen(nn.Module):
|
|||
))
|
||||
self.decoder = CheckpointedXTransformerWrapper(
|
||||
num_tokens=num_mel_tokens,
|
||||
max_seq_len=max_mel_tokens,
|
||||
use_pos_emb=False,
|
||||
attn_layers=Decoder(
|
||||
depth=depth,
|
||||
heads=model_dim//64,
|
||||
|
@ -268,7 +266,7 @@ class AutoregressiveCodegen(nn.Module):
|
|||
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
||||
return loss_mel
|
||||
|
||||
def generate(self, conditioning_signal, text_codes, **hf_generate_kwargs):
|
||||
def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs):
|
||||
if not hasattr(self, 'inference_model'):
|
||||
self.inference_model = InferenceModel(self)
|
||||
|
||||
|
@ -283,7 +281,7 @@ class AutoregressiveCodegen(nn.Module):
|
|||
self.inference_model.store_context(context)
|
||||
|
||||
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
||||
max_length=self.max_mel_tokens, output_attentions=False, return_dict_in_generate=True,
|
||||
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
|
||||
**hf_generate_kwargs)
|
||||
return gen
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user