Reduce complexity of the encoder for gpt_asr_hf
This commit is contained in:
parent
da55ca0438
commit
4cff774b0e
|
@ -1,14 +1,9 @@
|
|||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from munch import munchify
|
||||
from transformers import GPT2Model, GPT2Config
|
||||
|
||||
from models.gpt_voice.lucidrains_gpt import Transformer
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from models.tacotron2.text import symbols, sequence_to_text
|
||||
from models.tacotron2.text import symbols
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
@ -17,10 +12,10 @@ class ResBlock(nn.Module):
|
|||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
|
||||
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(chan),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
|
||||
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(chan)
|
||||
)
|
||||
|
||||
|
@ -32,17 +27,15 @@ class MelEncoder(nn.Module):
|
|||
def __init__(self, channels, mel_channels=80):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3),
|
||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=5, padding=2),
|
||||
ResBlock(channels//4),
|
||||
ResBlock(channels//4),
|
||||
nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2),
|
||||
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm1d(channels//2),
|
||||
nn.ReLU(),
|
||||
ResBlock(channels//2),
|
||||
ResBlock(channels//2),
|
||||
ResBlock(channels//2),
|
||||
nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2),
|
||||
ResBlock(channels),
|
||||
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
||||
ResBlock(channels),
|
||||
ResBlock(channels)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user