2023-08-02 21:53:35 +00:00
# todo: clean this mess up
import copy
import h5py
import json
import logging
import numpy as np
import os
import random
import torch
2023-08-27 00:53:23 +00:00
import itertools
2023-08-02 21:53:35 +00:00
from . config import cfg
2023-08-23 21:43:03 +00:00
from . emb . qnt import trim , trim_random , repeat_extend_audio , merge_audio , decode_to_file
2023-09-04 02:27:13 +00:00
from . utils . sampler import Sampler
2023-08-02 21:53:35 +00:00
from collections import defaultdict
from functools import cache , cached_property
from itertools import groupby , zip_longest
from pathlib import Path
from typing import Any
from torch import Tensor
from torch . utils . data import DataLoader , Dataset as _Dataset
2023-08-14 03:07:45 +00:00
from torch . utils . data . distributed import DistributedSampler
2023-08-02 21:53:35 +00:00
from tqdm . auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging . getLogger ( __name__ )
2024-04-16 00:54:32 +00:00
# to-do: clean up this symmap mess
2023-08-02 21:53:35 +00:00
def get_phone_symmap ( ) :
2024-04-21 19:49:18 +00:00
return cfg . tokenizer . get_vocab ( )
2023-08-02 21:53:35 +00:00
2024-04-21 19:49:18 +00:00
def tokenize ( phones ) :
2024-04-30 03:14:01 +00:00
return cfg . tokenizer . encode ( " " . join ( phones ) )
2023-10-12 01:38:40 +00:00
def get_lang_symmap ( ) :
2024-04-16 00:54:32 +00:00
return {
2023-10-12 01:38:40 +00:00
" en " : 0 ,
" ja " : 1 ,
}
2024-04-16 00:54:32 +00:00
def get_tone_symmap ( ) :
return {
" neutral " : 0 ,
}
2023-08-02 21:53:35 +00:00
return symmap
2023-08-19 03:22:13 +00:00
def get_task_symmap ( ) :
2024-04-16 00:54:32 +00:00
return {
2023-10-12 01:38:40 +00:00
" <tts> " : 0 ,
" <tts-c> " : 1 ,
" <ns> " : 2 ,
" <sr> " : 3 ,
" <tse> " : 4 ,
" <soe> " : 5 ,
" <mask> " : 6 ,
" <eoe> " : 7 ,
2023-08-19 03:22:13 +00:00
}
2023-08-02 21:53:35 +00:00
def _replace_file_extension ( path , suffix ) :
return ( path . parent / path . name . split ( " . " ) [ 0 ] ) . with_suffix ( suffix )
2024-04-19 02:24:06 +00:00
def _get_quant_extension ( ) :
2024-05-04 17:05:41 +00:00
return " .dac " if cfg . inference . audio_backend == " dac " else " .qnt.pt "
2024-04-19 02:24:06 +00:00
def _get_phone_extension ( ) :
2024-05-12 18:02:15 +00:00
return " .json " # if cfg.inference.audio_backend == "dac" else ".phn.txt"
2024-04-19 02:24:06 +00:00
2023-08-27 00:53:23 +00:00
def _get_quant_path ( path ) :
2024-04-19 02:24:06 +00:00
return _replace_file_extension ( path , _get_quant_extension ( ) )
2023-08-27 00:53:23 +00:00
def _get_phone_path ( path ) :
2024-04-19 02:24:06 +00:00
return _replace_file_extension ( path , _get_phone_extension ( ) )
2023-08-27 00:53:23 +00:00
2023-09-12 20:54:41 +00:00
_total_durations = { }
@cfg.diskcache ( )
def _calculate_durations ( type = " training " ) :
if type in _total_durations :
return _total_durations [ type ]
return 0
2023-09-02 21:29:53 +00:00
@cfg.diskcache ( )
2023-08-27 00:53:23 +00:00
def _load_paths ( dataset , type = " training " ) :
2024-04-29 03:28:29 +00:00
return { cfg . get_spkr ( cfg . data_dir / data_dir / " dummy " ) : _load_paths_from_metadata ( data_dir , type = type , validate = cfg . dataset . validate and type == " training " ) for data_dir in tqdm ( dataset , desc = f " Parsing dataset: { type } " ) }
2024-05-16 04:04:19 +00:00
def _load_paths_from_metadata ( group_name , type = " training " , validate = False ) :
data_dir = group_name if cfg . dataset . use_hdf5 else cfg . data_dir / group_name
2023-08-27 00:53:23 +00:00
_fn = _get_hdf5_paths if cfg . dataset . use_hdf5 else _get_paths_of_extensions
2024-05-16 04:04:19 +00:00
def key ( id , entry = None ) :
return f " / { type } / { _get_hdf5_path ( data_dir ) } / { id } " if cfg . dataset . use_hdf5 else data_dir / id
2024-05-12 03:58:38 +00:00
2024-05-16 04:04:19 +00:00
metadata_path = cfg . metadata_dir / f ' { group_name } .json '
2023-10-17 00:30:38 +00:00
metadata = { }
2024-04-29 03:28:29 +00:00
2023-10-17 00:30:38 +00:00
if cfg . dataset . use_metadata and metadata_path . exists ( ) :
metadata = json . loads ( open ( metadata_path , " r " , encoding = " utf-8 " ) . read ( ) )
2023-08-27 00:53:23 +00:00
2023-10-17 00:30:38 +00:00
if len ( metadata ) == 0 :
2024-04-19 02:24:06 +00:00
return _fn ( data_dir , type if cfg . dataset . use_hdf5 else _get_quant_extension ( ) , validate )
2023-08-27 00:53:23 +00:00
2024-05-16 04:04:19 +00:00
def _validate ( id , entry ) :
2024-05-12 03:58:38 +00:00
phones = entry [ ' phones ' ] if " phones " in entry else 0
duration = entry [ ' duration ' ] if " duration " in entry else 0
if type not in _total_durations :
_total_durations [ type ] = 0
_total_durations [ type ] + = duration
2024-05-16 04:04:19 +00:00
"""
2024-05-12 03:58:38 +00:00
if cfg . dataset . use_hdf5 :
k = key ( id )
if k not in cfg . hdf5 or " audio " not in cfg . hdf5 [ k ] or " text " not in cfg . hdf5 [ k ] :
return False
2024-05-16 04:04:19 +00:00
"""
2024-05-12 03:58:38 +00:00
2024-05-16 04:04:19 +00:00
return cfg . dataset . min_duration < = duration and duration < = cfg . dataset . max_duration #and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
2023-08-27 00:53:23 +00:00
2024-05-16 04:04:19 +00:00
return [ key ( id , entry ) for id , entry in metadata . items ( ) if not validate or _validate ( id , entry ) ]
2023-08-27 00:53:23 +00:00
2023-08-02 21:53:35 +00:00
def _get_hdf5_path ( path ) :
2024-04-29 03:28:29 +00:00
# to-do: better validation
#print(path)
return str ( path )
2023-08-02 21:53:35 +00:00
2023-08-27 00:53:23 +00:00
def _get_hdf5_paths ( data_dir , type = " training " , validate = False ) :
data_dir = str ( data_dir )
2023-08-02 21:53:35 +00:00
2024-05-16 04:04:19 +00:00
def _validate ( id , entry ) :
phones = entry . attrs [ ' phonemes ' ]
duration = entry . attrs [ ' duration ' ]
2023-09-12 20:54:41 +00:00
if type not in _total_durations :
_total_durations [ type ] = 0
2024-05-16 04:04:19 +00:00
_total_durations [ type ] + = entry . attrs [ ' duration ' ]
return cfg . dataset . min_duration < = duration and duration < = cfg . dataset . max_duration #and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
2023-08-27 00:53:23 +00:00
2024-04-29 03:28:29 +00:00
key = f " / { type } / { _get_hdf5_path ( data_dir ) } "
2024-05-16 04:04:19 +00:00
return [ Path ( f " { key } / { id } " ) for id , entry in cfg . hdf5 [ key ] . items ( ) if not validate or _validate ( id , entry ) ] if key in cfg . hdf5 else [ ]
2023-08-27 00:53:23 +00:00
2024-04-19 02:24:06 +00:00
def _get_paths_of_extensions ( path , extensions = _get_quant_extension ( ) , validate = False ) :
2023-08-27 00:53:23 +00:00
if isinstance ( path , str ) :
path = Path ( path )
def _validate ( path ) :
if " " . join ( path . suffixes ) not in extensions :
return False
if not _get_phone_path ( path ) . exists ( ) or not _get_quant_path ( path ) . exists ( ) :
return False
if not validate :
return True
# to-do: find an easy way to determine size from pickled quants without loading
# to-do: find a consistent way to derive phoneme count from filesize (probably can't due to utf-8)
phones = len ( _get_phones ( _get_phone_path ( path ) ) ) # _get_phone_path(path).stat().st_size // 2 + 1
return cfg . dataset . min_phones < = phones and phones < = cfg . dataset . max_phones
return [ p for p in list ( path . iterdir ( ) ) if _validate ( p ) ] if path . exists ( ) and path . is_dir ( ) else [ ]
2023-08-02 21:53:35 +00:00
def _load_quants ( path ) - > Tensor :
2024-04-19 02:24:06 +00:00
if _get_quant_extension ( ) == " .dac " :
qnt = np . load ( _get_quant_path ( path ) , allow_pickle = True ) [ ( ) ]
return torch . from_numpy ( qnt [ " codes " ] . astype ( int ) ) [ 0 ] [ : , : ] . t ( ) . to ( torch . int16 )
2023-08-27 00:53:23 +00:00
return torch . load ( _get_quant_path ( path ) ) [ 0 ] [ : , : ] . t ( ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
2023-10-11 00:18:24 +00:00
# prune consecutive spaces
def _cleanup_phones ( phones , targets = [ " " ] ) :
return [ p for i , p in enumerate ( phones ) if p not in targets or ( p in targets and p != phones [ i - 1 ] ) ]
2023-08-02 21:53:35 +00:00
@cache
2023-08-21 02:36:02 +00:00
def _get_phones ( path , language = " en " ) :
2024-04-19 02:24:06 +00:00
if _get_quant_extension ( ) == " .json " :
metadata = json . loads ( open ( _get_phone_path ( path ) , " r " , encoding = " utf-8 " ) . read ( ) )
content = metadata [ " phonemes " ]
else :
content = open ( _get_phone_path ( path ) , " r " , encoding = " utf-8 " ) . read ( ) . split ( " " )
2024-04-21 19:49:18 +00:00
return " " . join ( content )
2023-08-02 21:53:35 +00:00
def _interleaved_reorder ( l , fn ) :
groups = defaultdict ( list )
for e in l :
groups [ fn ( e ) ] . append ( e )
groups = { k : groups [ k ] for k in sorted ( groups ) }
for interleaved in zip_longest ( * groups . values ( ) ) :
for value in interleaved :
if value is not None :
yield value
class Dataset ( _Dataset ) :
def __init__ (
self ,
phone_symmap = None ,
training = False ,
extra_paths_by_spkr_name : dict [ str , list ] = { } ,
) :
super ( ) . __init__ ( )
self . _head = None
2023-08-19 01:58:07 +00:00
self . sampler = None
2023-08-02 21:53:35 +00:00
2023-08-27 00:53:23 +00:00
self . paths = [ ]
self . training = training
self . dataset_type = " training " if self . training else " validation "
self . dataset = cfg . dataset . training if self . training else cfg . dataset . validation
2023-10-22 14:01:47 +00:00
self . sampler_type = cfg . dataset . sample_type if self . dataset_type == " training " else " path "
2023-08-30 23:23:05 +00:00
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
if len ( self . dataset ) == 0 :
self . dataset = cfg . dataset . training
2023-08-27 00:53:23 +00:00
2023-10-17 00:30:38 +00:00
# dict of paths keyed by speaker names
2023-08-27 00:53:23 +00:00
self . paths_by_spkr_name = _load_paths ( self . dataset , self . dataset_type )
2023-09-12 20:54:41 +00:00
# cull speakers if they do not have enough utterances
if cfg . dataset . min_utterances > 0 :
keys = list ( self . paths_by_spkr_name . keys ( ) )
for key in keys :
if len ( self . paths_by_spkr_name [ key ] ) < cfg . dataset . min_utterances :
del self . paths_by_spkr_name [ key ]
2023-08-27 00:53:23 +00:00
self . paths = list ( itertools . chain . from_iterable ( self . paths_by_spkr_name . values ( ) ) )
2023-09-04 02:27:13 +00:00
self . samplers = { name : Sampler ( paths , keep_all = True ) for name , paths in self . paths_by_spkr_name . items ( ) }
2023-10-17 00:30:38 +00:00
# dict of speakers keyed by speaker group
self . spkrs_by_spkr_group = { }
for data_dir in self . dataset :
spkr = cfg . get_spkr ( data_dir / " dummy " )
spkr_group = cfg . get_spkr_group ( data_dir / " dummy " )
2023-10-22 14:01:47 +00:00
if spkr not in self . paths_by_spkr_name or len ( self . paths_by_spkr_name [ spkr ] ) < cfg . dataset . min_utterances :
continue
2023-10-17 00:30:38 +00:00
if spkr_group not in self . spkrs_by_spkr_group :
self . spkrs_by_spkr_group [ spkr_group ] = [ ]
self . spkrs_by_spkr_group [ spkr_group ] . append ( spkr )
self . spkr_groups = list ( self . spkrs_by_spkr_group . keys ( ) )
2023-12-21 00:45:58 +00:00
2023-10-17 00:30:38 +00:00
self . spkr_samplers = { name : Sampler ( [ * set ( speakers ) ] , keep_all = True ) for name , speakers in self . spkrs_by_spkr_group . items ( ) }
2023-09-04 02:27:13 +00:00
2023-10-22 14:01:47 +00:00
if self . sampler_type == " path " :
2023-08-27 00:53:23 +00:00
self . paths = [ * _interleaved_reorder ( self . paths , self . get_speaker ) ]
self . noise_paths = _load_paths ( cfg . dataset . noise , " noise " )
self . noise_paths = list ( itertools . chain . from_iterable ( self . noise_paths . values ( ) ) )
2023-08-02 21:53:35 +00:00
self . phone_symmap = phone_symmap or self . _get_phone_symmap ( )
2023-08-24 15:25:33 +00:00
self . spkr_symmap = self . _get_spkr_symmap ( )
2023-10-17 00:30:38 +00:00
self . spkr_group_symmap = self . _get_spkr_group_symmap ( )
2023-10-12 01:38:40 +00:00
self . lang_symmap = self . _get_lang_symmap ( )
2024-04-16 00:54:32 +00:00
self . tone_symmap = self . _get_tone_symmap ( )
2023-08-24 15:25:33 +00:00
self . task_symmap = self . _get_task_symmap ( )
2023-08-02 21:53:35 +00:00
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self . text_dtype = torch . uint8 if len ( self . phone_symmap ) < 256 else torch . int16
2023-10-17 00:30:38 +00:00
if len ( self . paths ) == 0 :
raise ValueError ( f " No valid path is found for { self . dataset_type } " )
2023-08-02 21:53:35 +00:00
2023-09-12 20:54:41 +00:00
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
self . duration = _calculate_durations ( self . dataset_type )
2023-08-02 21:53:35 +00:00
@cached_property
def phones ( self ) :
return sorted ( set ( ) . union ( * [ _get_phones ( path ) for path in self . paths ] ) )
2023-08-27 00:53:23 +00:00
def get_speaker ( self , path ) :
if isinstance ( path , str ) :
path = Path ( path )
res = cfg . get_spkr ( path )
return res
2023-10-12 01:38:40 +00:00
def get_speaker_group ( self , path ) :
if isinstance ( path , str ) :
path = Path ( path )
res = cfg . get_spkr_group ( path )
return res
def get_language ( self , speaker_group ) :
lang = " en "
for k , v in cfg . dataset . speaker_languages . items ( ) :
if speaker_group in v :
lang = k
break
return lang
2023-08-02 21:53:35 +00:00
@cached_property
def spkrs ( self ) :
2023-08-27 00:53:23 +00:00
return sorted ( { self . get_speaker ( path ) for path in self . paths } )
2023-08-02 21:53:35 +00:00
2023-08-19 03:22:13 +00:00
@cached_property
def tasks ( self ) :
2023-08-19 05:16:08 +00:00
return cfg . dataset . tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
2023-08-19 03:22:13 +00:00
2023-09-04 02:27:13 +00:00
def save_state_dict ( self , path ) :
state_dict = {
" samplers " : { name : sampler . current_pool for name , sampler in self . samplers . items ( ) }
}
torch . save ( state_dict , path )
def load_state_dict ( self , path ) :
state_dict = torch . load ( path )
if " samplers " in state_dict :
# better than naively setting the entire object
for name , sampler in state_dict [ " samplers " ] . items ( ) :
if name not in self . samplers :
continue
self . samplers [ name ] . current_pool = sampler
2023-08-19 03:22:13 +00:00
def _get_phone_symmap ( self ) :
return get_phone_symmap ( )
2023-08-02 21:53:35 +00:00
def _get_spkr_symmap ( self ) :
return { s : i for i , s in enumerate ( self . spkrs ) }
2023-10-17 00:30:38 +00:00
def _get_spkr_group_symmap ( self ) :
return { s : i for i , s in enumerate ( self . spkr_groups ) }
2023-10-12 01:38:40 +00:00
def _get_lang_symmap ( self ) :
return get_lang_symmap ( )
2024-04-16 00:54:32 +00:00
def _get_tone_symmap ( self ) :
return get_tone_symmap ( )
2023-08-19 03:22:13 +00:00
def _get_task_symmap ( self ) :
return get_task_symmap ( )
2023-10-12 01:38:40 +00:00
"""
2024-04-16 00:54:32 +00:00
def get_task_token ( self , token , levels = cfg . model . max_levels ) :
2023-08-19 05:16:08 +00:00
if not hasattr ( self , " task_symmap " ) :
self . task_symmap = self . _get_task_symmap ( )
2023-08-19 20:06:33 +00:00
return torch . Tensor ( [ [ self . task_symmap [ f ' < { token } > ' ] for _ in range ( levels ) ] ] ) . to ( dtype = torch . int16 )
2023-10-12 01:38:40 +00:00
"""
2023-08-19 03:22:13 +00:00
2023-08-19 20:06:33 +00:00
def sample_noise ( self ) :
2023-08-27 00:53:23 +00:00
path = random . choice ( self . noise_paths )
2023-08-19 05:16:08 +00:00
2023-08-27 00:53:23 +00:00
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2023-08-19 20:06:33 +00:00
qnt = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : ] ) . to ( torch . int16 )
2023-08-19 05:16:08 +00:00
else :
qnt = _load_quants ( path )
return qnt
2023-08-19 03:22:13 +00:00
2023-08-18 00:07:59 +00:00
def sample_speakers ( self , ignore = [ ] ) :
choices = set ( self . spkrs ) - set ( ignore )
return random . choice ( [ * choices ] )
2023-08-02 21:53:35 +00:00
def sample_prompts ( self , spkr_name , ignore ) :
prom_list = [ ]
choices = set ( self . paths_by_spkr_name [ spkr_name ] ) - { ignore }
choices = [ * choices ]
2024-04-16 00:54:32 +00:00
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
2023-08-02 21:53:35 +00:00
if len ( choices ) == 0 :
2023-08-17 00:39:21 +00:00
choices = [ * set ( self . paths_by_spkr_name [ spkr_name ] ) ]
"""
2023-08-02 21:53:35 +00:00
raise ValueError (
f " Failed to find another different utterance for { spkr_name } . "
)
2023-08-17 00:39:21 +00:00
"""
2023-08-02 21:53:35 +00:00
2023-08-19 04:55:40 +00:00
prom_length = 0
2024-05-11 14:50:54 +00:00
trim_length = random . randint ( cfg . dataset . prompt_duration_range [ 0 ] , cfg . dataset . prompt_duration_range [ 1 ] ) * cfg . dataset . frames_per_second
2023-08-19 04:55:40 +00:00
2023-08-02 21:53:35 +00:00
for _ in range ( cfg . dataset . max_prompts ) :
path = random . choice ( choices )
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2024-05-10 01:28:20 +00:00
if " audio " not in cfg . hdf5 [ key ] :
2024-05-12 03:58:38 +00:00
_logger . warning ( f ' MISSING AUDIO: { key } ' )
2024-05-10 01:28:20 +00:00
continue
2023-08-19 20:06:33 +00:00
qnt = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : ] ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
else :
qnt = _load_quants ( path )
2024-05-11 14:50:54 +00:00
if 0 < trim_length and trim_length < qnt . shape [ 0 ] :
2023-08-23 21:43:03 +00:00
qnt = trim ( qnt , trim_length )
2023-08-02 21:53:35 +00:00
prom_list . append ( qnt )
2023-08-19 04:55:40 +00:00
prom_length + = qnt . shape [ 0 ]
2023-08-02 21:53:35 +00:00
2023-08-19 04:55:40 +00:00
if prom_length > = trim_length or random . random ( ) > cfg . dataset . random_utterance :
2023-08-02 21:53:35 +00:00
break
2023-10-11 22:32:45 +00:00
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
2023-08-02 21:53:35 +00:00
prom = torch . cat ( prom_list )
2024-05-11 14:50:54 +00:00
if 0 < trim_length and trim_length < prom . shape [ 0 ] :
2023-08-23 21:43:03 +00:00
prom = trim ( prom , trim_length )
2023-08-02 21:53:35 +00:00
return prom
def __getitem__ ( self , index ) :
2023-10-22 14:01:47 +00:00
if self . sampler_type == " group " :
2023-10-17 00:30:38 +00:00
spkr_group = self . spkr_groups [ index ]
2023-10-22 14:01:47 +00:00
#spkr_group_id = self.spkr_group_symmap[spkr_group]
2023-10-17 00:30:38 +00:00
spkr_name = self . spkr_samplers [ spkr_group ] . sample ( )
2023-10-19 01:38:33 +00:00
spkr_id = self . spkr_symmap [ spkr_name ]
path = self . samplers [ spkr_name ] . sample ( )
2023-10-22 14:01:47 +00:00
elif self . sampler_type == " speaker " :
2023-08-17 00:39:21 +00:00
spkr_name = self . spkrs [ index ]
spkr_id = self . spkr_symmap [ spkr_name ]
2023-09-04 02:27:13 +00:00
path = self . samplers [ spkr_name ] . sample ( )
2023-10-17 00:30:38 +00:00
spkr_group = self . get_speaker_group ( path )
2023-10-22 14:01:47 +00:00
#spkr_group_id = self.spkr_group_symmap[spkr_group]
2023-08-02 21:53:35 +00:00
else :
2023-08-18 19:47:48 +00:00
path = self . paths [ index ]
2023-08-27 00:53:23 +00:00
spkr_name = self . get_speaker ( path )
2023-08-17 00:39:21 +00:00
spkr_id = self . spkr_symmap [ spkr_name ]
2023-10-17 00:30:38 +00:00
spkr_group = self . get_speaker_group ( path )
2023-10-22 14:01:47 +00:00
#spkr_group_id = self.spkr_group_symmap[spkr_group]
2023-08-02 21:53:35 +00:00
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2023-10-12 01:38:40 +00:00
2024-04-29 03:28:29 +00:00
if key not in cfg . hdf5 :
raise RuntimeError ( f ' Key of Path ( { path } ) not in HDF5: { key } ' )
2023-10-11 00:18:24 +00:00
text = cfg . hdf5 [ key ] [ " text " ] [ : ]
resps = cfg . hdf5 [ key ] [ " audio " ] [ : , : ]
text = torch . from_numpy ( text ) . to ( self . text_dtype )
resps = torch . from_numpy ( resps ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
else :
2024-04-21 19:49:18 +00:00
text = torch . tensor ( tokenize ( _get_phones ( path ) ) ) . to ( self . text_dtype )
2023-08-02 21:53:35 +00:00
resps = _load_quants ( path )
2023-10-11 22:32:45 +00:00
2023-10-13 03:21:43 +00:00
lang = torch . tensor ( [ self . lang_symmap [ self . get_language ( spkr_group ) ] ] ) . to ( torch . uint8 )
2023-10-12 01:38:40 +00:00
2023-10-11 22:32:45 +00:00
# append additional prompts in an attempt to artifically increase lengths / offer new data
if cfg . experimental and cfg . dataset . max_resps > 1 and random . random ( ) < cfg . dataset . p_resp_append :
choices = [ * ( set ( self . paths_by_spkr_name [ spkr_name ] ) - { path } ) ]
if len ( choices ) > 0 :
for _ in range ( cfg . dataset . max_resps - 1 ) :
sampled_path = random . choice ( choices )
choices = [ * ( set ( choices ) - { sampled_path } ) ]
if cfg . dataset . use_hdf5 :
2023-10-19 01:38:33 +00:00
key = _get_hdf5_path ( sampled_path )
2023-10-11 22:32:45 +00:00
txt = cfg . hdf5 [ key ] [ " text " ] [ : ]
qnt = cfg . hdf5 [ key ] [ " audio " ] [ : , : ]
2024-04-21 22:43:20 +00:00
txt = np . array ( txt )
2023-10-11 22:32:45 +00:00
txt = torch . from_numpy ( txt ) . to ( self . text_dtype )
qnt = torch . from_numpy ( qnt ) . to ( torch . int16 )
else :
2024-04-21 22:43:20 +00:00
#txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
txt = torch . tensor ( tokenize ( _get_phones ( sampled_path ) ) ) . to ( self . text_dtype )
2023-10-11 22:32:45 +00:00
qnt = _load_quants ( sampled_path )
# <s>[original text] [new text]</s>
# removes the original text's </s>, includes a space, and remove the new text's <s>
text = torch . concat ( [ text [ : - 1 ] , torch . tensor ( [ self . phone_symmap [ " " ] ] ) . to ( torch . int16 ) , txt [ 1 : ] ] )
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
resps = torch . concat ( [ resps , qnt ] )
2023-08-02 21:53:35 +00:00
2023-10-09 18:01:40 +00:00
task = " tts "
2024-05-11 14:50:54 +00:00
trim_length = random . randint ( cfg . dataset . prompt_duration_range [ 0 ] , cfg . dataset . prompt_duration_range [ 1 ] ) * cfg . dataset . frames_per_second
2023-10-09 18:01:40 +00:00
proms = self . sample_prompts ( spkr_name , ignore = path ) if random . random ( ) < cfg . dataset . random_utterance else resps
2023-10-11 22:32:45 +00:00
2023-10-09 18:01:40 +00:00
# Disabled until I swap over to a better method
"""
2023-08-17 23:56:37 +00:00
task = random . choice ( self . tasks )
2023-08-19 03:22:13 +00:00
2023-08-19 14:50:07 +00:00
# ensure a speaker has at least four utterances
# default to tts if not
if len ( set ( self . paths_by_spkr_name [ spkr_name ] ) - { path } ) < 4 :
task = " tts "
2023-08-20 11:29:17 +00:00
noise_scale = 0.25
2023-09-02 17:23:40 +00:00
if task == " tts " or task == " tts-c " :
2024-05-04 17:05:41 +00:00
trim_length = int ( cfg . dataset . prompt_duration * cfg . dataset . frames_per_second )
2023-09-02 18:39:17 +00:00
# demote if the target is too short
if task == " tts-c " and trim_length * 2 > = resps . shape [ 0 ] :
task = " tts "
2023-09-02 21:29:53 +00:00
2023-09-01 22:19:34 +00:00
# VALL-E continuous
# ignore if target utterance is shorter than prompt duration
# to-do: actually do this for the AR only as I don't think the paper trained the NAR for this
2023-09-02 18:39:17 +00:00
if task == " tts-c " :
2023-09-01 22:19:34 +00:00
proms = resps [ : trim_length , : ]
resps = resps [ trim_length : , : ]
2023-09-02 21:29:53 +00:00
proms = torch . cat ( [ self . get_task_token ( task ) , proms ] )
2023-09-01 22:19:34 +00:00
else :
proms = self . sample_prompts ( spkr_name , ignore = path ) if random . random ( ) < cfg . dataset . random_utterance else resps
2023-08-19 03:22:13 +00:00
# noise suppression || speech removal
elif task == " ns " or task == " sr " :
# sample random noise
2023-08-18 00:07:59 +00:00
noise = self . sample_noise ( )
2023-08-19 03:22:13 +00:00
# extend the noise to fill the target audio
noise = repeat_extend_audio ( noise , resps . shape [ 0 ] )
# create the input prompt by merging the target audio with the noise
2023-08-19 20:06:33 +00:00
proms = merge_audio ( resps , noise , scale = [ 1 , noise_scale ] , device = " cpu " )
2023-08-19 03:22:13 +00:00
# set the target to just be the noise if <sr>
if task == " sr " :
resps = noise
# prepend the task token
proms = torch . cat ( [ self . get_task_token ( task ) , proms ] )
2023-08-19 04:55:40 +00:00
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype )
2023-08-19 03:22:13 +00:00
# target speech extraction
2023-08-18 19:47:48 +00:00
elif task == " tse " :
2023-08-19 03:22:13 +00:00
# sample a random, clean, utterance for the target speaker
2023-08-19 06:16:46 +00:00
clean_proms = self . sample_prompts ( spkr_name , ignore = path )
2023-08-19 03:22:13 +00:00
# sample a random, clean utterance from a different speaker
2023-08-19 04:55:40 +00:00
other_proms = self . sample_prompts ( self . sample_speakers ( ignore = [ spkr_name ] ) , ignore = " " )
2023-08-19 03:22:13 +00:00
# overlay the random speaker over the target audio
2023-08-19 04:55:40 +00:00
smallest_size = min ( resps . shape [ 0 ] , other_proms . shape [ 0 ] )
if other_proms . shape [ 0 ] == smallest_size :
2023-08-19 05:16:08 +00:00
noisy_proms = merge_audio ( resps [ : smallest_size , : ] , other_proms , scale = [ 1 , random . uniform ( 0.5 , 0.75 ) ] , device = " cpu " )
2023-08-19 04:55:40 +00:00
noisy_proms = torch . cat ( [ noisy_proms , resps [ smallest_size : , : ] ] )
else :
2023-08-19 05:16:08 +00:00
noisy_proms = merge_audio ( resps , other_proms [ : smallest_size , : ] , scale = [ 1 , random . uniform ( 0.5 , 0.75 ) ] , device = " cpu " )
2023-08-19 04:55:40 +00:00
noisy_proms = torch . cat ( [ noisy_proms , other_proms [ smallest_size : , : ] ] )
# stitch together the promps
2023-08-19 03:22:13 +00:00
proms = torch . cat ( [ clean_proms , self . get_task_token ( task ) , noisy_proms ] )
2023-08-19 04:55:40 +00:00
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
2023-08-19 06:16:46 +00:00
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype ) # <s></s>
2023-08-19 04:55:40 +00:00
2023-08-18 19:47:48 +00:00
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
2023-08-19 03:22:13 +00:00
# clean speech editing
2023-08-19 06:16:46 +00:00
elif task == " cse " or task == " nse " :
choices = set ( self . paths_by_spkr_name [ spkr_name ] ) - { path }
2023-08-19 14:50:07 +00:00
sampled = random . sample ( [ * choices ] , 4 )
2023-08-19 06:16:46 +00:00
if cfg . dataset . use_hdf5 :
texts = [ torch . from_numpy ( cfg . hdf5 [ _get_hdf5_path ( path ) ] [ " text " ] [ : ] ) . to ( self . text_dtype ) for path in sampled ]
2023-08-19 20:06:33 +00:00
qnts = [ torch . from_numpy ( cfg . hdf5 [ _get_hdf5_path ( path ) ] [ " audio " ] [ : , : ] ) . to ( torch . int16 ) for path in sampled ]
2023-08-19 06:16:46 +00:00
else :
texts = [ torch . tensor ( [ * map ( self . phone_symmap . get , _get_phones ( path ) ) ] ) . to ( self . text_dtype ) for path in sampled ]
qnts = [ _load_quants ( path ) for path in sampled ]
# remove <s></s>
2023-08-19 14:50:07 +00:00
for i in range ( len ( texts ) ) :
texts [ i ] = texts [ i ] [ 1 : - 1 ]
2023-08-19 06:16:46 +00:00
pre_text , mid_text , post_text , edit_text = texts
pre_prom , mid_prom , post_prom , edit_prom = qnts
# randomly drop out pre
if random . random ( ) < 0.125 :
pre_text = None
pre_prom = None
# randomly drop out post
if random . random ( ) < 0.125 :
post_text = None
post_prom = None
# create new text
text = torch . cat (
2023-08-19 14:50:07 +00:00
[ torch . Tensor ( [ 1 ] ) . to ( dtype = self . text_dtype ) ] + # <s>
( [ pre_text , torch . Tensor ( [ 3 ] ) . to ( dtype = self . text_dtype ) ] if pre_text is not None else [ ] ) + # pre_text + space'
[ edit_text ] + # 'edit text'
( [ torch . Tensor ( [ 3 ] ) . to ( dtype = self . text_dtype ) , post_text ] if post_text is not None else [ ] ) + # 'space' + edit_text
[ torch . Tensor ( [ 2 ] ) . to ( dtype = self . text_dtype ) ] # </s>
2023-08-19 06:16:46 +00:00
)
if task == " nse " :
# sample random noise
noise = self . sample_noise ( )
# it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent
# but it's noise, it's supposed to be random
2023-08-19 20:06:33 +00:00
def noise_proms ( p ) :
2023-08-19 06:16:46 +00:00
# ignore if we turned it off
2023-08-19 20:06:33 +00:00
if p is None :
2023-08-19 06:16:46 +00:00
return None
# extend the noise to fill the target audio
2023-08-19 20:06:33 +00:00
n = repeat_extend_audio ( noise , p . shape [ 0 ] )
2023-08-19 06:16:46 +00:00
# merge the noise over the utterance
2023-08-19 20:06:33 +00:00
return merge_audio ( p , n , scale = [ 1 , noise_scale ] , device = " cpu " )
2023-08-19 06:16:46 +00:00
# apply noise to all pieces
pre_prom = noise_proms ( pre_prom )
mid_prom = noise_proms ( mid_prom )
post_prom = noise_proms ( post_prom )
edit_prom = noise_proms ( edit_prom )
else :
mid_prom = self . get_task_token ( " mask " )
# create new proms
proms = torch . cat (
( [ pre_prom ] if pre_prom is not None else [ ] ) +
[ self . get_task_token ( " soe " ) ] +
[ mid_prom ] + # is <mask> if task is CSE
[ self . get_task_token ( " eoe " ) ] +
( [ post_prom ] if post_prom is not None else [ ] )
)
# create new resp
resps = torch . cat (
( [ pre_prom ] if pre_prom is not None else [ ] ) +
[ edit_prom ] +
( [ post_prom ] if post_prom is not None else [ ] )
)
2023-08-19 20:06:33 +00:00
else :
2023-09-02 17:23:40 +00:00
raise Exception ( f ' Undefined task: { task } ' )
2023-10-09 18:01:40 +00:00
"""
2023-08-19 06:16:46 +00:00
2023-08-19 04:55:40 +00:00
"""
# emulate SVC
# takes in an utterance of the target speaker, a target utterenace as a reference clip as the input prompt
# targets an utterance of the target speaker with the same tempo + pitch + etc as the reference clip
# NOTE: I do not have a clue how to go about this. I *could* dynamically generate clips through RVC here, but I imagine the penalty would be astronomical
# ahead-of-time dataset preparation of a shit ton of RVC clips might be the key.
# aside from that, I have no clue how to go about training this, as this is entirely a proof of concept task.
elif task == " svc " :
# sample a random, clean utterance for the target speaker
proms = self . sample_prompts ( spkr_name , ignore = path ) if random . random ( ) < cfg . dataset . random_utterance else resps
# sample a reference clip from a different speaker
ref_proms = self . sample_rvc ( self . sample_speakers ( ignore = [ spkr_name ] ) )
#
resps =
# stitch together the promps
proms = torch . cat ( [ proms , self . get_task_token ( task ) , ref_proms ] )
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
text = torch . tensor ( [ 1 , 2 ] ) . to ( self . text_dtype )
"""
2023-08-18 00:07:59 +00:00
2023-08-19 20:06:33 +00:00
# trim to fit to requested prom/resps levels
2024-04-16 00:54:32 +00:00
proms = proms [ : , : cfg . model . prom_levels ]
resps = resps [ : , : cfg . model . prom_levels ]
2023-08-19 20:06:33 +00:00
2023-08-02 21:53:35 +00:00
return dict (
index = index ,
2023-08-27 17:26:12 +00:00
path = Path ( path ) ,
2023-08-02 21:53:35 +00:00
spkr_name = spkr_name ,
spkr_id = spkr_id ,
2023-08-17 23:56:37 +00:00
task = task ,
2023-10-12 01:38:40 +00:00
lang = lang ,
2023-08-02 21:53:35 +00:00
text = text ,
proms = proms ,
resps = resps ,
)
def head_ ( self , n ) :
self . _head = n
def training_ ( self , value ) :
self . training = value
def __len__ ( self ) :
2023-10-22 14:01:47 +00:00
if self . sampler_type == " group " :
2023-10-17 00:30:38 +00:00
return min ( len ( self . spkr_groups ) , self . _head or len ( self . spkr_groups ) )
2023-10-22 14:01:47 +00:00
if self . sampler_type == " speaker " :
2023-08-17 00:39:21 +00:00
return min ( len ( self . spkrs ) , self . _head or len ( self . spkrs ) )
2023-08-02 21:53:35 +00:00
return min ( len ( self . paths ) , self . _head or len ( self . paths ) )
def pin_memory ( self ) :
self . text = self . text . pin_memory ( )
self . proms = self . proms . pin_memory ( )
self . resps = self . resps . pin_memory ( )
self . resp = self . resp . pin_memory ( )
return self
def collate_fn ( samples : list [ dict ] ) :
batch : dict [ str , Any ] = { k : [ s [ k ] for s in samples ] for k in samples [ 0 ] }
return batch
def _seed_worker ( worker_id ) :
worker_seed = torch . initial_seed ( ) % 2 * * 32
np . random . seed ( worker_seed )
random . seed ( worker_seed )
def _create_dataloader ( dataset , training ) :
2023-08-17 18:41:53 +00:00
sampler = None
shuffle = True
if cfg . distributed and training :
sampler = DistributedSampler ( dataset )
shuffle = False
2023-08-02 21:53:35 +00:00
return DataLoader (
dataset = dataset ,
batch_size = cfg . hyperparameters . batch_size if training else cfg . evaluation . batch_size ,
2023-08-17 18:41:53 +00:00
shuffle = shuffle ,
2023-08-02 21:53:35 +00:00
drop_last = training ,
num_workers = cfg . dataset . workers ,
collate_fn = collate_fn ,
2023-08-24 22:05:56 +00:00
persistent_workers = cfg . dataset . workers > 1 ,
2023-08-02 21:53:35 +00:00
pin_memory = False , # True,
worker_init_fn = _seed_worker ,
2023-08-17 18:41:53 +00:00
sampler = sampler ,
2023-08-02 21:53:35 +00:00
)
def create_datasets ( ) :
2023-08-27 00:53:23 +00:00
train_dataset = Dataset ( training = True )
val_dataset = Dataset ( phone_symmap = train_dataset . phone_symmap , training = False )
2023-08-02 21:53:35 +00:00
2023-09-04 02:27:13 +00:00
train_state_path = cfg . relpath / " train_dataset.pt "
if train_state_path . exists ( ) :
train_dataset . load_state_dict ( train_state_path )
2023-08-02 21:53:35 +00:00
return train_dataset , val_dataset
def create_train_val_dataloader ( ) :
train_dataset , val_dataset = create_datasets ( )
subtrain_dataset = copy . deepcopy ( train_dataset )
2023-10-22 14:01:47 +00:00
if subtrain_dataset . sampler_type == " path " :
2023-08-17 20:04:45 +00:00
subtrain_dataset . head_ ( cfg . evaluation . size )
2023-08-02 21:53:35 +00:00
train_dl = _create_dataloader ( train_dataset , training = True )
val_dl = _create_dataloader ( val_dataset , training = False )
subtrain_dl = _create_dataloader ( subtrain_dataset , training = False )
_logger . info ( str ( train_dataset . phone_symmap ) )
_logger . info ( str ( train_dataset . spkr_symmap ) )
2023-10-17 00:30:38 +00:00
_logger . info ( str ( train_dataset . spkr_group_symmap ) )
2023-08-02 21:53:35 +00:00
_logger . info ( f " #samples (train): { len ( train_dataset ) } . " )
_logger . info ( f " #samples (val): { len ( val_dataset ) } . " )
_logger . info ( f " #samples (subtrain): { len ( subtrain_dataset ) } . " )
_logger . info ( f " #duration (train): { str ( train_dataset . duration ) } . " )
_logger . info ( f " #duration (val): { str ( val_dataset . duration ) } . " )
_logger . info ( f " #duration (subtrain): { str ( subtrain_dataset . duration ) } . " )
assert isinstance ( subtrain_dl . dataset , Dataset )
return train_dl , subtrain_dl , val_dl
2023-08-27 00:53:23 +00:00
# parse dataset into better to sample metadata
2024-04-29 14:09:26 +00:00
def create_dataset_metadata ( skip_existing = True ) :
2024-04-29 04:03:09 +00:00
symmap = get_phone_symmap ( )
root = str ( cfg . data_dir )
metadata_root = str ( cfg . metadata_dir )
2024-04-29 03:28:29 +00:00
2024-04-29 04:03:09 +00:00
cfg . metadata_dir . mkdir ( parents = True , exist_ok = True )
2023-08-27 00:53:23 +00:00
2024-04-29 04:03:09 +00:00
def add ( dir , type = " training " , audios = True , texts = True ) :
name = str ( dir )
name = name . replace ( root , " " )
2023-08-27 00:53:23 +00:00
2024-05-08 07:11:38 +00:00
speaker_name = name
metadata_path = Path ( f " { metadata_root } / { speaker_name } .json " )
metadata_path . parents [ 0 ] . mkdir ( parents = True , exist_ok = True )
2023-08-27 00:53:23 +00:00
2024-05-10 01:28:20 +00:00
try :
metadata = { } if not metadata_path . exists ( ) else json . loads ( open ( str ( metadata_path ) , " r " , encoding = " utf-8 " ) . read ( ) )
except Exception as e :
metadata = { }
2023-08-27 00:53:23 +00:00
2024-04-29 04:03:09 +00:00
if not os . path . isdir ( f ' { root } / { name } / ' ) :
return
# tqdm.write(f'{root}/{name}')
files = os . listdir ( f ' { root } / { name } / ' )
2023-10-17 00:30:38 +00:00
2024-04-29 04:03:09 +00:00
# grab IDs for every file
ids = { file . replace ( _get_quant_extension ( ) , " " ) . replace ( _get_phone_extension ( ) , " " ) for file in files }
2023-08-27 00:53:23 +00:00
2024-04-29 04:03:09 +00:00
for id in tqdm ( ids , desc = f " Processing { name } " ) :
try :
audio_exists = os . path . exists ( f ' { root } / { name } / { id } { _get_quant_extension ( ) } ' ) if audios else True
text_exists = os . path . exists ( f ' { root } / { name } / { id } { _get_phone_extension ( ) } ' ) if texts else True
2023-08-27 00:53:23 +00:00
2024-04-29 04:03:09 +00:00
if not audio_exists or not text_exists :
continue
2023-08-27 00:53:23 +00:00
2024-05-08 07:11:38 +00:00
key = f ' { type } / { speaker_name } / { id } '
2024-04-29 04:03:09 +00:00
2024-05-08 07:11:38 +00:00
if skip_existing and id in metadata :
2024-04-29 04:03:09 +00:00
continue
if id not in metadata :
metadata [ id ] = { }
2024-05-16 04:04:19 +00:00
utterance_metadata = { }
2024-04-29 04:03:09 +00:00
if audios :
2024-05-16 04:04:19 +00:00
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
dac = np . load ( f ' { root } / { name } / { id } { _get_quant_extension ( ) } ' , allow_pickle = True ) [ ( ) ]
qnt = torch . from_numpy ( dac [ " codes " ] . astype ( int ) ) [ 0 ] . t ( ) . to ( dtype = torch . int16 )
if " text " in dac [ " metadata " ] :
utterance_metadata [ " text " ] = dac [ " metadata " ] [ " text " ]
if " phonemes " in dac [ " metadata " ] :
utterance_metadata [ " phonemes " ] = dac [ " metadata " ] [ " phonemes " ]
if " language " in dac [ " metadata " ] :
utterance_metadata [ " language " ] = dac [ " metadata " ] [ " language " ]
if " original_length " in dac [ " metadata " ] and " sample_rate " in dac [ " metadata " ] :
utterance_metadata [ " duration " ] = dac [ " metadata " ] [ " original_length " ] / dac [ " metadata " ] [ " sample_rate " ]
2024-04-29 04:03:09 +00:00
# text
if texts :
2024-05-16 04:04:19 +00:00
if not utterance_metadata :
utterance_metadata = json . loads ( open ( f ' { root } / { name } / { id } { _get_phone_extension ( ) } ' , " r " , encoding = " utf-8 " ) . read ( ) )
2024-04-29 04:03:09 +00:00
2024-05-16 04:04:19 +00:00
for k , v in utterance_metadata . items ( ) :
metadata [ id ] [ k ] = v
2024-04-29 04:03:09 +00:00
except Exception as e :
2024-05-16 04:04:19 +00:00
tqdm . write ( f ' Error while processing { id } : { e } ' )
2024-04-29 04:03:09 +00:00
with open ( str ( metadata_path ) , " w " , encoding = " utf-8 " ) as f :
f . write ( json . dumps ( metadata ) )
# training
for data_dir in tqdm ( sorted ( cfg . dataset . training ) , desc = " Processing Training " ) :
add ( data_dir , type = " training " )
# validation
for data_dir in tqdm ( sorted ( cfg . dataset . validation ) , desc = ' Processing Validation ' ) :
add ( data_dir , type = " validation " )
# noise
for data_dir in tqdm ( sorted ( cfg . dataset . noise ) , desc = ' Processing Noise ' ) :
add ( data_dir , type = " noise " , texts = False )
2023-08-27 00:53:23 +00:00
2023-08-19 14:50:07 +00:00
# parse yaml to create an hdf5 file
2023-08-27 00:53:23 +00:00
def create_dataset_hdf5 ( skip_existing = True ) :
2023-08-19 14:50:07 +00:00
cfg . dataset . use_hdf5 = True
cfg . load_hdf5 ( write = True )
2024-05-08 07:11:38 +00:00
hf = cfg . hdf5
2023-08-19 14:50:07 +00:00
2023-08-02 21:53:35 +00:00
symmap = get_phone_symmap ( )
2024-04-29 03:28:29 +00:00
root = str ( cfg . data_dir )
metadata_root = str ( cfg . metadata_dir )
2023-08-02 21:53:35 +00:00
2023-08-27 00:53:23 +00:00
def add ( dir , type = " training " , audios = True , texts = True ) :
2024-04-29 03:28:29 +00:00
name = str ( dir )
name = name . replace ( root , " " )
2024-05-08 07:11:38 +00:00
# yucky
speaker_name = name
2024-05-10 01:28:20 +00:00
if " LibriTTS-R " in speaker_name :
speaker_name = speaker_name . replace ( " LibriTTS-R " , " LibriVox " )
2024-04-29 03:28:29 +00:00
2024-05-08 07:11:38 +00:00
metadata_path = Path ( f " { metadata_root } / { speaker_name } .json " )
metadata_path . parents [ 0 ] . mkdir ( parents = True , exist_ok = True )
2024-04-29 03:28:29 +00:00
metadata = { } if not metadata_path . exists ( ) else json . loads ( open ( str ( metadata_path ) , " r " , encoding = " utf-8 " ) . read ( ) )
2023-08-02 21:53:35 +00:00
if not os . path . isdir ( f ' { root } / { name } / ' ) :
return
# tqdm.write(f'{root}/{name}')
files = os . listdir ( f ' { root } / { name } / ' )
# grab IDs for every file
2024-04-19 02:24:06 +00:00
ids = { file . replace ( _get_quant_extension ( ) , " " ) . replace ( _get_phone_extension ( ) , " " ) for file in files }
2023-08-02 21:53:35 +00:00
for id in tqdm ( ids , desc = f " Processing { name } " ) :
2023-09-12 20:54:41 +00:00
try :
2024-05-16 04:04:19 +00:00
audio_exists = os . path . exists ( f ' { root } / { name } / { id } { _get_quant_extension ( ) } ' )
text_exists = os . path . exists ( f ' { root } / { name } / { id } { _get_phone_extension ( ) } ' ) if type != " Noise " else True
2023-08-02 21:53:35 +00:00
2024-05-16 04:04:19 +00:00
if not audio_exists :
2023-08-27 00:53:23 +00:00
continue
2023-08-02 21:53:35 +00:00
2024-05-08 07:11:38 +00:00
key = f ' { type } / { speaker_name } / { id } '
2023-09-12 20:54:41 +00:00
2024-05-16 04:04:19 +00:00
"""
2024-04-29 03:28:29 +00:00
if skip_existing and key in hf :
continue
2024-05-16 04:04:19 +00:00
"""
2024-04-29 03:28:29 +00:00
group = hf . create_group ( key ) if key not in hf else hf [ key ]
2024-05-16 04:04:19 +00:00
"""
2023-09-12 20:54:41 +00:00
group . attrs [ ' id ' ] = id
group . attrs [ ' type ' ] = type
2024-05-08 07:11:38 +00:00
group . attrs [ ' speaker ' ] = speaker_name
2024-05-16 04:04:19 +00:00
"""
2023-09-12 20:54:41 +00:00
2024-04-29 03:28:29 +00:00
if id not in metadata :
metadata [ id ] = { }
2023-09-12 20:54:41 +00:00
2024-05-16 04:04:19 +00:00
utterance_metadata = { }
2023-09-12 20:54:41 +00:00
# audio
if audios :
2024-05-16 04:04:19 +00:00
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
dac = np . load ( f ' { root } / { name } / { id } { _get_quant_extension ( ) } ' , allow_pickle = True ) [ ( ) ]
qnt = torch . from_numpy ( dac [ " codes " ] . astype ( int ) ) [ 0 ] . t ( ) . to ( dtype = torch . int16 )
if " text " in dac [ " metadata " ] :
utterance_metadata [ " text " ] = dac [ " metadata " ] [ " text " ]
if " phonemes " in dac [ " metadata " ] :
utterance_metadata [ " phonemes " ] = dac [ " metadata " ] [ " phonemes " ]
if " language " in dac [ " metadata " ] :
utterance_metadata [ " language " ] = dac [ " metadata " ] [ " language " ]
if " original_length " in dac [ " metadata " ] and " sample_rate " in dac [ " metadata " ] :
utterance_metadata [ " duration " ] = dac [ " metadata " ] [ " original_length " ] / dac [ " metadata " ] [ " sample_rate " ]
2024-04-29 03:28:29 +00:00
if " audio " not in group :
2024-05-16 04:04:19 +00:00
group . create_dataset ( ' audio ' , data = qnt . numpy ( ) . astype ( np . int16 ) , compression = ' lzf ' )
2024-04-19 02:24:06 +00:00
2023-09-12 20:54:41 +00:00
# text
if texts :
2024-05-16 04:04:19 +00:00
if not utterance_metadata and text_exists :
utterance_metadata = json . loads ( open ( f ' { root } / { name } / { id } { _get_phone_extension ( ) } ' , " r " , encoding = " utf-8 " ) . read ( ) )
2024-04-21 19:49:18 +00:00
2024-05-16 04:04:19 +00:00
phn = " " . join ( utterance_metadata [ " phonemes " ] )
phn = cfg . tokenizer . encode ( phn )
2024-04-21 19:49:18 +00:00
phn = np . array ( phn ) . astype ( np . uint8 )
2024-04-18 18:32:41 +00:00
2024-04-29 03:28:29 +00:00
if " text " not in group :
group . create_dataset ( ' text ' , data = phn , compression = ' lzf ' )
2024-04-18 18:32:41 +00:00
2024-05-16 04:04:19 +00:00
for k , v in utterance_metadata . items ( ) :
group . attrs [ k ] = v
metadata [ id ] [ k ] = v
2024-04-29 03:28:29 +00:00
2023-09-12 20:54:41 +00:00
except Exception as e :
2024-05-16 04:04:19 +00:00
tqdm . write ( f ' Error while processing { id } : { e } ' )
2023-08-27 00:53:23 +00:00
2024-04-29 03:28:29 +00:00
with open ( str ( metadata_path ) , " w " , encoding = " utf-8 " ) as f :
2023-08-27 00:53:23 +00:00
f . write ( json . dumps ( metadata ) )
2023-08-02 21:53:35 +00:00
# training
2024-05-08 07:11:38 +00:00
for data_dir in tqdm ( cfg . dataset . training , desc = " Processing Training " ) :
2023-08-02 21:53:35 +00:00
add ( data_dir , type = " training " )
# validation
2024-05-08 07:11:38 +00:00
for data_dir in tqdm ( cfg . dataset . validation , desc = ' Processing Validation ' ) :
2023-08-02 21:53:35 +00:00
add ( data_dir , type = " validation " )
2023-08-19 04:57:07 +00:00
# noise
2024-05-08 07:11:38 +00:00
for data_dir in tqdm ( cfg . dataset . noise , desc = ' Processing Noise ' ) :
2023-08-19 14:50:07 +00:00
add ( data_dir , type = " noise " , texts = False )
2023-08-19 04:57:07 +00:00
2023-08-02 21:53:35 +00:00
# write symmap
2023-08-27 00:53:23 +00:00
if " symmap " in hf :
del hf [ ' symmap ' ]
2023-08-02 21:53:35 +00:00
2023-10-12 01:38:40 +00:00
hf . create_dataset ( ' symmap ' , data = json . dumps ( symmap ) )
2023-08-02 21:53:35 +00:00
hf . close ( )
if __name__ == " __main__ " :
2023-08-17 20:04:45 +00:00
import argparse
parser = argparse . ArgumentParser ( " Save trained model to path. " )
2023-08-19 14:50:07 +00:00
parser . add_argument ( " --action " , type = str )
parser . add_argument ( " --tasks " , type = str )
2023-08-17 20:04:45 +00:00
args = parser . parse_args ( )
2023-08-19 14:50:07 +00:00
task = args . action
2023-08-19 04:55:40 +00:00
2023-08-20 11:29:17 +00:00
cfg . dataset . workers = 1
2023-08-27 00:53:23 +00:00
class LoggerOveride :
def info ( self , * args ) :
print ( * args )
_logger = LoggerOveride ( )
2023-08-19 14:50:07 +00:00
if args . action == " hdf5 " :
2023-08-17 20:04:45 +00:00
create_dataset_hdf5 ( )
2024-05-12 18:02:15 +00:00
elif args . action == " list-dataset " :
dataset = [ ]
for group in os . listdir ( cfg . data_dir ) :
for name in os . listdir ( cfg . data_dir / group ) :
if len ( os . listdir ( cfg . data_dir / group / name ) ) == 0 :
continue
dataset . append ( f ' { group } / { name } ' )
2024-05-16 04:04:19 +00:00
print ( json . dumps ( dataset ) )
2023-08-27 00:53:23 +00:00
elif args . action == " metadata " :
create_dataset_metadata ( )
2023-08-19 14:50:07 +00:00
elif args . action == " sample " :
2023-08-19 04:55:40 +00:00
train_dl , subtrain_dl , val_dl = create_train_val_dataloader ( )
samples = {
" training " : [ next ( iter ( train_dl ) ) , next ( iter ( train_dl ) ) ] ,
" evaluation " : [ next ( iter ( subtrain_dl ) ) , next ( iter ( subtrain_dl ) ) ] ,
" validation " : [ next ( iter ( val_dl ) ) , next ( iter ( val_dl ) ) ] ,
}
2024-05-12 12:30:59 +00:00
Path ( " ./data/sample-test/ " ) . mkdir ( parents = True , exist_ok = True )
for k , v in samples . items ( ) :
for i in range ( len ( v ) ) :
for j in tqdm ( range ( len ( v [ i ] [ ' proms ' ] ) ) , desc = " Decoding... " ) :
"""
try :
decode_to_file ( v [ i ] [ ' proms ' ] [ j ] , f " ./data/sample-test/ { k } . { i } . { j } .proms.wav " , device = " cpu " )
except Exception as e :
print ( f " Error while decoding prom { k } . { i } . { j } .wav: " , str ( e ) )
try :
decode_to_file ( v [ i ] [ ' resps ' ] [ j ] , f " ./data/sample-test/ { k } . { i } . { j } .resps.wav " , device = " cpu " )
except Exception as e :
print ( f " Error while decoding resp { k } . { i } . { j } .wav: " , str ( e ) )
"""
v [ i ] [ ' proms ' ] [ j ] = v [ i ] [ ' proms ' ] [ j ] . shape
v [ i ] [ ' resps ' ] [ j ] = v [ i ] [ ' resps ' ] [ j ] . shape
2023-08-19 04:55:40 +00:00
for k , v in samples . items ( ) :
for i in range ( len ( v ) ) :
2024-05-12 12:30:59 +00:00
print ( f ' { k } [ { i } ]: ' , v [ i ] )
2023-08-27 00:53:23 +00:00
2024-05-12 12:30:59 +00:00
#train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
2023-09-04 02:27:13 +00:00
2023-08-19 14:50:07 +00:00
elif args . action == " tasks " :
2023-08-19 04:55:40 +00:00
index = 0
2023-08-19 14:50:07 +00:00
cfg . dataset . tasks_list = args . tasks . split ( " , " )
train_dl , subtrain_dl , val_dl = create_train_val_dataloader ( )
batch = next ( iter ( train_dl ) )
2023-08-18 19:47:48 +00:00
2023-08-19 14:50:07 +00:00
for text , resps , proms , task in zip ( batch [ " text " ] , batch [ " resps " ] , batch [ " proms " ] , batch [ " task " ] ) :
if task not in cfg . dataset . tasks_list :
continue
2023-08-18 19:47:48 +00:00
2024-04-16 00:54:32 +00:00
print ( text , task , cfg . model . prom_levels )
2023-08-21 02:36:02 +00:00
print ( proms . shape , resps . shape )
2023-08-28 16:02:45 +00:00
tokens = 0
tokens + = sum ( [ text . shape [ 0 ] for text in batch [ " text " ] ] )
tokens + = sum ( [ resps . shape [ 0 ] for resps in batch [ " resps " ] ] )
print ( tokens )
2023-08-20 11:29:17 +00:00
decode_to_file ( proms , f " ./data/ { task } .proms.wav " , device = " cpu " )
decode_to_file ( resps , f " ./data/ { task } .resps.wav " , device = " cpu " )
2023-08-19 14:50:07 +00:00
break