2023-08-02 21:53:35 +00:00
# todo: clean this mess up
import copy
import h5py
import json
import logging
import numpy as np
import os
import random
import torch
2023-08-27 00:53:23 +00:00
import itertools
2023-08-02 21:53:35 +00:00
from . config import cfg
2023-08-23 21:43:03 +00:00
from . emb . qnt import trim , trim_random , repeat_extend_audio , merge_audio , decode_to_file
2023-09-04 02:27:13 +00:00
from . utils . sampler import Sampler
2023-08-02 21:53:35 +00:00
from collections import defaultdict
from functools import cache , cached_property
from itertools import groupby , zip_longest
from pathlib import Path
from typing import Any
from torch import Tensor
from torch . utils . data import DataLoader , Dataset as _Dataset
2023-08-14 03:07:45 +00:00
from torch . utils . data . distributed import DistributedSampler
2023-08-02 21:53:35 +00:00
from tqdm . auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging . getLogger ( __name__ )
def get_phone_symmap ( ) :
2023-08-21 02:36:02 +00:00
if cfg . dataset . use_hdf5 and ' symmap ' in cfg . hdf5 :
return json . loads ( cfg . hdf5 [ ' symmap ' ] . asstr ( ) [ ( ) ] )
2023-08-02 21:53:35 +00:00
2023-08-27 00:53:23 +00:00
symmap = { ' <s> ' : 1 , ' </s> ' : 2 , ' ' : 3 , ' . ' : 4 , ' , ' : 5 , ' ! ' : 6 , ' ? ' : 7 , ' p ' : 7 , ' iː ' : 8 , ' ɚ ' : 9 , ' ˌ ' : 10 , ' dˌ ' : 11 , ' mˌ ' : 12 , ' d ' : 13 , ' ɹ ' : 14 , ' tˈ ' : 15 , ' pˌ ' : 16 , ' uː ' : 17 , ' l ' : 18 , ' æ ' : 19 , ' ɛ ' : 20 , ' ɪ ' : 21 , ' j ' : 22 , ' ʊ ' : 23 , ' t ' : 24 , ' n ' : 25 , ' v ' : 26 , ' a ' : 27 , ' o ' : 28 , ' ŋ ' : 29 , ' w ' : 30 , ' ʌ ' : 31 , ' hˈ ' : 32 , ' ɡ ˈ ' : 33 , ' ə ' : 34 , ' θˈ ' : 35 , ' dˈ ' : 36 , ' wˌ ' : 37 , ' h ' : 38 , ' z ' : 39 , ' k ' : 40 , ' ð ' : 41 , ' ɡˌ ' : 42 , ' ˈ ' : 43 , ' fˈ ' : 44 , ' i ' : 45 , ' s ' : 46 , ' ʃ ' : 47 , ' wˈ ' : 48 , ' ðˈ ' : 49 , ' ɹˈ ' : 50 , ' lˈ ' : 51 , ' ɡ ' : 52 , ' oː ' : 53 , ' mˈ ' : 54 , ' e ' : 55 , ' ɑ ː ' : 56 , ' nˈ ' : 57 , ' m ' : 58 , ' θˌ ' : 59 , ' sˈ ' : 60 , ' f ' : 61 , ' ɔː ' : 62 , ' hˌ ' : 63 , ' b ' : 64 , ' jˈ ' : 65 , ' ɐ ' : 66 , ' ʒˈ ' : 67 , ' θ ' : 68 , ' bˈ ' : 69 , ' ɾ ' : 70 , ' ɜː ' : 71 , ' ʌˈ ' : 72 , ' ʃˌ ' : 73 , ' bˌ ' : 74 , ' kˈ ' : 75 , ' ɔ ' : 76 , ' zˈ ' : 77 , ' ᵻ ' : 78 , ' kˌ ' : 79 , ' vˈ ' : 80 , ' fˌ ' : 81 , ' ʒ ' : 82 , ' ʃˈ ' : 83 , ' ɹˌ ' : 84 , ' tˌ ' : 85 , ' pˈ ' : 86 , ' ðˌ ' : 87 , ' sˌ ' : 88 , ' nˌ ' : 89 , ' lˌ ' : 90 , ' ̩ ' : 91 , ' ʔ ' : 92 , ' vˌ ' : 93 , ' ɪ ˈ ' : 94 , ' " ' : 95 , ' ɪˌ ' : 96 , ' ʒˌ ' : 97 , ' uː ˌ ' : 98 , ' ʊˈ ' : 99 , ' jˌ ' : 100 , ' uː ˈ ' : 101 , ' iː ˈ ' : 102 , ' zˌ ' : 103 , ' .ˈ ' : 104 , ' … ' : 105 , ' ŋˌ ' : 106 , ' ɐˌ ' : 107 , ' —ˈ ' : 108 , ' iˌ ' : 109 , ' iː ˌ ' : 110 , ' ɛː ' : 111 , ' ) ' : 112 , ' )ˈ ' : 113 , ' ( ' : 114 , ' u ' : 115 , ' - ' : 116 , ' ɖˈ ' : 117 , ' iˈ ' : 118 , ' ʰˈ ' : 119 , ' ɟˈ ' : 120 , ' ̃ ' : 121 , ' eː ' : 122 , ' ɾˈ ' : 123 , ' r ' : 124 , ' ʰ ' : 125 , ' -ˌ ' : 126 , ' ɫ ' : 127 , ' q ' : 128 , ' — ' : 129 , ' ʊˌ ' : 130 , ' aː ' : 131 , ' cˈ ' : 132 , ' …ˈ ' : 133 , ' c ' : 134 , ' ɳ ' : 135 , ' ɐˈ ' : 136 , ' x ' : 137 , ' ʔˌ ' : 138 , ' .ˌ ' : 139 , ' ɑ ' : 140 , ' ?ˈ ' : 141 , ' ̩ˈ ' : 142 , ' " ˈ ' : 143 , ' ,ˈ ' : 144 , ' ŋˈ ' : 145 , ' əˌ ' : 146 , ' !ˈ ' : 147 , ' " ˌ ' : 148 , ' ?ˌ ' : 149 , ' ,ˌ ' : 150 , ' —ˌ ' : 151 , ' ̩ˌ ' : 152 , ' əˈ ' : 153 , ' !ˌ ' : 154 , ' ɬ ' : 155 , ' ʲ ' : 156 , ' ¡ ' : 157 , ' ɯ ' : 158 , ' qˌ ' : 159 , ' ʑ ' : 160 , ' ʑˈ ' : 161 , ' ¿ ' : 162 , ' ɑ ː ˈ ' : 163 , ' iː ː ' : 164 , ' ɛˈ ' : 165 , ' ¡ˈ ' : 166 , ' æˈ ' : 167 , ' ç ' : 168 , ' ɾˌ ' : 169 , ' ᵻˈ ' : 170 , ' xˈ ' : 171 , ' ɔːˈ ' : 172 , ' ; ' : 173 , ' ɬˌ ' : 174 , ' : ' : 175 , ' ʔ ˈ ' : 176 , ' ɑːˌ ' : 177 , ' ɬˈ ' : 178 , ' ” ' : 179 , ' “ ' : 180 , ' “ˈ ' : 181 , ' “ˌ ' : 182 , ' ;ˈ ' : 183 , ' ;ˌ ' : 184 , ' :ˈ ' : 185 }
2023-08-02 21:53:35 +00:00
return symmap
2023-08-19 03:22:13 +00:00
def get_task_symmap ( ) :
start = 1024
symmap = {
" <tts> " : - 100 ,
" <ns> " : start + 0 ,
" <sr> " : start + 1 ,
" <tse> " : start + 2 ,
" <soe> " : start + 3 ,
" <mask> " : start + 4 ,
" <eoe> " : start + 5 ,
2023-09-03 00:25:43 +00:00
" <tts-c> " : start + 6 ,
2023-08-19 03:22:13 +00:00
}
return symmap
2023-08-02 21:53:35 +00:00
def _replace_file_extension ( path , suffix ) :
return ( path . parent / path . name . split ( " . " ) [ 0 ] ) . with_suffix ( suffix )
2023-08-27 00:53:23 +00:00
def _get_quant_path ( path ) :
return _replace_file_extension ( path , " .qnt.pt " )
def _get_phone_path ( path ) :
return _replace_file_extension ( path , " .phn.txt " )
2023-09-12 20:54:41 +00:00
_total_durations = { }
@cfg.diskcache ( )
def _calculate_durations ( type = " training " ) :
if type in _total_durations :
return _total_durations [ type ]
return 0
2023-09-02 21:29:53 +00:00
@cfg.diskcache ( )
2023-08-27 00:53:23 +00:00
def _load_paths ( dataset , type = " training " ) :
return { cfg . get_spkr ( data_dir / " dummy " ) : _load_paths_from_metadata ( data_dir , type = type , validate = cfg . dataset . validate and type == " training " ) for data_dir in tqdm ( dataset , desc = f " Parsing dataset: { type } " ) }
def _load_paths_from_metadata ( data_dir , type = " training " , validate = False ) :
_fn = _get_hdf5_paths if cfg . dataset . use_hdf5 else _get_paths_of_extensions
def _validate ( entry ) :
2023-09-12 20:54:41 +00:00
if " phones " not in entry or " duration " not in entry :
return False
2023-08-27 00:53:23 +00:00
phones = entry [ ' phones ' ]
duration = entry [ ' duration ' ]
2023-09-12 20:54:41 +00:00
if type not in _total_durations :
_total_durations [ type ] = 0
_total_durations [ type ] + = entry [ ' duration ' ]
2023-08-27 00:53:23 +00:00
return cfg . dataset . min_duration < = duration and duration < = cfg . dataset . max_duration and cfg . dataset . min_phones < = phones and phones < = cfg . dataset . max_phones
metadata_path = data_dir / " metadata.json "
if not cfg . dataset . use_metadata or not metadata_path . exists ( ) :
return _fn ( data_dir , type if cfg . dataset . use_hdf5 else " .qnt.pt " , validate )
speaker = cfg . get_spkr ( data_dir / " dummy " )
metadata = json . loads ( open ( metadata_path , " r " , encoding = " utf-8 " ) . read ( ) )
def key ( dir , id ) :
if not cfg . dataset . use_hdf5 :
return data_dir / id
return f " / { type } { _get_hdf5_path ( data_dir ) } / { id } "
return [ key ( dir , id ) for id in metadata . keys ( ) if not validate or _validate ( metadata [ id ] ) ]
2023-08-02 21:53:35 +00:00
def _get_hdf5_path ( path ) :
path = str ( path )
if path [ : 2 ] != " ./ " :
path = f ' ./ { path } '
return path . replace ( cfg . cfg_path , " " )
2023-08-27 00:53:23 +00:00
def _get_hdf5_paths ( data_dir , type = " training " , validate = False ) :
data_dir = str ( data_dir )
2023-08-02 21:53:35 +00:00
2023-08-27 00:53:23 +00:00
def _validate ( child ) :
phones = child . attrs [ ' phonemes ' ]
duration = child . attrs [ ' duration ' ]
2023-09-12 20:54:41 +00:00
if type not in _total_durations :
_total_durations [ type ] = 0
_total_durations [ type ] + = entry [ ' duration ' ]
2023-08-27 00:53:23 +00:00
return cfg . dataset . min_duration < = duration and duration < = cfg . dataset . max_duration and cfg . dataset . min_phones < = phones and phones < = cfg . dataset . max_phones
key = f " / { type } { _get_hdf5_path ( data_dir ) } "
return [ Path ( f " { key } / { child . attrs [ ' id ' ] } " ) for child in cfg . hdf5 [ key ] . values ( ) if not validate or _validate ( child ) ] if key in cfg . hdf5 else [ ]
def _get_paths_of_extensions ( path , extensions = " .qnt.pt " , validate = False ) :
if isinstance ( path , str ) :
path = Path ( path )
def _validate ( path ) :
if " " . join ( path . suffixes ) not in extensions :
return False
if not _get_phone_path ( path ) . exists ( ) or not _get_quant_path ( path ) . exists ( ) :
return False
if not validate :
return True
# to-do: find an easy way to determine size from pickled quants without loading
# to-do: find a consistent way to derive phoneme count from filesize (probably can't due to utf-8)
phones = len ( _get_phones ( _get_phone_path ( path ) ) ) # _get_phone_path(path).stat().st_size // 2 + 1
return cfg . dataset . min_phones < = phones and phones < = cfg . dataset . max_phones
return [ p for p in list ( path . iterdir ( ) ) if _validate ( p ) ] if path . exists ( ) and path . is_dir ( ) else [ ]
2023-08-02 21:53:35 +00:00
def _load_quants ( path ) - > Tensor :
2023-08-27 00:53:23 +00:00
return torch . load ( _get_quant_path ( path ) ) [ 0 ] [ : , : ] . t ( ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
@cache
2023-08-21 02:36:02 +00:00
def _get_phones ( path , language = " en " ) :
2023-08-27 00:53:23 +00:00
content = open ( _get_phone_path ( path ) , " r " , encoding = " utf-8 " ) . read ( ) . split ( " " )
2023-08-24 14:20:47 +00:00
return [ " <s> " ] + [ " " if not p else p for p in content ] + [ " </s> " ]
2023-08-02 21:53:35 +00:00
def _interleaved_reorder ( l , fn ) :
groups = defaultdict ( list )
for e in l :
groups [ fn ( e ) ] . append ( e )
groups = { k : groups [ k ] for k in sorted ( groups ) }
for interleaved in zip_longest ( * groups . values ( ) ) :
for value in interleaved :
if value is not None :
yield value
class Dataset ( _Dataset ) :
def __init__ (
self ,
phone_symmap = None ,
training = False ,
extra_paths_by_spkr_name : dict [ str , list ] = { } ,
) :
super ( ) . __init__ ( )
self . _head = None
2023-08-19 01:58:07 +00:00
self . sampler = None
2023-08-02 21:53:35 +00:00
2023-08-27 00:53:23 +00:00
self . paths = [ ]
self . training = training
self . dataset_type = " training " if self . training else " validation "
self . dataset = cfg . dataset . training if self . training else cfg . dataset . validation
2023-08-30 23:23:05 +00:00
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
if len ( self . dataset ) == 0 :
self . dataset = cfg . dataset . training
2023-08-27 00:53:23 +00:00
self . paths_by_spkr_name = _load_paths ( self . dataset , self . dataset_type )
2023-09-12 20:54:41 +00:00
# cull speakers if they do not have enough utterances
if cfg . dataset . min_utterances > 0 :
keys = list ( self . paths_by_spkr_name . keys ( ) )
for key in keys :
if len ( self . paths_by_spkr_name [ key ] ) < cfg . dataset . min_utterances :
del self . paths_by_spkr_name [ key ]
2023-08-27 00:53:23 +00:00
self . paths = list ( itertools . chain . from_iterable ( self . paths_by_spkr_name . values ( ) ) )
2023-09-04 02:27:13 +00:00
self . samplers = { name : Sampler ( paths , keep_all = True ) for name , paths in self . paths_by_spkr_name . items ( ) }
2023-08-27 00:53:23 +00:00
if cfg . dataset . sample_type == " path " :
self . paths = [ * _interleaved_reorder ( self . paths , self . get_speaker ) ]
self . noise_paths = _load_paths ( cfg . dataset . noise , " noise " )
self . noise_paths = list ( itertools . chain . from_iterable ( self . noise_paths . values ( ) ) )
2023-08-02 21:53:35 +00:00
self . phone_symmap = phone_symmap or self . _get_phone_symmap ( )
2023-08-24 15:25:33 +00:00
self . spkr_symmap = self . _get_spkr_symmap ( )
self . task_symmap = self . _get_task_symmap ( )
2023-08-02 21:53:35 +00:00
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self . text_dtype = torch . uint8 if len ( self . phone_symmap ) < 256 else torch . int16
if len ( self . paths ) == 0 and training :
raise ValueError ( " No valid path is found for training. " )
2023-09-12 20:54:41 +00:00
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
self . duration = _calculate_durations ( self . dataset_type )
2023-08-02 21:53:35 +00:00
@cached_property
def phones ( self ) :
return sorted ( set ( ) . union ( * [ _get_phones ( path ) for path in self . paths ] ) )
2023-08-27 00:53:23 +00:00
def get_speaker ( self , path ) :
if isinstance ( path , str ) :
path = Path ( path )
res = cfg . get_spkr ( path )
return res
2023-08-02 21:53:35 +00:00
@cached_property
def spkrs ( self ) :
2023-08-27 00:53:23 +00:00
return sorted ( { self . get_speaker ( path ) for path in self . paths } )
2023-08-02 21:53:35 +00:00
2023-08-19 03:22:13 +00:00
@cached_property
def tasks ( self ) :
2023-08-19 05:16:08 +00:00
return cfg . dataset . tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
2023-08-19 03:22:13 +00:00
2023-09-04 02:27:13 +00:00
def save_state_dict ( self , path ) :
state_dict = {
" samplers " : { name : sampler . current_pool for name , sampler in self . samplers . items ( ) }
}
torch . save ( state_dict , path )
def load_state_dict ( self , path ) :
state_dict = torch . load ( path )
if " samplers " in state_dict :
# better than naively setting the entire object
for name , sampler in state_dict [ " samplers " ] . items ( ) :
if name not in self . samplers :
continue
self . samplers [ name ] . current_pool = sampler
2023-08-19 03:22:13 +00:00
def _get_phone_symmap ( self ) :
return get_phone_symmap ( )
2023-08-02 21:53:35 +00:00
def _get_spkr_symmap ( self ) :
return { s : i for i , s in enumerate ( self . spkrs ) }
2023-08-19 03:22:13 +00:00
def _get_task_symmap ( self ) :
return get_task_symmap ( )
2023-08-19 20:06:33 +00:00
def get_task_token ( self , token , levels = cfg . models . max_levels ) :
2023-08-19 05:16:08 +00:00
if not hasattr ( self , " task_symmap " ) :
self . task_symmap = self . _get_task_symmap ( )
2023-08-19 20:06:33 +00:00
return torch . Tensor ( [ [ self . task_symmap [ f ' < { token } > ' ] for _ in range ( levels ) ] ] ) . to ( dtype = torch . int16 )
2023-08-19 03:22:13 +00:00
2023-08-19 20:06:33 +00:00
def sample_noise ( self ) :
2023-08-27 00:53:23 +00:00
path = random . choice ( self . noise_paths )
2023-08-19 05:16:08 +00:00
2023-08-27 00:53:23 +00:00
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2023-08-19 20:06:33 +00:00
qnt = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : ] ) . to ( torch . int16 )
2023-08-19 05:16:08 +00:00
else :
qnt = _load_quants ( path )
return qnt
2023-08-19 03:22:13 +00:00
2023-08-18 00:07:59 +00:00
def sample_speakers ( self , ignore = [ ] ) :
choices = set ( self . spkrs ) - set ( ignore )
return random . choice ( [ * choices ] )
2023-08-02 21:53:35 +00:00
def sample_prompts ( self , spkr_name , ignore ) :
prom_list = [ ]
choices = set ( self . paths_by_spkr_name [ spkr_name ] ) - { ignore }
choices = [ * choices ]
2023-08-17 00:39:21 +00:00
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step
2023-08-02 21:53:35 +00:00
if len ( choices ) == 0 :
2023-08-17 00:39:21 +00:00
choices = [ * set ( self . paths_by_spkr_name [ spkr_name ] ) ]
"""
2023-08-02 21:53:35 +00:00
raise ValueError (
f " Failed to find another different utterance for { spkr_name } . "
)
2023-08-17 00:39:21 +00:00
"""
2023-08-02 21:53:35 +00:00
# shuffle it up a bit
2023-08-19 04:55:40 +00:00
prom_length = 0
2023-10-02 21:52:42 +00:00
trim_length = random . randint ( 75 * 3 , 75 * 9 ) # [3 seconds, 9 seconds]
#trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
2023-08-19 04:55:40 +00:00
2023-08-02 21:53:35 +00:00
for _ in range ( cfg . dataset . max_prompts ) :
path = random . choice ( choices )
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2023-08-19 20:06:33 +00:00
qnt = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : ] ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
else :
qnt = _load_quants ( path )
if cfg . dataset . prompt_duration > 0 and trim_length < qnt . shape [ 0 ] :
2023-08-23 21:43:03 +00:00
qnt = trim ( qnt , trim_length )
2023-08-02 21:53:35 +00:00
prom_list . append ( qnt )
2023-08-19 04:55:40 +00:00
prom_length + = qnt . shape [ 0 ]
2023-08-02 21:53:35 +00:00
2023-08-19 04:55:40 +00:00
if prom_length > = trim_length or random . random ( ) > cfg . dataset . random_utterance :
2023-08-02 21:53:35 +00:00
break
prom = torch . cat ( prom_list )
if cfg . dataset . prompt_duration > 0 and trim_length < prom . shape [ 0 ] :
2023-08-23 21:43:03 +00:00
prom = trim ( prom , trim_length )
2023-08-02 21:53:35 +00:00
return prom
def __getitem__ ( self , index ) :
2023-08-19 01:58:07 +00:00
if cfg . dataset . sample_type == " speaker " :
2023-08-17 00:39:21 +00:00
spkr_name = self . spkrs [ index ]
spkr_id = self . spkr_symmap [ spkr_name ]
2023-09-04 02:27:13 +00:00
path = self . samplers [ spkr_name ] . sample ( )
2023-08-02 21:53:35 +00:00
else :
2023-08-18 19:47:48 +00:00
path = self . paths [ index ]
2023-08-27 00:53:23 +00:00
spkr_name = self . get_speaker ( path )
2023-08-17 00:39:21 +00:00
spkr_id = self . spkr_symmap [ spkr_name ]
2023-08-02 21:53:35 +00:00
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
text = torch . from_numpy ( cfg . hdf5 [ key ] [ " text " ] [ : ] ) . to ( self . text_dtype )
2023-08-19 20:06:33 +00:00
resps = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : ] ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
else :
text = torch . tensor ( [ * map ( self . phone_symmap . get , _get_phones ( path ) ) ] ) . to ( self . text_dtype )
resps = _load_quants ( path )
2023-08-17 23:56:37 +00:00
task = random . choice ( self . tasks )
2023-08-19 03:22:13 +00:00
2023-08-19 14:50:07 +00:00
# ensure a speaker has at least four utterances
# default to tts if not
if len ( set ( self . paths_by_spkr_name [ spkr_name ] ) - { path } ) < 4 :
task = " tts "
2023-08-20 11:29:17 +00:00
noise_scale = 0.25
2023-08-18 00:07:59 +00:00
# text-to-speech
2023-09-02 17:23:40 +00:00
if task == " tts " or task == " tts-c " :
2023-09-01 22:19:34 +00:00
trim_length = int ( cfg . dataset . prompt_duration * 75 )
2023-09-02 18:39:17 +00:00
# demote if the target is too short
if task == " tts-c " and trim_length * 2 > = resps . shape [ 0 ] :
task = " tts "
2023-09-02 21:29:53 +00:00
2023-09-01 22:19:34 +00:00
# VALL-E continuous
# ignore if target utterance is shorter than prompt duration
# to-do: actually do this for the AR only as I don't think the paper trained the NAR for this
2023-09-02 18:39:17 +00:00
if task == " tts-c " :
2023-09-01 22:19:34 +00:00
proms = resps [ : trim_length , : ]
resps = resps [ trim_length : , : ]
2023-09-02 21:29:53 +00:00
proms = torch . cat ( [ self . get_task_token ( task ) , proms ] )
2023-09-01 22:19:34 +00:00
else :
proms = self . sample_prompts ( spkr_name , ignore = path ) if random . random ( ) < cfg . dataset . random_utterance else resps
2023-08-19 03:22:13 +00:00
# noise suppression || speech removal
elif task == " ns " or task == " sr " :
# sample random noise
2023-08-18 00:07:59 +00:00
noise = self . sample_noise ( )
2023-08-19 03:22:13 +00:00
# extend the noise to fill the target audio
noise = repeat_extend_audio ( noise , resps . shape [ 0 ] )
# create the input prompt by merging the target audio with the noise
2023-08-19 20:06:33 +00:00
proms = merge_audio ( resps , noise , scale = [ 1 , noise_scale ] , device = " cpu " )
2023-08-19 03:22:13 +00:00
# set the target to just be the noise if <sr>
if task == " sr " :
resps = noise
# prepend the task token
proms = torch . cat ( [ self . get_task_token ( task ) , proms ] )
2023-08-19 04:55:40 +00:00
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype )
2023-08-19 03:22:13 +00:00
# target speech extraction
2023-08-18 19:47:48 +00:00
elif task == " tse " :
2023-08-19 03:22:13 +00:00
# sample a random, clean, utterance for the target speaker
2023-08-19 06:16:46 +00:00
clean_proms = self . sample_prompts ( spkr_name , ignore = path )
2023-08-19 03:22:13 +00:00
# sample a random, clean utterance from a different speaker
2023-08-19 04:55:40 +00:00
other_proms = self . sample_prompts ( self . sample_speakers ( ignore = [ spkr_name ] ) , ignore = " " )
2023-08-19 03:22:13 +00:00
# overlay the random speaker over the target audio
2023-08-19 04:55:40 +00:00
smallest_size = min ( resps . shape [ 0 ] , other_proms . shape [ 0 ] )
if other_proms . shape [ 0 ] == smallest_size :
2023-08-19 05:16:08 +00:00
noisy_proms = merge_audio ( resps [ : smallest_size , : ] , other_proms , scale = [ 1 , random . uniform ( 0.5 , 0.75 ) ] , device = " cpu " )
2023-08-19 04:55:40 +00:00
noisy_proms = torch . cat ( [ noisy_proms , resps [ smallest_size : , : ] ] )
else :
2023-08-19 05:16:08 +00:00
noisy_proms = merge_audio ( resps , other_proms [ : smallest_size , : ] , scale = [ 1 , random . uniform ( 0.5 , 0.75 ) ] , device = " cpu " )
2023-08-19 04:55:40 +00:00
noisy_proms = torch . cat ( [ noisy_proms , other_proms [ smallest_size : , : ] ] )
# stitch together the promps
2023-08-19 03:22:13 +00:00
proms = torch . cat ( [ clean_proms , self . get_task_token ( task ) , noisy_proms ] )
2023-08-19 04:55:40 +00:00
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
2023-08-19 06:16:46 +00:00
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype ) # <s></s>
2023-08-19 04:55:40 +00:00
2023-08-18 19:47:48 +00:00
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
2023-08-19 03:22:13 +00:00
# clean speech editing
2023-08-19 06:16:46 +00:00
elif task == " cse " or task == " nse " :
choices = set ( self . paths_by_spkr_name [ spkr_name ] ) - { path }
2023-08-19 14:50:07 +00:00
sampled = random . sample ( [ * choices ] , 4 )
2023-08-19 06:16:46 +00:00
if cfg . dataset . use_hdf5 :
texts = [ torch . from_numpy ( cfg . hdf5 [ _get_hdf5_path ( path ) ] [ " text " ] [ : ] ) . to ( self . text_dtype ) for path in sampled ]
2023-08-19 20:06:33 +00:00
qnts = [ torch . from_numpy ( cfg . hdf5 [ _get_hdf5_path ( path ) ] [ " audio " ] [ : , : ] ) . to ( torch . int16 ) for path in sampled ]
2023-08-19 06:16:46 +00:00
else :
texts = [ torch . tensor ( [ * map ( self . phone_symmap . get , _get_phones ( path ) ) ] ) . to ( self . text_dtype ) for path in sampled ]
qnts = [ _load_quants ( path ) for path in sampled ]
# remove <s></s>
2023-08-19 14:50:07 +00:00
for i in range ( len ( texts ) ) :
texts [ i ] = texts [ i ] [ 1 : - 1 ]
2023-08-19 06:16:46 +00:00
pre_text , mid_text , post_text , edit_text = texts
pre_prom , mid_prom , post_prom , edit_prom = qnts
# randomly drop out pre
if random . random ( ) < 0.125 :
pre_text = None
pre_prom = None
# randomly drop out post
if random . random ( ) < 0.125 :
post_text = None
post_prom = None
# create new text
text = torch . cat (
2023-08-19 14:50:07 +00:00
[ torch . Tensor ( [ 1 ] ) . to ( dtype = self . text_dtype ) ] + # <s>
( [ pre_text , torch . Tensor ( [ 3 ] ) . to ( dtype = self . text_dtype ) ] if pre_text is not None else [ ] ) + # pre_text + space'
[ edit_text ] + # 'edit text'
( [ torch . Tensor ( [ 3 ] ) . to ( dtype = self . text_dtype ) , post_text ] if post_text is not None else [ ] ) + # 'space' + edit_text
[ torch . Tensor ( [ 2 ] ) . to ( dtype = self . text_dtype ) ] # </s>
2023-08-19 06:16:46 +00:00
)
if task == " nse " :
# sample random noise
noise = self . sample_noise ( )
# it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent
# but it's noise, it's supposed to be random
2023-08-19 20:06:33 +00:00
def noise_proms ( p ) :
2023-08-19 06:16:46 +00:00
# ignore if we turned it off
2023-08-19 20:06:33 +00:00
if p is None :
2023-08-19 06:16:46 +00:00
return None
# extend the noise to fill the target audio
2023-08-19 20:06:33 +00:00
n = repeat_extend_audio ( noise , p . shape [ 0 ] )
2023-08-19 06:16:46 +00:00
# merge the noise over the utterance
2023-08-19 20:06:33 +00:00
return merge_audio ( p , n , scale = [ 1 , noise_scale ] , device = " cpu " )
2023-08-19 06:16:46 +00:00
# apply noise to all pieces
pre_prom = noise_proms ( pre_prom )
mid_prom = noise_proms ( mid_prom )
post_prom = noise_proms ( post_prom )
edit_prom = noise_proms ( edit_prom )
else :
mid_prom = self . get_task_token ( " mask " )
# create new proms
proms = torch . cat (
( [ pre_prom ] if pre_prom is not None else [ ] ) +
[ self . get_task_token ( " soe " ) ] +
[ mid_prom ] + # is <mask> if task is CSE
[ self . get_task_token ( " eoe " ) ] +
( [ post_prom ] if post_prom is not None else [ ] )
)
# create new resp
resps = torch . cat (
( [ pre_prom ] if pre_prom is not None else [ ] ) +
[ edit_prom ] +
( [ post_prom ] if post_prom is not None else [ ] )
)
2023-08-19 20:06:33 +00:00
else :
2023-09-02 17:23:40 +00:00
raise Exception ( f ' Undefined task: { task } ' )
2023-08-19 06:16:46 +00:00
2023-08-19 04:55:40 +00:00
"""
# emulate SVC
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
# targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip
# NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical
# ahead-of-time dataset preparation of a shit ton of RVC clips might be the key.
# aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task.
elif task == " svc " :
# sample a random, clean utterance for the target speaker
proms = self . sample_prompts ( spkr_name , ignore = path ) if random . random ( ) < cfg . dataset . random_utterance else resps
# sample a reference clip from a different speaker
ref_proms = self . sample_rvc ( self . sample_speakers ( ignore = [ spkr_name ] ) )
#
resps =
# stitch together the promps
proms = torch . cat ( [ proms , self . get_task_token ( task ) , ref_proms ] )
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype )
"""
2023-08-18 00:07:59 +00:00
2023-08-19 20:06:33 +00:00
# trim to fit to requested prom/resps levels
proms = proms [ : , : cfg . models . prom_levels ]
resps = resps [ : , : cfg . models . prom_levels ]
2023-08-02 21:53:35 +00:00
return dict (
index = index ,
2023-08-27 17:26:12 +00:00
path = Path ( path ) ,
2023-08-02 21:53:35 +00:00
spkr_name = spkr_name ,
spkr_id = spkr_id ,
2023-08-17 23:56:37 +00:00
task = task ,
2023-08-02 21:53:35 +00:00
text = text ,
proms = proms ,
resps = resps ,
)
def head_ ( self , n ) :
self . _head = n
def training_ ( self , value ) :
self . training = value
def __len__ ( self ) :
2023-08-19 01:58:07 +00:00
if cfg . dataset . sample_type == " speaker " :
2023-08-17 00:39:21 +00:00
return min ( len ( self . spkrs ) , self . _head or len ( self . spkrs ) )
2023-08-02 21:53:35 +00:00
return min ( len ( self . paths ) , self . _head or len ( self . paths ) )
def pin_memory ( self ) :
self . text = self . text . pin_memory ( )
self . proms = self . proms . pin_memory ( )
self . resps = self . resps . pin_memory ( )
self . resp = self . resp . pin_memory ( )
return self
def collate_fn ( samples : list [ dict ] ) :
batch : dict [ str , Any ] = { k : [ s [ k ] for s in samples ] for k in samples [ 0 ] }
return batch
def _seed_worker ( worker_id ) :
worker_seed = torch . initial_seed ( ) % 2 * * 32
np . random . seed ( worker_seed )
random . seed ( worker_seed )
def _create_dataloader ( dataset , training ) :
2023-08-17 18:41:53 +00:00
sampler = None
shuffle = True
if cfg . distributed and training :
sampler = DistributedSampler ( dataset )
shuffle = False
2023-08-02 21:53:35 +00:00
return DataLoader (
dataset = dataset ,
batch_size = cfg . hyperparameters . batch_size if training else cfg . evaluation . batch_size ,
2023-08-17 18:41:53 +00:00
shuffle = shuffle ,
2023-08-02 21:53:35 +00:00
drop_last = training ,
num_workers = cfg . dataset . workers ,
collate_fn = collate_fn ,
2023-08-24 22:05:56 +00:00
persistent_workers = cfg . dataset . workers > 1 ,
2023-08-02 21:53:35 +00:00
pin_memory = False , # True,
worker_init_fn = _seed_worker ,
2023-08-17 18:41:53 +00:00
sampler = sampler ,
2023-08-02 21:53:35 +00:00
)
def create_datasets ( ) :
2023-08-27 00:53:23 +00:00
train_dataset = Dataset ( training = True )
val_dataset = Dataset ( phone_symmap = train_dataset . phone_symmap , training = False )
2023-08-02 21:53:35 +00:00
2023-09-04 02:27:13 +00:00
train_state_path = cfg . relpath / " train_dataset.pt "
if train_state_path . exists ( ) :
train_dataset . load_state_dict ( train_state_path )
2023-08-02 21:53:35 +00:00
return train_dataset , val_dataset
def create_train_val_dataloader ( ) :
train_dataset , val_dataset = create_datasets ( )
subtrain_dataset = copy . deepcopy ( train_dataset )
2023-08-19 01:58:07 +00:00
if cfg . dataset . sample_type == " path " :
2023-08-17 20:04:45 +00:00
subtrain_dataset . head_ ( cfg . evaluation . size )
2023-08-02 21:53:35 +00:00
train_dl = _create_dataloader ( train_dataset , training = True )
val_dl = _create_dataloader ( val_dataset , training = False )
subtrain_dl = _create_dataloader ( subtrain_dataset , training = False )
_logger . info ( str ( train_dataset . phone_symmap ) )
_logger . info ( str ( train_dataset . spkr_symmap ) )
_logger . info ( f " #samples (train): { len ( train_dataset ) } . " )
_logger . info ( f " #samples (val): { len ( val_dataset ) } . " )
_logger . info ( f " #samples (subtrain): { len ( subtrain_dataset ) } . " )
_logger . info ( f " #duration (train): { str ( train_dataset . duration ) } . " )
_logger . info ( f " #duration (val): { str ( val_dataset . duration ) } . " )
_logger . info ( f " #duration (subtrain): { str ( subtrain_dataset . duration ) } . " )
assert isinstance ( subtrain_dl . dataset , Dataset )
return train_dl , subtrain_dl , val_dl
2023-08-27 00:53:23 +00:00
# parse dataset into better to sample metadata
def create_dataset_metadata ( ) :
cfg . dataset . validate = False
cfg . dataset . use_hdf5 = False
paths_by_spkr_name = { }
paths_by_spkr_name | = _load_paths ( cfg . dataset . training , " training " )
paths_by_spkr_name | = _load_paths ( cfg . dataset . validation , " validation " )
paths_by_spkr_name | = _load_paths ( cfg . dataset . noise , " noise " )
paths = list ( itertools . chain . from_iterable ( paths_by_spkr_name . values ( ) ) )
metadata = { }
for path in tqdm ( paths , desc = " Parsing paths " ) :
speaker = cfg . get_spkr ( path )
if speaker not in metadata :
metadata [ speaker ] = { }
if cfg . dataset . use_hdf5 :
phones = cfg . hdf5 [ _get_hdf5_path ( path ) ] . attrs [ ' phonemes ' ]
duration = cfg . hdf5 [ _get_hdf5_path ( path ) ] . attrs [ ' duration ' ]
else :
phns_path = _get_phone_path ( path )
qnts_path = _get_quant_path ( path )
phones = len ( _get_phones ( phns_path ) ) if phns_path . exists ( ) else 0
duration = _load_quants ( qnts_path ) . shape [ 0 ] / 75 if qnts_path . exists ( ) else 0
metadata [ speaker ] [ path . name . split ( " . " ) [ 0 ] ] = {
" phones " : phones ,
" duration " : duration
}
for speaker , paths in tqdm ( paths_by_spkr_name . items ( ) , desc = " Writing metadata " ) :
if len ( paths ) == 0 :
continue
with open ( paths [ 0 ] . parent / " metadata.json " , " w " , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( metadata [ speaker ] ) )
with open ( cfg . relpath / " metadata.json " , " w " , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( metadata ) )
2023-08-19 14:50:07 +00:00
# parse yaml to create an hdf5 file
2023-08-27 00:53:23 +00:00
def create_dataset_hdf5 ( skip_existing = True ) :
2023-08-19 14:50:07 +00:00
cfg . dataset . use_hdf5 = True
cfg . load_hdf5 ( write = True )
2023-08-02 21:53:35 +00:00
symmap = get_phone_symmap ( )
root = cfg . cfg_path
hf = cfg . hdf5
2023-08-27 00:53:23 +00:00
def add ( dir , type = " training " , audios = True , texts = True ) :
name = " ./ " + str ( dir )
name = name . replace ( root , " " )
metadata = { }
2023-08-02 21:53:35 +00:00
if not os . path . isdir ( f ' { root } / { name } / ' ) :
return
# tqdm.write(f'{root}/{name}')
files = os . listdir ( f ' { root } / { name } / ' )
# grab IDs for every file
ids = { " . " . join ( file . split ( " . " ) [ : - 2 ] ) for file in files }
for id in tqdm ( ids , desc = f " Processing { name } " ) :
2023-09-12 20:54:41 +00:00
try :
audio_exists = os . path . exists ( f ' { root } / { name } / { id } .qnt.pt ' ) if audios else True
text_exists = os . path . exists ( f ' { root } / { name } / { id } .phn.txt ' ) if texts else True
2023-08-02 21:53:35 +00:00
2023-09-12 20:54:41 +00:00
if not audio_exists or not text_exists :
2023-08-27 00:53:23 +00:00
continue
2023-08-02 21:53:35 +00:00
2023-09-12 20:54:41 +00:00
key = f ' { type } / { name } / { id } '
if key in hf :
if skip_existing :
continue
del hf [ key ]
group = hf . create_group ( key )
group . attrs [ ' id ' ] = id
group . attrs [ ' type ' ] = type
group . attrs [ ' speaker ' ] = name
metadata [ id ] = { }
# audio
if audios :
qnt = torch . load ( f ' { root } / { name } / { id } .qnt.pt ' ) [ 0 ] . t ( )
if " audio " in group :
del group [ " audio " ]
group . create_dataset ( ' audio ' , data = qnt . numpy ( ) , compression = ' lzf ' )
group . attrs [ ' duration ' ] = qnt . shape [ 0 ] / 75
metadata [ id ] [ " duration " ] = qnt . shape [ 0 ] / 75
else :
group . attrs [ ' duration ' ] = 0
metadata [ id ] [ " duration " ] = 0
# text
if texts :
content = open ( f ' { root } / { name } / { id } .phn.txt ' , " r " , encoding = " utf-8 " ) . read ( ) . split ( " " )
phones = [ f " <s> " ] + [ " " if not p else p for p in content ] + [ f " </s> " ]
for s in set ( phones ) :
if s not in symmap :
symmap [ s ] = len ( symmap . keys ( ) )
phn = [ symmap [ s ] for s in phones ]
if " text " in group :
del group [ " text " ]
group . create_dataset ( ' text ' , data = phn , compression = ' lzf ' , chunks = True )
group . attrs [ ' phonemes ' ] = len ( phn )
metadata [ id ] [ " phones " ] = len ( phn )
else :
group . attrs [ ' phonemes ' ] = 0
metadata [ id ] [ " phones " ] = 0
except Exception as e :
pass
2023-08-27 00:53:23 +00:00
with open ( dir / " metadata.json " , " w " , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( metadata ) )
2023-08-02 21:53:35 +00:00
# training
for data_dir in tqdm ( cfg . dataset . training , desc = " Processing Training " ) :
add ( data_dir , type = " training " )
# validation
for data_dir in tqdm ( cfg . dataset . validation , desc = ' Processing Validation ' ) :
add ( data_dir , type = " validation " )
2023-08-19 04:57:07 +00:00
# noise
for data_dir in tqdm ( cfg . dataset . noise , desc = ' Processing Noise ' ) :
2023-08-19 14:50:07 +00:00
add ( data_dir , type = " noise " , texts = False )
2023-08-19 04:57:07 +00:00
2023-08-02 21:53:35 +00:00
# write symmap
2023-08-27 00:53:23 +00:00
if " symmap " in hf :
del hf [ ' symmap ' ]
hf . create_dataset ( ' symmap ' , data = json . dumps ( symmap ) )
2023-08-02 21:53:35 +00:00
hf . close ( )
if __name__ == " __main__ " :
2023-08-17 20:04:45 +00:00
import argparse
parser = argparse . ArgumentParser ( " Save trained model to path. " )
2023-08-19 14:50:07 +00:00
parser . add_argument ( " --action " , type = str )
parser . add_argument ( " --tasks " , type = str )
2023-08-17 20:04:45 +00:00
args = parser . parse_args ( )
2023-08-19 14:50:07 +00:00
task = args . action
2023-08-19 04:55:40 +00:00
2023-08-20 11:29:17 +00:00
cfg . dataset . workers = 1
2023-08-27 00:53:23 +00:00
class LoggerOveride :
def info ( self , * args ) :
print ( * args )
_logger = LoggerOveride ( )
2023-08-19 14:50:07 +00:00
if args . action == " hdf5 " :
2023-08-17 20:04:45 +00:00
create_dataset_hdf5 ( )
2023-08-27 00:53:23 +00:00
elif args . action == " metadata " :
create_dataset_metadata ( )
2023-08-19 14:50:07 +00:00
elif args . action == " sample " :
2023-08-19 04:55:40 +00:00
train_dl , subtrain_dl , val_dl = create_train_val_dataloader ( )
samples = {
" training " : [ next ( iter ( train_dl ) ) , next ( iter ( train_dl ) ) ] ,
" evaluation " : [ next ( iter ( subtrain_dl ) ) , next ( iter ( subtrain_dl ) ) ] ,
" validation " : [ next ( iter ( val_dl ) ) , next ( iter ( val_dl ) ) ] ,
}
for k , v in samples . items ( ) :
for i in range ( len ( v ) ) :
del v [ i ] [ ' proms ' ]
del v [ i ] [ ' resps ' ]
print ( f ' { k } : ' , v )
2023-08-27 00:53:23 +00:00
2023-09-04 02:27:13 +00:00
train_dl . dataset . save_state_dict ( cfg . relpath / " train_dataset.pt " )
2023-08-19 14:50:07 +00:00
elif args . action == " tasks " :
2023-08-19 04:55:40 +00:00
index = 0
2023-08-19 14:50:07 +00:00
cfg . dataset . tasks_list = args . tasks . split ( " , " )
train_dl , subtrain_dl , val_dl = create_train_val_dataloader ( )
batch = next ( iter ( train_dl ) )
2023-08-18 19:47:48 +00:00
2023-08-19 14:50:07 +00:00
for text , resps , proms , task in zip ( batch [ " text " ] , batch [ " resps " ] , batch [ " proms " ] , batch [ " task " ] ) :
if task not in cfg . dataset . tasks_list :
continue
2023-08-18 19:47:48 +00:00
2023-08-20 11:29:17 +00:00
print ( text , task , cfg . models . prom_levels )
2023-08-21 02:36:02 +00:00
print ( proms . shape , resps . shape )
2023-08-28 16:02:45 +00:00
tokens = 0
tokens + = sum ( [ text . shape [ 0 ] for text in batch [ " text " ] ] )
tokens + = sum ( [ resps . shape [ 0 ] for resps in batch [ " resps " ] ] )
print ( tokens )
2023-08-20 11:29:17 +00:00
decode_to_file ( proms , f " ./data/ { task } .proms.wav " , device = " cpu " )
decode_to_file ( resps , f " ./data/ { task } .resps.wav " , device = " cpu " )
2023-08-19 14:50:07 +00:00
break