Reduce complexity of the encoder for gpt_asr_hf

This commit is contained in:
James Betker 2021-11-01 17:02:28 -06:00
parent da55ca0438
commit 4cff774b0e

View File

@ -1,14 +1,9 @@
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from munch import munchify
from transformers import GPT2Model, GPT2Config
from models.gpt_voice.lucidrains_gpt import Transformer
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.text import symbols, sequence_to_text
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
@ -17,10 +12,10 @@ class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.BatchNorm1d(chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.BatchNorm1d(chan)
)
@ -32,17 +27,15 @@ class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3),
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=5, padding=2),
ResBlock(channels//4),
ResBlock(channels//4),
nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2),
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
nn.BatchNorm1d(channels//2),
nn.ReLU(),
ResBlock(channels//2),
ResBlock(channels//2),
ResBlock(channels//2),
nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2),
ResBlock(channels),
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
ResBlock(channels),
ResBlock(channels)
)