Add deepspeech model and support for decoding with it
This commit is contained in:
parent
15437b2fc3
commit
5d714bc566
|
@ -128,6 +128,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
gap = audio_norm.shape[-1] - self.pad_to
|
gap = audio_norm.shape[-1] - self.pad_to
|
||||||
start = min(max(random.randint(0, gap-1) + sk * gap // 2, 0), gap-1)
|
start = min(max(random.randint(0, gap-1) + sk * gap // 2, 0), gap-1)
|
||||||
clips.append(audio_norm[:, start:start+self.pad_to])
|
clips.append(audio_norm[:, start:start+self.pad_to])
|
||||||
|
else:
|
||||||
|
clips.append(audio_norm)
|
||||||
|
|
||||||
output = {
|
output = {
|
||||||
'clip': clips[0],
|
'clip': clips[0],
|
||||||
|
|
0
codes/models/deepspeech/__init__.py
Normal file
0
codes/models/deepspeech/__init__.py
Normal file
181
codes/models/deepspeech/decoder.py
Normal file
181
codes/models/deepspeech/decoder.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# Copyright 2015-2016 Nervana Systems Inc.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# Modified to support pytorch Tensors
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from six.moves import xrange
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(object):
|
||||||
|
"""
|
||||||
|
Basic decoder class from which all other decoders inherit. Implements several
|
||||||
|
helper functions. Subclasses should implement the decode() method.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
labels (list): mapping from integers to characters.
|
||||||
|
blank_index (int, optional): index for the blank '_' character. Defaults to 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, labels, blank_index=0):
|
||||||
|
self.labels = labels
|
||||||
|
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
|
||||||
|
self.blank_index = blank_index
|
||||||
|
space_index = len(labels) # To prevent errors in decode, we add an out of bounds index for the space
|
||||||
|
if ' ' in labels:
|
||||||
|
space_index = labels.index(' ')
|
||||||
|
self.space_index = space_index
|
||||||
|
|
||||||
|
def decode(self, probs, sizes=None):
|
||||||
|
"""
|
||||||
|
Given a matrix of character probabilities, returns the decoder's
|
||||||
|
best guess of the transcription
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
probs: Tensor of character probabilities, where probs[c,t]
|
||||||
|
is the probability of character c at time t
|
||||||
|
sizes(optional): Size of each sequence in the mini-batch
|
||||||
|
Returns:
|
||||||
|
string: sequence of the model's best guess for the transcription
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class BeamCTCDecoder(Decoder):
|
||||||
|
def __init__(self,
|
||||||
|
labels,
|
||||||
|
lm_path=None,
|
||||||
|
alpha=0,
|
||||||
|
beta=0,
|
||||||
|
cutoff_top_n=40,
|
||||||
|
cutoff_prob=1.0,
|
||||||
|
beam_width=100,
|
||||||
|
num_processes=4,
|
||||||
|
blank_index=0):
|
||||||
|
super(BeamCTCDecoder, self).__init__(labels)
|
||||||
|
try:
|
||||||
|
from ctcdecode import CTCBeamDecoder
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("BeamCTCDecoder requires paddledecoder package.")
|
||||||
|
labels = list(labels) # Ensure labels are a list before passing to decoder
|
||||||
|
self._decoder = CTCBeamDecoder(labels, lm_path, alpha, beta, cutoff_top_n, cutoff_prob, beam_width,
|
||||||
|
num_processes, blank_index)
|
||||||
|
|
||||||
|
def convert_to_strings(self, out, seq_len):
|
||||||
|
results = []
|
||||||
|
for b, batch in enumerate(out):
|
||||||
|
utterances = []
|
||||||
|
for p, utt in enumerate(batch):
|
||||||
|
size = seq_len[b][p]
|
||||||
|
if size > 0:
|
||||||
|
transcript = ''.join(map(lambda x: self.int_to_char[x.item()], utt[0:size]))
|
||||||
|
else:
|
||||||
|
transcript = ''
|
||||||
|
utterances.append(transcript)
|
||||||
|
results.append(utterances)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def convert_tensor(self, offsets, sizes):
|
||||||
|
results = []
|
||||||
|
for b, batch in enumerate(offsets):
|
||||||
|
utterances = []
|
||||||
|
for p, utt in enumerate(batch):
|
||||||
|
size = sizes[b][p]
|
||||||
|
if sizes[b][p] > 0:
|
||||||
|
utterances.append(utt[0:size])
|
||||||
|
else:
|
||||||
|
utterances.append(torch.tensor([], dtype=torch.int))
|
||||||
|
results.append(utterances)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def decode(self, probs, sizes=None):
|
||||||
|
"""
|
||||||
|
Decodes probability output using ctcdecode package.
|
||||||
|
Arguments:
|
||||||
|
probs: Tensor of character probabilities, where probs[c,t]
|
||||||
|
is the probability of character c at time t
|
||||||
|
sizes: Size of each sequence in the mini-batch
|
||||||
|
Returns:
|
||||||
|
string: sequences of the model's best guess for the transcription
|
||||||
|
"""
|
||||||
|
probs = probs.cpu()
|
||||||
|
out, scores, offsets, seq_lens = self._decoder.decode(probs, sizes)
|
||||||
|
|
||||||
|
strings = self.convert_to_strings(out, seq_lens)
|
||||||
|
offsets = self.convert_tensor(offsets, seq_lens)
|
||||||
|
return strings, offsets
|
||||||
|
|
||||||
|
|
||||||
|
class GreedyDecoder(Decoder):
|
||||||
|
def __init__(self, labels, blank_index=0):
|
||||||
|
super(GreedyDecoder, self).__init__(labels, blank_index)
|
||||||
|
|
||||||
|
def convert_to_strings(self,
|
||||||
|
sequences,
|
||||||
|
sizes=None,
|
||||||
|
remove_repetitions=False,
|
||||||
|
return_offsets=False):
|
||||||
|
"""Given a list of numeric sequences, returns the corresponding strings"""
|
||||||
|
strings = []
|
||||||
|
offsets = [] if return_offsets else None
|
||||||
|
for x in xrange(len(sequences)):
|
||||||
|
seq_len = sizes[x] if sizes is not None else len(sequences[x])
|
||||||
|
string, string_offsets = self.process_string(sequences[x], seq_len, remove_repetitions)
|
||||||
|
strings.append([string]) # We only return one path
|
||||||
|
if return_offsets:
|
||||||
|
offsets.append([string_offsets])
|
||||||
|
if return_offsets:
|
||||||
|
return strings, offsets
|
||||||
|
else:
|
||||||
|
return strings
|
||||||
|
|
||||||
|
def process_string(self,
|
||||||
|
sequence,
|
||||||
|
size,
|
||||||
|
remove_repetitions=False):
|
||||||
|
string = ''
|
||||||
|
offsets = []
|
||||||
|
for i in range(size):
|
||||||
|
char = self.int_to_char[sequence[i].item()]
|
||||||
|
if char != self.int_to_char[self.blank_index]:
|
||||||
|
# if this char is a repetition and remove_repetitions=true, then skip
|
||||||
|
if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i - 1].item()]:
|
||||||
|
pass
|
||||||
|
elif char == self.labels[self.space_index]:
|
||||||
|
string += ' '
|
||||||
|
offsets.append(i)
|
||||||
|
else:
|
||||||
|
string = string + char
|
||||||
|
offsets.append(i)
|
||||||
|
return string, torch.tensor(offsets, dtype=torch.int)
|
||||||
|
|
||||||
|
def decode(self, probs, sizes=None):
|
||||||
|
"""
|
||||||
|
Returns the argmax decoding given the probability matrix. Removes
|
||||||
|
repeated elements in the sequence, as well as blanks.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim
|
||||||
|
sizes(optional): Size of each sequence in the mini-batch
|
||||||
|
Returns:
|
||||||
|
strings: sequences of the model's best guess for the transcription on inputs
|
||||||
|
offsets: time step per character predicted
|
||||||
|
"""
|
||||||
|
_, max_probs = torch.max(probs, 2)
|
||||||
|
strings, offsets = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)),
|
||||||
|
sizes,
|
||||||
|
remove_repetitions=True,
|
||||||
|
return_offsets=True)
|
||||||
|
return strings, offsets
|
276
codes/models/deepspeech/deepspeech.py
Normal file
276
codes/models/deepspeech/deepspeech.py
Normal file
|
@ -0,0 +1,276 @@
|
||||||
|
# Source: https://github.com/SeanNaren/deepspeech.pytorch
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchaudio.functional import magphase
|
||||||
|
|
||||||
|
from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
|
from models.deepspeech.decoder import GreedyDecoder
|
||||||
|
from trainer.networks import register_model
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceWise(nn.Module):
|
||||||
|
def __init__(self, module):
|
||||||
|
"""
|
||||||
|
Collapses input of dim T*N*H to (T*N)*H, and applies to a module.
|
||||||
|
Allows handling of variable sequence lengths and minibatch sizes.
|
||||||
|
:param module: Module to apply input to.
|
||||||
|
"""
|
||||||
|
super(SequenceWise, self).__init__()
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
t, n = x.size(0), x.size(1)
|
||||||
|
x = x.view(t * n, -1)
|
||||||
|
x = self.module(x)
|
||||||
|
x = x.view(t, n, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
tmpstr = self.__class__.__name__ + ' (\n'
|
||||||
|
tmpstr += self.module.__repr__()
|
||||||
|
tmpstr += ')'
|
||||||
|
return tmpstr
|
||||||
|
|
||||||
|
|
||||||
|
class MaskConv(nn.Module):
|
||||||
|
def __init__(self, seq_module):
|
||||||
|
"""
|
||||||
|
Adds padding to the output of the module based on the given lengths. This is to ensure that the
|
||||||
|
results of the model do not change when batch sizes change during inference.
|
||||||
|
Input needs to be in the shape of (BxCxDxT)
|
||||||
|
:param seq_module: The sequential module containing the conv stack.
|
||||||
|
"""
|
||||||
|
super(MaskConv, self).__init__()
|
||||||
|
self.seq_module = seq_module
|
||||||
|
|
||||||
|
def forward(self, x, lengths):
|
||||||
|
"""
|
||||||
|
:param x: The input of size BxCxDxT
|
||||||
|
:param lengths: The actual length of each sequence in the batch
|
||||||
|
:return: Masked output from the module
|
||||||
|
"""
|
||||||
|
for module in self.seq_module:
|
||||||
|
x = module(x)
|
||||||
|
mask = torch.BoolTensor(x.size()).fill_(0)
|
||||||
|
if x.is_cuda:
|
||||||
|
mask = mask.cuda()
|
||||||
|
for i, length in enumerate(lengths):
|
||||||
|
length = length.item()
|
||||||
|
if (mask[i].size(2) - length) > 0:
|
||||||
|
mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1)
|
||||||
|
x = x.masked_fill(mask, 0)
|
||||||
|
return x, lengths
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceBatchSoftmax(nn.Module):
|
||||||
|
def forward(self, input_):
|
||||||
|
if not self.training:
|
||||||
|
return F.softmax(input_, dim=-1)
|
||||||
|
else:
|
||||||
|
return input_
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRNN(nn.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True):
|
||||||
|
super(BatchRNN, self).__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.bidirectional = bidirectional
|
||||||
|
self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None
|
||||||
|
self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size,
|
||||||
|
bidirectional=bidirectional, bias=True)
|
||||||
|
self.num_directions = 2 if bidirectional else 1
|
||||||
|
|
||||||
|
def flatten_parameters(self):
|
||||||
|
self.rnn.flatten_parameters()
|
||||||
|
|
||||||
|
def forward(self, x, output_lengths):
|
||||||
|
if self.batch_norm is not None:
|
||||||
|
x = self.batch_norm(x)
|
||||||
|
x = nn.utils.rnn.pack_padded_sequence(x, output_lengths)
|
||||||
|
x, h = self.rnn(x)
|
||||||
|
x, _ = nn.utils.rnn.pad_packed_sequence(x)
|
||||||
|
if self.bidirectional:
|
||||||
|
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Lookahead(nn.Module):
|
||||||
|
# Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks
|
||||||
|
# input shape - sequence, batch, feature - TxNxH
|
||||||
|
# output shape - same as input
|
||||||
|
def __init__(self, n_features, context):
|
||||||
|
super(Lookahead, self).__init__()
|
||||||
|
assert context > 0
|
||||||
|
self.context = context
|
||||||
|
self.n_features = n_features
|
||||||
|
self.pad = (0, self.context - 1)
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
self.n_features,
|
||||||
|
self.n_features,
|
||||||
|
kernel_size=self.context,
|
||||||
|
stride=1,
|
||||||
|
groups=self.n_features,
|
||||||
|
padding=0,
|
||||||
|
bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(0, 1).transpose(1, 2)
|
||||||
|
x = F.pad(x, pad=self.pad, value=0)
|
||||||
|
x = self.conv(x)
|
||||||
|
x = x.transpose(1, 2).transpose(0, 1).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + '(' \
|
||||||
|
+ 'n_features=' + str(self.n_features) \
|
||||||
|
+ ', context=' + str(self.context) + ')'
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSpeech(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size: int = 1024,
|
||||||
|
hidden_layers: int = 5,
|
||||||
|
lookahead_context: int = 20,
|
||||||
|
bidirectional: bool = True,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
window_size: int = .02,
|
||||||
|
window_stride: int = .01
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.bidirectional = bidirectional
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.window_size = window_size
|
||||||
|
self.window_stride = window_stride
|
||||||
|
|
||||||
|
self.labels = [ "_", "'", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q",
|
||||||
|
"R", "S", "T", "U", "V", "W", "X", "Y", "Z", " " ]
|
||||||
|
num_classes = len(self.labels)
|
||||||
|
self.conv = MaskConv(nn.Sequential(
|
||||||
|
nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)),
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.Hardtanh(0, 20, inplace=True),
|
||||||
|
nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)),
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.Hardtanh(0, 20, inplace=True)
|
||||||
|
))
|
||||||
|
# Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1
|
||||||
|
rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
|
||||||
|
rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1)
|
||||||
|
rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1)
|
||||||
|
rnn_input_size *= 32
|
||||||
|
|
||||||
|
self.rnns = nn.Sequential(
|
||||||
|
BatchRNN(
|
||||||
|
input_size=rnn_input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
rnn_type=nn.LSTM,
|
||||||
|
bidirectional=self.bidirectional,
|
||||||
|
batch_norm=False
|
||||||
|
),
|
||||||
|
*(
|
||||||
|
BatchRNN(
|
||||||
|
input_size=hidden_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
rnn_type=nn.LSTM,
|
||||||
|
bidirectional=self.bidirectional
|
||||||
|
) for x in range(hidden_layers - 1)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lookahead = nn.Sequential(
|
||||||
|
# consider adding batch norm?
|
||||||
|
Lookahead(hidden_size, context=lookahead_context),
|
||||||
|
nn.Hardtanh(0, 20, inplace=True)
|
||||||
|
) if not self.bidirectional else None
|
||||||
|
|
||||||
|
fully_connected = nn.Sequential(
|
||||||
|
nn.BatchNorm1d(hidden_size),
|
||||||
|
nn.Linear(hidden_size, num_classes, bias=False)
|
||||||
|
)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
SequenceWise(fully_connected),
|
||||||
|
)
|
||||||
|
self.inference_softmax = InferenceBatchSoftmax()
|
||||||
|
self.evaluation_decoder = GreedyDecoder(self.labels) # Decoder used for inference.
|
||||||
|
|
||||||
|
def forward(self, wav, lengths=None):
|
||||||
|
if lengths is None:
|
||||||
|
lengths = torch.tensor([wav.shape[-1] for _ in range(wav.shape[0])], dtype=torch.int32, device=wav.device)
|
||||||
|
|
||||||
|
x = self.audio_to_spectrogram(wav)
|
||||||
|
|
||||||
|
lengths = (lengths // math.ceil(wav.shape[-1] / x.shape[-1])).cpu().int() # 160 is the spectrogram compression
|
||||||
|
output_lengths = self.get_seq_lens(lengths)
|
||||||
|
x, _ = self.conv(x, output_lengths)
|
||||||
|
|
||||||
|
sizes = x.size()
|
||||||
|
x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension
|
||||||
|
x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH
|
||||||
|
|
||||||
|
for rnn in self.rnns:
|
||||||
|
x = rnn(x, output_lengths)
|
||||||
|
|
||||||
|
if not self.bidirectional: # no need for lookahead layer in bidirectional
|
||||||
|
x = self.lookahead(x)
|
||||||
|
|
||||||
|
x = self.fc(x)
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
#x = self.inference_softmax(x) <-- doesn't work?
|
||||||
|
return x, output_lengths
|
||||||
|
|
||||||
|
def infer(self, inputs, lengths=None):
|
||||||
|
out, output_sizes = self(inputs, lengths)
|
||||||
|
decoded_output, _ = self.evaluation_decoder.decode(out, output_sizes)
|
||||||
|
return decoded_output
|
||||||
|
|
||||||
|
def get_seq_lens(self, input_length):
|
||||||
|
"""
|
||||||
|
Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable
|
||||||
|
containing the size sequences that will be output by the network.
|
||||||
|
:param input_length: 1D Tensor
|
||||||
|
:return: 1D Tensor scaled by model
|
||||||
|
"""
|
||||||
|
seq_len = input_length
|
||||||
|
for m in self.conv.modules():
|
||||||
|
if type(m) == nn.modules.conv.Conv2d:
|
||||||
|
seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1) // m.stride[1] + 1)
|
||||||
|
return seq_len.int()
|
||||||
|
|
||||||
|
def audio_to_spectrogram(self, y):
|
||||||
|
if len(y.shape) == 3:
|
||||||
|
assert y.shape[1] == 1
|
||||||
|
y = y.squeeze(1)
|
||||||
|
n_fft = int(self.sample_rate * self.window_size)
|
||||||
|
win_length = n_fft
|
||||||
|
hop_length = int(self.sample_rate * self.window_stride)
|
||||||
|
# STFT
|
||||||
|
D = torch.stft(y, n_fft=n_fft, hop_length=hop_length,
|
||||||
|
win_length=win_length, window=torch.hamming_window(win_length, device=y.device))
|
||||||
|
spect, phase = magphase(D)
|
||||||
|
# S = log(S+1)
|
||||||
|
spect = torch.log1p(spect)
|
||||||
|
|
||||||
|
return spect.unsqueeze(1) # Deepspeech operates in a 2D spectrogram regime.
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_deepspeech(opt_net, opt):
|
||||||
|
return DeepSpeech(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
|
# Test for ~4 second audio clip at 22050Hz
|
||||||
|
if __name__ == '__main__':
|
||||||
|
clip = load_audio('D:\\data\\audio\\libritts\\test-clean\\1089\\134686\\1089_134686_000008_000000.wav', 16000).cuda()
|
||||||
|
model = DeepSpeech().cuda()
|
||||||
|
model.eval()
|
||||||
|
sd = torch.load('\\\\192.168.5.3\\rtx3080_drv\\deepspeech.pytorch\\checkpoint_sd.pth')
|
||||||
|
with torch.no_grad():
|
||||||
|
model.load_state_dict(sd)
|
||||||
|
print(model(clip)[0].shape)
|
||||||
|
print(model.infer(clip))
|
|
@ -41,7 +41,7 @@ if __name__ == "__main__":
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
want_metrics = False
|
want_metrics = False
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_mass.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_deepspeech_libri.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
utils.util.loaded_options = opt
|
utils.util.loaded_options = opt
|
||||||
|
@ -71,11 +71,12 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
tq = tqdm(test_loader)
|
tq = tqdm(test_loader)
|
||||||
for data in tq:
|
for data in tq:
|
||||||
if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
#if data['clips'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
||||||
continue
|
# continue
|
||||||
pred = forward_pass(model, data, dataset_dir, opt, batch)
|
pred = forward_pass(model, data, dataset_dir, opt, batch)
|
||||||
pred = pred.replace('_', '')
|
pred = pred.replace('_', '')
|
||||||
output.write(f'{pred}\t{os.path.basename(data["path"][0])}\n')
|
output.write(f'{pred}\t{os.path.basename(data["path"][0])}\n')
|
||||||
|
print(pred)
|
||||||
output.flush()
|
output.flush()
|
||||||
batch += 1
|
batch += 1
|
||||||
|
|
|
@ -284,7 +284,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_noisy_audio_clips_classifier.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_distill.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user