2023-08-02 21:53:35 +00:00
# todo: clean this mess up
import copy
import h5py
import json
import logging
import numpy as np
import os
import random
import torch
2023-08-27 00:53:23 +00:00
import itertools
2023-08-02 21:53:35 +00:00
from . config import cfg
2023-08-23 21:43:03 +00:00
from . emb . qnt import trim , trim_random , repeat_extend_audio , merge_audio , decode_to_file
2024-06-29 03:28:54 +00:00
from . utils . sampler import PoolSampler , OrderedSampler , BatchedOrderedSampler , RandomSampler
2024-06-01 14:29:49 +00:00
from . utils . distributed import global_rank , local_rank , world_size
2023-08-02 21:53:35 +00:00
from collections import defaultdict
from functools import cache , cached_property
from itertools import groupby , zip_longest
from pathlib import Path
from typing import Any
from torch import Tensor
from torch . utils . data import DataLoader , Dataset as _Dataset
2023-08-14 03:07:45 +00:00
from torch . utils . data . distributed import DistributedSampler
2024-06-04 02:28:49 +00:00
from torch . nn . utils . rnn import pad_sequence
2023-08-02 21:53:35 +00:00
from tqdm . auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging . getLogger ( __name__ )
2024-06-04 02:28:49 +00:00
# fold into a typical LLM sequence (one embedding rather than split embeddings)
def fold_inputs (
text_list = [ ] ,
prom_list = [ ] ,
resp_list = [ ] ,
2024-06-04 23:30:30 +00:00
targ_list = [ ] ,
2024-06-04 02:28:49 +00:00
ignore_index = None ,
sep = 3 ,
stop = 3 ,
text_tokens = 256 ,
audio_tokens = 1024 ,
2024-06-04 19:19:52 +00:00
audio_rvq_levels = cfg . model . max_levels ,
quant_levels = None ,
2024-06-04 02:28:49 +00:00
) :
def _create_mask ( l , device ) :
seq = torch . arange ( max ( l ) , device = device ) . unsqueeze ( 0 ) # (1 t)
stop = torch . tensor ( l , device = device ) . unsqueeze ( 1 ) # (b 1)
return ( seq < stop ) . float ( ) # (b t)
def list_to_tensor ( x_list : list [ Tensor ] ) :
l = list ( map ( len , x_list ) )
x = pad_sequence ( x_list ) . t ( )
m = _create_mask ( l , x_list [ 0 ] . device )
m = m . to ( x )
return x , m
device = text_list [ 0 ] . device
batch_size = len ( text_list )
input_ids = [ [ ] for _ in range ( batch_size ) ]
offset = 0
sep = torch . Tensor ( [ sep ] )
stop = torch . Tensor ( [ stop ] )
for i , text in enumerate ( text_list ) :
seq = text . to ( " cpu " , dtype = torch . int64 )
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( sep )
offset = text_tokens
2024-06-05 01:41:13 +00:00
# inject target quant_level
if quant_levels is not None :
for i , rvq in enumerate ( quant_levels ) :
seq = torch . Tensor ( [ offset + rvq ] ) . to ( " cpu " , dtype = torch . int64 )
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( sep )
offset = text_tokens + audio_rvq_levels
2024-06-04 02:28:49 +00:00
for i , prom in enumerate ( prom_list ) :
2024-06-04 23:30:30 +00:00
# deinterleaved
2024-06-04 19:19:52 +00:00
if quant_levels is not None :
quant_level = quant_levels [ i ]
if ignore_index is not None :
seq = torch . Tensor ( [ ignore_index for _ in range ( prom . shape [ 0 ] ) ] ) . to ( " cpu " , dtype = torch . int64 )
else :
seq = prom [ : , quant_level ] . to ( " cpu " , dtype = torch . int64 )
for idx , token in enumerate ( seq ) :
token + = offset + ( audio_tokens * quant_level )
2024-06-04 23:30:30 +00:00
# interleaved
2024-06-04 02:28:49 +00:00
else :
2024-06-04 19:19:52 +00:00
if ignore_index is not None :
seq = torch . Tensor ( [ ignore_index for _ in range ( prom . shape [ 0 ] * prom . shape [ 1 ] ) ] ) . to ( " cpu " , dtype = torch . int64 )
else :
seq = prom . flatten ( ) . to ( " cpu " , dtype = torch . int64 )
for idx , token in enumerate ( seq ) :
token + = offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
2024-06-04 02:28:49 +00:00
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( sep )
2024-06-05 01:41:13 +00:00
offset = text_tokens + audio_rvq_levels + ( audio_tokens * audio_rvq_levels )
2024-06-04 23:30:30 +00:00
2024-06-04 02:28:49 +00:00
for i , resp in enumerate ( resp_list ) :
2024-06-04 23:30:30 +00:00
# deinterleaved
if quant_levels is not None :
# grab the previous rvq level
quant_level = quant_levels [ i ] - 1
2024-06-04 23:40:30 +00:00
# way to signal we want to inference for rvq level 0
# without it, it's a random chance for any level to be selected again
2024-06-05 01:41:13 +00:00
2024-06-04 23:30:30 +00:00
if quant_level < 0 :
2024-06-05 01:41:13 +00:00
continue
2024-06-04 23:30:30 +00:00
seq = sep
else :
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
if isinstance ( resp , list ) :
seq = resp [ quant_level ] . to ( " cpu " , dtype = torch . int64 )
else :
seq = resp [ : , quant_level ] . to ( " cpu " , dtype = torch . int64 )
for idx , token in enumerate ( seq ) :
token + = offset + ( audio_tokens * quant_level )
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( stop )
# interleaved
else :
seq = resp . flatten ( ) . to ( " cpu " , dtype = torch . int64 )
for idx , token in enumerate ( seq ) :
token + = offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( stop )
for i , resp in enumerate ( targ_list ) :
# deinterleaved
2024-06-04 19:19:52 +00:00
if quant_levels is not None :
quant_level = quant_levels [ i ]
seq = resp [ : , quant_level ] . to ( " cpu " , dtype = torch . int64 )
for idx , token in enumerate ( seq ) :
token + = offset + ( audio_tokens * quant_level )
input_ids [ i ] . append ( seq )
2024-06-04 23:30:30 +00:00
input_ids [ i ] . append ( stop )
# interleaved
2024-06-04 19:19:52 +00:00
else :
seq = resp . flatten ( ) . to ( " cpu " , dtype = torch . int64 )
for idx , token in enumerate ( seq ) :
token + = offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( stop )
2024-06-04 02:28:49 +00:00
for i , batch in enumerate ( input_ids ) :
input_ids [ i ] = torch . concat ( input_ids [ i ] , dim = - 1 ) . to ( device = device , dtype = torch . int64 )
return list_to_tensor ( input_ids )
# unfold from one unified token ID space to separate token spaces
2024-06-04 19:19:52 +00:00
# to-do: unfold at a specific RVQ level instead if requested
2024-06-04 02:28:49 +00:00
def unfold_outputs (
output_ids ,
sep = 3 ,
stop = 3 ,
text_tokens = 256 ,
audio_tokens = 1024 ,
2024-06-04 19:19:52 +00:00
audio_rvq_levels = cfg . model . max_levels ,
quant_levels = None ,
2024-06-04 02:28:49 +00:00
) :
device = output_ids . device
batch_size = output_ids . shape [ 0 ]
text_list = [ [ ] for _ in range ( batch_size ) ]
prom_list = [ [ ] for _ in range ( batch_size ) ]
resp_list = [ [ ] for _ in range ( batch_size ) ]
for i , batch in enumerate ( output_ids ) :
2024-06-04 23:30:30 +00:00
# crigne logic to handle prefix resp for rvq levels > 0
# a better way is to observe if the rvq level increased
should_flush = False
flushed = False
2024-06-04 02:28:49 +00:00
for idx , token in enumerate ( batch ) :
id = token . item ( )
if id == sep or id == stop :
2024-06-04 23:30:30 +00:00
if should_flush and quant_levels is not None and quant_levels [ i ] > 0 :
resp_list [ i ] = [ ]
should_flush = False
flushed = True
2024-06-04 02:28:49 +00:00
continue
if 0 < = id and id < text_tokens :
text_list [ i ] . append ( id )
2024-06-05 01:41:13 +00:00
elif text_tokens + audio_rvq_levels < = id and id < text_tokens + audio_rvq_levels + ( audio_tokens * audio_rvq_levels ) :
prom_list [ i ] . append ( ( id - text_tokens - audio_rvq_levels ) % audio_tokens )
elif text_tokens + audio_rvq_levels + ( audio_tokens * audio_rvq_levels ) < = id :
resp_list [ i ] . append ( ( id - text_tokens - audio_rvq_levels ) % audio_tokens )
2024-06-04 23:30:30 +00:00
if not flushed :
should_flush = True
2024-06-04 02:28:49 +00:00
2024-06-04 19:19:52 +00:00
if quant_levels is not None :
prom_list [ i ] = torch . Tensor ( prom_list [ i ] ) . t ( ) . to ( device = device , dtype = torch . int64 )
resp_list [ i ] = torch . Tensor ( resp_list [ i ] ) . t ( ) . to ( device = device , dtype = torch . int64 )
2024-06-04 02:28:49 +00:00
else :
2024-06-04 19:19:52 +00:00
prom_len = len ( prom_list [ i ] )
if prom_len % audio_rvq_levels == 0 and False :
prom_list [ i ] = torch . Tensor ( prom_list [ i ] ) . reshape ( audio_rvq_levels , prom_len / / audio_rvq_levels ) . t ( )
else :
bins = [ [ ] for _ in range ( audio_rvq_levels ) ]
for pos in range ( prom_len ) :
rvq = pos % audio_rvq_levels
bins [ rvq ] . append ( prom_list [ i ] [ pos ] )
nearest = ( len ( bins ) / / audio_rvq_levels ) * audio_rvq_levels
bins = bins [ : nearest ]
prom_list [ i ] = torch . Tensor ( bins ) . t ( ) . to ( device = device , dtype = torch . int64 )
resp_len = len ( resp_list [ i ] )
if len ( resp_list [ i ] ) % audio_rvq_levels == 0 and False :
resp_list [ i ] = torch . Tensor ( resp_list [ i ] ) . reshape ( audio_rvq_levels , resp_len / / audio_rvq_levels ) . t ( )
else :
bins = [ [ ] for _ in range ( audio_rvq_levels ) ]
for pos in range ( resp_len ) :
rvq = pos % audio_rvq_levels
bins [ rvq ] . append ( resp_list [ i ] [ pos ] )
nearest = ( len ( bins ) / / audio_rvq_levels ) * audio_rvq_levels
bins = bins [ : nearest ]
resp_list [ i ] = torch . Tensor ( bins ) . t ( ) . to ( device = device , dtype = torch . int64 )
2024-06-04 02:28:49 +00:00
2024-06-04 05:07:00 +00:00
text_list [ i ] = torch . Tensor ( text_list [ i ] ) . to ( device = device , dtype = torch . int64 )
2024-06-04 02:28:49 +00:00
return dict (
text_list = text_list ,
prom_list = prom_list ,
resp_list = resp_list
)
2024-04-16 00:54:32 +00:00
# to-do: clean up this symmap mess
2023-08-02 21:53:35 +00:00
def get_phone_symmap ( ) :
2024-04-21 19:49:18 +00:00
return cfg . tokenizer . get_vocab ( )
2023-08-02 21:53:35 +00:00
2024-04-21 19:49:18 +00:00
def tokenize ( phones ) :
2024-04-30 03:14:01 +00:00
return cfg . tokenizer . encode ( " " . join ( phones ) )
2023-10-12 01:38:40 +00:00
def get_lang_symmap ( ) :
2024-04-16 00:54:32 +00:00
return {
2023-10-12 01:38:40 +00:00
" en " : 0 ,
" ja " : 1 ,
}
2024-04-16 00:54:32 +00:00
def get_tone_symmap ( ) :
return {
" neutral " : 0 ,
}
2023-08-02 21:53:35 +00:00
return symmap
2023-08-19 03:22:13 +00:00
def get_task_symmap ( ) :
2024-04-16 00:54:32 +00:00
return {
2023-10-12 01:38:40 +00:00
" <tts> " : 0 ,
" <tts-c> " : 1 ,
" <ns> " : 2 ,
" <sr> " : 3 ,
" <tse> " : 4 ,
" <soe> " : 5 ,
" <mask> " : 6 ,
" <eoe> " : 7 ,
2023-08-19 03:22:13 +00:00
}
2023-08-02 21:53:35 +00:00
def _replace_file_extension ( path , suffix ) :
return ( path . parent / path . name . split ( " . " ) [ 0 ] ) . with_suffix ( suffix )
2024-04-19 02:24:06 +00:00
def _get_quant_extension ( ) :
2024-05-25 16:07:52 +00:00
return " .dac " if cfg . audio_backend == " dac " else " .enc "
2024-04-19 02:24:06 +00:00
def _get_phone_extension ( ) :
2024-05-25 16:07:52 +00:00
return " .json " # if cfg.audio_backend == "dac" else ".phn.txt"
2024-04-19 02:24:06 +00:00
2023-08-27 00:53:23 +00:00
def _get_quant_path ( path ) :
2024-04-19 02:24:06 +00:00
return _replace_file_extension ( path , _get_quant_extension ( ) )
2023-08-27 00:53:23 +00:00
def _get_phone_path ( path ) :
2024-04-19 02:24:06 +00:00
return _replace_file_extension ( path , _get_phone_extension ( ) )
2023-08-27 00:53:23 +00:00
2024-06-14 03:37:34 +00:00
_durations_map = { }
# makeshift caching the above to disk
2023-09-12 20:54:41 +00:00
@cfg.diskcache ( )
2024-06-14 03:37:34 +00:00
def _get_duration_map ( type = " training " ) :
return _durations_map [ type ] if type in _durations_map else { }
2023-09-12 20:54:41 +00:00
2023-09-02 21:29:53 +00:00
@cfg.diskcache ( )
2023-08-27 00:53:23 +00:00
def _load_paths ( dataset , type = " training " ) :
2024-04-29 03:28:29 +00:00
return { cfg . get_spkr ( cfg . data_dir / data_dir / " dummy " ) : _load_paths_from_metadata ( data_dir , type = type , validate = cfg . dataset . validate and type == " training " ) for data_dir in tqdm ( dataset , desc = f " Parsing dataset: { type } " ) }
2024-05-16 04:04:19 +00:00
def _load_paths_from_metadata ( group_name , type = " training " , validate = False ) :
data_dir = group_name if cfg . dataset . use_hdf5 else cfg . data_dir / group_name
2023-08-27 00:53:23 +00:00
_fn = _get_hdf5_paths if cfg . dataset . use_hdf5 else _get_paths_of_extensions
2024-05-16 04:04:19 +00:00
def key ( id , entry = None ) :
return f " / { type } / { _get_hdf5_path ( data_dir ) } / { id } " if cfg . dataset . use_hdf5 else data_dir / id
2024-05-12 03:58:38 +00:00
2024-05-16 04:04:19 +00:00
metadata_path = cfg . metadata_dir / f ' { group_name } .json '
2023-10-17 00:30:38 +00:00
metadata = { }
2024-04-29 03:28:29 +00:00
2023-10-17 00:30:38 +00:00
if cfg . dataset . use_metadata and metadata_path . exists ( ) :
metadata = json . loads ( open ( metadata_path , " r " , encoding = " utf-8 " ) . read ( ) )
2023-08-27 00:53:23 +00:00
2023-10-17 00:30:38 +00:00
if len ( metadata ) == 0 :
2024-04-19 02:24:06 +00:00
return _fn ( data_dir , type if cfg . dataset . use_hdf5 else _get_quant_extension ( ) , validate )
2023-08-27 00:53:23 +00:00
2024-05-16 04:04:19 +00:00
def _validate ( id , entry ) :
2024-05-12 03:58:38 +00:00
phones = entry [ ' phones ' ] if " phones " in entry else 0
duration = entry [ ' duration ' ] if " duration " in entry else 0
2024-06-14 03:37:34 +00:00
# add to duration bucket
k = key ( id , entry )
if type not in _durations_map :
_durations_map [ type ] = { }
_durations_map [ type ] [ k ] = duration
if not validate :
return True
return cfg . dataset . min_duration < = duration and duration < = cfg . dataset . max_duration
2023-08-27 00:53:23 +00:00
2024-06-14 03:37:34 +00:00
return [ key ( id , entry ) for id , entry in metadata . items ( ) if _validate ( id , entry ) ]
2023-08-27 00:53:23 +00:00
2023-08-02 21:53:35 +00:00
def _get_hdf5_path ( path ) :
2024-04-29 03:28:29 +00:00
# to-do: better validation
#print(path)
return str ( path )
2023-08-02 21:53:35 +00:00
2023-08-27 00:53:23 +00:00
def _get_hdf5_paths ( data_dir , type = " training " , validate = False ) :
data_dir = str ( data_dir )
2024-06-14 03:37:34 +00:00
key = f " / { type } / { _get_hdf5_path ( data_dir ) } "
2023-08-02 21:53:35 +00:00
2024-05-16 04:04:19 +00:00
def _validate ( id , entry ) :
phones = entry . attrs [ ' phonemes ' ]
duration = entry . attrs [ ' duration ' ]
2023-08-27 00:53:23 +00:00
2024-06-14 03:37:34 +00:00
if type not in _durations_map :
_durations_map [ type ] = { }
_durations_map [ type ] [ f " { key } / { id } " ] = duration
if not validate :
return True
return cfg . dataset . min_duration < = duration and duration < = cfg . dataset . max_duration
return [ Path ( f " { key } / { id } " ) for id , entry in cfg . hdf5 [ key ] . items ( ) if _validate ( id , entry ) ] if key in cfg . hdf5 else [ ]
2023-08-27 00:53:23 +00:00
2024-04-19 02:24:06 +00:00
def _get_paths_of_extensions ( path , extensions = _get_quant_extension ( ) , validate = False ) :
2023-08-27 00:53:23 +00:00
if isinstance ( path , str ) :
path = Path ( path )
def _validate ( path ) :
if " " . join ( path . suffixes ) not in extensions :
return False
if not _get_phone_path ( path ) . exists ( ) or not _get_quant_path ( path ) . exists ( ) :
return False
if not validate :
return True
# to-do: find an easy way to determine size from pickled quants without loading
# to-do: find a consistent way to derive phoneme count from filesize (probably can't due to utf-8)
phones = len ( _get_phones ( _get_phone_path ( path ) ) ) # _get_phone_path(path).stat().st_size // 2 + 1
return cfg . dataset . min_phones < = phones and phones < = cfg . dataset . max_phones
return [ p for p in list ( path . iterdir ( ) ) if _validate ( p ) ] if path . exists ( ) and path . is_dir ( ) else [ ]
2023-08-02 21:53:35 +00:00
2024-05-18 12:14:26 +00:00
def _load_quants ( path , return_metadata = False ) - > Tensor :
qnt = np . load ( _get_quant_path ( path ) , allow_pickle = True ) [ ( ) ]
if return_metadata :
return torch . from_numpy ( qnt [ " codes " ] . astype ( int ) ) [ 0 ] [ : , : ] . t ( ) . to ( torch . int16 ) , qnt [ " metadata " ]
return torch . from_numpy ( qnt [ " codes " ] . astype ( int ) ) [ 0 ] [ : , : ] . t ( ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
2023-10-11 00:18:24 +00:00
# prune consecutive spaces
def _cleanup_phones ( phones , targets = [ " " ] ) :
return [ p for i , p in enumerate ( phones ) if p not in targets or ( p in targets and p != phones [ i - 1 ] ) ]
2023-08-02 21:53:35 +00:00
@cache
2024-05-18 12:14:26 +00:00
def _get_phones ( path ) :
phone_path = _get_phone_path ( path )
quant_path = _get_quant_path ( path )
if phone_path . exists ( ) :
metadata = json . loads ( open ( phone_path , " r " , encoding = " utf-8 " ) . read ( ) )
elif quant_path . exists ( ) :
_ , metadata = _load_quants ( path , return_metadata = True )
2024-04-19 02:24:06 +00:00
else :
2024-05-18 12:14:26 +00:00
raise Exception ( f " Could not load phonemes: { path } " )
2024-04-21 19:49:18 +00:00
2024-05-18 12:14:26 +00:00
content = metadata [ " phonemes " ]
2024-04-21 19:49:18 +00:00
return " " . join ( content )
2023-08-02 21:53:35 +00:00
def _interleaved_reorder ( l , fn ) :
groups = defaultdict ( list )
for e in l :
groups [ fn ( e ) ] . append ( e )
groups = { k : groups [ k ] for k in sorted ( groups ) }
for interleaved in zip_longest ( * groups . values ( ) ) :
for value in interleaved :
if value is not None :
yield value
class Dataset ( _Dataset ) :
def __init__ (
self ,
phone_symmap = None ,
training = False ,
extra_paths_by_spkr_name : dict [ str , list ] = { } ,
) :
super ( ) . __init__ ( )
self . _head = None
2023-08-19 01:58:07 +00:00
self . sampler = None
2023-08-02 21:53:35 +00:00
2023-08-27 00:53:23 +00:00
self . paths = [ ]
self . training = training
self . dataset_type = " training " if self . training else " validation "
self . dataset = cfg . dataset . training if self . training else cfg . dataset . validation
2024-06-07 02:57:11 +00:00
self . sampler_type = cfg . dataset . sample_type # if self.dataset_type == "training" else "group"
2024-06-14 03:37:34 +00:00
self . sampler_order = cfg . dataset . sample_order
2024-06-30 16:36:46 +00:00
self . sampler_shuffle = cfg . dataset . sample_shuffle
2023-08-30 23:23:05 +00:00
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
if len ( self . dataset ) == 0 :
self . dataset = cfg . dataset . training
2023-08-27 00:53:23 +00:00
2023-10-17 00:30:38 +00:00
# dict of paths keyed by speaker names
2023-08-27 00:53:23 +00:00
self . paths_by_spkr_name = _load_paths ( self . dataset , self . dataset_type )
2023-09-12 20:54:41 +00:00
# cull speakers if they do not have enough utterances
if cfg . dataset . min_utterances > 0 :
keys = list ( self . paths_by_spkr_name . keys ( ) )
for key in keys :
if len ( self . paths_by_spkr_name [ key ] ) < cfg . dataset . min_utterances :
del self . paths_by_spkr_name [ key ]
2024-06-14 03:37:34 +00:00
# flatten paths
2023-08-27 00:53:23 +00:00
self . paths = list ( itertools . chain . from_iterable ( self . paths_by_spkr_name . values ( ) ) )
2024-06-01 14:29:49 +00:00
# split dataset accordingly per GPU
2024-06-14 03:37:34 +00:00
if cfg . distributed and self . training :
2024-06-29 15:10:35 +00:00
"""
2024-06-01 14:29:49 +00:00
batches = len ( self . paths ) / / world_size ( )
start = batches * global_rank ( )
end = batches * ( global_rank ( ) + 1 )
self . paths = self . paths [ start : end ]
2024-06-29 15:10:35 +00:00
"""
self . paths = [ path for i , path in enumerate ( self . paths ) if i % world_size ( ) == 0 ]
2024-06-01 14:29:49 +00:00
# recreate paths_by_spkr_name
self . paths_by_spkr_name = { }
for path in self . paths :
2024-06-01 15:30:13 +00:00
name = cfg . get_spkr ( Path ( path ) )
if name not in self . paths_by_spkr_name :
2024-06-01 14:29:49 +00:00
self . paths_by_spkr_name [ name ] = [ ]
self . paths_by_spkr_name [ name ] . append ( path )
2024-06-14 03:37:34 +00:00
# do it here due to the above
self . duration = 0
self . duration_map = _get_duration_map ( self . dataset_type )
self . duration_buckets = { }
# store in corresponding bucket
for path in self . paths :
duration = self . duration_map [ path ]
self . duration + = duration
# only calc duration if we're tot going to order by duration
if self . sampler_order != " duration " :
continue
2024-06-29 03:28:54 +00:00
bucket = int ( round ( duration ) )
2024-06-14 03:37:34 +00:00
if bucket not in self . duration_buckets :
self . duration_buckets [ bucket ] = [ ]
self . duration_buckets [ bucket ] . append ( ( Path ( path ) , duration ) )
2024-06-29 03:28:54 +00:00
# ensure they're ordered
self . duration_buckets = dict ( sorted ( self . duration_buckets . items ( ) ) )
2024-06-14 03:37:34 +00:00
# sort by duration
if self . sampler_order == " duration " :
2024-06-29 03:28:54 +00:00
flattened = { }
2024-06-14 03:37:34 +00:00
# sort and interleave
for bucket in self . duration_buckets :
# sort by duration
self . duration_buckets [ bucket ] . sort ( key = lambda x : x [ 1 ] )
2024-06-29 03:28:54 +00:00
# split to retain tuples
flattened [ bucket ] = self . duration_buckets [ bucket ]
2024-06-14 03:37:34 +00:00
# replace with path
2024-06-29 03:28:54 +00:00
flattened [ bucket ] = [ x [ 0 ] for x in flattened [ bucket ] ]
2024-06-14 03:37:34 +00:00
# flatten by paths
2024-06-29 03:28:54 +00:00
flattened [ bucket ] = [ * _interleaved_reorder ( flattened [ bucket ] , self . get_speaker ) ]
2024-06-14 03:37:34 +00:00
# flatten paths
2024-06-29 03:28:54 +00:00
self . paths = list ( itertools . chain . from_iterable ( flattened . values ( ) ) )
2024-06-30 16:36:46 +00:00
else :
2024-06-14 03:37:34 +00:00
# just interleave
self . paths = [ * _interleaved_reorder ( self . paths , self . get_speaker ) ]
2023-08-27 00:53:23 +00:00
2023-10-17 00:30:38 +00:00
# dict of speakers keyed by speaker group
self . spkrs_by_spkr_group = { }
for data_dir in self . dataset :
spkr = cfg . get_spkr ( data_dir / " dummy " )
spkr_group = cfg . get_spkr_group ( data_dir / " dummy " )
2023-10-22 14:01:47 +00:00
if spkr not in self . paths_by_spkr_name or len ( self . paths_by_spkr_name [ spkr ] ) < cfg . dataset . min_utterances :
continue
2023-10-17 00:30:38 +00:00
if spkr_group not in self . spkrs_by_spkr_group :
self . spkrs_by_spkr_group [ spkr_group ] = [ ]
self . spkrs_by_spkr_group [ spkr_group ] . append ( spkr )
self . spkr_groups = list ( self . spkrs_by_spkr_group . keys ( ) )
2023-12-21 00:45:58 +00:00
2023-08-27 00:53:23 +00:00
self . noise_paths = _load_paths ( cfg . dataset . noise , " noise " )
self . noise_paths = list ( itertools . chain . from_iterable ( self . noise_paths . values ( ) ) )
2023-08-02 21:53:35 +00:00
self . phone_symmap = phone_symmap or self . _get_phone_symmap ( )
2023-08-24 15:25:33 +00:00
self . spkr_symmap = self . _get_spkr_symmap ( )
2023-10-17 00:30:38 +00:00
self . spkr_group_symmap = self . _get_spkr_group_symmap ( )
2023-10-12 01:38:40 +00:00
self . lang_symmap = self . _get_lang_symmap ( )
2024-04-16 00:54:32 +00:00
self . tone_symmap = self . _get_tone_symmap ( )
2023-08-24 15:25:33 +00:00
self . task_symmap = self . _get_task_symmap ( )
2023-08-02 21:53:35 +00:00
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self . text_dtype = torch . uint8 if len ( self . phone_symmap ) < 256 else torch . int16
2023-10-17 00:30:38 +00:00
if len ( self . paths ) == 0 :
raise ValueError ( f " No valid path is found for { self . dataset_type } " )
2024-06-14 21:55:40 +00:00
if self . sampler_type == " path " :
2024-06-29 03:28:54 +00:00
if self . sampler_order == " duration " and cfg . dataset . sample_max_duration_batch > 0 :
2024-06-30 03:14:35 +00:00
self . sampler = BatchedOrderedSampler (
2024-06-30 16:36:46 +00:00
self . duration_buckets if not self . sampler_state_dict_path . exists ( ) else { } , # pass nothing if we're just going to load from a state anyways
max_duration = cfg . dataset . sample_max_duration_batch ,
max_batch_size = cfg . hyperparameters . batch_size if self . training else cfg . evaluation . batch_size ,
shuffle = self . sampler_shuffle
2024-06-30 03:14:35 +00:00
)
2024-06-29 03:28:54 +00:00
else :
2024-06-30 16:36:46 +00:00
self . sampler = OrderedSampler ( len ( self ) ) if not self . sampler_shuffle else RandomSampler ( len ( self ) )
2024-06-14 21:55:40 +00:00
self . samplers = { }
self . spkr_samplers = { }
else :
self . sampler = RandomSampler ( len ( self ) )
2024-06-30 16:36:46 +00:00
self . samplers = { name : PoolSampler ( paths , keep_all = True , shuffle = self . sampler_shuffle ) for name , paths in self . paths_by_spkr_name . items ( ) }
self . spkr_samplers = { name : PoolSampler ( [ * set ( speakers ) ] , keep_all = True , shuffle = self . sampler_shuffle ) for name , speakers in self . spkrs_by_spkr_group . items ( ) }
2024-06-14 21:55:40 +00:00
self . load_state_dict ( )
2024-06-29 15:10:35 +00:00
@cached_property
def sampler_state_dict_path ( self ) :
2024-06-30 16:36:46 +00:00
return cfg . rel_path / f " sampler. { self . sampler_type } .rank { global_rank ( ) } .pt "
2024-06-14 21:55:40 +00:00
2023-08-27 00:53:23 +00:00
def get_speaker ( self , path ) :
if isinstance ( path , str ) :
path = Path ( path )
res = cfg . get_spkr ( path )
return res
2023-10-12 01:38:40 +00:00
def get_speaker_group ( self , path ) :
if isinstance ( path , str ) :
path = Path ( path )
res = cfg . get_spkr_group ( path )
return res
def get_language ( self , speaker_group ) :
lang = " en "
for k , v in cfg . dataset . speaker_languages . items ( ) :
if speaker_group in v :
lang = k
break
return lang
2023-08-02 21:53:35 +00:00
@cached_property
def spkrs ( self ) :
2023-08-27 00:53:23 +00:00
return sorted ( { self . get_speaker ( path ) for path in self . paths } )
2023-08-02 21:53:35 +00:00
2023-08-19 03:22:13 +00:00
@cached_property
def tasks ( self ) :
2023-08-19 05:16:08 +00:00
return cfg . dataset . tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
2023-08-19 03:22:13 +00:00
2024-06-14 21:55:40 +00:00
def save_state_dict ( self , path = None ) :
if path is None :
2024-06-30 16:36:46 +00:00
path = self . sampler_state_dict_path
2024-06-14 21:55:40 +00:00
if self . sampler_type == " path " :
state_dict = self . sampler . get_state ( )
else :
state_dict = {
" samplers " : { name : sampler . get_state ( ) for name , sampler in self . samplers . items ( ) } ,
" spkr_samplers " : { name : sampler . get_state ( ) for name , sampler in self . spkr_samplers . items ( ) } ,
}
2023-09-04 02:27:13 +00:00
torch . save ( state_dict , path )
2024-06-14 21:55:40 +00:00
def load_state_dict ( self , path = None ) :
if path is None :
2024-06-30 16:36:46 +00:00
path = self . sampler_state_dict_path
2023-09-04 02:27:13 +00:00
2024-06-14 21:55:40 +00:00
if not path . exists ( ) :
return
state_dict = torch . load ( path )
if self . sampler_type == " path " :
2024-06-15 17:29:03 +00:00
state_dict = self . sampler . set_state ( state_dict )
2024-06-14 21:55:40 +00:00
else :
2023-09-04 02:27:13 +00:00
for name , sampler in state_dict [ " samplers " ] . items ( ) :
if name not in self . samplers :
continue
2024-06-15 17:29:03 +00:00
self . samplers [ name ] . set_state ( sampler )
2024-06-14 21:55:40 +00:00
for name , sampler in state_dict [ " spkr_samplers " ] . items ( ) :
if name not in self . spkr_samplers :
continue
2024-06-15 17:29:03 +00:00
self . spkr_samplers [ name ] . set_state ( sampler )
2023-09-04 02:27:13 +00:00
2023-08-19 03:22:13 +00:00
def _get_phone_symmap ( self ) :
return get_phone_symmap ( )
2023-08-02 21:53:35 +00:00
def _get_spkr_symmap ( self ) :
return { s : i for i , s in enumerate ( self . spkrs ) }
2023-10-17 00:30:38 +00:00
def _get_spkr_group_symmap ( self ) :
return { s : i for i , s in enumerate ( self . spkr_groups ) }
2023-10-12 01:38:40 +00:00
def _get_lang_symmap ( self ) :
return get_lang_symmap ( )
2024-04-16 00:54:32 +00:00
def _get_tone_symmap ( self ) :
return get_tone_symmap ( )
2023-08-19 03:22:13 +00:00
def _get_task_symmap ( self ) :
return get_task_symmap ( )
2023-08-19 20:06:33 +00:00
def sample_noise ( self ) :
2023-08-27 00:53:23 +00:00
path = random . choice ( self . noise_paths )
2023-08-19 05:16:08 +00:00
2023-08-27 00:53:23 +00:00
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2023-08-19 20:06:33 +00:00
qnt = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : ] ) . to ( torch . int16 )
2023-08-19 05:16:08 +00:00
else :
2024-05-18 12:14:26 +00:00
qnt = _load_quants ( path , return_metadata = False )
2023-08-19 05:16:08 +00:00
return qnt
2023-08-19 03:22:13 +00:00
2023-08-18 00:07:59 +00:00
def sample_speakers ( self , ignore = [ ] ) :
choices = set ( self . spkrs ) - set ( ignore )
return random . choice ( [ * choices ] )
2023-08-02 21:53:35 +00:00
def sample_prompts ( self , spkr_name , ignore ) :
prom_list = [ ]
choices = set ( self . paths_by_spkr_name [ spkr_name ] ) - { ignore }
choices = [ * choices ]
2024-04-16 00:54:32 +00:00
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
2023-08-02 21:53:35 +00:00
if len ( choices ) == 0 :
2023-08-17 00:39:21 +00:00
choices = [ * set ( self . paths_by_spkr_name [ spkr_name ] ) ]
"""
2023-08-02 21:53:35 +00:00
raise ValueError (
f " Failed to find another different utterance for { spkr_name } . "
)
2023-08-17 00:39:21 +00:00
"""
2023-08-02 21:53:35 +00:00
2023-08-19 04:55:40 +00:00
prom_length = 0
2024-05-16 12:25:33 +00:00
trim_length = int ( random . uniform ( cfg . dataset . prompt_duration_range [ 0 ] , cfg . dataset . prompt_duration_range [ 1 ] ) * cfg . dataset . frames_per_second )
2023-08-19 04:55:40 +00:00
2023-08-02 21:53:35 +00:00
for _ in range ( cfg . dataset . max_prompts ) :
path = random . choice ( choices )
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2024-05-10 01:28:20 +00:00
if " audio " not in cfg . hdf5 [ key ] :
2024-05-12 03:58:38 +00:00
_logger . warning ( f ' MISSING AUDIO: { key } ' )
2024-05-10 01:28:20 +00:00
continue
2023-08-19 20:06:33 +00:00
qnt = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : ] ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
else :
2024-05-18 12:14:26 +00:00
qnt = _load_quants ( path , return_metadata = False )
2023-08-02 21:53:35 +00:00
2024-05-11 14:50:54 +00:00
if 0 < trim_length and trim_length < qnt . shape [ 0 ] :
2023-08-23 21:43:03 +00:00
qnt = trim ( qnt , trim_length )
2023-08-02 21:53:35 +00:00
prom_list . append ( qnt )
2023-08-19 04:55:40 +00:00
prom_length + = qnt . shape [ 0 ]
2023-08-02 21:53:35 +00:00
2023-08-19 04:55:40 +00:00
if prom_length > = trim_length or random . random ( ) > cfg . dataset . random_utterance :
2023-08-02 21:53:35 +00:00
break
2023-10-11 22:32:45 +00:00
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
2023-08-02 21:53:35 +00:00
prom = torch . cat ( prom_list )
2024-05-11 14:50:54 +00:00
if 0 < trim_length and trim_length < prom . shape [ 0 ] :
2023-08-23 21:43:03 +00:00
prom = trim ( prom , trim_length )
2023-08-02 21:53:35 +00:00
return prom
def __getitem__ ( self , index ) :
2023-10-22 14:01:47 +00:00
if self . sampler_type == " group " :
2023-10-17 00:30:38 +00:00
spkr_group = self . spkr_groups [ index ]
2023-10-22 14:01:47 +00:00
#spkr_group_id = self.spkr_group_symmap[spkr_group]
2023-10-17 00:30:38 +00:00
spkr_name = self . spkr_samplers [ spkr_group ] . sample ( )
2023-10-19 01:38:33 +00:00
spkr_id = self . spkr_symmap [ spkr_name ]
path = self . samplers [ spkr_name ] . sample ( )
2023-10-22 14:01:47 +00:00
elif self . sampler_type == " speaker " :
2023-08-17 00:39:21 +00:00
spkr_name = self . spkrs [ index ]
spkr_id = self . spkr_symmap [ spkr_name ]
2023-09-04 02:27:13 +00:00
path = self . samplers [ spkr_name ] . sample ( )
2023-10-17 00:30:38 +00:00
spkr_group = self . get_speaker_group ( path )
2023-10-22 14:01:47 +00:00
#spkr_group_id = self.spkr_group_symmap[spkr_group]
2023-08-02 21:53:35 +00:00
else :
2023-08-18 19:47:48 +00:00
path = self . paths [ index ]
2023-08-27 00:53:23 +00:00
spkr_name = self . get_speaker ( path )
2023-08-17 00:39:21 +00:00
spkr_id = self . spkr_symmap [ spkr_name ]
2023-10-17 00:30:38 +00:00
spkr_group = self . get_speaker_group ( path )
2023-10-22 14:01:47 +00:00
#spkr_group_id = self.spkr_group_symmap[spkr_group]
2023-08-02 21:53:35 +00:00
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2023-10-12 01:38:40 +00:00
2024-04-29 03:28:29 +00:00
if key not in cfg . hdf5 :
raise RuntimeError ( f ' Key of Path ( { path } ) not in HDF5: { key } ' )
2023-10-11 00:18:24 +00:00
text = cfg . hdf5 [ key ] [ " text " ] [ : ]
resps = cfg . hdf5 [ key ] [ " audio " ] [ : , : ]
text = torch . from_numpy ( text ) . to ( self . text_dtype )
resps = torch . from_numpy ( resps ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
else :
2024-05-18 12:14:26 +00:00
resps , metadata = _load_quants ( path , return_metadata = True )
text = torch . tensor ( tokenize ( metadata [ " phonemes " ] ) ) . to ( self . text_dtype )
2023-10-11 22:32:45 +00:00
2023-10-13 03:21:43 +00:00
lang = torch . tensor ( [ self . lang_symmap [ self . get_language ( spkr_group ) ] ] ) . to ( torch . uint8 )
2023-10-12 01:38:40 +00:00
2023-10-11 22:32:45 +00:00
# append additional prompts in an attempt to artifically increase lengths / offer new data
2024-06-30 15:37:33 +00:00
"""
2024-06-30 16:36:46 +00:00
# disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works
# it probably is better to pad with silence instead of just stitching utterances and ruining things
2024-06-30 15:37:33 +00:00
if cfg . dataset . max_resps > 1 and random . random ( ) < cfg . dataset . p_resp_append :
2023-10-11 22:32:45 +00:00
choices = [ * ( set ( self . paths_by_spkr_name [ spkr_name ] ) - { path } ) ]
if len ( choices ) > 0 :
for _ in range ( cfg . dataset . max_resps - 1 ) :
sampled_path = random . choice ( choices )
choices = [ * ( set ( choices ) - { sampled_path } ) ]
if cfg . dataset . use_hdf5 :
2023-10-19 01:38:33 +00:00
key = _get_hdf5_path ( sampled_path )
2023-10-11 22:32:45 +00:00
txt = cfg . hdf5 [ key ] [ " text " ] [ : ]
qnt = cfg . hdf5 [ key ] [ " audio " ] [ : , : ]
2024-04-21 22:43:20 +00:00
txt = np . array ( txt )
2023-10-11 22:32:45 +00:00
txt = torch . from_numpy ( txt ) . to ( self . text_dtype )
qnt = torch . from_numpy ( qnt ) . to ( torch . int16 )
else :
2024-04-21 22:43:20 +00:00
#txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
2024-05-18 12:14:26 +00:00
#txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
qnt , metadata = _load_quants ( sampled_path , return_metadata = True )
txt = torch . tensor ( tokenize ( metadata [ " phonemes " ] ) ) . to ( self . text_dtype )
2023-10-11 22:32:45 +00:00
# <s>[original text] [new text]</s>
# removes the original text's </s>, includes a space, and remove the new text's <s>
text = torch . concat ( [ text [ : - 1 ] , torch . tensor ( [ self . phone_symmap [ " " ] ] ) . to ( torch . int16 ) , txt [ 1 : ] ] )
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
resps = torch . concat ( [ resps , qnt ] )
2024-06-30 15:37:33 +00:00
"""
2023-08-02 21:53:35 +00:00
2023-10-09 18:01:40 +00:00
task = " tts "
2024-05-16 12:25:33 +00:00
trim_length = int ( random . uniform ( cfg . dataset . prompt_duration_range [ 0 ] , cfg . dataset . prompt_duration_range [ 1 ] ) * cfg . dataset . frames_per_second )
2023-10-09 18:01:40 +00:00
proms = self . sample_prompts ( spkr_name , ignore = path ) if random . random ( ) < cfg . dataset . random_utterance else resps
2023-10-11 22:32:45 +00:00
2023-10-09 18:01:40 +00:00
# Disabled until I swap over to a better method
"""
2023-08-17 23:56:37 +00:00
task = random . choice ( self . tasks )
2023-08-19 03:22:13 +00:00
2023-08-19 14:50:07 +00:00
# ensure a speaker has at least four utterances
# default to tts if not
if len ( set ( self . paths_by_spkr_name [ spkr_name ] ) - { path } ) < 4 :
task = " tts "
2023-08-20 11:29:17 +00:00
noise_scale = 0.25
2023-09-02 17:23:40 +00:00
if task == " tts " or task == " tts-c " :
2024-05-04 17:05:41 +00:00
trim_length = int ( cfg . dataset . prompt_duration * cfg . dataset . frames_per_second )
2023-09-02 18:39:17 +00:00
# demote if the target is too short
if task == " tts-c " and trim_length * 2 > = resps . shape [ 0 ] :
task = " tts "
2023-09-02 21:29:53 +00:00
2023-09-01 22:19:34 +00:00
# VALL-E continuous
# ignore if target utterance is shorter than prompt duration
# to-do: actually do this for the AR only as I don't think the paper trained the NAR for this
2023-09-02 18:39:17 +00:00
if task == " tts-c " :
2023-09-01 22:19:34 +00:00
proms = resps [ : trim_length , : ]
resps = resps [ trim_length : , : ]
2023-09-02 21:29:53 +00:00
proms = torch . cat ( [ self . get_task_token ( task ) , proms ] )
2023-09-01 22:19:34 +00:00
else :
proms = self . sample_prompts ( spkr_name , ignore = path ) if random . random ( ) < cfg . dataset . random_utterance else resps
2023-08-19 03:22:13 +00:00
# noise suppression || speech removal
elif task == " ns " or task == " sr " :
# sample random noise
2023-08-18 00:07:59 +00:00
noise = self . sample_noise ( )
2023-08-19 03:22:13 +00:00
# extend the noise to fill the target audio
noise = repeat_extend_audio ( noise , resps . shape [ 0 ] )
# create the input prompt by merging the target audio with the noise
2023-08-19 20:06:33 +00:00
proms = merge_audio ( resps , noise , scale = [ 1 , noise_scale ] , device = " cpu " )
2023-08-19 03:22:13 +00:00
# set the target to just be the noise if <sr>
if task == " sr " :
resps = noise
# prepend the task token
proms = torch . cat ( [ self . get_task_token ( task ) , proms ] )
2023-08-19 04:55:40 +00:00
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype )
2023-08-19 03:22:13 +00:00
# target speech extraction
2023-08-18 19:47:48 +00:00
elif task == " tse " :
2023-08-19 03:22:13 +00:00
# sample a random, clean, utterance for the target speaker
2023-08-19 06:16:46 +00:00
clean_proms = self . sample_prompts ( spkr_name , ignore = path )
2023-08-19 03:22:13 +00:00
# sample a random, clean utterance from a different speaker
2023-08-19 04:55:40 +00:00
other_proms = self . sample_prompts ( self . sample_speakers ( ignore = [ spkr_name ] ) , ignore = " " )
2023-08-19 03:22:13 +00:00
# overlay the random speaker over the target audio
2023-08-19 04:55:40 +00:00
smallest_size = min ( resps . shape [ 0 ] , other_proms . shape [ 0 ] )
if other_proms . shape [ 0 ] == smallest_size :
2023-08-19 05:16:08 +00:00
noisy_proms = merge_audio ( resps [ : smallest_size , : ] , other_proms , scale = [ 1 , random . uniform ( 0.5 , 0.75 ) ] , device = " cpu " )
2023-08-19 04:55:40 +00:00
noisy_proms = torch . cat ( [ noisy_proms , resps [ smallest_size : , : ] ] )
else :
2023-08-19 05:16:08 +00:00
noisy_proms = merge_audio ( resps , other_proms [ : smallest_size , : ] , scale = [ 1 , random . uniform ( 0.5 , 0.75 ) ] , device = " cpu " )
2023-08-19 04:55:40 +00:00
noisy_proms = torch . cat ( [ noisy_proms , other_proms [ smallest_size : , : ] ] )
# stitch together the promps
2023-08-19 03:22:13 +00:00
proms = torch . cat ( [ clean_proms , self . get_task_token ( task ) , noisy_proms ] )
2023-08-19 04:55:40 +00:00
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
2023-08-19 06:16:46 +00:00
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype ) # <s></s>
2023-08-19 04:55:40 +00:00
2023-08-18 19:47:48 +00:00
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
2023-08-19 03:22:13 +00:00
# clean speech editing
2023-08-19 06:16:46 +00:00
elif task == " cse " or task == " nse " :
choices = set ( self . paths_by_spkr_name [ spkr_name ] ) - { path }
2023-08-19 14:50:07 +00:00
sampled = random . sample ( [ * choices ] , 4 )
2023-08-19 06:16:46 +00:00
if cfg . dataset . use_hdf5 :
texts = [ torch . from_numpy ( cfg . hdf5 [ _get_hdf5_path ( path ) ] [ " text " ] [ : ] ) . to ( self . text_dtype ) for path in sampled ]
2023-08-19 20:06:33 +00:00
qnts = [ torch . from_numpy ( cfg . hdf5 [ _get_hdf5_path ( path ) ] [ " audio " ] [ : , : ] ) . to ( torch . int16 ) for path in sampled ]
2023-08-19 06:16:46 +00:00
else :
texts = [ torch . tensor ( [ * map ( self . phone_symmap . get , _get_phones ( path ) ) ] ) . to ( self . text_dtype ) for path in sampled ]
qnts = [ _load_quants ( path ) for path in sampled ]
# remove <s></s>
2023-08-19 14:50:07 +00:00
for i in range ( len ( texts ) ) :
texts [ i ] = texts [ i ] [ 1 : - 1 ]
2023-08-19 06:16:46 +00:00
pre_text , mid_text , post_text , edit_text = texts
pre_prom , mid_prom , post_prom , edit_prom = qnts
# randomly drop out pre
if random . random ( ) < 0.125 :
pre_text = None
pre_prom = None
# randomly drop out post
if random . random ( ) < 0.125 :
post_text = None
post_prom = None
# create new text
text = torch . cat (
2023-08-19 14:50:07 +00:00
[ torch . Tensor ( [ 1 ] ) . to ( dtype = self . text_dtype ) ] + # <s>
( [ pre_text , torch . Tensor ( [ 3 ] ) . to ( dtype = self . text_dtype ) ] if pre_text is not None else [ ] ) + # pre_text + space'
[ edit_text ] + # 'edit text'
( [ torch . Tensor ( [ 3 ] ) . to ( dtype = self . text_dtype ) , post_text ] if post_text is not None else [ ] ) + # 'space' + edit_text
[ torch . Tensor ( [ 2 ] ) . to ( dtype = self . text_dtype ) ] # </s>
2023-08-19 06:16:46 +00:00
)
if task == " nse " :
# sample random noise
noise = self . sample_noise ( )
# it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent
# but it's noise, it's supposed to be random
2023-08-19 20:06:33 +00:00
def noise_proms ( p ) :
2023-08-19 06:16:46 +00:00
# ignore if we turned it off
2023-08-19 20:06:33 +00:00
if p is None :
2023-08-19 06:16:46 +00:00
return None
# extend the noise to fill the target audio
2023-08-19 20:06:33 +00:00
n = repeat_extend_audio ( noise , p . shape [ 0 ] )
2023-08-19 06:16:46 +00:00
# merge the noise over the utterance
2023-08-19 20:06:33 +00:00
return merge_audio ( p , n , scale = [ 1 , noise_scale ] , device = " cpu " )
2023-08-19 06:16:46 +00:00
# apply noise to all pieces
pre_prom = noise_proms ( pre_prom )
mid_prom = noise_proms ( mid_prom )
post_prom = noise_proms ( post_prom )
edit_prom = noise_proms ( edit_prom )
else :
mid_prom = self . get_task_token ( " mask " )
# create new proms
proms = torch . cat (
( [ pre_prom ] if pre_prom is not None else [ ] ) +
[ self . get_task_token ( " soe " ) ] +
[ mid_prom ] + # is <mask> if task is CSE
[ self . get_task_token ( " eoe " ) ] +
( [ post_prom ] if post_prom is not None else [ ] )
)
# create new resp
resps = torch . cat (
( [ pre_prom ] if pre_prom is not None else [ ] ) +
[ edit_prom ] +
( [ post_prom ] if post_prom is not None else [ ] )
)
2023-08-19 20:06:33 +00:00
else :
2023-09-02 17:23:40 +00:00
raise Exception ( f ' Undefined task: { task } ' )
2023-10-09 18:01:40 +00:00
"""
2023-08-19 06:16:46 +00:00
2023-08-19 04:55:40 +00:00
"""
# emulate SVC
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
# targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip
# NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical
# ahead-of-time dataset preparation of a shit ton of RVC clips might be the key.
# aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task.
elif task == " svc " :
# sample a random, clean utterance for the target speaker
proms = self . sample_prompts ( spkr_name , ignore = path ) if random . random ( ) < cfg . dataset . random_utterance else resps
# sample a reference clip from a different speaker
ref_proms = self . sample_rvc ( self . sample_speakers ( ignore = [ spkr_name ] ) )
#
resps =
# stitch together the promps
proms = torch . cat ( [ proms , self . get_task_token ( task ) , ref_proms ] )
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype )
"""
2023-08-18 00:07:59 +00:00
2023-08-19 20:06:33 +00:00
# trim to fit to requested prom/resps levels
2024-07-16 00:59:48 +00:00
proms = proms [ : , : cfg . model . resp_levels ]
resps = resps [ : , : cfg . model . resp_levels ]
2023-08-19 20:06:33 +00:00
2023-08-02 21:53:35 +00:00
return dict (
index = index ,
2023-08-27 17:26:12 +00:00
path = Path ( path ) ,
2023-08-02 21:53:35 +00:00
spkr_name = spkr_name ,
spkr_id = spkr_id ,
2023-08-17 23:56:37 +00:00
task = task ,
2023-10-12 01:38:40 +00:00
lang = lang ,
2023-08-02 21:53:35 +00:00
text = text ,
proms = proms ,
resps = resps ,
)
def head_ ( self , n ) :
self . _head = n
def training_ ( self , value ) :
self . training = value
def __len__ ( self ) :
2023-10-22 14:01:47 +00:00
if self . sampler_type == " group " :
2023-10-17 00:30:38 +00:00
return min ( len ( self . spkr_groups ) , self . _head or len ( self . spkr_groups ) )
2023-10-22 14:01:47 +00:00
if self . sampler_type == " speaker " :
2023-08-17 00:39:21 +00:00
return min ( len ( self . spkrs ) , self . _head or len ( self . spkrs ) )
2023-08-02 21:53:35 +00:00
return min ( len ( self . paths ) , self . _head or len ( self . paths ) )
def collate_fn ( samples : list [ dict ] ) :
batch : dict [ str , Any ] = { k : [ s [ k ] for s in samples ] for k in samples [ 0 ] }
return batch
def _seed_worker ( worker_id ) :
worker_seed = torch . initial_seed ( ) % 2 * * 32
np . random . seed ( worker_seed )
random . seed ( worker_seed )
def _create_dataloader ( dataset , training ) :
2024-06-29 03:28:54 +00:00
kwargs = dict (
2024-06-30 16:36:46 +00:00
shuffle = False ,
2024-06-29 03:28:54 +00:00
batch_size = cfg . hyperparameters . batch_size if training else cfg . evaluation . batch_size ,
2023-08-02 21:53:35 +00:00
drop_last = training ,
2024-06-29 03:28:54 +00:00
sampler = dataset . sampler ,
) if not isinstance ( dataset . sampler , BatchedOrderedSampler ) else dict (
batch_sampler = dataset . sampler ,
)
return DataLoader (
dataset = dataset ,
2023-08-02 21:53:35 +00:00
num_workers = cfg . dataset . workers ,
collate_fn = collate_fn ,
2023-08-24 22:05:56 +00:00
persistent_workers = cfg . dataset . workers > 1 ,
2024-06-30 16:36:46 +00:00
pin_memory = False ,
2023-08-02 21:53:35 +00:00
worker_init_fn = _seed_worker ,
2024-06-29 03:28:54 +00:00
* * kwargs ,
2023-08-02 21:53:35 +00:00
)
def create_datasets ( ) :
2023-08-27 00:53:23 +00:00
train_dataset = Dataset ( training = True )
val_dataset = Dataset ( phone_symmap = train_dataset . phone_symmap , training = False )
2023-08-02 21:53:35 +00:00
return train_dataset , val_dataset
def create_train_val_dataloader ( ) :
train_dataset , val_dataset = create_datasets ( )
2024-06-15 17:08:03 +00:00
# it'll cry about trying to pickle a torch._C_generator or something
try :
subtrain_dataset = copy . deepcopy ( train_dataset )
except Exception as e :
subtrain_dataset = Dataset ( training = True )
2023-10-22 14:01:47 +00:00
if subtrain_dataset . sampler_type == " path " :
2023-08-17 20:04:45 +00:00
subtrain_dataset . head_ ( cfg . evaluation . size )
2023-08-02 21:53:35 +00:00
train_dl = _create_dataloader ( train_dataset , training = True )
val_dl = _create_dataloader ( val_dataset , training = False )
subtrain_dl = _create_dataloader ( subtrain_dataset , training = False )
_logger . info ( str ( train_dataset . phone_symmap ) )
_logger . info ( str ( train_dataset . spkr_symmap ) )
2023-10-17 00:30:38 +00:00
_logger . info ( str ( train_dataset . spkr_group_symmap ) )
2023-08-02 21:53:35 +00:00
_logger . info ( f " #samples (train): { len ( train_dataset ) } . " )
_logger . info ( f " #samples (val): { len ( val_dataset ) } . " )
_logger . info ( f " #samples (subtrain): { len ( subtrain_dataset ) } . " )
_logger . info ( f " #duration (train): { str ( train_dataset . duration ) } . " )
_logger . info ( f " #duration (val): { str ( val_dataset . duration ) } . " )
_logger . info ( f " #duration (subtrain): { str ( subtrain_dataset . duration ) } . " )
assert isinstance ( subtrain_dl . dataset , Dataset )
return train_dl , subtrain_dl , val_dl
2023-08-27 00:53:23 +00:00
# parse dataset into better to sample metadata
2024-04-29 14:09:26 +00:00
def create_dataset_metadata ( skip_existing = True ) :
2024-04-29 04:03:09 +00:00
symmap = get_phone_symmap ( )
root = str ( cfg . data_dir )
metadata_root = str ( cfg . metadata_dir )
2024-04-29 03:28:29 +00:00
2024-04-29 04:03:09 +00:00
cfg . metadata_dir . mkdir ( parents = True , exist_ok = True )
2023-08-27 00:53:23 +00:00
2024-04-29 04:03:09 +00:00
def add ( dir , type = " training " , audios = True , texts = True ) :
name = str ( dir )
name = name . replace ( root , " " )
2023-08-27 00:53:23 +00:00
2024-05-08 07:11:38 +00:00
speaker_name = name
metadata_path = Path ( f " { metadata_root } / { speaker_name } .json " )
metadata_path . parents [ 0 ] . mkdir ( parents = True , exist_ok = True )
2023-08-27 00:53:23 +00:00
2024-05-10 01:28:20 +00:00
try :
metadata = { } if not metadata_path . exists ( ) else json . loads ( open ( str ( metadata_path ) , " r " , encoding = " utf-8 " ) . read ( ) )
except Exception as e :
metadata = { }
2023-08-27 00:53:23 +00:00
2024-04-29 04:03:09 +00:00
if not os . path . isdir ( f ' { root } / { name } / ' ) :
return
# tqdm.write(f'{root}/{name}')
files = os . listdir ( f ' { root } / { name } / ' )
2023-10-17 00:30:38 +00:00
2024-04-29 04:03:09 +00:00
# grab IDs for every file
ids = { file . replace ( _get_quant_extension ( ) , " " ) . replace ( _get_phone_extension ( ) , " " ) for file in files }
2023-08-27 00:53:23 +00:00
2024-04-29 04:03:09 +00:00
for id in tqdm ( ids , desc = f " Processing { name } " ) :
try :
2024-05-18 12:14:26 +00:00
quant_exists = os . path . exists ( f ' { root } / { name } / { id } { _get_quant_extension ( ) } ' ) if audios else True
2024-04-29 04:03:09 +00:00
text_exists = os . path . exists ( f ' { root } / { name } / { id } { _get_phone_extension ( ) } ' ) if texts else True
2023-08-27 00:53:23 +00:00
2024-05-18 12:14:26 +00:00
if not quant_exists :
2024-04-29 04:03:09 +00:00
continue
2023-08-27 00:53:23 +00:00
2024-05-08 07:11:38 +00:00
key = f ' { type } / { speaker_name } / { id } '
2024-04-29 04:03:09 +00:00
2024-05-08 07:11:38 +00:00
if skip_existing and id in metadata :
2024-04-29 04:03:09 +00:00
continue
if id not in metadata :
metadata [ id ] = { }
2024-05-16 04:04:19 +00:00
utterance_metadata = { }
2024-04-29 04:03:09 +00:00
if audios :
2024-05-16 04:04:19 +00:00
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
dac = np . load ( f ' { root } / { name } / { id } { _get_quant_extension ( ) } ' , allow_pickle = True ) [ ( ) ]
qnt = torch . from_numpy ( dac [ " codes " ] . astype ( int ) ) [ 0 ] . t ( ) . to ( dtype = torch . int16 )
if " text " in dac [ " metadata " ] :
utterance_metadata [ " text " ] = dac [ " metadata " ] [ " text " ]
if " phonemes " in dac [ " metadata " ] :
utterance_metadata [ " phonemes " ] = dac [ " metadata " ] [ " phonemes " ]
if " language " in dac [ " metadata " ] :
utterance_metadata [ " language " ] = dac [ " metadata " ] [ " language " ]
if " original_length " in dac [ " metadata " ] and " sample_rate " in dac [ " metadata " ] :
utterance_metadata [ " duration " ] = dac [ " metadata " ] [ " original_length " ] / dac [ " metadata " ] [ " sample_rate " ]
2024-04-29 04:03:09 +00:00
# text
2024-05-18 12:14:26 +00:00
if texts and text_exists and not utterance_metadata :
utterance_metadata = json . loads ( open ( f ' { root } / { name } / { id } { _get_phone_extension ( ) } ' , " r " , encoding = " utf-8 " ) . read ( ) )
2024-04-29 04:03:09 +00:00
2024-05-16 04:04:19 +00:00
for k , v in utterance_metadata . items ( ) :
metadata [ id ] [ k ] = v
2024-04-29 04:03:09 +00:00
except Exception as e :
2024-05-16 04:04:19 +00:00
tqdm . write ( f ' Error while processing { id } : { e } ' )
2024-04-29 04:03:09 +00:00
with open ( str ( metadata_path ) , " w " , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( metadata ) )
# training
for data_dir in tqdm ( sorted ( cfg . dataset . training ) , desc = " Processing Training " ) :
add ( data_dir , type = " training " )
# validation
for data_dir in tqdm ( sorted ( cfg . dataset . validation ) , desc = ' Processing Validation ' ) :
add ( data_dir , type = " validation " )
# noise
for data_dir in tqdm ( sorted ( cfg . dataset . noise ) , desc = ' Processing Noise ' ) :
add ( data_dir , type = " noise " , texts = False )
2023-08-27 00:53:23 +00:00
2023-08-19 14:50:07 +00:00
# parse yaml to create an hdf5 file
2023-08-27 00:53:23 +00:00
def create_dataset_hdf5 ( skip_existing = True ) :
2023-08-19 14:50:07 +00:00
cfg . dataset . use_hdf5 = True
cfg . load_hdf5 ( write = True )
2024-05-08 07:11:38 +00:00
hf = cfg . hdf5
2023-08-19 14:50:07 +00:00
2023-08-02 21:53:35 +00:00
symmap = get_phone_symmap ( )
2024-04-29 03:28:29 +00:00
root = str ( cfg . data_dir )
metadata_root = str ( cfg . metadata_dir )
2023-08-02 21:53:35 +00:00
2023-08-27 00:53:23 +00:00
def add ( dir , type = " training " , audios = True , texts = True ) :
2024-04-29 03:28:29 +00:00
name = str ( dir )
name = name . replace ( root , " " )
2024-05-08 07:11:38 +00:00
# yucky
speaker_name = name
2024-05-10 01:28:20 +00:00
if " LibriTTS-R " in speaker_name :
speaker_name = speaker_name . replace ( " LibriTTS-R " , " LibriVox " )
2024-04-29 03:28:29 +00:00
2024-05-08 07:11:38 +00:00
metadata_path = Path ( f " { metadata_root } / { speaker_name } .json " )
metadata_path . parents [ 0 ] . mkdir ( parents = True , exist_ok = True )
2024-04-29 03:28:29 +00:00
metadata = { } if not metadata_path . exists ( ) else json . loads ( open ( str ( metadata_path ) , " r " , encoding = " utf-8 " ) . read ( ) )
2023-08-02 21:53:35 +00:00
if not os . path . isdir ( f ' { root } / { name } / ' ) :
return
2024-05-25 16:07:52 +00:00
2023-08-02 21:53:35 +00:00
files = os . listdir ( f ' { root } / { name } / ' )
# grab IDs for every file
2024-04-19 02:24:06 +00:00
ids = { file . replace ( _get_quant_extension ( ) , " " ) . replace ( _get_phone_extension ( ) , " " ) for file in files }
2024-05-25 16:07:52 +00:00
"""
# rephonemizes if you fuck up and use and old tokenizer...
for id , entry in tqdm ( metadata . items ( ) , desc = f " Processing { name } " ) :
key = f ' { type } / { speaker_name } / { id } '
if key not in hf :
continue
group = hf [ key ]
if " phonemes " not in entry :
continue
if " text " not in group :
continue
txt = entry [ " phonemes " ]
phn = " " . join ( txt )
phn = cfg . tokenizer . encode ( phn )
phn = np . array ( phn ) . astype ( np . uint8 )
del group [ " text " ]
group . create_dataset ( ' text ' , data = phn , compression = ' lzf ' )
"""
2023-08-02 21:53:35 +00:00
for id in tqdm ( ids , desc = f " Processing { name } " ) :
2023-09-12 20:54:41 +00:00
try :
2024-05-18 12:14:26 +00:00
quant_exists = os . path . exists ( f ' { root } / { name } / { id } { _get_quant_extension ( ) } ' ) if audios else True
text_exists = os . path . exists ( f ' { root } / { name } / { id } { _get_phone_extension ( ) } ' ) if texts else True
2023-08-02 21:53:35 +00:00
2024-05-18 12:14:26 +00:00
if not quant_exists :
2023-08-27 00:53:23 +00:00
continue
2023-08-02 21:53:35 +00:00
2024-05-08 07:11:38 +00:00
key = f ' { type } / { speaker_name } / { id } '
2023-09-12 20:54:41 +00:00
2024-04-29 03:28:29 +00:00
if skip_existing and key in hf :
continue
group = hf . create_group ( key ) if key not in hf else hf [ key ]
if id not in metadata :
metadata [ id ] = { }
2023-09-12 20:54:41 +00:00
2024-05-16 04:04:19 +00:00
utterance_metadata = { }
2023-09-12 20:54:41 +00:00
# audio
if audios :
2024-05-16 04:04:19 +00:00
dac = np . load ( f ' { root } / { name } / { id } { _get_quant_extension ( ) } ' , allow_pickle = True ) [ ( ) ]
qnt = torch . from_numpy ( dac [ " codes " ] . astype ( int ) ) [ 0 ] . t ( ) . to ( dtype = torch . int16 )
if " text " in dac [ " metadata " ] :
utterance_metadata [ " text " ] = dac [ " metadata " ] [ " text " ]
if " phonemes " in dac [ " metadata " ] :
utterance_metadata [ " phonemes " ] = dac [ " metadata " ] [ " phonemes " ]
if " language " in dac [ " metadata " ] :
utterance_metadata [ " language " ] = dac [ " metadata " ] [ " language " ]
if " original_length " in dac [ " metadata " ] and " sample_rate " in dac [ " metadata " ] :
utterance_metadata [ " duration " ] = dac [ " metadata " ] [ " original_length " ] / dac [ " metadata " ] [ " sample_rate " ]
2024-04-29 03:28:29 +00:00
if " audio " not in group :
2024-05-16 04:04:19 +00:00
group . create_dataset ( ' audio ' , data = qnt . numpy ( ) . astype ( np . int16 ) , compression = ' lzf ' )
2024-04-19 02:24:06 +00:00
2023-09-12 20:54:41 +00:00
# text
if texts :
2024-05-16 04:04:19 +00:00
if not utterance_metadata and text_exists :
utterance_metadata = json . loads ( open ( f ' { root } / { name } / { id } { _get_phone_extension ( ) } ' , " r " , encoding = " utf-8 " ) . read ( ) )
2024-04-21 19:49:18 +00:00
2024-05-16 04:04:19 +00:00
phn = " " . join ( utterance_metadata [ " phonemes " ] )
phn = cfg . tokenizer . encode ( phn )
2024-04-21 19:49:18 +00:00
phn = np . array ( phn ) . astype ( np . uint8 )
2024-04-18 18:32:41 +00:00
2024-04-29 03:28:29 +00:00
if " text " not in group :
group . create_dataset ( ' text ' , data = phn , compression = ' lzf ' )
2024-04-18 18:32:41 +00:00
2024-05-16 04:04:19 +00:00
for k , v in utterance_metadata . items ( ) :
group . attrs [ k ] = v
metadata [ id ] [ k ] = v
2024-04-29 03:28:29 +00:00
2023-09-12 20:54:41 +00:00
except Exception as e :
2024-05-16 04:04:19 +00:00
tqdm . write ( f ' Error while processing { id } : { e } ' )
2023-08-27 00:53:23 +00:00
2024-05-25 16:07:52 +00:00
"""
2024-04-29 03:28:29 +00:00
with open ( str ( metadata_path ) , " w " , encoding = " utf-8 " ) as f :
2023-08-27 00:53:23 +00:00
f . write ( json . dumps ( metadata ) )
2024-05-25 16:07:52 +00:00
"""
2023-08-27 00:53:23 +00:00
2023-08-02 21:53:35 +00:00
# training
2024-05-08 07:11:38 +00:00
for data_dir in tqdm ( cfg . dataset . training , desc = " Processing Training " ) :
2023-08-02 21:53:35 +00:00
add ( data_dir , type = " training " )
# validation
2024-05-08 07:11:38 +00:00
for data_dir in tqdm ( cfg . dataset . validation , desc = ' Processing Validation ' ) :
2023-08-02 21:53:35 +00:00
add ( data_dir , type = " validation " )
2023-08-19 04:57:07 +00:00
# noise
2024-05-08 07:11:38 +00:00
for data_dir in tqdm ( cfg . dataset . noise , desc = ' Processing Noise ' ) :
2023-08-19 14:50:07 +00:00
add ( data_dir , type = " noise " , texts = False )
2023-08-19 04:57:07 +00:00
2023-08-02 21:53:35 +00:00
# write symmap
2023-08-27 00:53:23 +00:00
if " symmap " in hf :
del hf [ ' symmap ' ]
2023-08-02 21:53:35 +00:00
2023-10-12 01:38:40 +00:00
hf . create_dataset ( ' symmap ' , data = json . dumps ( symmap ) )
2023-08-02 21:53:35 +00:00
hf . close ( )
2024-06-25 18:41:29 +00:00
def transcribe_dataset ( ) :
import os
import json
import torch
import torchaudio
import whisperx
from tqdm . auto import tqdm
from pathlib import Path
# to-do: use argparser
batch_size = 16
device = " cuda "
dtype = " float16 "
model_name = " large-v3 "
input_audio = " voices "
output_dataset = " training/metadata "
skip_existing = True
diarize = False
#
model = whisperx . load_model ( model_name , device , compute_type = dtype )
align_model , align_model_metadata , align_model_language = ( None , None , None )
if diarize :
diarize_model = whisperx . DiarizationPipeline ( device = device )
else :
diarize_model = None
def pad ( num , zeroes ) :
return str ( num ) . zfill ( zeroes + 1 )
for dataset_name in os . listdir ( f ' ./ { input_audio } / ' ) :
if not os . path . isdir ( f ' ./ { input_audio } / { dataset_name } / ' ) :
continue
for speaker_id in tqdm ( os . listdir ( f ' ./ { input_audio } / { dataset_name } / ' ) , desc = " Processing speaker " ) :
if not os . path . isdir ( f ' ./ { input_audio } / { dataset_name } / { speaker_id } ' ) :
continue
outpath = Path ( f ' ./ { output_dataset } / { dataset_name } / { speaker_id } /whisper.json ' )
if outpath . exists ( ) :
metadata = json . loads ( open ( outpath , ' r ' , encoding = ' utf-8 ' ) . read ( ) )
else :
os . makedirs ( f ' ./ { output_dataset } / { dataset_name } / { speaker_id } / ' , exist_ok = True )
metadata = { }
for filename in tqdm ( os . listdir ( f ' ./ { input_audio } / { dataset_name } / { speaker_id } / ' ) , desc = f " Processing speaker: { speaker_id } " ) :
if skip_existing and filename in metadata :
continue
if " .json " in filename :
continue
inpath = f ' ./ { input_audio } / { dataset_name } / { speaker_id } / { filename } '
if os . path . isdir ( inpath ) :
continue
metadata [ filename ] = {
" segments " : [ ] ,
" language " : " " ,
" text " : " " ,
" start " : 0 ,
" end " : 0 ,
}
audio = whisperx . load_audio ( inpath )
result = model . transcribe ( audio , batch_size = batch_size )
language = result [ " language " ]
if language [ : 2 ] not in [ " ja " ] :
language = " en "
if align_model_language != language :
tqdm . write ( f ' Loading language: { language } ' )
align_model , align_model_metadata = whisperx . load_align_model ( language_code = language , device = device )
align_model_language = language
result = whisperx . align ( result [ " segments " ] , align_model , align_model_metadata , audio , device , return_char_alignments = False )
metadata [ filename ] [ " segments " ] = result [ " segments " ]
metadata [ filename ] [ " language " ] = language
if diarize_model is not None :
diarize_segments = diarize_model ( audio )
result = whisperx . assign_word_speakers ( diarize_segments , result )
text = [ ]
start = 0
end = 0
for segment in result [ " segments " ] :
text . append ( segment [ " text " ] )
start = min ( start , segment [ " start " ] )
end = max ( end , segment [ " end " ] )
metadata [ filename ] [ " text " ] = " " . join ( text ) . strip ( )
metadata [ filename ] [ " start " ] = start
metadata [ filename ] [ " end " ] = end
open ( outpath , ' w ' , encoding = ' utf-8 ' ) . write ( json . dumps ( metadata ) )
2023-08-02 21:53:35 +00:00
if __name__ == " __main__ " :
2023-08-17 20:04:45 +00:00
import argparse
parser = argparse . ArgumentParser ( " Save trained model to path. " )
2023-08-19 14:50:07 +00:00
parser . add_argument ( " --action " , type = str )
parser . add_argument ( " --tasks " , type = str )
2024-06-09 16:39:43 +00:00
args , unknown = parser . parse_known_args ( )
2023-08-17 20:04:45 +00:00
2023-08-19 14:50:07 +00:00
task = args . action
2023-08-19 04:55:40 +00:00
2023-08-20 11:29:17 +00:00
cfg . dataset . workers = 1
2023-08-27 00:53:23 +00:00
class LoggerOveride :
def info ( self , * args ) :
print ( * args )
_logger = LoggerOveride ( )
2023-08-19 14:50:07 +00:00
if args . action == " hdf5 " :
2024-06-25 18:41:29 +00:00
transcribe_dataset ( )
elif args . action == " hdf5 " :
2023-08-17 20:04:45 +00:00
create_dataset_hdf5 ( )
2024-05-12 18:02:15 +00:00
elif args . action == " list-dataset " :
dataset = [ ]
for group in os . listdir ( cfg . data_dir ) :
for name in os . listdir ( cfg . data_dir / group ) :
if len ( os . listdir ( cfg . data_dir / group / name ) ) == 0 :
continue
dataset . append ( f ' { group } / { name } ' )
2024-05-16 04:04:19 +00:00
print ( json . dumps ( dataset ) )
2023-08-27 00:53:23 +00:00
elif args . action == " metadata " :
create_dataset_metadata ( )
2023-08-19 14:50:07 +00:00
elif args . action == " sample " :
2023-08-19 04:55:40 +00:00
train_dl , subtrain_dl , val_dl = create_train_val_dataloader ( )
samples = {
" training " : [ next ( iter ( train_dl ) ) , next ( iter ( train_dl ) ) ] ,
" evaluation " : [ next ( iter ( subtrain_dl ) ) , next ( iter ( subtrain_dl ) ) ] ,
" validation " : [ next ( iter ( val_dl ) ) , next ( iter ( val_dl ) ) ] ,
}
2024-05-12 12:30:59 +00:00
Path ( " ./data/sample-test/ " ) . mkdir ( parents = True , exist_ok = True )
for k , v in samples . items ( ) :
for i in range ( len ( v ) ) :
for j in tqdm ( range ( len ( v [ i ] [ ' proms ' ] ) ) , desc = " Decoding... " ) :
"""
try :
decode_to_file ( v [ i ] [ ' proms ' ] [ j ] , f " ./data/sample-test/ { k } . { i } . { j } .proms.wav " , device = " cpu " )
except Exception as e :
print ( f " Error while decoding prom { k } . { i } . { j } .wav: " , str ( e ) )
try :
decode_to_file ( v [ i ] [ ' resps ' ] [ j ] , f " ./data/sample-test/ { k } . { i } . { j } .resps.wav " , device = " cpu " )
except Exception as e :
print ( f " Error while decoding resp { k } . { i } . { j } .wav: " , str ( e ) )
"""
v [ i ] [ ' proms ' ] [ j ] = v [ i ] [ ' proms ' ] [ j ] . shape
v [ i ] [ ' resps ' ] [ j ] = v [ i ] [ ' resps ' ] [ j ] . shape
2023-08-19 04:55:40 +00:00
for k , v in samples . items ( ) :
for i in range ( len ( v ) ) :
2024-05-12 12:30:59 +00:00
print ( f ' { k } [ { i } ]: ' , v [ i ] )
2023-08-27 00:53:23 +00:00
2023-08-19 14:50:07 +00:00
elif args . action == " tasks " :
2023-08-19 04:55:40 +00:00
index = 0
2023-08-19 14:50:07 +00:00
cfg . dataset . tasks_list = args . tasks . split ( " , " )
train_dl , subtrain_dl , val_dl = create_train_val_dataloader ( )
batch = next ( iter ( train_dl ) )
2023-08-18 19:47:48 +00:00
2023-08-19 14:50:07 +00:00
for text , resps , proms , task in zip ( batch [ " text " ] , batch [ " resps " ] , batch [ " proms " ] , batch [ " task " ] ) :
if task not in cfg . dataset . tasks_list :
continue
2023-08-18 19:47:48 +00:00
2024-07-16 00:59:48 +00:00
print ( text , task , cfg . model . resp_levels )
2023-08-21 02:36:02 +00:00
print ( proms . shape , resps . shape )
2023-08-28 16:02:45 +00:00
tokens = 0
tokens + = sum ( [ text . shape [ 0 ] for text in batch [ " text " ] ] )
tokens + = sum ( [ resps . shape [ 0 ] for resps in batch [ " resps " ] ] )
print ( tokens )
2023-08-20 11:29:17 +00:00
decode_to_file ( proms , f " ./data/ { task } .proms.wav " , device = " cpu " )
decode_to_file ( resps , f " ./data/ { task } .resps.wav " , device = " cpu " )
2023-08-19 14:50:07 +00:00
break