Add deepspeech model and support for decoding with it

This commit is contained in:
James Betker 2021-10-27 13:09:46 -06:00
parent 15437b2fc3
commit 5d714bc566
6 changed files with 464 additions and 4 deletions

View File

@ -128,6 +128,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
gap = audio_norm.shape[-1] - self.pad_to
start = min(max(random.randint(0, gap-1) + sk * gap // 2, 0), gap-1)
clips.append(audio_norm[:, start:start+self.pad_to])
else:
clips.append(audio_norm)
output = {
'clip': clips[0],

View File

View File

@ -0,0 +1,181 @@
#!/usr/bin/env python
# ----------------------------------------------------------------------------
# Copyright 2015-2016 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
# Modified to support pytorch Tensors
import torch
from six.moves import xrange
class Decoder(object):
"""
Basic decoder class from which all other decoders inherit. Implements several
helper functions. Subclasses should implement the decode() method.
Arguments:
labels (list): mapping from integers to characters.
blank_index (int, optional): index for the blank '_' character. Defaults to 0.
"""
def __init__(self, labels, blank_index=0):
self.labels = labels
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
self.blank_index = blank_index
space_index = len(labels) # To prevent errors in decode, we add an out of bounds index for the space
if ' ' in labels:
space_index = labels.index(' ')
self.space_index = space_index
def decode(self, probs, sizes=None):
"""
Given a matrix of character probabilities, returns the decoder's
best guess of the transcription
Arguments:
probs: Tensor of character probabilities, where probs[c,t]
is the probability of character c at time t
sizes(optional): Size of each sequence in the mini-batch
Returns:
string: sequence of the model's best guess for the transcription
"""
raise NotImplementedError
class BeamCTCDecoder(Decoder):
def __init__(self,
labels,
lm_path=None,
alpha=0,
beta=0,
cutoff_top_n=40,
cutoff_prob=1.0,
beam_width=100,
num_processes=4,
blank_index=0):
super(BeamCTCDecoder, self).__init__(labels)
try:
from ctcdecode import CTCBeamDecoder
except ImportError:
raise ImportError("BeamCTCDecoder requires paddledecoder package.")
labels = list(labels) # Ensure labels are a list before passing to decoder
self._decoder = CTCBeamDecoder(labels, lm_path, alpha, beta, cutoff_top_n, cutoff_prob, beam_width,
num_processes, blank_index)
def convert_to_strings(self, out, seq_len):
results = []
for b, batch in enumerate(out):
utterances = []
for p, utt in enumerate(batch):
size = seq_len[b][p]
if size > 0:
transcript = ''.join(map(lambda x: self.int_to_char[x.item()], utt[0:size]))
else:
transcript = ''
utterances.append(transcript)
results.append(utterances)
return results
def convert_tensor(self, offsets, sizes):
results = []
for b, batch in enumerate(offsets):
utterances = []
for p, utt in enumerate(batch):
size = sizes[b][p]
if sizes[b][p] > 0:
utterances.append(utt[0:size])
else:
utterances.append(torch.tensor([], dtype=torch.int))
results.append(utterances)
return results
def decode(self, probs, sizes=None):
"""
Decodes probability output using ctcdecode package.
Arguments:
probs: Tensor of character probabilities, where probs[c,t]
is the probability of character c at time t
sizes: Size of each sequence in the mini-batch
Returns:
string: sequences of the model's best guess for the transcription
"""
probs = probs.cpu()
out, scores, offsets, seq_lens = self._decoder.decode(probs, sizes)
strings = self.convert_to_strings(out, seq_lens)
offsets = self.convert_tensor(offsets, seq_lens)
return strings, offsets
class GreedyDecoder(Decoder):
def __init__(self, labels, blank_index=0):
super(GreedyDecoder, self).__init__(labels, blank_index)
def convert_to_strings(self,
sequences,
sizes=None,
remove_repetitions=False,
return_offsets=False):
"""Given a list of numeric sequences, returns the corresponding strings"""
strings = []
offsets = [] if return_offsets else None
for x in xrange(len(sequences)):
seq_len = sizes[x] if sizes is not None else len(sequences[x])
string, string_offsets = self.process_string(sequences[x], seq_len, remove_repetitions)
strings.append([string]) # We only return one path
if return_offsets:
offsets.append([string_offsets])
if return_offsets:
return strings, offsets
else:
return strings
def process_string(self,
sequence,
size,
remove_repetitions=False):
string = ''
offsets = []
for i in range(size):
char = self.int_to_char[sequence[i].item()]
if char != self.int_to_char[self.blank_index]:
# if this char is a repetition and remove_repetitions=true, then skip
if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i - 1].item()]:
pass
elif char == self.labels[self.space_index]:
string += ' '
offsets.append(i)
else:
string = string + char
offsets.append(i)
return string, torch.tensor(offsets, dtype=torch.int)
def decode(self, probs, sizes=None):
"""
Returns the argmax decoding given the probability matrix. Removes
repeated elements in the sequence, as well as blanks.
Arguments:
probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim
sizes(optional): Size of each sequence in the mini-batch
Returns:
strings: sequences of the model's best guess for the transcription on inputs
offsets: time step per character predicted
"""
_, max_probs = torch.max(probs, 2)
strings, offsets = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)),
sizes,
remove_repetitions=True,
return_offsets=True)
return strings, offsets

