GptAsrHf2 checkin
This commit is contained in:
parent
07c2b9907c
commit
c1bef01dfa
|
@ -31,19 +31,19 @@ class ResBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MelEncoder(nn.Module):
|
class MelEncoder(nn.Module):
|
||||||
def __init__(self, channels, mel_channels=80):
|
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
|
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
|
||||||
ResBlock(channels//4),
|
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
||||||
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
||||||
nn.GroupNorm(channels//16, channels//2),
|
nn.GroupNorm(channels//16, channels//2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
ResBlock(channels//2),
|
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
||||||
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
||||||
nn.GroupNorm(channels//8, channels),
|
nn.GroupNorm(channels//8, channels),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
ResBlock(channels)
|
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -211,20 +211,20 @@ def null_position_embeddings(range, dim):
|
||||||
|
|
||||||
class GptAsrHf2(nn.Module):
|
class GptAsrHf2(nn.Module):
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, checkpointing=True,
|
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, checkpointing=True,
|
||||||
number_text_tokens=512, start_token=511, stop_token=0):
|
number_text_tokens=512, start_token=511, stop_token=0, mel_encoder_resblocks_per_level=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number_text_tokens = number_text_tokens
|
self.number_text_tokens = number_text_tokens
|
||||||
self.start_token = start_token
|
self.start_token = start_token
|
||||||
self.stop_token = 0
|
self.stop_token = stop_token
|
||||||
|
|
||||||
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
|
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
|
||||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||||
|
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.max_mel_frames = self.max_mel_frames
|
self.max_mel_frames = self.max_mel_frames
|
||||||
self.mel_encoder = MelEncoder(model_dim)
|
self.mel_encoder = MelEncoder(model_dim, resblocks_per_reduction=mel_encoder_resblocks_per_level)
|
||||||
|
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||||
self.text_solo_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
|
||||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
|
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
|
||||||
seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames
|
seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames
|
||||||
self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens,
|
self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens,
|
||||||
|
@ -236,6 +236,8 @@ class GptAsrHf2(nn.Module):
|
||||||
gradient_checkpointing=checkpointing,
|
gradient_checkpointing=checkpointing,
|
||||||
use_cache=not checkpointing)
|
use_cache=not checkpointing)
|
||||||
self.gpt = GPT2Model(self.gpt_config)
|
self.gpt = GPT2Model(self.gpt_config)
|
||||||
|
self.text_solo_embedding = nn.Parameter(torch.randn(1,1,512) * self.gpt.config.initializer_range, requires_grad=True)
|
||||||
|
|
||||||
# Override the built in positional embeddings
|
# Override the built in positional embeddings
|
||||||
del self.gpt.wpe
|
del self.gpt.wpe
|
||||||
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
|
@ -244,7 +246,7 @@ class GptAsrHf2(nn.Module):
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
for module in [self.text_pos_embedding, self.text_solo_pos_embedding, self.mel_pos_embedding]:
|
for module in [self.text_pos_embedding, self.mel_pos_embedding]:
|
||||||
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
@ -287,7 +289,8 @@ class GptAsrHf2(nn.Module):
|
||||||
def text_only(self, text_inputs):
|
def text_only(self, text_inputs):
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
||||||
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
||||||
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \
|
||||||
|
self.text_solo_embedding
|
||||||
text_logits = self.get_logits(None, text_emb)
|
text_logits = self.get_logits(None, text_emb)
|
||||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
return loss_text.mean(), text_logits
|
return loss_text.mean(), text_logits
|
||||||
|
|
|
@ -88,7 +88,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
||||||
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_unified_voice.yml')
|
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_unified_voice.yml')
|
||||||
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
||||||
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_voice\\models\\13750_gpt_ema.pth')
|
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_voice\\models\\54000_gpt.pth')
|
||||||
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
||||||
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
|
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
|
||||||
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='libri_test')
|
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='libri_test')
|
||||||
|
|
|
@ -286,7 +286,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_voice_voice_clip.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user