DL-Art-School/codes/models/deepspeech/decoder.py

181 lines
7.2 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# ----------------------------------------------------------------------------
# Copyright 2015-2016 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
# Modified to support pytorch Tensors
import torch
from six.moves import xrange
class Decoder(object):
"""
Basic decoder class from which all other decoders inherit. Implements several
helper functions. Subclasses should implement the decode() method.
Arguments:
labels (list): mapping from integers to characters.
blank_index (int, optional): index for the blank '_' character. Defaults to 0.
"""
def __init__(self, labels, blank_index=0):
self.labels = labels
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
self.blank_index = blank_index
space_index = len(labels) # To prevent errors in decode, we add an out of bounds index for the space
if ' ' in labels:
space_index = labels.index(' ')
self.space_index = space_index
def decode(self, probs, sizes=None):
"""
Given a matrix of character probabilities, returns the decoder's
best guess of the transcription
Arguments:
probs: Tensor of character probabilities, where probs[c,t]
is the probability of character c at time t
sizes(optional): Size of each sequence in the mini-batch
Returns:
string: sequence of the model's best guess for the transcription
"""
raise NotImplementedError
class BeamCTCDecoder(Decoder):
def __init__(self,
labels,
lm_path=None,
alpha=0,
beta=0,
cutoff_top_n=40,
cutoff_prob=1.0,
beam_width=100,
num_processes=4,
blank_index=0):
super(BeamCTCDecoder, self).__init__(labels)
try:
from ctcdecode import CTCBeamDecoder
except ImportError:
raise ImportError("BeamCTCDecoder requires paddledecoder package.")
labels = list(labels) # Ensure labels are a list before passing to decoder
self._decoder = CTCBeamDecoder(labels, lm_path, alpha, beta, cutoff_top_n, cutoff_prob, beam_width,
num_processes, blank_index)
def convert_to_strings(self, out, seq_len):
results = []
for b, batch in enumerate(out):
utterances = []
for p, utt in enumerate(batch):
size = seq_len[b][p]
if size > 0:
transcript = ''.join(map(lambda x: self.int_to_char[x.item()], utt[0:size]))
else:
transcript = ''
utterances.append(transcript)
results.append(utterances)
return results
def convert_tensor(self, offsets, sizes):
results = []
for b, batch in enumerate(offsets):
utterances = []
for p, utt in enumerate(batch):
size = sizes[b][p]
if sizes[b][p] > 0:
utterances.append(utt[0:size])
else:
utterances.append(torch.tensor([], dtype=torch.int))
results.append(utterances)
return results
def decode(self, probs, sizes=None):
"""
Decodes probability output using ctcdecode package.
Arguments:
probs: Tensor of character probabilities, where probs[c,t]
is the probability of character c at time t
sizes: Size of each sequence in the mini-batch
Returns:
string: sequences of the model's best guess for the transcription
"""
probs = probs.cpu()
out, scores, offsets, seq_lens = self._decoder.decode(probs, sizes)
strings = self.convert_to_strings(out, seq_lens)
offsets = self.convert_tensor(offsets, seq_lens)
return strings, offsets
class GreedyDecoder(Decoder):
def __init__(self, labels, blank_index=0):
super(GreedyDecoder, self).__init__(labels, blank_index)
def convert_to_strings(self,
sequences,
sizes=None,
remove_repetitions=False,
return_offsets=False):
"""Given a list of numeric sequences, returns the corresponding strings"""
strings = []
offsets = [] if return_offsets else None
for x in xrange(len(sequences)):
seq_len = sizes[x] if sizes is not None else len(sequences[x])
string, string_offsets = self.process_string(sequences[x], seq_len, remove_repetitions)
strings.append([string]) # We only return one path
if return_offsets:
offsets.append([string_offsets])
if return_offsets:
return strings, offsets
else:
return strings
def process_string(self,
sequence,
size,
remove_repetitions=False):
string = ''
offsets = []
for i in range(size):
char = self.int_to_char[sequence[i].item()]
if char != self.int_to_char[self.blank_index]:
# if this char is a repetition and remove_repetitions=true, then skip
if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i - 1].item()]:
pass
elif char == self.labels[self.space_index]:
string += ' '
offsets.append(i)
else:
string = string + char
offsets.append(i)
return string, torch.tensor(offsets, dtype=torch.int)
def decode(self, probs, sizes=None):
"""
Returns the argmax decoding given the probability matrix. Removes
repeated elements in the sequence, as well as blanks.
Arguments:
probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim
sizes(optional): Size of each sequence in the mini-batch
Returns:
strings: sequences of the model's best guess for the transcription on inputs
offsets: time step per character predicted
"""
_, max_probs = torch.max(probs, 2)
strings, offsets = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)),
sizes,
remove_repetitions=True,
return_offsets=True)
return strings, offsets