View File

@ -0,0 +1,276 @@
# Source: https://github.com/SeanNaren/deepspeech.pytorch
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import magphase
from data.audio.unsupervised_audio_dataset import load_audio
from models.deepspeech.decoder import GreedyDecoder
from trainer.networks import register_model
class SequenceWise(nn.Module):
def __init__(self, module):
"""
Collapses input of dim T*N*H to (T*N)*H, and applies to a module.
Allows handling of variable sequence lengths and minibatch sizes.
:param module: Module to apply input to.
"""
super(SequenceWise, self).__init__()
self.module = module
def forward(self, x):
t, n = x.size(0), x.size(1)
x = x.view(t * n, -1)
x = self.module(x)
x = x.view(t, n, -1)
return x
def __repr__(self):
tmpstr = self.__class__.__name__ + ' (\n'
tmpstr += self.module.__repr__()
tmpstr += ')'
return tmpstr
class MaskConv(nn.Module):
def __init__(self, seq_module):
"""
Adds padding to the output of the module based on the given lengths. This is to ensure that the
results of the model do not change when batch sizes change during inference.
Input needs to be in the shape of (BxCxDxT)
:param seq_module: The sequential module containing the conv stack.
"""
super(MaskConv, self).__init__()
self.seq_module = seq_module
def forward(self, x, lengths):
"""
:param x: The input of size BxCxDxT
:param lengths: The actual length of each sequence in the batch
:return: Masked output from the module
"""
for module in self.seq_module:
x = module(x)
mask = torch.BoolTensor(x.size()).fill_(0)
if x.is_cuda:
mask = mask.cuda()
for i, length in enumerate(lengths):
length = length.item()
if (mask[i].size(2) - length) > 0:
mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1)
x = x.masked_fill(mask, 0)
return x, lengths
class InferenceBatchSoftmax(nn.Module):
def forward(self, input_):
if not self.training:
return F.softmax(input_, dim=-1)
else:
return input_
class BatchRNN(nn.Module):
def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True):
super(BatchRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bidirectional = bidirectional
self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None
self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size,
bidirectional=bidirectional, bias=True)
self.num_directions = 2 if bidirectional else 1
def flatten_parameters(self):
self.rnn.flatten_parameters()
def forward(self, x, output_lengths):
if self.batch_norm is not None:
x = self.batch_norm(x)
x = nn.utils.rnn.pack_padded_sequence(x, output_lengths)
x, h = self.rnn(x)
x, _ = nn.utils.rnn.pad_packed_sequence(x)
if self.bidirectional:
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum
return x
class Lookahead(nn.Module):
# Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks
# input shape - sequence, batch, feature - TxNxH
# output shape - same as input
def __init__(self, n_features, context):
super(Lookahead, self).__init__()
assert context > 0
self.context = context
self.n_features = n_features
self.pad = (0, self.context - 1)
self.conv = nn.Conv1d(
self.n_features,
self.n_features,
kernel_size=self.context,
stride=1,
groups=self.n_features,
padding=0,
bias=False
)
def forward(self, x):
x = x.transpose(0, 1).transpose(1, 2)
x = F.pad(x, pad=self.pad, value=0)
x = self.conv(x)
x = x.transpose(1, 2).transpose(0, 1).contiguous()
return x
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'n_features=' + str(self.n_features) \
+ ', context=' + str(self.context) + ')'
class DeepSpeech(nn.Module):
def __init__(self,
hidden_size: int = 1024,
hidden_layers: int = 5,
lookahead_context: int = 20,
bidirectional: bool = True,
sample_rate: int = 16000,
window_size: int = .02,
window_stride: int = .01
):
super().__init__()
self.bidirectional = bidirectional
self.sample_rate = sample_rate
self.window_size = window_size
self.window_stride = window_stride
self.labels = [ "_", "'", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q",
"R", "S", "T", "U", "V", "W", "X", "Y", "Z", " " ]
num_classes = len(self.labels)
self.conv = MaskConv(nn.Sequential(
nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True),
nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True)
))
# Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1
rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1)
rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1)
rnn_input_size *= 32
self.rnns = nn.Sequential(
BatchRNN(
input_size=rnn_input_size,
hidden_size=hidden_size,
rnn_type=nn.LSTM,
bidirectional=self.bidirectional,
batch_norm=False
),
*(
BatchRNN(
input_size=hidden_size,
hidden_size=hidden_size,
rnn_type=nn.LSTM,
bidirectional=self.bidirectional
) for x in range(hidden_layers - 1)
)
)
self.lookahead = nn.Sequential(
# consider adding batch norm?
Lookahead(hidden_size, context=lookahead_context),
nn.Hardtanh(0, 20, inplace=True)
) if not self.bidirectional else None
fully_connected = nn.Sequential(
nn.BatchNorm1d(hidden_size),
nn.Linear(hidden_size, num_classes, bias=False)
)
self.fc = nn.Sequential(
SequenceWise(fully_connected),
)
self.inference_softmax = InferenceBatchSoftmax()
self.evaluation_decoder = GreedyDecoder(self.labels) # Decoder used for inference.
def forward(self, wav, lengths=None):
if lengths is None:
lengths = torch.tensor([wav.shape[-1] for _ in range(wav.shape[0])], dtype=torch.int32, device=wav.device)
x = self.audio_to_spectrogram(wav)
lengths = (lengths // math.ceil(wav.shape[-1] / x.shape[-1])).cpu().int() # 160 is the spectrogram compression
output_lengths = self.get_seq_lens(lengths)
x, _ = self.conv(x, output_lengths)
sizes = x.size()
x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension
x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH
for rnn in self.rnns:
x = rnn(x, output_lengths)
if not self.bidirectional: # no need for lookahead layer in bidirectional
x = self.lookahead(x)
x = self.fc(x)
x = x.transpose(0, 1)
#x = self.inference_softmax(x) <-- doesn't work?
return x, output_lengths
def infer(self, inputs, lengths=None):
out, output_sizes = self(inputs, lengths)
decoded_output, _ = self.evaluation_decoder.decode(out, output_sizes)
return decoded_output
def get_seq_lens(self, input_length):
"""
Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable
containing the size sequences that will be output by the network.
:param input_length: 1D Tensor
:return: 1D Tensor scaled by model
"""
seq_len = input_length
for m in self.conv.modules():
if type(m) == nn.modules.conv.Conv2d:
seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1) // m.stride[1] + 1)
return seq_len.int()
def audio_to_spectrogram(self, y):
if len(y.shape) == 3:
assert y.shape[1] == 1
y = y.squeeze(1)
n_fft = int(self.sample_rate * self.window_size)
win_length = n_fft
hop_length = int(self.sample_rate * self.window_stride)
# STFT
D = torch.stft(y, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, window=torch.hamming_window(win_length, device=y.device))
spect, phase = magphase(D)
# S = log(S+1)
spect = torch.log1p(spect)
return spect.unsqueeze(1) # Deepspeech operates in a 2D spectrogram regime.
@register_model
def register_deepspeech(opt_net, opt):
return DeepSpeech(**opt_net['kwargs'])
# Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__':
clip = load_audio('D:\\data\\audio\\libritts\\test-clean\\1089\\134686\\1089_134686_000008_000000.wav', 16000).cuda()
model = DeepSpeech().cuda()
model.eval()
sd = torch.load('\\\\192.168.5.3\\rtx3080_drv\\deepspeech.pytorch\\checkpoint_sd.pth')
with torch.no_grad():
model.load_state_dict(sd)
print(model(clip)[0].shape)
print(model.infer(clip))

View File

@ -41,7 +41,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_mass.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_deepspeech_libri.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
@ -71,11 +71,12 @@ if __name__ == "__main__":
tq = tqdm(test_loader)
for data in tq:
if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
continue
#if data['clips'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
# continue
pred = forward_pass(model, data, dataset_dir, opt, batch)
pred = pred.replace('_', '')
output.write(f'{pred}\t{os.path.basename(data["path"][0])}\n')
print(pred)
output.flush()
batch += 1

View File

@ -284,7 +284,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_noisy_audio_clips_classifier.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_distill.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()