This commit is contained in:
James Betker 2021-08-24 17:12:04 -06:00
parent 9dfe936c16
commit d05cc1f46c
2 changed files with 4 additions and 3 deletions

View File

@ -90,8 +90,9 @@ if __name__ == "__main__":
clip_size = model.max_mel_frames
while start+clip_size < mels.shape[-1]:
clip = mels[:, :, start:start+clip_size]
preds = torch.nn.functional.sigmoid(model(clip)).squeeze(-1).squeeze(0) # Squeeze off the batch and sigmoid dimensions, leaving only the sequence dimension.
indices = torch.nonzero(preds > cutoff_pred_percent)
pred_starts, pred_ends = model(clip)
pred_ends = torch.nn.functional.sigmoid(pred_ends).squeeze(-1).squeeze(0) # Squeeze off the batch and sigmoid dimensions, leaving only the sequence dimension.
indices = torch.nonzero(pred_ends > cutoff_pred_percent)
for i in indices:
i = i.item()
sentence = mels[0, :, last_detection_start:start+i]

View File

@ -284,7 +284,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_stop_libritts.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()