277 lines
10 KiB
Python
277 lines
10 KiB
Python
# Source: https://github.com/SeanNaren/deepspeech.pytorch
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchaudio.functional import magphase
|
|
|
|
from data.audio.unsupervised_audio_dataset import load_audio
|
|
from models.deepspeech.decoder import GreedyDecoder
|
|
from trainer.networks import register_model
|
|
|
|
|
|
class SequenceWise(nn.Module):
|
|
def __init__(self, module):
|
|
"""
|
|
Collapses input of dim T*N*H to (T*N)*H, and applies to a module.
|
|
Allows handling of variable sequence lengths and minibatch sizes.
|
|
:param module: Module to apply input to.
|
|
"""
|
|
super(SequenceWise, self).__init__()
|
|
self.module = module
|
|
|
|
def forward(self, x):
|
|
t, n = x.size(0), x.size(1)
|
|
x = x.view(t * n, -1)
|
|
x = self.module(x)
|
|
x = x.view(t, n, -1)
|
|
return x
|
|
|
|
def __repr__(self):
|
|
tmpstr = self.__class__.__name__ + ' (\n'
|
|
tmpstr += self.module.__repr__()
|
|
tmpstr += ')'
|
|
return tmpstr
|
|
|
|
|
|
class MaskConv(nn.Module):
|
|
def __init__(self, seq_module):
|
|
"""
|
|
Adds padding to the output of the module based on the given lengths. This is to ensure that the
|
|
results of the model do not change when batch sizes change during inference.
|
|
Input needs to be in the shape of (BxCxDxT)
|
|
:param seq_module: The sequential module containing the conv stack.
|
|
"""
|
|
super(MaskConv, self).__init__()
|
|
self.seq_module = seq_module
|
|
|
|
def forward(self, x, lengths):
|
|
"""
|
|
:param x: The input of size BxCxDxT
|
|
:param lengths: The actual length of each sequence in the batch
|
|
:return: Masked output from the module
|
|
"""
|
|
for module in self.seq_module:
|
|
x = module(x)
|
|
mask = torch.BoolTensor(x.size()).fill_(0)
|
|
if x.is_cuda:
|
|
mask = mask.cuda()
|
|
for i, length in enumerate(lengths):
|
|
length = length.item()
|
|
if (mask[i].size(2) - length) > 0:
|
|
mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1)
|
|
x = x.masked_fill(mask, 0)
|
|
return x, lengths
|
|
|
|
|
|
class InferenceBatchSoftmax(nn.Module):
|
|
def forward(self, input_):
|
|
if not self.training:
|
|
return F.softmax(input_, dim=-1)
|
|
else:
|
|
return input_
|
|
|
|
|
|
class BatchRNN(nn.Module):
|
|
def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True):
|
|
super(BatchRNN, self).__init__()
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.bidirectional = bidirectional
|
|
self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None
|
|
self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size,
|
|
bidirectional=bidirectional, bias=True)
|
|
self.num_directions = 2 if bidirectional else 1
|
|
|
|
def flatten_parameters(self):
|
|
self.rnn.flatten_parameters()
|
|
|
|
def forward(self, x, output_lengths):
|
|
if self.batch_norm is not None:
|
|
x = self.batch_norm(x)
|
|
x = nn.utils.rnn.pack_padded_sequence(x, output_lengths)
|
|
x, h = self.rnn(x)
|
|
x, _ = nn.utils.rnn.pad_packed_sequence(x)
|
|
if self.bidirectional:
|
|
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum
|
|
return x
|
|
|
|
|
|
class Lookahead(nn.Module):
|
|
# Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks
|
|
# input shape - sequence, batch, feature - TxNxH
|
|
# output shape - same as input
|
|
def __init__(self, n_features, context):
|
|
super(Lookahead, self).__init__()
|
|
assert context > 0
|
|
self.context = context
|
|
self.n_features = n_features
|
|
self.pad = (0, self.context - 1)
|
|
self.conv = nn.Conv1d(
|
|
self.n_features,
|
|
self.n_features,
|
|
kernel_size=self.context,
|
|
stride=1,
|
|
groups=self.n_features,
|
|
padding=0,
|
|
bias=False
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x.transpose(0, 1).transpose(1, 2)
|
|
x = F.pad(x, pad=self.pad, value=0)
|
|
x = self.conv(x)
|
|
x = x.transpose(1, 2).transpose(0, 1).contiguous()
|
|
return x
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '(' \
|
|
+ 'n_features=' + str(self.n_features) \
|
|
+ ', context=' + str(self.context) + ')'
|
|
|
|
|
|
class DeepSpeech(nn.Module):
|
|
def __init__(self,
|
|
hidden_size: int = 1024,
|
|
hidden_layers: int = 5,
|
|
lookahead_context: int = 20,
|
|
bidirectional: bool = True,
|
|
sample_rate: int = 16000,
|
|
window_size: int = .02,
|
|
window_stride: int = .01
|
|
):
|
|
super().__init__()
|
|
self.bidirectional = bidirectional
|
|
self.sample_rate = sample_rate
|
|
self.window_size = window_size
|
|
self.window_stride = window_stride
|
|
|
|
self.labels = [ "_", "'", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q",
|
|
"R", "S", "T", "U", "V", "W", "X", "Y", "Z", " " ]
|
|
num_classes = len(self.labels)
|
|
self.conv = MaskConv(nn.Sequential(
|
|
nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)),
|
|
nn.BatchNorm2d(32),
|
|
nn.Hardtanh(0, 20, inplace=True),
|
|
nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)),
|
|
nn.BatchNorm2d(32),
|
|
nn.Hardtanh(0, 20, inplace=True)
|
|
))
|
|
# Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1
|
|
rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
|
|
rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1)
|
|
rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1)
|
|
rnn_input_size *= 32
|
|
|
|
self.rnns = nn.Sequential(
|
|
BatchRNN(
|
|
input_size=rnn_input_size,
|
|
hidden_size=hidden_size,
|
|
rnn_type=nn.LSTM,
|
|
bidirectional=self.bidirectional,
|
|
batch_norm=False
|
|
),
|
|
*(
|
|
BatchRNN(
|
|
input_size=hidden_size,
|
|
hidden_size=hidden_size,
|
|
rnn_type=nn.LSTM,
|
|
bidirectional=self.bidirectional
|
|
) for x in range(hidden_layers - 1)
|
|
)
|
|
)
|
|
|
|
self.lookahead = nn.Sequential(
|
|
# consider adding batch norm?
|
|
Lookahead(hidden_size, context=lookahead_context),
|
|
nn.Hardtanh(0, 20, inplace=True)
|
|
) if not self.bidirectional else None
|
|
|
|
fully_connected = nn.Sequential(
|
|
nn.BatchNorm1d(hidden_size),
|
|
nn.Linear(hidden_size, num_classes, bias=False)
|
|
)
|
|
self.fc = nn.Sequential(
|
|
SequenceWise(fully_connected),
|
|
)
|
|
self.inference_softmax = InferenceBatchSoftmax()
|
|
self.evaluation_decoder = GreedyDecoder(self.labels) # Decoder used for inference.
|
|
|
|
def forward(self, wav, lengths=None):
|
|
if lengths is None:
|
|
lengths = torch.tensor([wav.shape[-1] for _ in range(wav.shape[0])], dtype=torch.int32, device=wav.device)
|
|
|
|
x = self.audio_to_spectrogram(wav)
|
|
|
|
lengths = (lengths // math.ceil(wav.shape[-1] / x.shape[-1])).cpu().int() # 160 is the spectrogram compression
|
|
output_lengths = self.get_seq_lens(lengths)
|
|
x, _ = self.conv(x, output_lengths)
|
|
|
|
sizes = x.size()
|
|
x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension
|
|
x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH
|
|
|
|
for rnn in self.rnns:
|
|
x = rnn(x, output_lengths)
|
|
|
|
if not self.bidirectional: # no need for lookahead layer in bidirectional
|
|
x = self.lookahead(x)
|
|
|
|
x = self.fc(x)
|
|
x = x.transpose(0, 1)
|
|
#x = self.inference_softmax(x) <-- doesn't work?
|
|
return x, output_lengths
|
|
|
|
def infer(self, inputs, lengths=None):
|
|
out, output_sizes = self(inputs, lengths)
|
|
decoded_output, _ = self.evaluation_decoder.decode(out, output_sizes)
|
|
return decoded_output
|
|
|
|
def get_seq_lens(self, input_length):
|
|
"""
|
|
Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable
|
|
containing the size sequences that will be output by the network.
|
|
:param input_length: 1D Tensor
|
|
:return: 1D Tensor scaled by model
|
|
"""
|
|
seq_len = input_length
|
|
for m in self.conv.modules():
|
|
if type(m) == nn.modules.conv.Conv2d:
|
|
seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1) // m.stride[1] + 1)
|
|
return seq_len.int()
|
|
|
|
def audio_to_spectrogram(self, y):
|
|
if len(y.shape) == 3:
|
|
assert y.shape[1] == 1
|
|
y = y.squeeze(1)
|
|
n_fft = int(self.sample_rate * self.window_size)
|
|
win_length = n_fft
|
|
hop_length = int(self.sample_rate * self.window_stride)
|
|
# STFT
|
|
D = torch.stft(y, n_fft=n_fft, hop_length=hop_length,
|
|
win_length=win_length, window=torch.hamming_window(win_length, device=y.device))
|
|
spect, phase = magphase(D)
|
|
# S = log(S+1)
|
|
spect = torch.log1p(spect)
|
|
|
|
return spect.unsqueeze(1) # Deepspeech operates in a 2D spectrogram regime.
|
|
|
|
|
|
@register_model
|
|
def register_deepspeech(opt_net, opt):
|
|
return DeepSpeech(**opt_net['kwargs'])
|
|
|
|
|
|
# Test for ~4 second audio clip at 22050Hz
|
|
if __name__ == '__main__':
|
|
clip = load_audio('D:\\data\\audio\\libritts\\test-clean\\1089\\134686\\1089_134686_000008_000000.wav', 16000).cuda()
|
|
model = DeepSpeech().cuda()
|
|
model.eval()
|
|
sd = torch.load('\\\\192.168.5.3\\rtx3080_drv\\deepspeech.pytorch\\checkpoint_sd.pth')
|
|
with torch.no_grad():
|
|
model.load_state_dict(sd)
|
|
print(model(clip)[0].shape)
|
|
print(model.infer(clip))
|