Reduce complexity of the encoder for gpt_asr_hf
This commit is contained in:
parent
da55ca0438
commit
4cff774b0e
|
@ -1,14 +1,9 @@
|
||||||
from time import time
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from munch import munchify
|
|
||||||
from transformers import GPT2Model, GPT2Config
|
from transformers import GPT2Model, GPT2Config
|
||||||
|
|
||||||
from models.gpt_voice.lucidrains_gpt import Transformer
|
from models.tacotron2.text import symbols
|
||||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
|
||||||
from models.tacotron2.text import symbols, sequence_to_text
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
@ -17,10 +12,10 @@ class ResBlock(nn.Module):
|
||||||
def __init__(self, chan):
|
def __init__(self, chan):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
|
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||||
nn.BatchNorm1d(chan),
|
nn.BatchNorm1d(chan),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(chan, chan, kernel_size=5, padding = 2),
|
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||||
nn.BatchNorm1d(chan)
|
nn.BatchNorm1d(chan)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,17 +27,15 @@ class MelEncoder(nn.Module):
|
||||||
def __init__(self, channels, mel_channels=80):
|
def __init__(self, channels, mel_channels=80):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=7, padding=3),
|
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=5, padding=2),
|
||||||
ResBlock(channels//4),
|
ResBlock(channels//4),
|
||||||
ResBlock(channels//4),
|
ResBlock(channels//4),
|
||||||
nn.Conv1d(channels//4, channels//2, kernel_size=5, stride=2, padding=2),
|
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
||||||
nn.BatchNorm1d(channels//2),
|
nn.BatchNorm1d(channels//2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
ResBlock(channels//2),
|
ResBlock(channels//2),
|
||||||
ResBlock(channels//2),
|
ResBlock(channels//2),
|
||||||
ResBlock(channels//2),
|
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
||||||
nn.Conv1d(channels//2, channels, kernel_size=5, stride=2, padding=2),
|
|
||||||
ResBlock(channels),
|
|
||||||
ResBlock(channels),
|
ResBlock(channels),
|
||||||
ResBlock(channels)
|
ResBlock(channels)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user