2023-08-02 21:53:35 +00:00
# todo: clean this mess up
import copy
import h5py
import json
import logging
import numpy as np
import os
import random
import torch
from . config import cfg
2023-08-19 04:55:40 +00:00
from . emb . qnt import trim_random , repeat_extend_audio , merge_audio , decode_to_file
2023-08-02 21:53:35 +00:00
from collections import defaultdict
from functools import cache , cached_property
from itertools import groupby , zip_longest
from pathlib import Path
from typing import Any
from torch import Tensor
from torch . utils . data import DataLoader , Dataset as _Dataset
2023-08-14 03:07:45 +00:00
from torch . utils . data . distributed import DistributedSampler
2023-08-02 21:53:35 +00:00
from tqdm . auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging . getLogger ( __name__ )
def get_phone_symmap ( ) :
if cfg . dataset . use_hdf5 and ' symmap ' in cfg . hdf5 :
return json . loads ( cfg . hdf5 [ ' symmap ' ] . asstr ( ) [ ( ) ] )
symmap = { ' <s> ' : 1 , ' </s> ' : 2 , ' ' : 3 , ' . ' : 4 , ' , ' : 5 , ' ! ' : 6 , ' ? ' : 7 , ' p ' : 7 , ' iː ' : 8 , ' ɚ ' : 9 , ' ˌ ' : 10 , ' dˌ ' : 11 , ' mˌ ' : 12 , ' d ' : 13 , ' ɹ ' : 14 , ' tˈ ' : 15 , ' pˌ ' : 16 , ' uː ' : 17 , ' l ' : 18 , ' æ ' : 19 , ' ɛ ' : 20 , ' ɪ ' : 21 , ' j ' : 22 , ' ʊ ' : 23 , ' t ' : 24 , ' n ' : 25 , ' v ' : 26 , ' a ' : 27 , ' o ' : 28 , ' ŋ ' : 29 , ' w ' : 30 , ' ʌ ' : 31 , ' hˈ ' : 32 , ' ɡ ˈ ' : 33 , ' ə ' : 34 , ' θˈ ' : 35 , ' dˈ ' : 36 , ' wˌ ' : 37 , ' h ' : 38 , ' z ' : 39 , ' k ' : 40 , ' ð ' : 41 , ' ɡˌ ' : 42 , ' ˈ ' : 43 , ' fˈ ' : 44 , ' i ' : 45 , ' s ' : 46 , ' ʃ ' : 47 , ' wˈ ' : 48 , ' ðˈ ' : 49 , ' ɹˈ ' : 50 , ' lˈ ' : 51 , ' ɡ ' : 52 , ' oː ' : 53 , ' mˈ ' : 54 , ' e ' : 55 , ' ɑ ː ' : 56 , ' nˈ ' : 57 , ' m ' : 58 , ' θˌ ' : 59 , ' sˈ ' : 60 , ' f ' : 61 , ' ɔː ' : 62 , ' hˌ ' : 63 , ' b ' : 64 , ' jˈ ' : 65 , ' ɐ ' : 66 , ' ʒˈ ' : 67 , ' θ ' : 68 , ' bˈ ' : 69 , ' ɾ ' : 70 , ' ɜː ' : 71 , ' ʌˈ ' : 72 , ' ʃˌ ' : 73 , ' bˌ ' : 74 , ' kˈ ' : 75 , ' ɔ ' : 76 , ' zˈ ' : 77 , ' ᵻ ' : 78 , ' kˌ ' : 79 , ' vˈ ' : 80 , ' fˌ ' : 81 , ' ʒ ' : 82 , ' ʃˈ ' : 83 , ' ɹˌ ' : 84 , ' tˌ ' : 85 , ' pˈ ' : 86 , ' ðˌ ' : 87 , ' sˌ ' : 88 , ' nˌ ' : 89 , ' lˌ ' : 90 , ' ̩ ' : 91 , ' ʔ ' : 92 , ' vˌ ' : 93 , ' ɪ ˈ ' : 94 , ' " ' : 95 , ' ɪˌ ' : 96 , ' ʒˌ ' : 97 , ' uː ˌ ' : 98 , ' ʊˈ ' : 99 , ' jˌ ' : 100 , ' uː ˈ ' : 101 , ' iː ˈ ' : 102 , ' zˌ ' : 103 , ' .ˈ ' : 104 , ' … ' : 105 , ' ŋˌ ' : 106 , ' ɐˌ ' : 107 , ' —ˈ ' : 108 , ' iˌ ' : 109 , ' iː ˌ ' : 110 , ' ɛː ' : 111 , ' ) ' : 112 , ' )ˈ ' : 113 , ' ( ' : 114 , ' u ' : 115 , ' - ' : 116 , ' ɖˈ ' : 117 , ' iˈ ' : 118 , ' ʰˈ ' : 119 , ' ɟˈ ' : 120 , ' ̃ ' : 121 , ' eː ' : 122 , ' ɾˈ ' : 123 , ' r ' : 124 , ' ʰ ' : 125 , ' -ˌ ' : 126 , ' ɫ ' : 127 , ' q ' : 128 , ' — ' : 129 , ' ʊˌ ' : 130 , ' aː ' : 131 , ' cˈ ' : 132 , ' …ˈ ' : 133 , ' c ' : 134 , ' ɳ ' : 135 , ' ɐˈ ' : 136 , ' x ' : 137 , ' ʔˌ ' : 138 , ' .ˌ ' : 139 , ' ɑ ' : 140 , ' ?ˈ ' : 141 , ' ̩ˈ ' : 142 , ' " ˈ ' : 143 , ' ,ˈ ' : 144 , ' ŋˈ ' : 145 , ' əˌ ' : 146 , ' !ˈ ' : 147 , ' " ˌ ' : 148 , ' ?ˌ ' : 149 , ' ,ˌ ' : 150 , ' —ˌ ' : 151 , ' ̩ˌ ' : 152 , ' əˈ ' : 153 , ' !ˌ ' : 154 , ' ɬ ' : 155 , ' ʲ ' : 156 , ' ¡ ' : 157 , ' ɯ ' : 158 , ' qˌ ' : 159 , ' ʑ ' : 160 , ' ʑˈ ' : 161 , ' ¿ ' : 162 , ' ɑ ː ˈ ' : 163 , ' iː ː ' : 164 , ' ɛˈ ' : 165 , ' ¡ˈ ' : 166 , ' æˈ ' : 167 , ' ç ' : 168 , ' ɾˌ ' : 169 , ' ᵻˈ ' : 170 , ' xˈ ' : 171 , ' ɔːˈ ' : 172 , ' ; ' : 173 , ' ɬˌ ' : 174 , ' : ' : 175 , ' ʔ ˈ ' : 176 , ' ɑːˌ ' : 177 , ' ɬˈ ' : 178 }
return symmap
2023-08-19 03:22:13 +00:00
def get_task_symmap ( ) :
start = 1024
symmap = {
" <tts> " : - 100 ,
" <ns> " : start + 0 ,
" <sr> " : start + 1 ,
" <tse> " : start + 2 ,
" <soe> " : start + 3 ,
" <mask> " : start + 4 ,
" <eoe> " : start + 5 ,
" <svc> " : start + 6 ,
}
return symmap
2023-08-02 21:53:35 +00:00
def _replace_file_extension ( path , suffix ) :
return ( path . parent / path . name . split ( " . " ) [ 0 ] ) . with_suffix ( suffix )
def _get_hdf5_path ( path ) :
path = str ( path )
if path [ : 2 ] != " ./ " :
path = f ' ./ { path } '
return path . replace ( cfg . cfg_path , " " )
def _get_quant_path ( path ) :
return _replace_file_extension ( path , " .qnt.pt " )
def _get_phone_path ( path ) :
return _replace_file_extension ( path , " .phn.txt " )
def _load_quants ( path ) - > Tensor :
path = _get_quant_path ( path )
2023-08-19 01:58:07 +00:00
return torch . load ( path ) [ 0 ] [ : cfg . models . prom_levels , : ] . t ( ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
@cache
def _get_phones ( path , lang_marker = " en " ) :
path = _get_phone_path ( path )
with open ( path , " r " , encoding = " utf8 " ) as f :
content = f . read ( )
split = content . split ( " " )
return [ f " <s> " ] + [ " " if not p else p for p in split ] + [ f " </s> " ]
def _interleaved_reorder ( l , fn ) :
groups = defaultdict ( list )
for e in l :
groups [ fn ( e ) ] . append ( e )
groups = { k : groups [ k ] for k in sorted ( groups ) }
for interleaved in zip_longest ( * groups . values ( ) ) :
for value in interleaved :
if value is not None :
yield value
@cache
def _validate ( path , min_phones , max_phones , min_duration , max_duration ) :
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
if key not in cfg . hdf5 :
return False
phones = cfg . hdf5 [ key ] . attrs [ ' phonemes ' ]
duration = cfg . hdf5 [ key ] . attrs [ ' duration ' ]
if phones < min_phones or phones > max_phones :
return False
if duration < min_duration or duration > max_duration :
return False
return True
if not os . path . exists ( _get_phone_path ( path ) ) or not os . path . exists ( _get_quant_path ( path ) ) :
return False
phones = _get_phones ( path )
unique_phones = list ( set ( phones ) )
if len ( unique_phones ) == 0 :
return False
if len ( unique_phones ) == 1 and unique_phones [ 0 ] == " " :
return False
if len ( phones ) < min_phones or len ( phones ) > max_phones :
return False
return True
class Dataset ( _Dataset ) :
def __init__ (
self ,
paths ,
phone_symmap = None ,
spkr_symmap = None ,
2023-08-19 03:22:13 +00:00
task_symmap = None ,
2023-08-02 21:53:35 +00:00
min_phones = cfg . dataset . phones_range [ 0 ] ,
max_phones = cfg . dataset . phones_range [ 1 ] ,
min_duration = cfg . dataset . duration_range [ 0 ] ,
max_duration = cfg . dataset . duration_range [ 1 ] ,
training = False ,
extra_paths_by_spkr_name : dict [ str , list ] = { } ,
) :
super ( ) . __init__ ( )
self . _head = None
self . min_phones = min_phones
self . max_phones = max_phones
self . min_duration = min_duration
self . max_duration = max_duration
2023-08-19 01:58:07 +00:00
self . sampler = None
2023-08-02 21:53:35 +00:00
if cfg . dataset . validate :
self . paths = [
path for path in paths if _validate ( path , self . min_phones , self . max_phones , self . min_duration , self . max_duration )
]
else :
self . paths = paths
self . phone_symmap = phone_symmap or self . _get_phone_symmap ( )
2023-08-19 03:22:13 +00:00
self . spkr_symmap = spkr_symmap or self . _get_spkr_symmap ( )
self . task_symmap = get_task_symmap or self . _get_task_symmap ( )
2023-08-02 21:53:35 +00:00
self . training = training
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self . text_dtype = torch . uint8 if len ( self . phone_symmap ) < 256 else torch . int16
self . paths_by_spkr_name = self . _get_paths_by_spkr_name ( extra_paths_by_spkr_name )
if cfg . dataset . validate :
self . paths = [
p for p in self . paths if len ( self . paths_by_spkr_name [ cfg . get_spkr ( p ) ] ) > 1
]
2023-08-19 01:58:07 +00:00
if cfg . dataset . sample_type == " path " :
self . paths = [ * _interleaved_reorder ( self . paths , cfg . get_spkr ) ]
2023-08-02 21:53:35 +00:00
if len ( self . paths ) == 0 and training :
raise ValueError ( " No valid path is found for training. " )
self . duration = 0
self . durations = { }
if cfg . dataset . use_hdf5 :
for path in self . paths :
key = _get_hdf5_path ( path )
spkr_name = cfg . get_spkr ( path )
spkr_id = self . spkr_symmap [ spkr_name ]
duration = cfg . hdf5 [ key ] . attrs [ ' duration ' ]
self . duration + = duration
if spkr_id not in self . durations :
self . durations [ spkr_id ] = duration
else :
self . durations [ spkr_id ] + = duration
2023-08-19 03:22:13 +00:00
2023-08-02 21:53:35 +00:00
def _get_paths_by_spkr_name ( self , extra_paths_by_spkr_name : dict [ str , list ] ) :
ret = defaultdict ( list )
for path in self . paths :
ret [ cfg . get_spkr ( path ) ] . append ( path )
for k , v in extra_paths_by_spkr_name . items ( ) :
ret [ k ] . extend ( v )
return { * * ret }
@cached_property
def phones ( self ) :
return sorted ( set ( ) . union ( * [ _get_phones ( path ) for path in self . paths ] ) )
@cached_property
def spkrs ( self ) :
return sorted ( { cfg . get_spkr ( path ) for path in self . paths } )
2023-08-19 03:22:13 +00:00
@cached_property
def tasks ( self ) :
2023-08-19 05:16:08 +00:00
return cfg . dataset . tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
2023-08-19 03:22:13 +00:00
def _get_phone_symmap ( self ) :
return get_phone_symmap ( )
2023-08-02 21:53:35 +00:00
def _get_spkr_symmap ( self ) :
return { s : i for i , s in enumerate ( self . spkrs ) }
2023-08-19 03:22:13 +00:00
def _get_task_symmap ( self ) :
return get_task_symmap ( )
2023-08-19 04:55:40 +00:00
def get_task_token ( self , token ) :
2023-08-19 05:16:08 +00:00
if not hasattr ( self , " task_symmap " ) :
self . task_symmap = self . _get_task_symmap ( )
2023-08-19 04:55:40 +00:00
return torch . Tensor ( [ [ self . task_symmap [ f ' < { token } > ' ] for _ in range ( cfg . models . prom_levels ) ] ] ) . to ( dtype = torch . int16 )
2023-08-19 03:22:13 +00:00
2023-08-19 05:16:08 +00:00
def sample_noise ( self ) :
2023-08-19 04:55:40 +00:00
paths = [ ]
for data_dir in cfg . dataset . noise :
paths . extend ( data_dir . rglob ( " *.qnt.pt " ) )
path = random . choice ( paths )
2023-08-19 05:16:08 +00:00
if False and cfg . dataset . use_hdf5 :
key = f ' /noise/ { _get_hdf5_path ( path ) } '
qnt = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : cfg . models . prom_levels ] ) . to ( torch . int16 )
else :
qnt = _load_quants ( path )
return qnt
2023-08-19 03:22:13 +00:00
2023-08-18 00:07:59 +00:00
def sample_speakers ( self , ignore = [ ] ) :
choices = set ( self . spkrs ) - set ( ignore )
return random . choice ( [ * choices ] )
2023-08-02 21:53:35 +00:00
def sample_prompts ( self , spkr_name , ignore ) :
prom_list = [ ]
choices = set ( self . paths_by_spkr_name [ spkr_name ] ) - { ignore }
choices = [ * choices ]
2023-08-17 00:39:21 +00:00
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step
2023-08-02 21:53:35 +00:00
if len ( choices ) == 0 :
2023-08-17 00:39:21 +00:00
choices = [ * set ( self . paths_by_spkr_name [ spkr_name ] ) ]
"""
2023-08-02 21:53:35 +00:00
raise ValueError (
f " Failed to find another different utterance for { spkr_name } . "
)
2023-08-17 00:39:21 +00:00
"""
2023-08-02 21:53:35 +00:00
# shuffle it up a bit
2023-08-19 04:55:40 +00:00
prom_length = 0
trim_length = int ( cfg . dataset . prompt_duration * 75 ) + random . randint ( - 16 , 16 )
2023-08-02 21:53:35 +00:00
for _ in range ( cfg . dataset . max_prompts ) :
path = random . choice ( choices )
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2023-08-19 01:58:07 +00:00
qnt = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : cfg . models . prom_levels ] ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
else :
qnt = _load_quants ( path )
if cfg . dataset . prompt_duration > 0 and trim_length < qnt . shape [ 0 ] :
2023-08-19 03:22:13 +00:00
qnt = trim_random ( qnt , trim_length )
2023-08-02 21:53:35 +00:00
prom_list . append ( qnt )
2023-08-19 04:55:40 +00:00
prom_length + = qnt . shape [ 0 ]
2023-08-02 21:53:35 +00:00
2023-08-19 04:55:40 +00:00
if prom_length > = trim_length or random . random ( ) > cfg . dataset . random_utterance :
2023-08-02 21:53:35 +00:00
break
prom = torch . cat ( prom_list )
if cfg . dataset . prompt_duration > 0 and trim_length < prom . shape [ 0 ] :
2023-08-19 03:22:13 +00:00
prom = trim_random ( prom , trim_length )
2023-08-02 21:53:35 +00:00
return prom
def __getitem__ ( self , index ) :
2023-08-19 01:58:07 +00:00
if cfg . dataset . sample_type == " speaker " :
2023-08-17 00:39:21 +00:00
spkr_name = self . spkrs [ index ]
spkr_id = self . spkr_symmap [ spkr_name ]
path = random . choice ( [ * set ( self . paths_by_spkr_name [ spkr_name ] ) ] )
2023-08-02 21:53:35 +00:00
else :
2023-08-18 19:47:48 +00:00
path = self . paths [ index ]
2023-08-17 00:39:21 +00:00
spkr_name = cfg . get_spkr ( path )
spkr_id = self . spkr_symmap [ spkr_name ]
2023-08-02 21:53:35 +00:00
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
text = torch . from_numpy ( cfg . hdf5 [ key ] [ " text " ] [ : ] ) . to ( self . text_dtype )
2023-08-19 01:58:07 +00:00
resps = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : cfg . models . prom_levels ] ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
else :
text = torch . tensor ( [ * map ( self . phone_symmap . get , _get_phones ( path ) ) ] ) . to ( self . text_dtype )
resps = _load_quants ( path )
2023-08-17 23:56:37 +00:00
task = random . choice ( self . tasks )
2023-08-19 03:22:13 +00:00
2023-08-19 14:50:07 +00:00
# ensure a speaker has at least four utterances
# default to tts if not
if len ( set ( self . paths_by_spkr_name [ spkr_name ] ) - { path } ) < 4 :
task = " tts "
2023-08-19 06:16:46 +00:00
noise_scale = 0.125
2023-08-18 00:07:59 +00:00
# text-to-speech
2023-08-17 23:56:37 +00:00
if task == " tts " :
proms = self . sample_prompts ( spkr_name , ignore = path ) if random . random ( ) < cfg . dataset . random_utterance else resps
2023-08-19 03:22:13 +00:00
# noise suppression || speech removal
elif task == " ns " or task == " sr " :
# sample random noise
2023-08-18 00:07:59 +00:00
noise = self . sample_noise ( )
2023-08-19 03:22:13 +00:00
# extend the noise to fill the target audio
noise = repeat_extend_audio ( noise , resps . shape [ 0 ] )
# create the input prompt by merging the target audio with the noise
2023-08-19 06:16:46 +00:00
proms = merge_audio ( resps , noise , scale = [ 1 , noise_scale ] , device = " cpu " )
2023-08-19 03:22:13 +00:00
# set the target to just be the noise if <sr>
if task == " sr " :
resps = noise
# prepend the task token
proms = torch . cat ( [ self . get_task_token ( task ) , proms ] )
2023-08-19 04:55:40 +00:00
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype )
2023-08-19 03:22:13 +00:00
# target speech extraction
2023-08-18 19:47:48 +00:00
elif task == " tse " :
2023-08-19 03:22:13 +00:00
# sample a random, clean, utterance for the target speaker
2023-08-19 06:16:46 +00:00
clean_proms = self . sample_prompts ( spkr_name , ignore = path )
2023-08-19 03:22:13 +00:00
# sample a random, clean utterance from a different speaker
2023-08-19 04:55:40 +00:00
other_proms = self . sample_prompts ( self . sample_speakers ( ignore = [ spkr_name ] ) , ignore = " " )
2023-08-19 03:22:13 +00:00
# overlay the random speaker over the target audio
2023-08-19 04:55:40 +00:00
smallest_size = min ( resps . shape [ 0 ] , other_proms . shape [ 0 ] )
if other_proms . shape [ 0 ] == smallest_size :
2023-08-19 05:16:08 +00:00
noisy_proms = merge_audio ( resps [ : smallest_size , : ] , other_proms , scale = [ 1 , random . uniform ( 0.5 , 0.75 ) ] , device = " cpu " )
2023-08-19 04:55:40 +00:00
noisy_proms = torch . cat ( [ noisy_proms , resps [ smallest_size : , : ] ] )
else :
2023-08-19 05:16:08 +00:00
noisy_proms = merge_audio ( resps , other_proms [ : smallest_size , : ] , scale = [ 1 , random . uniform ( 0.5 , 0.75 ) ] , device = " cpu " )
2023-08-19 04:55:40 +00:00
noisy_proms = torch . cat ( [ noisy_proms , other_proms [ smallest_size : , : ] ] )
# stitch together the promps
2023-08-19 03:22:13 +00:00
proms = torch . cat ( [ clean_proms , self . get_task_token ( task ) , noisy_proms ] )
2023-08-19 04:55:40 +00:00
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
2023-08-19 06:16:46 +00:00
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype ) # <s></s>
2023-08-19 04:55:40 +00:00
2023-08-18 19:47:48 +00:00
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
2023-08-19 03:22:13 +00:00
# clean speech editing
2023-08-19 06:16:46 +00:00
elif task == " cse " or task == " nse " :
choices = set ( self . paths_by_spkr_name [ spkr_name ] ) - { path }
2023-08-19 14:50:07 +00:00
sampled = random . sample ( [ * choices ] , 4 )
2023-08-19 06:16:46 +00:00
if cfg . dataset . use_hdf5 :
texts = [ torch . from_numpy ( cfg . hdf5 [ _get_hdf5_path ( path ) ] [ " text " ] [ : ] ) . to ( self . text_dtype ) for path in sampled ]
qnts = [ torch . from_numpy ( cfg . hdf5 [ _get_hdf5_path ( path ) ] [ " audio " ] [ : , : cfg . models . prom_levels ] ) . to ( torch . int16 ) for path in sampled ]
else :
texts = [ torch . tensor ( [ * map ( self . phone_symmap . get , _get_phones ( path ) ) ] ) . to ( self . text_dtype ) for path in sampled ]
qnts = [ _load_quants ( path ) for path in sampled ]
# remove <s></s>
2023-08-19 14:50:07 +00:00
for i in range ( len ( texts ) ) :
texts [ i ] = texts [ i ] [ 1 : - 1 ]
2023-08-19 06:16:46 +00:00
pre_text , mid_text , post_text , edit_text = texts
pre_prom , mid_prom , post_prom , edit_prom = qnts
# randomly drop out pre
if random . random ( ) < 0.125 :
pre_text = None
pre_prom = None
# randomly drop out post
if random . random ( ) < 0.125 :
post_text = None
post_prom = None
# create new text
text = torch . cat (
2023-08-19 14:50:07 +00:00
[ torch . Tensor ( [ 1 ] ) . to ( dtype = self . text_dtype ) ] + # <s>
( [ pre_text , torch . Tensor ( [ 3 ] ) . to ( dtype = self . text_dtype ) ] if pre_text is not None else [ ] ) + # pre_text + space'
[ edit_text ] + # 'edit text'
( [ torch . Tensor ( [ 3 ] ) . to ( dtype = self . text_dtype ) , post_text ] if post_text is not None else [ ] ) + # 'space' + edit_text
[ torch . Tensor ( [ 2 ] ) . to ( dtype = self . text_dtype ) ] # </s>
2023-08-19 06:16:46 +00:00
)
if task == " nse " :
# sample random noise
noise = self . sample_noise ( )
# it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent
# but it's noise, it's supposed to be random
def noise_proms ( proms ) :
# ignore if we turned it off
if proms is None :
return None
# extend the noise to fill the target audio
n = repeat_extend_audio ( noise , proms . shape [ 0 ] )
# merge the noise over the utterance
2023-08-19 14:50:07 +00:00
return merge_audio ( proms , n , scale = [ 1 , noise_scale ] , device = " cpu " )
2023-08-19 06:16:46 +00:00
# apply noise to all pieces
pre_prom = noise_proms ( pre_prom )
mid_prom = noise_proms ( mid_prom )
post_prom = noise_proms ( post_prom )
edit_prom = noise_proms ( edit_prom )
else :
mid_prom = self . get_task_token ( " mask " )
# create new proms
proms = torch . cat (
( [ pre_prom ] if pre_prom is not None else [ ] ) +
[ self . get_task_token ( " soe " ) ] +
[ mid_prom ] + # is <mask> if task is CSE
[ self . get_task_token ( " eoe " ) ] +
( [ post_prom ] if post_prom is not None else [ ] )
)
# create new resp
resps = torch . cat (
( [ pre_prom ] if pre_prom is not None else [ ] ) +
[ edit_prom ] +
( [ post_prom ] if post_prom is not None else [ ] )
)
2023-08-19 04:55:40 +00:00
"""
# emulate SVC
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
# targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip
# NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical
# ahead-of-time dataset preparation of a shit ton of RVC clips might be the key.
# aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task.
elif task == " svc " :
# sample a random, clean utterance for the target speaker
proms = self . sample_prompts ( spkr_name , ignore = path ) if random . random ( ) < cfg . dataset . random_utterance else resps
# sample a reference clip from a different speaker
ref_proms = self . sample_rvc ( self . sample_speakers ( ignore = [ spkr_name ] ) )
#
resps =
# stitch together the promps
proms = torch . cat ( [ proms , self . get_task_token ( task ) , ref_proms ] )
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype )
"""
2023-08-18 00:07:59 +00:00
2023-08-02 21:53:35 +00:00
return dict (
index = index ,
path = path ,
spkr_name = spkr_name ,
spkr_id = spkr_id ,
2023-08-17 23:56:37 +00:00
task = task ,
2023-08-02 21:53:35 +00:00
text = text ,
proms = proms ,
resps = resps ,
)
def head_ ( self , n ) :
self . _head = n
def training_ ( self , value ) :
self . training = value
def __len__ ( self ) :
2023-08-19 01:58:07 +00:00
if cfg . dataset . sample_type == " speaker " :
2023-08-17 00:39:21 +00:00
return min ( len ( self . spkrs ) , self . _head or len ( self . spkrs ) )
2023-08-02 21:53:35 +00:00
return min ( len ( self . paths ) , self . _head or len ( self . paths ) )
def pin_memory ( self ) :
self . text = self . text . pin_memory ( )
self . proms = self . proms . pin_memory ( )
self . resps = self . resps . pin_memory ( )
self . resp = self . resp . pin_memory ( )
return self
def collate_fn ( samples : list [ dict ] ) :
batch : dict [ str , Any ] = { k : [ s [ k ] for s in samples ] for k in samples [ 0 ] }
return batch
def _seed_worker ( worker_id ) :
worker_seed = torch . initial_seed ( ) % 2 * * 32
np . random . seed ( worker_seed )
random . seed ( worker_seed )
def _create_dataloader ( dataset , training ) :
2023-08-17 18:41:53 +00:00
sampler = None
shuffle = True
if cfg . distributed and training :
sampler = DistributedSampler ( dataset )
shuffle = False
2023-08-02 21:53:35 +00:00
return DataLoader (
dataset = dataset ,
batch_size = cfg . hyperparameters . batch_size if training else cfg . evaluation . batch_size ,
2023-08-17 18:41:53 +00:00
shuffle = shuffle ,
2023-08-02 21:53:35 +00:00
drop_last = training ,
num_workers = cfg . dataset . workers ,
collate_fn = collate_fn ,
persistent_workers = True ,
pin_memory = False , # True,
worker_init_fn = _seed_worker ,
2023-08-17 18:41:53 +00:00
sampler = sampler ,
2023-08-02 21:53:35 +00:00
)
def _load_dataset_paths ( ) :
hf = cfg . hdf5
paths = {
" training " : [ ] ,
" validation " : [ ] ,
}
datasets = {
" training " : [ ] ,
" validation " : [ ] ,
}
def get_paths ( data_dir , type = " training " ) :
key = f " / { type } { _get_hdf5_path ( data_dir ) } "
if key not in cfg . hdf5 :
return
paths [ type ] . extend ( [ f " { key } / { child . attrs [ ' id ' ] } " for child in cfg . hdf5 [ key ] . values ( ) ] )
for data_dir in cfg . dataset . training :
get_paths ( data_dir , " training " )
for data_dir in cfg . dataset . validation :
get_paths ( data_dir , " validation " )
for _ , type in enumerate ( paths ) :
dirs = paths [ type ]
if len ( dirs ) == 0 :
continue
dirs = [ Path ( p ) for p in dirs ]
pairs = sorted ( [ ( cfg . get_spkr ( p ) , p ) for p in dirs ] )
for _ , group in groupby ( pairs , lambda pair : pair [ 0 ] ) :
shuffled = sorted ( [ p for _ , p in group ] )
random . seed ( 0 )
random . shuffle ( shuffled )
datasets [ type ] . extend ( shuffled )
return datasets [ " training " ] , datasets [ " validation " ]
2023-08-17 23:56:37 +00:00
# to-do: mirror the hdf5-based load function
2023-08-02 21:53:35 +00:00
def _load_train_val_paths ( ) :
paths = [ ]
train_paths = [ ]
val_paths = [ ]
for data_dir in cfg . dataset . training :
paths . extend ( data_dir . rglob ( " *.qnt.pt " ) )
if len ( paths ) > 0 :
pairs = sorted ( [ ( cfg . get_spkr ( p ) , p ) for p in paths ] )
del paths
for _ , group in groupby ( pairs , lambda pair : pair [ 0 ] ) :
paths = sorted ( [ p for _ , p in group ] )
random . seed ( 0 )
random . shuffle ( paths )
train_paths . extend ( paths )
for data_dir in cfg . dataset . validation :
paths . extend ( data_dir . rglob ( " *.qnt.pt " ) )
if len ( paths ) > 0 :
pairs = sorted ( [ ( cfg . get_spkr ( p ) , p ) for p in paths ] )
del paths
for _ , group in groupby ( pairs , lambda pair : pair [ 0 ] ) :
paths = sorted ( [ p for _ , p in group ] )
random . seed ( 0 )
random . shuffle ( paths )
val_paths . extend ( paths )
train_paths , val_paths = map ( sorted , [ train_paths , val_paths ] )
if len ( train_paths ) == 0 :
2023-08-17 23:56:37 +00:00
raise RuntimeError ( f " Failed to find any .qnt.pt file in specified training dataset. " )
# to-do: allow setting aside a fixed portion of the training dataset for validation
# something like the last X percent of each speaker is set aside
2023-08-02 21:53:35 +00:00
if len ( val_paths ) == 0 :
val_paths = [ train_paths [ 0 ] ]
return train_paths , val_paths
@cfg.diskcache ( )
def create_datasets ( ) :
train_paths , val_paths = _load_dataset_paths ( ) if cfg . dataset . use_hdf5 else _load_train_val_paths ( )
train_dataset = Dataset (
train_paths ,
training = True ,
)
val_dataset = Dataset (
val_paths ,
train_dataset . phone_symmap ,
#train_dataset.spkr_symmap,
#extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
)
val_dataset . head_ ( cfg . evaluation . size )
return train_dataset , val_dataset
def create_train_val_dataloader ( ) :
train_dataset , val_dataset = create_datasets ( )
subtrain_dataset = copy . deepcopy ( train_dataset )
2023-08-19 01:58:07 +00:00
if cfg . dataset . sample_type == " path " :
2023-08-17 20:04:45 +00:00
subtrain_dataset . head_ ( cfg . evaluation . size )
2023-08-02 21:53:35 +00:00
train_dl = _create_dataloader ( train_dataset , training = True )
val_dl = _create_dataloader ( val_dataset , training = False )
subtrain_dl = _create_dataloader ( subtrain_dataset , training = False )
_logger . info ( str ( train_dataset . phone_symmap ) )
_logger . info ( str ( train_dataset . spkr_symmap ) )
_logger . info ( f " #samples (train): { len ( train_dataset ) } . " )
_logger . info ( f " #samples (val): { len ( val_dataset ) } . " )
_logger . info ( f " #samples (subtrain): { len ( subtrain_dataset ) } . " )
"""
_logger . info ( f " #durations (train): { str ( train_dataset . durations ) } . " )
_logger . info ( f " #durations (val): { str ( val_dataset . durations ) } . " )
_logger . info ( f " #durations (subtrain): { str ( subtrain_dataset . durations ) } . " )
"""
_logger . info ( f " #duration (train): { str ( train_dataset . duration ) } . " )
_logger . info ( f " #duration (val): { str ( val_dataset . duration ) } . " )
_logger . info ( f " #duration (subtrain): { str ( subtrain_dataset . duration ) } . " )
assert isinstance ( subtrain_dl . dataset , Dataset )
return train_dl , subtrain_dl , val_dl
2023-08-19 14:50:07 +00:00
# parse yaml to create an hdf5 file
2023-08-02 21:53:35 +00:00
def create_dataset_hdf5 ( ) :
2023-08-19 14:50:07 +00:00
cfg . dataset . use_hdf5 = True
cfg . load_hdf5 ( write = True )
2023-08-02 21:53:35 +00:00
symmap = get_phone_symmap ( )
root = cfg . cfg_path
hf = cfg . hdf5
2023-08-19 14:50:07 +00:00
def add ( dir , type = " training " , audios = True , texts = True ) :
2023-08-02 21:53:35 +00:00
dir = " ./ " + str ( dir )
name = dir . replace ( root , " " )
print ( str ( dir ) , name )
if not os . path . isdir ( f ' { root } / { name } / ' ) :
return
# tqdm.write(f'{root}/{name}')
files = os . listdir ( f ' { root } / { name } / ' )
# grab IDs for every file
ids = { " . " . join ( file . split ( " . " ) [ : - 2 ] ) for file in files }
for id in tqdm ( ids , desc = f " Processing { name } " ) :
2023-08-19 14:50:07 +00:00
audio_exists = os . path . exists ( f ' { root } / { name } / { id } .qnt.pt ' ) if audios else True
text_exists = os . path . exists ( f ' { root } / { name } / { id } .phn.txt ' ) if texts else True
if not audio_exists or not text_exists :
2023-08-02 21:53:35 +00:00
continue
key = f ' { type } / { name } / { id } '
if key in hf :
# print("Skipping existing entry:", key)
continue
group = hf . create_group ( key )
# audio
2023-08-19 14:50:07 +00:00
if audios :
qnt = torch . load ( f ' { root } / { name } / { id } .qnt.pt ' ) [ 0 ] . t ( )
group . create_dataset ( ' audio ' , data = qnt . numpy ( ) , compression = ' lzf ' )
2023-08-02 21:53:35 +00:00
# text
2023-08-19 14:50:07 +00:00
if texts :
with open ( f ' { root } / { name } / { id } .phn.txt ' , " r " , encoding = " utf8 " ) as f :
content = f . read ( )
split = content . split ( " " )
phones = [ f " <s> " ] + [ " " if not p else p for p in split ] + [ f " </s> " ]
for s in set ( phones ) :
if s not in symmap :
symmap [ s ] = len ( symmap . keys ( ) )
phn = [ symmap [ s ] for s in phones ]
group . create_dataset ( ' text ' , data = phn , compression = ' lzf ' , chunks = True )
# metadata
group . attrs [ ' id ' ] = id
group . attrs [ ' type ' ] = type
group . attrs [ ' speaker ' ] = name
group . attrs [ ' duration ' ] = qnt . shape [ 0 ] / 75
group . attrs [ ' phonemes ' ] = len ( phn )
2023-08-02 21:53:35 +00:00
# training
for data_dir in tqdm ( cfg . dataset . training , desc = " Processing Training " ) :
add ( data_dir , type = " training " )
# validation
for data_dir in tqdm ( cfg . dataset . validation , desc = ' Processing Validation ' ) :
add ( data_dir , type = " validation " )
2023-08-19 04:57:07 +00:00
# noise
for data_dir in tqdm ( cfg . dataset . noise , desc = ' Processing Noise ' ) :
2023-08-19 14:50:07 +00:00
add ( data_dir , type = " noise " , texts = False )
2023-08-19 04:57:07 +00:00
2023-08-02 21:53:35 +00:00
# write symmap
2023-08-19 14:50:07 +00:00
try :
hf . create_dataset ( ' symmap ' , data = json . dumps ( symmap ) )
except Exception as e :
pass
2023-08-02 21:53:35 +00:00
hf . close ( )
if __name__ == " __main__ " :
2023-08-17 20:04:45 +00:00
import argparse
parser = argparse . ArgumentParser ( " Save trained model to path. " )
2023-08-19 14:50:07 +00:00
parser . add_argument ( " --action " , type = str )
parser . add_argument ( " --tasks " , type = str )
2023-08-17 20:04:45 +00:00
args = parser . parse_args ( )
2023-08-19 14:50:07 +00:00
task = args . action
2023-08-19 04:55:40 +00:00
2023-08-19 14:50:07 +00:00
if args . action == " hdf5 " :
2023-08-17 20:04:45 +00:00
create_dataset_hdf5 ( )
2023-08-19 14:50:07 +00:00
elif args . action == " sample " :
2023-08-19 04:55:40 +00:00
train_dl , subtrain_dl , val_dl = create_train_val_dataloader ( )
samples = {
" training " : [ next ( iter ( train_dl ) ) , next ( iter ( train_dl ) ) ] ,
" evaluation " : [ next ( iter ( subtrain_dl ) ) , next ( iter ( subtrain_dl ) ) ] ,
" validation " : [ next ( iter ( val_dl ) ) , next ( iter ( val_dl ) ) ] ,
}
for k , v in samples . items ( ) :
for i in range ( len ( v ) ) :
del v [ i ] [ ' proms ' ]
del v [ i ] [ ' resps ' ]
print ( f ' { k } : ' , v )
2023-08-19 14:50:07 +00:00
elif args . action == " tasks " :
2023-08-19 04:55:40 +00:00
index = 0
2023-08-19 14:50:07 +00:00
cfg . dataset . tasks_list = args . tasks . split ( " , " )
train_dl , subtrain_dl , val_dl = create_train_val_dataloader ( )
batch = next ( iter ( train_dl ) )
2023-08-18 19:47:48 +00:00
2023-08-19 14:50:07 +00:00
for text , resps , proms , task in zip ( batch [ " text " ] , batch [ " resps " ] , batch [ " proms " ] , batch [ " task " ] ) :
if task not in cfg . dataset . tasks_list :
continue
2023-08-18 19:47:48 +00:00
2023-08-19 14:50:07 +00:00
print ( text , task )
decode_to_file ( proms , f " ./. { task } .proms.wav " , device = " cpu " )
decode_to_file ( resps , f " ./. { task } .resps.wav " , device = " cpu " )
break