forked from mrq/DL-Art-School
get rid of relative position embeddings, which do not work with DDP & checkpointing
This commit is contained in:
parent
b6d62aca5d
commit
4c6bdfc9e2
|
@ -196,17 +196,15 @@ class CheckpointedXTransformerWrapper(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class AutoregressiveCodegen(nn.Module):
|
class AutoregressiveCodegen(nn.Module):
|
||||||
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, max_text_tokens=4000,
|
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
|
||||||
max_mel_tokens=4000, dropout=.1):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.START_TOKEN=8192
|
self.START_TOKEN=8192
|
||||||
self.STOP_TOKEN=8193
|
self.STOP_TOKEN=8193
|
||||||
self.max_mel_tokens = max_mel_tokens
|
|
||||||
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
||||||
self.encoder = CheckpointedXTransformerWrapper(
|
self.encoder = CheckpointedXTransformerWrapper(
|
||||||
num_tokens=num_text_tokens,
|
num_tokens=num_text_tokens,
|
||||||
max_seq_len=max_text_tokens,
|
use_pos_emb=False,
|
||||||
attn_layers = Encoder(
|
attn_layers = Encoder(
|
||||||
depth=depth//2,
|
depth=depth//2,
|
||||||
heads=model_dim//64,
|
heads=model_dim//64,
|
||||||
|
@ -221,7 +219,7 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
))
|
))
|
||||||
self.decoder = CheckpointedXTransformerWrapper(
|
self.decoder = CheckpointedXTransformerWrapper(
|
||||||
num_tokens=num_mel_tokens,
|
num_tokens=num_mel_tokens,
|
||||||
max_seq_len=max_mel_tokens,
|
use_pos_emb=False,
|
||||||
attn_layers=Decoder(
|
attn_layers=Decoder(
|
||||||
depth=depth,
|
depth=depth,
|
||||||
heads=model_dim//64,
|
heads=model_dim//64,
|
||||||
|
@ -268,7 +266,7 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
||||||
return loss_mel
|
return loss_mel
|
||||||
|
|
||||||
def generate(self, conditioning_signal, text_codes, **hf_generate_kwargs):
|
def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs):
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
self.inference_model = InferenceModel(self)
|
self.inference_model = InferenceModel(self)
|
||||||
|
|
||||||
|
@ -283,7 +281,7 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
self.inference_model.store_context(context)
|
self.inference_model.store_context(context)
|
||||||
|
|
||||||
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
||||||
max_length=self.max_mel_tokens, output_attentions=False, return_dict_in_generate=True,
|
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
|
||||||
**hf_generate_kwargs)
|
**hf_generate_kwargs)
|
||||||
return gen
|
return gen
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user