Reduce complexity of the encoder for gpt_asr_hf

This commit is contained in:
James Betker 2021-11-01 17:02:28 -06:00
parent da55ca0438
commit 4cff774b0e

View File

@ -1,14 +1,9 @@
from time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from munch import munchify
from transformers import GPT2Model, GPT2Config from transformers import GPT2Model, GPT2Config
from models.gpt_voice.lucidrains_gpt import Transformer from models.tacotron2.text import symbols
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.text import symbols, sequence_to_text
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
@ -17,10 +12,10 @@ class ResBlock(nn.Module):
def __init__(self, chan): def __init__(self, chan):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=5, padding = 2), nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.BatchNorm1d(chan), nn.BatchNorm1d(chan),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=5, padding = 2), nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.BatchNorm1d(chan) nn.BatchNorm1d(chan)
) )
@ -32,17 +27,15 @@ class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80): def __init__(self, channels, mel_channels=80):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3), self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=5, padding=2),
ResBlock(channels//4), ResBlock(channels//4),
ResBlock(channels//4), ResBlock(channels//4),
nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2), nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
nn.BatchNorm1d(channels//2), nn.BatchNorm1d(channels//2),
nn.ReLU(), nn.ReLU(),
ResBlock(channels//2), ResBlock(channels//2),
ResBlock(channels//2), ResBlock(channels//2),
ResBlock(channels//2), nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2),
ResBlock(channels),
ResBlock(channels), ResBlock(channels),
ResBlock(channels) ResBlock(channels)
) )