2023-08-02 21:53:35 +00:00
# todo: clean this mess up
import copy
import h5py
import json
2024-09-21 17:29:28 +00:00
import re
2023-08-02 21:53:35 +00:00
import logging
import numpy as np
import os
import random
import torch
2023-08-27 00:53:23 +00:00
import itertools
2023-08-02 21:53:35 +00:00
from . config import cfg
2024-08-07 01:23:33 +00:00
from . emb . qnt import trim , trim_random , repeat_extend_audio , concat_audio , merge_audio , decode_to_file , decode as decode_qnt , encode as encode_qnt , pad_codes_with_silence
2024-09-09 14:57:32 +00:00
from . emb . g2p import encode as encode_phns
2024-06-29 03:28:54 +00:00
from . utils . sampler import PoolSampler , OrderedSampler , BatchedOrderedSampler , RandomSampler
2024-11-11 22:32:08 +00:00
from . utils . distributed import global_rank , local_rank , world_size , is_global_leader
2024-09-18 21:43:57 +00:00
from . utils . io import torch_save , torch_load , json_read , json_write , json_stringify , json_parse
2024-09-21 17:19:34 +00:00
from . utils import setup_logging
2023-08-02 21:53:35 +00:00
from collections import defaultdict
from functools import cache , cached_property
from itertools import groupby , zip_longest
from pathlib import Path
from typing import Any
from torch import Tensor
from torch . utils . data import DataLoader , Dataset as _Dataset
2023-08-14 03:07:45 +00:00
from torch . utils . data . distributed import DistributedSampler
2024-06-04 02:28:49 +00:00
from torch . nn . utils . rnn import pad_sequence
2023-08-02 21:53:35 +00:00
from tqdm . auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging . getLogger ( __name__ )
2024-12-05 02:31:44 +00:00
# cringe
try :
import nltk
2024-12-05 03:24:22 +00:00
nltk . data . path . append ( " ./.nltk/ " )
2024-12-05 02:31:44 +00:00
if not Path ( " .nltk " ) . exists ( ) :
nltk . download ( ' punkt_tab ' , download_dir = " ./.nltk/ " )
except Exception as e :
nltk = None
_logger . warning ( f " Error while querying for NTLK: { str ( e ) } " )
def sentence_split ( s , split_by = " sentences " , quote_placeholder = " <QUOTE> " ) :
if split_by is None :
return [ s ]
# NTLK is not available, fallback
if nltk is None :
split_by = " \n "
# split by delimiter instead
if split_by != " sentences " :
return s . split ( split_by )
# use NTLK to handle splitting by sentences, because I don't want to write my own parser to split by punctuation
# nltk does not split quotations all that nicely, so we coerce them into placeholders, then replace afterwards
s = s . replace ( ' " ' , quote_placeholder )
sentences = nltk . sent_tokenize ( s )
2024-12-08 20:52:47 +00:00
return [ sentence . replace ( quote_placeholder , ' " ' ) for sentence in sentences if sentence ]
2024-12-05 02:31:44 +00:00
2024-12-11 02:13:21 +00:00
# to-do: improve upon this since it's kind of ass
# this might be better to live in emb.g2p
def normalize_text ( s ) :
s = s . lower ( )
s = re . sub ( r ' [^ \ w \ s] ' , ' ' , s )
return s
2024-10-10 18:40:25 +00:00
@cache
2024-11-14 13:34:22 +00:00
def get_random_prompts ( validation = False , min_length = 0 , tokenized = False ) :
2024-10-11 00:40:01 +00:00
duration_range = [ 5.5 , 12.0 ] # to-do: pull from cfg.dataset.duration_range
2024-10-10 18:40:25 +00:00
sentences = [
" The birch canoe slid on the smooth planks. " ,
" Glue the sheet to the dark blue background. " ,
" It ' s easy to tell the depth of a well. " ,
" These days a chicken leg is a rare dish. " ,
" Rice is often served in round bowls. " ,
" The juice of lemons makes fine punch. " ,
" The box was thrown beside the parked truck. " ,
" The hogs were fed chopped corn and garbage. " ,
" Four hours of steady work faced us. " ,
" A large size in stockings is hard to sell. " ,
" The boy was there when the sun rose. " ,
" A rod is used to catch pink salmon. " ,
" The source of the huge river is the clear spring. " ,
" Kick the ball straight and follow through. " ,
" Help the woman get back to her feet. " ,
" A pot of tea helps to pass the evening. " ,
" Smoky fires lack flame and heat. " ,
" The soft cushion broke the man ' s fall. " ,
" The salt breeze came across from the sea. " ,
" The girl at the booth sold fifty bonds. " ,
" The small pup gnawed a hole in the sock. " ,
" The fish twisted and turned on the bent hook. " ,
" Press the pants and sew a button on the vest. " ,
" The swan dive was far short of perfect. " ,
" The beauty of the view stunned the young boy. " ,
" Two blue fish swam in the tank. " ,
" Her purse was full of useless trash. " ,
" The colt reared and threw the tall rider. " ,
" It snowed, rained, and hailed the same morning. " ,
" Read verse out loud for pleasure. " ,
2024-11-21 19:18:11 +00:00
" Perfect. Please move quickly to the chamber lock, as the effect of prolonged exposure to the button are not part of this test. " ,
2024-10-10 18:40:25 +00:00
]
2024-11-21 19:18:11 +00:00
harvard_sentences_path = Path ( " ./data/harvard_sentences.txt " )
if harvard_sentences_path . exists ( ) :
sentences = open ( harvard_sentences_path , " r " , encoding = " utf-8 " ) . read ( ) . split ( " \n " )
2024-10-10 18:40:25 +00:00
# Pull from validation dataset if existing + requested
if validation and cfg . dataset . validation :
2024-10-10 18:52:37 +00:00
paths = _load_paths ( cfg . dataset . validation , type = " validation " , silent = True )
2024-10-10 18:40:25 +00:00
paths = list ( itertools . chain . from_iterable ( paths . values ( ) ) )
for path in paths :
2024-10-11 00:04:12 +00:00
duration = 0
2024-10-10 18:40:25 +00:00
text_string = " "
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
metadata = { f ' { k } ' : f ' { v } ' for k , v in cfg . hdf5 [ key ] . attrs . items ( ) }
2024-10-11 00:04:12 +00:00
metadata = process_artifact_metadata ( { " metadata " : metadata } )
2024-10-10 18:40:25 +00:00
text_string = metadata [ " text " ] if " text " in metadata else " "
2024-10-11 00:04:12 +00:00
duration = metadata [ ' duration ' ] if " duration " in metadata else 0
2024-10-10 18:40:25 +00:00
else :
2024-12-12 02:55:43 +00:00
_ , metadata = _load_artifact ( path , return_metadata = True )
2024-10-11 00:04:12 +00:00
metadata = process_artifact_metadata ( { " metadata " : metadata } )
2024-10-10 18:40:25 +00:00
text_string = metadata [ " text " ] if " text " in metadata else " "
2024-10-11 00:04:12 +00:00
duration = metadata [ ' duration ' ] if " duration " in metadata else 0
2024-10-10 18:40:25 +00:00
2024-10-11 00:40:01 +00:00
if len ( text_string ) < min_length or not ( duration_range [ 0 ] < = duration and duration < = duration_range [ 1 ] ) :
2024-10-10 18:40:25 +00:00
continue
sentences . append ( text_string )
2024-10-11 00:04:12 +00:00
# tokenize here because our harvard sentences need to be phonemized anyways
2024-10-10 18:40:25 +00:00
if tokenized :
return [ torch . tensor ( tokenize ( encode_phns ( text ) ) ) . to ( dtype = torch . uint8 ) for text in sentences ]
return sentences
# samples a random text prompt
def get_random_prompt ( * args , * * kwargs ) :
# Harvard sentences
return random . choice ( get_random_prompts ( * args , * * kwargs ) )
2024-06-04 02:28:49 +00:00
# fold into a typical LLM sequence (one embedding rather than split embeddings)
def fold_inputs (
text_list = [ ] ,
2024-07-27 20:36:05 +00:00
lang_list = [ ] ,
task_list = [ ] ,
tone_list = [ ] ,
2024-06-04 02:28:49 +00:00
prom_list = [ ] ,
resp_list = [ ] ,
2024-06-04 23:30:30 +00:00
targ_list = [ ] ,
2024-06-04 02:28:49 +00:00
ignore_index = None ,
sep = 3 ,
stop = 3 ,
2024-07-27 20:36:05 +00:00
config = None ,
2024-06-04 02:28:49 +00:00
2024-06-04 19:19:52 +00:00
quant_levels = None ,
2024-06-04 02:28:49 +00:00
) :
2024-07-27 20:36:05 +00:00
if config is None :
config = cfg . model
2024-06-04 02:28:49 +00:00
def _create_mask ( l , device ) :
seq = torch . arange ( max ( l ) , device = device ) . unsqueeze ( 0 ) # (1 t)
stop = torch . tensor ( l , device = device ) . unsqueeze ( 1 ) # (b 1)
return ( seq < stop ) . float ( ) # (b t)
2024-08-10 02:15:01 +00:00
def list_to_tensor ( x_list : list [ Tensor ] , mask = True ) :
2024-06-04 02:28:49 +00:00
l = list ( map ( len , x_list ) )
x = pad_sequence ( x_list ) . t ( )
2024-08-10 02:15:01 +00:00
if not mask :
return x
2024-06-04 02:28:49 +00:00
m = _create_mask ( l , x_list [ 0 ] . device )
m = m . to ( x )
return x , m
2024-07-27 20:36:05 +00:00
def process_prom_or_task ( i , prom ) :
if prom is None :
2024-08-10 02:15:01 +00:00
return 0
2024-07-27 20:36:05 +00:00
if isinstance ( prom , str ) :
task = get_task_symmap ( ) [ f ' < { input } > ' ]
2024-08-04 03:10:21 +00:00
seq = torch . tensor ( [ task_start + task ] , device = device , dtype = dtype )
2024-07-27 20:36:05 +00:00
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( sep )
2024-08-10 02:15:01 +00:00
return seq . shape [ 0 ] + 1
2024-07-27 20:36:05 +00:00
# deinterleaved
if quant_levels is not None :
quant_level = quant_levels [ i ]
if ignore_index is not None :
2024-08-04 03:10:21 +00:00
seq = torch . tensor ( [ ignore_index for _ in range ( prom . shape [ 0 ] ) ] , device = device , dtype = dtype )
2024-07-27 20:36:05 +00:00
else :
seq = prom [ : , quant_level ] . to ( device = device , dtype = dtype ) . clone ( )
for idx , token in enumerate ( seq ) :
token + = prom_start + ( config . audio_tokens * quant_level )
# interleaved
else :
if ignore_index is not None :
2024-08-04 03:10:21 +00:00
seq = torch . tensor ( [ ignore_index for _ in range ( prom . shape [ 0 ] * prom . shape [ 1 ] ) ] , device = device , dtype = dtype )
2024-07-27 20:36:05 +00:00
else :
seq = prom . flatten ( ) . to ( device = device , dtype = dtype )
for idx , token in enumerate ( seq ) :
token + = prom_start + ( config . audio_tokens * ( idx % config . resp_levels ) )
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( sep )
2024-08-10 02:15:01 +00:00
return seq . shape [ 0 ] + 1
def generate_position_ids ( length , sep = True ) :
return [ i for i in range ( length + ( 1 if sep else 0 ) ) ]
2024-07-27 20:36:05 +00:00
"""
if quant_levels is not None :
resps_list = [ [ ] if l == 0 else resp for l , resp in zip ( quant_levels , resp_list ) ]
"""
2024-06-04 02:28:49 +00:00
device = text_list [ 0 ] . device
2024-07-27 20:36:05 +00:00
dtype = torch . int64
2024-06-04 02:28:49 +00:00
batch_size = len ( text_list )
input_ids = [ [ ] for _ in range ( batch_size ) ]
2024-08-10 02:15:01 +00:00
position_ids = [ [ ] for _ in range ( batch_size ) ]
2024-06-04 02:28:49 +00:00
offset = 0
2024-08-04 03:10:21 +00:00
sep = torch . tensor ( [ sep ] , device = device , dtype = dtype )
stop = torch . tensor ( [ stop ] , device = device , dtype = dtype )
2024-07-27 20:36:05 +00:00
text_start = 0
text_end = text_start + config . text_tokens
2024-06-04 02:28:49 +00:00
2024-07-27 20:36:05 +00:00
lang_start = text_end
lang_end = lang_start + config . langs
rvq_start = lang_end
rvq_end = rvq_start + config . resp_levels
prom_start = rvq_end
prom_end = prom_start + config . audio_tokens * config . resp_levels
task_start = prom_end
task_end = task_start + config . tasks
tone_start = task_end
tone_end = tone_start + config . tones
resp_start = tone_end
resp_end = resp_start + config . audio_tokens * config . resp_levels
# text tokens
2024-06-04 02:28:49 +00:00
for i , text in enumerate ( text_list ) :
2024-07-27 20:36:05 +00:00
if isinstance ( text , torch . Tensor ) :
seq = text + text_start
else :
2024-08-04 03:10:21 +00:00
seq = torch . tensor ( [ text_start + text ] , device = device , dtype = dtype )
2024-08-10 02:15:01 +00:00
2024-07-27 20:36:05 +00:00
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( sep )
2024-08-10 02:15:01 +00:00
position_ids [ i ] . append ( generate_position_ids ( seq . shape [ 0 ] ) )
2024-07-27 20:36:05 +00:00
# lang tokens
for i , lang in enumerate ( lang_list ) :
if isinstance ( lang , torch . Tensor ) :
seq = lang + lang_start
else :
2024-08-04 03:10:21 +00:00
seq = torch . tensor ( [ lang_start + lang ] , device = device , dtype = dtype )
2024-08-10 02:15:01 +00:00
2024-06-04 02:28:49 +00:00
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( sep )
2024-08-10 02:15:01 +00:00
position_ids [ i ] . append ( generate_position_ids ( seq . shape [ 0 ] ) )
2024-06-04 02:28:49 +00:00
2024-06-05 01:41:13 +00:00
# inject target quant_level
if quant_levels is not None :
for i , rvq in enumerate ( quant_levels ) :
2024-07-27 20:36:05 +00:00
if isinstance ( rvq , torch . Tensor ) :
seq = rvq + rvq_start
else :
2024-08-04 03:10:21 +00:00
seq = torch . tensor ( [ rvq_start + rvq ] , device = device , dtype = dtype )
2024-06-05 01:41:13 +00:00
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( sep )
2024-08-10 02:15:01 +00:00
position_ids [ i ] . append ( generate_position_ids ( seq . shape [ 0 ] ) )
2024-07-27 20:36:05 +00:00
# prom / task tokens
2024-06-04 02:28:49 +00:00
for i , prom in enumerate ( prom_list ) :
2024-07-27 20:36:05 +00:00
# list of proms with a possible task token
2024-08-10 02:15:01 +00:00
length = 0
2024-07-27 20:36:05 +00:00
if isinstance ( prom , list ) :
for p in prom :
2024-08-10 02:15:01 +00:00
length + = process_prom_or_task ( i , p )
2024-07-27 20:36:05 +00:00
# raw tensor
2024-06-04 02:28:49 +00:00
else :
2024-08-10 02:15:01 +00:00
length + = process_prom_or_task ( i , prom )
position_ids [ i ] . append ( generate_position_ids ( length , sep = False ) )
2024-06-04 02:28:49 +00:00
2024-07-27 20:36:05 +00:00
# tone tokens
for i , tone in enumerate ( tone_list ) :
if isinstance ( tone , torch . Tensor ) :
seq = tone + tone_start
else :
2024-08-04 03:10:21 +00:00
seq = torch . tensor ( [ tone_start + tone ] , device = device , dtype = dtype )
2024-06-04 02:28:49 +00:00
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( sep )
2024-06-04 23:30:30 +00:00
2024-08-10 02:15:01 +00:00
position_ids [ i ] . append ( generate_position_ids ( seq . shape [ 0 ] ) )
2024-07-27 20:36:05 +00:00
# resp tokens
2024-06-04 02:28:49 +00:00
for i , resp in enumerate ( resp_list ) :
2024-06-04 23:30:30 +00:00
# deinterleaved
if quant_levels is not None :
# grab the previous rvq level
quant_level = quant_levels [ i ] - 1
2024-06-04 23:40:30 +00:00
# way to signal we want to inference for rvq level 0
2024-07-27 20:36:05 +00:00
# without it, it's a random chance for any level to be selected again
2024-06-04 23:30:30 +00:00
if quant_level < 0 :
2024-06-05 01:41:13 +00:00
continue
2024-06-04 23:30:30 +00:00
else :
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
if isinstance ( resp , list ) :
2024-07-27 20:36:05 +00:00
seq = resp [ quant_level ] . to ( device = device , dtype = dtype ) . clone ( )
2024-06-04 23:30:30 +00:00
else :
2024-07-27 20:36:05 +00:00
seq = resp [ : , quant_level ] . to ( device = device , dtype = dtype ) . clone ( )
2024-06-04 23:30:30 +00:00
for idx , token in enumerate ( seq ) :
2024-07-27 20:36:05 +00:00
token + = resp_start + ( config . audio_tokens * quant_level )
2024-06-04 23:30:30 +00:00
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( stop )
2024-08-10 02:15:01 +00:00
position_ids [ i ] . append ( generate_position_ids ( seq . shape [ 0 ] ) )
2024-06-04 23:30:30 +00:00
# interleaved
else :
2024-07-27 20:36:05 +00:00
seq = resp . flatten ( ) . to ( device = device , dtype = dtype )
2024-06-04 23:30:30 +00:00
for idx , token in enumerate ( seq ) :
2024-07-27 20:36:05 +00:00
token + = resp_start + ( config . audio_tokens * ( idx % config . resp_levels ) )
2024-06-04 23:30:30 +00:00
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( stop )
2024-08-10 02:15:01 +00:00
position_ids [ i ] . append ( generate_position_ids ( seq . shape [ 0 ] ) )
2024-06-04 23:30:30 +00:00
2024-07-27 20:36:05 +00:00
# targ list
2024-06-04 23:30:30 +00:00
for i , resp in enumerate ( targ_list ) :
# deinterleaved
2024-06-04 19:19:52 +00:00
if quant_levels is not None :
quant_level = quant_levels [ i ]
2024-07-27 20:36:05 +00:00
seq = resp [ : , quant_level ] . to ( device = device , dtype = dtype )
2024-06-04 19:19:52 +00:00
for idx , token in enumerate ( seq ) :
2024-07-27 20:36:05 +00:00
token + = resp_start + ( config . audio_tokens * quant_level )
2024-06-04 19:19:52 +00:00
input_ids [ i ] . append ( seq )
2024-06-04 23:30:30 +00:00
input_ids [ i ] . append ( stop )
2024-08-10 02:15:01 +00:00
position_ids [ i ] . append ( generate_position_ids ( seq . shape [ 0 ] ) )
2024-06-04 23:30:30 +00:00
# interleaved
2024-06-04 19:19:52 +00:00
else :
2024-07-27 20:36:05 +00:00
seq = resp . flatten ( ) . to ( device = device , dtype = dtype )
2024-06-04 19:19:52 +00:00
for idx , token in enumerate ( seq ) :
2024-07-27 20:36:05 +00:00
token + = resp_start + ( config . audio_tokens * ( idx % config . resp_levels ) )
2024-06-04 19:19:52 +00:00
input_ids [ i ] . append ( seq )
input_ids [ i ] . append ( stop )
2024-08-10 02:15:01 +00:00
position_ids [ i ] . append ( generate_position_ids ( seq . shape [ 0 ] ) )
2024-06-04 02:28:49 +00:00
for i , batch in enumerate ( input_ids ) :
2024-07-27 20:36:05 +00:00
input_ids [ i ] = torch . concat ( input_ids [ i ] , dim = - 1 ) . to ( device = device , dtype = dtype )
2024-08-10 02:15:01 +00:00
position_ids [ i ] = torch . concat ( [ torch . tensor ( ids , device = device , dtype = dtype ) for ids in position_ids [ i ] ] , dim = - 1 )
input_ids , attention_mask = list_to_tensor ( input_ids )
position_ids = list_to_tensor ( position_ids , mask = False )
2024-06-04 02:28:49 +00:00
2024-08-10 02:15:01 +00:00
return input_ids , attention_mask , position_ids
2024-06-04 02:28:49 +00:00
# unfold from one unified token ID space to separate token spaces
2024-06-04 19:19:52 +00:00
# to-do: unfold at a specific RVQ level instead if requested
2024-06-04 02:28:49 +00:00
def unfold_outputs (
output_ids ,
sep = 3 ,
stop = 3 ,
2024-07-27 20:36:05 +00:00
config = None ,
2024-06-04 19:19:52 +00:00
quant_levels = None ,
2024-06-04 02:28:49 +00:00
) :
2024-07-27 20:36:05 +00:00
def bin_to_rvqs ( tokens ) :
length = len ( tokens )
"""
if length % config . resp_levels == 0 :
2024-08-04 03:10:21 +00:00
tokens = torch . tensor ( tokens ) . reshape ( config . resp_levels , length / / config . resp_levels ) . t ( )
2024-07-27 20:36:05 +00:00
"""
bins = [ [ ] for _ in range ( config . resp_levels ) ]
for pos in range ( length ) :
rvq = pos % config . resp_levels
bins [ rvq ] . append ( tokens [ pos ] )
nearest = ( len ( bins ) / / config . resp_levels ) * config . resp_levels
bins = bins [ : nearest ]
2024-08-04 03:10:21 +00:00
return torch . tensor ( bins , device = device , dtype = dtype ) . t ( )
2024-07-27 20:36:05 +00:00
if config is None :
config = cfg . model
2024-06-04 02:28:49 +00:00
device = output_ids . device
2024-07-27 20:36:05 +00:00
dtype = torch . int64
2024-06-04 02:28:49 +00:00
batch_size = output_ids . shape [ 0 ]
text_list = [ [ ] for _ in range ( batch_size ) ]
2024-07-27 20:36:05 +00:00
rvq_list = [ [ ] for _ in range ( batch_size ) ]
lang_list = [ [ ] for _ in range ( batch_size ) ]
task_list = [ [ ] for _ in range ( batch_size ) ]
tone_list = [ [ ] for _ in range ( batch_size ) ]
2024-06-04 02:28:49 +00:00
prom_list = [ [ ] for _ in range ( batch_size ) ]
resp_list = [ [ ] for _ in range ( batch_size ) ]
2024-07-27 20:36:05 +00:00
text_start = 0
text_end = text_start + config . text_tokens
lang_start = text_end
lang_end = lang_start + config . langs
rvq_start = lang_end
rvq_end = rvq_start + config . resp_levels
prom_start = rvq_end
prom_end = prom_start + config . audio_tokens * config . resp_levels
task_start = prom_end
task_end = task_start + config . tasks
tone_start = task_end
tone_end = tone_start + config . tones
resp_start = tone_end
resp_end = resp_start + config . audio_tokens * config . resp_levels
2024-06-04 02:28:49 +00:00
for i , batch in enumerate ( output_ids ) :
2024-07-27 20:36:05 +00:00
# cringe logic to handle prefix resp for rvq levels > 0
2024-06-04 23:30:30 +00:00
# a better way is to observe if the rvq level increased
should_flush = False
flushed = False
2024-06-04 02:28:49 +00:00
for idx , token in enumerate ( batch ) :
id = token . item ( )
if id == sep or id == stop :
2024-06-04 23:30:30 +00:00
if should_flush and quant_levels is not None and quant_levels [ i ] > 0 :
resp_list [ i ] = [ ]
should_flush = False
flushed = True
2024-06-04 02:28:49 +00:00
continue
2024-07-27 20:36:05 +00:00
# text tokens
if text_start < = id and id < text_end :
text_list [ i ] . append ( ( id - text_start ) % config . text_tokens )
# lang tokens
elif lang_start < = id and id < lang_end :
lang_list [ i ] . append ( ( id - lang_start ) % config . langs )
# rvq levels
elif rvq_start < = id and id < rvq_end :
rvq_list [ i ] . append ( ( id - rvq_start ) % config . resp_levels )
# prom tokens
elif prom_start < = id and id < prom_end :
prom_list [ i ] . append ( ( id - prom_start ) % config . audio_tokens )
# task tokens
elif task_start < = id and id < task_end :
task_list [ i ] . append ( ( id - task_start ) % config . tasks )
# lang tokens
elif tone_start < = id and id < tone_end :
tone_list [ i ] . append ( ( id - tone_start ) % config . tones )
# resp tokens
elif resp_start < = id and id < resp_end :
resp_list [ i ] . append ( ( id - resp_start ) % config . audio_tokens )
2024-06-04 23:30:30 +00:00
if not flushed :
should_flush = True
2024-06-04 02:28:49 +00:00
2024-06-04 19:19:52 +00:00
if quant_levels is not None :
2024-08-04 03:10:21 +00:00
prom_list [ i ] = torch . tensor ( prom_list [ i ] , device = device , dtype = dtype ) . t ( )
resp_list [ i ] = torch . tensor ( resp_list [ i ] , device = device , dtype = dtype ) . t ( )
2024-06-04 02:28:49 +00:00
else :
2024-07-27 20:36:05 +00:00
prom_list [ i ] = bin_to_rvqs ( prom_list [ i ] )
resp_list [ i ] = bin_to_rvqs ( resp_list [ i ] )
2024-08-04 03:10:21 +00:00
text_list [ i ] = torch . tensor ( text_list [ i ] , device = device , dtype = dtype )
task_list [ i ] = torch . tensor ( task_list [ i ] , device = device , dtype = dtype )
lang_list [ i ] = torch . tensor ( lang_list [ i ] , device = device , dtype = dtype )
tone_list [ i ] = torch . tensor ( tone_list [ i ] , device = device , dtype = dtype )
2024-06-04 02:28:49 +00:00
return dict (
text_list = text_list ,
prom_list = prom_list ,
2024-07-27 20:36:05 +00:00
resp_list = resp_list ,
task_list = task_list ,
lang_list = lang_list ,
tone_list = tone_list ,
2024-06-04 02:28:49 +00:00
)
2024-04-16 00:54:32 +00:00
# to-do: clean up this symmap mess
2023-08-02 21:53:35 +00:00
def get_phone_symmap ( ) :
2024-04-21 19:49:18 +00:00
return cfg . tokenizer . get_vocab ( )
2023-08-02 21:53:35 +00:00
2024-04-21 19:49:18 +00:00
def tokenize ( phones ) :
2024-08-07 01:23:33 +00:00
if isinstance ( phones , list ) :
phones = " " . join ( phones )
return cfg . tokenizer . encode ( phones )
2023-10-12 01:38:40 +00:00
def get_lang_symmap ( ) :
2024-04-16 00:54:32 +00:00
return {
2023-10-12 01:38:40 +00:00
" en " : 0 ,
" ja " : 1 ,
2024-09-19 00:36:03 +00:00
" de " : 2 ,
" fr " : 3 ,
2024-12-09 20:26:19 +00:00
" zh " : 4 , # mandarin I presume
" ko " : 5 ,
2023-10-12 01:38:40 +00:00
}
2024-04-16 00:54:32 +00:00
def get_tone_symmap ( ) :
return {
" neutral " : 0 ,
2024-12-09 20:26:19 +00:00
# could use 4 instead of 8 basic emotions
# "joy": 1,
# "fear": 2,
# "surprise": 3,
# "anger": 4,
2024-04-16 00:54:32 +00:00
}
2023-08-02 21:53:35 +00:00
2023-08-19 03:22:13 +00:00
def get_task_symmap ( ) :
2024-04-16 00:54:32 +00:00
return {
2023-10-12 01:38:40 +00:00
" <tts> " : 0 ,
" <tts-c> " : 1 ,
" <ns> " : 2 ,
" <sr> " : 3 ,
" <tse> " : 4 ,
" <soe> " : 5 ,
" <mask> " : 6 ,
" <eoe> " : 7 ,
2024-09-06 01:43:20 +00:00
" <stt> " : 8 ,
2024-07-19 04:25:32 +00:00
2024-11-10 18:48:41 +00:00
" <len> " : 0 , # fake
2024-07-19 04:25:32 +00:00
" <nse> " : 6 , # fake
" <cse> " : 6 , # fake
2023-08-19 03:22:13 +00:00
}
2023-08-02 21:53:35 +00:00
def _replace_file_extension ( path , suffix ) :
2024-11-03 02:00:21 +00:00
if not isinstance ( path , Path ) :
path = Path ( path )
2023-08-02 21:53:35 +00:00
return ( path . parent / path . name . split ( " . " ) [ 0 ] ) . with_suffix ( suffix )
2024-12-12 02:55:43 +00:00
def _get_artifact_extension ( ) :
2024-05-25 16:07:52 +00:00
return " .dac " if cfg . audio_backend == " dac " else " .enc "
2024-04-19 02:24:06 +00:00
2024-12-12 02:55:43 +00:00
def _get_metadata_extension ( ) :
return " .json "
2024-04-19 02:24:06 +00:00
2024-12-12 02:55:43 +00:00
def _get_artifact_path ( path ) :
return _replace_file_extension ( path , _get_artifact_extension ( ) )
2023-08-27 00:53:23 +00:00
2024-06-14 03:37:34 +00:00
_durations_map = { }
def _get_duration_map ( type = " training " ) :
return _durations_map [ type ] if type in _durations_map else { }
2023-09-12 20:54:41 +00:00
2024-11-14 00:04:04 +00:00
def _load_paths ( dataset , type = " training " , silent = not is_global_leader ( ) , dataset_hash_key = None ) :
2024-11-11 23:00:49 +00:00
if not dataset_hash_key :
dataset_hash_key = cfg . dataset . hash_key ( sorted ( dataset ) )
cached_dir = cfg . cache_dir / dataset_hash_key
2024-11-11 22:32:08 +00:00
cached_durations_path = cached_dir / f " durations[ { type } ].json "
cached_paths_path = cached_dir / f " dataloader[ { type } ].json "
# load the duration table first, since this is independent from the loaded paths
if cached_durations_path . exists ( ) :
_durations_map [ type ] = json_read ( cached_durations_path )
# load the cached valid paths (if we're requesting cache use)
if cached_paths_path . exists ( ) and cfg . dataset . cache :
# to-do: automatic conversion between HDF5 formatted paths and on-disk paths
return json_read ( cached_paths_path )
# deduce valid paths
paths = { cfg . get_spkr ( cfg . data_dir / data_dir / " dummy " ) : _load_paths_from_metadata ( data_dir , type = type , validate = cfg . dataset . validate and type == " training " ) for data_dir in tqdm ( dataset , desc = f " Parsing dataset: { type } " , disable = silent ) }
# and write if global leader (to avoid other processes writing to the same file at once)
if is_global_leader ( ) :
if not cached_dir . exists ( ) :
cached_dir . mkdir ( parents = True , exist_ok = True )
json_write ( _durations_map [ type ] , cached_durations_path , truncate = True )
json_write ( paths , cached_paths_path , truncate = True )
return paths
2024-04-29 03:28:29 +00:00
2024-05-16 04:04:19 +00:00
def _load_paths_from_metadata ( group_name , type = " training " , validate = False ) :
data_dir = group_name if cfg . dataset . use_hdf5 else cfg . data_dir / group_name
2023-08-27 00:53:23 +00:00
_fn = _get_hdf5_paths if cfg . dataset . use_hdf5 else _get_paths_of_extensions
2024-05-16 04:04:19 +00:00
def key ( id , entry = None ) :
2024-11-12 02:21:16 +00:00
return f " / { type } / { _get_hdf5_path ( data_dir ) } / { id } " if cfg . dataset . use_hdf5 else str ( data_dir / id )
2024-05-12 03:58:38 +00:00
2024-05-16 04:04:19 +00:00
metadata_path = cfg . metadata_dir / f ' { group_name } .json '
2023-10-17 00:30:38 +00:00
metadata = { }
2024-04-29 03:28:29 +00:00
2023-10-17 00:30:38 +00:00
if cfg . dataset . use_metadata and metadata_path . exists ( ) :
2024-09-18 03:57:04 +00:00
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
2024-12-12 06:31:58 +00:00
try :
metadata = json_read ( metadata_path )
except Exception as e :
return [ ]
2023-08-27 00:53:23 +00:00
2023-10-17 00:30:38 +00:00
if len ( metadata ) == 0 :
2024-12-12 02:55:43 +00:00
return _fn ( data_dir , type if cfg . dataset . use_hdf5 else _get_artifact_extension ( ) , validate )
2023-08-27 00:53:23 +00:00
2024-05-16 04:04:19 +00:00
def _validate ( id , entry ) :
2024-05-12 03:58:38 +00:00
phones = entry [ ' phones ' ] if " phones " in entry else 0
duration = entry [ ' duration ' ] if " duration " in entry else 0
2024-06-14 03:37:34 +00:00
# add to duration bucket
k = key ( id , entry )
if type not in _durations_map :
_durations_map [ type ] = { }
_durations_map [ type ] [ k ] = duration
if not validate :
return True
return cfg . dataset . min_duration < = duration and duration < = cfg . dataset . max_duration
2023-08-27 00:53:23 +00:00
2024-06-14 03:37:34 +00:00
return [ key ( id , entry ) for id , entry in metadata . items ( ) if _validate ( id , entry ) ]
2023-08-27 00:53:23 +00:00
2023-08-02 21:53:35 +00:00
def _get_hdf5_path ( path ) :
2024-04-29 03:28:29 +00:00
# to-do: better validation
return str ( path )
2023-08-02 21:53:35 +00:00
2023-08-27 00:53:23 +00:00
def _get_hdf5_paths ( data_dir , type = " training " , validate = False ) :
data_dir = str ( data_dir )
2024-06-14 03:37:34 +00:00
key = f " / { type } / { _get_hdf5_path ( data_dir ) } "
2023-08-02 21:53:35 +00:00
2024-05-16 04:04:19 +00:00
def _validate ( id , entry ) :
phones = entry . attrs [ ' phonemes ' ]
duration = entry . attrs [ ' duration ' ]
2023-08-27 00:53:23 +00:00
2024-06-14 03:37:34 +00:00
if type not in _durations_map :
_durations_map [ type ] = { }
_durations_map [ type ] [ f " { key } / { id } " ] = duration
if not validate :
return True
return cfg . dataset . min_duration < = duration and duration < = cfg . dataset . max_duration
return [ Path ( f " { key } / { id } " ) for id , entry in cfg . hdf5 [ key ] . items ( ) if _validate ( id , entry ) ] if key in cfg . hdf5 else [ ]
2023-08-27 00:53:23 +00:00
2024-12-12 02:55:43 +00:00
def _get_paths_of_extensions ( path , extensions = _get_artifact_extension ( ) , validate = False ) :
2023-08-27 00:53:23 +00:00
if isinstance ( path , str ) :
path = Path ( path )
2024-11-12 02:21:16 +00:00
return [ p for p in list ( path . iterdir ( ) ) ] if path . exists ( ) and path . is_dir ( ) else [ ]
2023-08-02 21:53:35 +00:00
2024-12-12 02:55:43 +00:00
def _load_artifact ( path , return_metadata = False ) - > Tensor :
qnt = np . load ( _get_artifact_path ( path ) , allow_pickle = True ) [ ( ) ]
2024-05-18 12:14:26 +00:00
if return_metadata :
return torch . from_numpy ( qnt [ " codes " ] . astype ( int ) ) [ 0 ] [ : , : ] . t ( ) . to ( torch . int16 ) , qnt [ " metadata " ]
return torch . from_numpy ( qnt [ " codes " ] . astype ( int ) ) [ 0 ] [ : , : ] . t ( ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
def _interleaved_reorder ( l , fn ) :
groups = defaultdict ( list )
for e in l :
groups [ fn ( e ) ] . append ( e )
groups = { k : groups [ k ] for k in sorted ( groups ) }
for interleaved in zip_longest ( * groups . values ( ) ) :
for value in interleaved :
if value is not None :
yield value
class Dataset ( _Dataset ) :
def __init__ (
self ,
phone_symmap = None ,
training = False ,
extra_paths_by_spkr_name : dict [ str , list ] = { } ,
) :
super ( ) . __init__ ( )
self . _head = None
2023-08-19 01:58:07 +00:00
self . sampler = None
2023-08-02 21:53:35 +00:00
2023-08-27 00:53:23 +00:00
self . paths = [ ]
self . training = training
self . dataset_type = " training " if self . training else " validation "
2024-11-11 23:00:49 +00:00
self . dataset = sorted ( cfg . dataset . training if self . training else cfg . dataset . validation )
2024-07-25 17:39:57 +00:00
self . sampler_type = cfg . dataset . sample_type if self . dataset_type == " training " else " path "
2024-06-14 03:37:34 +00:00
self . sampler_order = cfg . dataset . sample_order
2024-09-21 17:19:34 +00:00
self . sampler_shuffle = cfg . dataset . sample_shuffle if self . dataset_type == " training " else True
2023-08-30 23:23:05 +00:00
2024-11-11 23:00:49 +00:00
self . dataset_hash_key = cfg . dataset . hash_key ( sorted ( self . dataset ) )
2024-11-18 18:46:50 +00:00
self . duration = 0
self . duration_buckets = { }
self . current_index = 0
self . batch_size = cfg . hyperparameters . batch_size if self . training else cfg . evaluation . batch_size
2024-11-11 23:00:49 +00:00
2023-08-30 23:23:05 +00:00
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
if len ( self . dataset ) == 0 :
self . dataset = cfg . dataset . training
2024-11-02 01:54:53 +00:00
# hard error because I kept getting tricked by this myself
if self . sampler_order == " duration " and self . sampler_type != " path " :
raise Exception ( f ' Requesting sample_type= { self . sampler_type } with sample_order= { self . sampler_order } , yet combination will not give expected results. ' )
2023-08-27 00:53:23 +00:00
2023-10-17 00:30:38 +00:00
# dict of paths keyed by speaker names
2024-11-11 23:00:49 +00:00
self . paths_by_spkr_name = _load_paths ( self . dataset , self . dataset_type , dataset_hash_key = self . dataset_hash_key )
2024-11-11 22:32:08 +00:00
self . duration_map = _get_duration_map ( self . dataset_type )
2023-09-12 20:54:41 +00:00
2024-11-18 18:46:50 +00:00
# cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
if cfg . dataset . min_utterances > 0 or cfg . dataset . max_utterances > 0 :
2023-09-12 20:54:41 +00:00
keys = list ( self . paths_by_spkr_name . keys ( ) )
for key in keys :
if len ( self . paths_by_spkr_name [ key ] ) < cfg . dataset . min_utterances :
del self . paths_by_spkr_name [ key ]
2024-11-18 20:12:26 +00:00
continue
2023-09-12 20:54:41 +00:00
2024-11-18 18:46:50 +00:00
# slice away extraneous utterances
if cfg . dataset . max_utterances :
self . paths_by_spkr_name [ key ] = self . paths_by_spkr_name [ key ] [ : cfg . dataset . max_utterances ]
2024-06-14 03:37:34 +00:00
# flatten paths
2023-08-27 00:53:23 +00:00
self . paths = list ( itertools . chain . from_iterable ( self . paths_by_spkr_name . values ( ) ) )
2024-06-01 14:29:49 +00:00
# split dataset accordingly per GPU
2024-06-14 03:37:34 +00:00
if cfg . distributed and self . training :
2024-06-29 15:10:35 +00:00
self . paths = [ path for i , path in enumerate ( self . paths ) if i % world_size ( ) == 0 ]
2024-06-01 14:29:49 +00:00
# recreate paths_by_spkr_name
self . paths_by_spkr_name = { }
for path in self . paths :
2024-06-01 15:30:13 +00:00
name = cfg . get_spkr ( Path ( path ) )
if name not in self . paths_by_spkr_name :
2024-06-01 14:29:49 +00:00
self . paths_by_spkr_name [ name ] = [ ]
self . paths_by_spkr_name [ name ] . append ( path )
2024-06-14 03:37:34 +00:00
# store in corresponding bucket
for path in self . paths :
duration = self . duration_map [ path ]
self . duration + = duration
2024-07-22 00:12:03 +00:00
# only calc duration if we're going to order by duration
2024-06-14 03:37:34 +00:00
if self . sampler_order != " duration " :
continue
2024-06-29 03:28:54 +00:00
bucket = int ( round ( duration ) )
2024-06-14 03:37:34 +00:00
if bucket not in self . duration_buckets :
self . duration_buckets [ bucket ] = [ ]
self . duration_buckets [ bucket ] . append ( ( Path ( path ) , duration ) )
# sort by duration
if self . sampler_order == " duration " :
2024-11-14 00:04:04 +00:00
# ensure they're ordered
self . duration_buckets = dict ( sorted ( self . duration_buckets . items ( ) ) )
2024-06-29 03:28:54 +00:00
flattened = { }
2024-06-14 03:37:34 +00:00
# sort and interleave
for bucket in self . duration_buckets :
# sort by duration
self . duration_buckets [ bucket ] . sort ( key = lambda x : x [ 1 ] )
2024-06-29 03:28:54 +00:00
# split to retain tuples
flattened [ bucket ] = self . duration_buckets [ bucket ]
2024-06-14 03:37:34 +00:00
# replace with path
2024-06-29 03:28:54 +00:00
flattened [ bucket ] = [ x [ 0 ] for x in flattened [ bucket ] ]
2024-06-14 03:37:34 +00:00
# flatten by paths
2024-06-29 03:28:54 +00:00
flattened [ bucket ] = [ * _interleaved_reorder ( flattened [ bucket ] , self . get_speaker ) ]
2024-06-14 03:37:34 +00:00
# flatten paths
2024-06-29 03:28:54 +00:00
self . paths = list ( itertools . chain . from_iterable ( flattened . values ( ) ) )
2024-10-22 23:12:39 +00:00
elif self . sampler_order == " random " :
random . shuffle ( self . paths )
2024-06-30 16:36:46 +00:00
else :
2024-06-14 03:37:34 +00:00
# just interleave
self . paths = [ * _interleaved_reorder ( self . paths , self . get_speaker ) ]
2023-10-17 00:30:38 +00:00
# dict of speakers keyed by speaker group
self . spkrs_by_spkr_group = { }
for data_dir in self . dataset :
spkr = cfg . get_spkr ( data_dir / " dummy " )
spkr_group = cfg . get_spkr_group ( data_dir / " dummy " )
2023-10-22 14:01:47 +00:00
if spkr not in self . paths_by_spkr_name or len ( self . paths_by_spkr_name [ spkr ] ) < cfg . dataset . min_utterances :
continue
2023-10-17 00:30:38 +00:00
if spkr_group not in self . spkrs_by_spkr_group :
self . spkrs_by_spkr_group [ spkr_group ] = [ ]
self . spkrs_by_spkr_group [ spkr_group ] . append ( spkr )
self . spkr_groups = list ( self . spkrs_by_spkr_group . keys ( ) )
2023-12-21 00:45:58 +00:00
2023-08-27 00:53:23 +00:00
self . noise_paths = _load_paths ( cfg . dataset . noise , " noise " )
self . noise_paths = list ( itertools . chain . from_iterable ( self . noise_paths . values ( ) ) )
2023-08-02 21:53:35 +00:00
self . phone_symmap = phone_symmap or self . _get_phone_symmap ( )
2023-08-24 15:25:33 +00:00
self . spkr_symmap = self . _get_spkr_symmap ( )
2023-10-17 00:30:38 +00:00
self . spkr_group_symmap = self . _get_spkr_group_symmap ( )
2023-10-12 01:38:40 +00:00
self . lang_symmap = self . _get_lang_symmap ( )
2024-04-16 00:54:32 +00:00
self . tone_symmap = self . _get_tone_symmap ( )
2023-08-24 15:25:33 +00:00
self . task_symmap = self . _get_task_symmap ( )
2023-08-02 21:53:35 +00:00
2024-07-18 21:48:41 +00:00
# grab IDs for bos, space, and eos for easy input creation later
2024-11-22 17:29:12 +00:00
try :
self . empty_text = [ cfg . tokenizer . _bos_token , cfg . tokenizer . get_vocab ( ) [ " " ] , cfg . tokenizer . _eos_token ]
except Exception as e :
self . empty_text = [ None , None , None ]
2024-07-18 21:16:14 +00:00
2024-07-18 22:16:32 +00:00
# have it fetch at training time if any is invalid, because the tokenizer obj might not have it easily fetchable ahead of itme
# encoding before parallelizing things causes things to whine
if self . empty_text [ 0 ] is None or self . empty_text [ - 1 ] is None :
self . empty_text = None
2023-08-02 21:53:35 +00:00
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self . text_dtype = torch . uint8 if len ( self . phone_symmap ) < 256 else torch . int16
2023-10-17 00:30:38 +00:00
if len ( self . paths ) == 0 :
raise ValueError ( f " No valid path is found for { self . dataset_type } " )
2024-06-14 21:55:40 +00:00
2024-10-10 18:40:25 +00:00
if self . sampler_type == " path " and self . training :
2024-06-29 03:28:54 +00:00
if self . sampler_order == " duration " and cfg . dataset . sample_max_duration_batch > 0 :
2024-06-30 03:14:35 +00:00
self . sampler = BatchedOrderedSampler (
2024-06-30 16:36:46 +00:00
self . duration_buckets if not self . sampler_state_dict_path . exists ( ) else { } , # pass nothing if we're just going to load from a state anyways
max_duration = cfg . dataset . sample_max_duration_batch ,
2024-11-18 18:46:50 +00:00
max_batch_size = self . batch_size ,
2024-11-11 23:00:49 +00:00
shuffle = self . sampler_shuffle ,
2024-06-30 03:14:35 +00:00
)
2024-11-18 18:46:50 +00:00
self . batch_size = 1
2024-06-29 03:28:54 +00:00
else :
2024-11-12 00:16:56 +00:00
self . sampler = OrderedSampler ( len ( self ) ) if not self . sampler_shuffle else RandomSampler ( len ( self ) )
2024-06-14 21:55:40 +00:00
self . samplers = { }
self . spkr_samplers = { }
else :
2024-11-12 00:16:56 +00:00
self . sampler = RandomSampler ( len ( self ) )
self . samplers = { name : PoolSampler ( paths , keep_all = True , shuffle = self . sampler_shuffle ) for name , paths in self . paths_by_spkr_name . items ( ) }
self . spkr_samplers = { name : PoolSampler ( [ * set ( speakers ) ] , keep_all = True , shuffle = self . sampler_shuffle ) for name , speakers in self . spkrs_by_spkr_group . items ( ) }
2024-11-11 23:00:49 +00:00
# dereference buckets
self . duration_map = None
self . duration_buckets = None
2024-06-14 21:55:40 +00:00
2024-10-10 18:40:25 +00:00
self . load_state_dict ( )
2024-06-29 15:10:35 +00:00
@cached_property
def sampler_state_dict_path ( self ) :
2024-09-08 23:05:21 +00:00
return cfg . ckpt_dir / ( cfg . lora . full_name if cfg . lora is not None else cfg . model . full_name ) / f " sampler. { self . sampler_type } .rank { global_rank ( ) } .pt "
2024-06-14 21:55:40 +00:00
2023-08-27 00:53:23 +00:00
def get_speaker ( self , path ) :
if isinstance ( path , str ) :
path = Path ( path )
res = cfg . get_spkr ( path )
return res
2023-10-12 01:38:40 +00:00
def get_speaker_group ( self , path ) :
if isinstance ( path , str ) :
path = Path ( path )
res = cfg . get_spkr_group ( path )
return res
2024-09-19 00:36:03 +00:00
# this isn't really necessary since our data/metadata contains markers for languages, but this is still in in-case it's needed to force a language setting (for example, whisperX's lang isn't that accurate at times)
def get_language ( self , speaker_group , lang = " en " ) :
2023-10-12 01:38:40 +00:00
for k , v in cfg . dataset . speaker_languages . items ( ) :
if speaker_group in v :
lang = k
break
2024-09-19 00:36:03 +00:00
return lang . lower ( )
2023-10-12 01:38:40 +00:00
2023-08-02 21:53:35 +00:00
@cached_property
def spkrs ( self ) :
2023-08-27 00:53:23 +00:00
return sorted ( { self . get_speaker ( path ) for path in self . paths } )
2023-08-02 21:53:35 +00:00
2023-08-19 03:22:13 +00:00
@cached_property
def tasks ( self ) :
2023-08-19 05:16:08 +00:00
return cfg . dataset . tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
2023-08-19 03:22:13 +00:00
2024-06-14 21:55:40 +00:00
def save_state_dict ( self , path = None ) :
if path is None :
2024-06-30 16:36:46 +00:00
path = self . sampler_state_dict_path
2024-06-14 21:55:40 +00:00
2024-09-08 23:05:21 +00:00
if not path . parent . exists ( ) :
path . parent . mkdir ( parents = True , exist_ok = True )
2024-06-14 21:55:40 +00:00
if self . sampler_type == " path " :
state_dict = self . sampler . get_state ( )
else :
state_dict = {
" samplers " : { name : sampler . get_state ( ) for name , sampler in self . samplers . items ( ) } ,
" spkr_samplers " : { name : sampler . get_state ( ) for name , sampler in self . spkr_samplers . items ( ) } ,
}
2024-11-12 00:16:56 +00:00
if " dataset_hash_key " not in state_dict :
state_dict [ " dataset_hash_key " ] = self . dataset_hash_key
2024-08-04 04:15:20 +00:00
torch_save ( state_dict , path )
2023-09-04 02:27:13 +00:00
2024-06-14 21:55:40 +00:00
def load_state_dict ( self , path = None ) :
2024-10-10 18:40:25 +00:00
if not self . training :
return
2024-06-14 21:55:40 +00:00
if path is None :
2024-06-30 16:36:46 +00:00
path = self . sampler_state_dict_path
2023-09-04 02:27:13 +00:00
2024-06-14 21:55:40 +00:00
if not path . exists ( ) :
return
2024-08-04 04:15:20 +00:00
state_dict = torch_load ( path )
2024-11-12 00:16:56 +00:00
if " dataset_hash_key " in state_dict :
if self . dataset_hash_key != state_dict [ " dataset_hash_key " ] :
2024-11-13 15:09:28 +00:00
_logger . warning ( f ' Mismatched dataset hash key for { self . dataset_type } dataloader, ignoring loading of state dict. ' )
2024-11-12 00:16:56 +00:00
return
2024-11-11 23:00:49 +00:00
2024-06-14 21:55:40 +00:00
if self . sampler_type == " path " :
2024-06-15 17:29:03 +00:00
state_dict = self . sampler . set_state ( state_dict )
2024-06-14 21:55:40 +00:00
else :
2023-09-04 02:27:13 +00:00
for name , sampler in state_dict [ " samplers " ] . items ( ) :
if name not in self . samplers :
continue
2024-06-15 17:29:03 +00:00
self . samplers [ name ] . set_state ( sampler )
2024-06-14 21:55:40 +00:00
for name , sampler in state_dict [ " spkr_samplers " ] . items ( ) :
if name not in self . spkr_samplers :
continue
2024-06-15 17:29:03 +00:00
self . spkr_samplers [ name ] . set_state ( sampler )
2023-09-04 02:27:13 +00:00
2023-08-19 03:22:13 +00:00
def _get_phone_symmap ( self ) :
return get_phone_symmap ( )
2023-08-02 21:53:35 +00:00
def _get_spkr_symmap ( self ) :
return { s : i for i , s in enumerate ( self . spkrs ) }
2023-10-17 00:30:38 +00:00
def _get_spkr_group_symmap ( self ) :
return { s : i for i , s in enumerate ( self . spkr_groups ) }
2023-10-12 01:38:40 +00:00
def _get_lang_symmap ( self ) :
return get_lang_symmap ( )
2024-04-16 00:54:32 +00:00
def _get_tone_symmap ( self ) :
return get_tone_symmap ( )
2023-08-19 03:22:13 +00:00
def _get_task_symmap ( self ) :
return get_task_symmap ( )
2023-08-19 20:06:33 +00:00
def sample_noise ( self ) :
2023-08-27 00:53:23 +00:00
path = random . choice ( self . noise_paths )
2023-08-19 05:16:08 +00:00
2023-08-27 00:53:23 +00:00
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2023-08-19 20:06:33 +00:00
qnt = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : ] ) . to ( torch . int16 )
2023-08-19 05:16:08 +00:00
else :
2024-12-12 02:55:43 +00:00
qnt = _load_artifact ( path , return_metadata = False )
2023-08-19 05:16:08 +00:00
return qnt
2023-08-19 03:22:13 +00:00
2023-08-18 00:07:59 +00:00
def sample_speakers ( self , ignore = [ ] ) :
choices = set ( self . spkrs ) - set ( ignore )
return random . choice ( [ * choices ] )
2024-07-18 21:16:14 +00:00
def sample_utterance ( self , spkr_name , ignore = [ ] ) :
choices = [ * ( set ( self . paths_by_spkr_name [ spkr_name ] ) - set ( ignore ) ) ]
if len ( choices ) == 0 :
return None , None , None
path = random . choice ( choices )
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
if key not in cfg . hdf5 :
raise RuntimeError ( f ' Key of Path ( { path } ) not in HDF5: { key } ' )
2024-08-03 03:25:49 +00:00
#metadata = cfg.hdf5[key].attrs
metadata = { f ' { k } ' : f ' { v } ' for k , v in cfg . hdf5 [ key ] . attrs . items ( ) }
2024-07-18 21:16:14 +00:00
text = cfg . hdf5 [ key ] [ " text " ] [ : ]
resps = cfg . hdf5 [ key ] [ " audio " ] [ : , : ]
text = torch . from_numpy ( text ) . to ( self . text_dtype )
resps = torch . from_numpy ( resps ) . to ( torch . int16 )
"""
lang = metadata [ " language " ] if " language " in metadata else None
tone = metadata [ " tone " ] if " tone " in metadata else None
"""
else :
2024-12-12 02:55:43 +00:00
resps , metadata = _load_artifact ( path , return_metadata = True )
2024-07-18 21:16:14 +00:00
text = torch . tensor ( tokenize ( metadata [ " phonemes " ] ) ) . to ( self . text_dtype )
"""
lang = metadata [ " language " ] if " language " in metadata else None
tone = metadata [ " tone " ] if " tone " in metadata else None
"""
return path , text , resps
2024-09-21 17:59:51 +00:00
# icky slop
2024-09-26 21:26:40 +00:00
def get_similar_utterance ( self , path , offset = None ) :
if offset is None :
offset = cfg . dataset . prompt_similar_top_k_offset
2024-09-21 17:59:51 +00:00
reference = path . name
if cfg . dataset . use_hdf5 :
root = Path ( * path . parts [ : - 1 ] )
path = Path ( * path . parts [ 2 : - 1 ] )
else :
root = Path ( * path . parts [ : - 1 ] )
path = Path ( * path . parts [ len ( cfg . data_dir . parts ) : - 1 ] )
metadata = json_read ( cfg . metadata_dir / path . with_suffix ( " .json " ) , default = { } )
2024-09-17 04:10:29 +00:00
if reference not in metadata :
return None
2024-09-21 17:19:34 +00:00
2024-09-17 04:10:29 +00:00
reference_metadata = metadata [ reference ]
2024-09-21 17:59:51 +00:00
2024-09-17 04:10:29 +00:00
if " similar " not in reference_metadata :
return None
2024-09-21 17:59:51 +00:00
2024-09-17 04:10:29 +00:00
if len ( reference_metadata [ " similar " ] ) > = offset :
2024-09-21 17:19:34 +00:00
offset = 0
2024-09-21 17:59:51 +00:00
2024-09-18 03:57:04 +00:00
metadata_keys = list ( metadata . keys ( ) )
2024-09-26 21:26:40 +00:00
if cfg . dataset . prompt_similar_top_k > 1 :
indices = reference_metadata [ " similar " ] [ offset : offset + cfg . dataset . prompt_similar_top_k ]
index = random . choice ( indices )
else :
index = reference_metadata [ " similar " ] [ offset ]
2024-09-21 17:59:51 +00:00
name = metadata_keys [ index ]
return root / name
2024-07-18 21:16:14 +00:00
2024-09-17 04:10:29 +00:00
def sample_prompts ( self , spkr_name , reference , should_trim = True ) :
2024-12-11 02:13:21 +00:00
# return no prompt if explicitly requested for who knows why
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
2024-12-12 01:10:32 +00:00
if not cfg . dataset . prompt_duration_range or cfg . dataset . prompt_duration_range [ - 1 ] == 0 or len ( self . paths_by_spkr_name [ spkr_name ] ) < = 1 :
2024-07-23 00:36:07 +00:00
return None
2023-08-02 21:53:35 +00:00
prom_list = [ ]
2024-09-17 04:10:29 +00:00
choices = set ( self . paths_by_spkr_name [ spkr_name ] ) - { reference }
2023-08-02 21:53:35 +00:00
choices = [ * choices ]
2024-04-16 00:54:32 +00:00
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
2023-08-02 21:53:35 +00:00
if len ( choices ) == 0 :
2023-08-17 00:39:21 +00:00
choices = [ * set ( self . paths_by_spkr_name [ spkr_name ] ) ]
"""
2023-08-02 21:53:35 +00:00
raise ValueError (
f " Failed to find another different utterance for { spkr_name } . "
)
2023-08-17 00:39:21 +00:00
"""
2023-08-02 21:53:35 +00:00
2023-08-19 04:55:40 +00:00
prom_length = 0
2024-10-18 14:40:06 +00:00
duration_lo , duration_hi = cfg . dataset . prompt_duration_range
trim_length = int ( random . uniform ( duration_lo , duration_hi ) * cfg . dataset . frames_per_second ) if trim else 0
2023-08-19 04:55:40 +00:00
2024-10-17 22:06:48 +00:00
for _ in range ( cfg . dataset . prompt_max_samples ) :
if reference is not None :
2024-09-17 04:10:29 +00:00
# yuck
2024-10-17 22:06:48 +00:00
path = None
if random . random ( ) < cfg . dataset . prompt_similar_p :
path = self . get_similar_utterance ( reference , offset = len ( prom_list ) )
2024-09-17 04:10:29 +00:00
if not path :
path = random . choice ( choices )
else :
path = random . choice ( choices )
2023-08-02 21:53:35 +00:00
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2023-08-19 20:06:33 +00:00
qnt = torch . from_numpy ( cfg . hdf5 [ key ] [ " audio " ] [ : , : ] ) . to ( torch . int16 )
2023-08-02 21:53:35 +00:00
else :
2024-12-12 02:55:43 +00:00
qnt = _load_artifact ( path , return_metadata = False )
2023-08-02 21:53:35 +00:00
2024-05-11 14:50:54 +00:00
if 0 < trim_length and trim_length < qnt . shape [ 0 ] :
2024-07-23 00:36:07 +00:00
qnt = trim ( qnt , trim_length , reencode = cfg . dataset . reencode_on_concat , device = cfg . dataset . reencode_device )
2023-08-02 21:53:35 +00:00
prom_list . append ( qnt )
2023-08-19 04:55:40 +00:00
prom_length + = qnt . shape [ 0 ]
2023-08-02 21:53:35 +00:00
2024-10-17 22:06:48 +00:00
if prom_length > = trim_length :
2023-08-02 21:53:35 +00:00
break
2023-10-11 22:32:45 +00:00
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
2024-07-23 00:36:07 +00:00
prom = concat_audio ( * prom_list , reencode = cfg . dataset . reencode_on_concat , device = cfg . dataset . reencode_device )
2023-08-02 21:53:35 +00:00
2024-05-11 14:50:54 +00:00
if 0 < trim_length and trim_length < prom . shape [ 0 ] :
2024-07-23 00:36:07 +00:00
prom = trim ( prom , trim_length , reencode = cfg . dataset . reencode_on_concat , device = cfg . dataset . reencode_device )
2023-08-02 21:53:35 +00:00
return prom
def __getitem__ ( self , index ) :
2024-11-18 18:46:50 +00:00
self . current_index = index
2024-07-18 22:16:32 +00:00
if self . empty_text is None :
self . empty_text = tokenize ( " " )
2024-07-18 21:16:14 +00:00
bos_id , space_id , eos_id = self . empty_text
2023-10-22 14:01:47 +00:00
if self . sampler_type == " group " :
2023-10-17 00:30:38 +00:00
spkr_group = self . spkr_groups [ index ]
2023-10-22 14:01:47 +00:00
#spkr_group_id = self.spkr_group_symmap[spkr_group]
2023-10-17 00:30:38 +00:00
spkr_name = self . spkr_samplers [ spkr_group ] . sample ( )
2023-10-19 01:38:33 +00:00
spkr_id = self . spkr_symmap [ spkr_name ]
path = self . samplers [ spkr_name ] . sample ( )
2023-10-22 14:01:47 +00:00
elif self . sampler_type == " speaker " :
2023-08-17 00:39:21 +00:00
spkr_name = self . spkrs [ index ]
spkr_id = self . spkr_symmap [ spkr_name ]
2023-09-04 02:27:13 +00:00
path = self . samplers [ spkr_name ] . sample ( )
2023-10-17 00:30:38 +00:00
spkr_group = self . get_speaker_group ( path )
2023-10-22 14:01:47 +00:00
#spkr_group_id = self.spkr_group_symmap[spkr_group]
2023-08-02 21:53:35 +00:00
else :
2023-08-18 19:47:48 +00:00
path = self . paths [ index ]
2023-08-27 00:53:23 +00:00
spkr_name = self . get_speaker ( path )
2023-08-17 00:39:21 +00:00
spkr_id = self . spkr_symmap [ spkr_name ]
2023-10-17 00:30:38 +00:00
spkr_group = self . get_speaker_group ( path )
2023-10-22 14:01:47 +00:00
#spkr_group_id = self.spkr_group_symmap[spkr_group]
2023-08-02 21:53:35 +00:00
2024-09-26 23:37:56 +00:00
if not isinstance ( path , Path ) :
path = Path ( path )
2023-08-02 21:53:35 +00:00
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
2023-10-12 01:38:40 +00:00
2024-04-29 03:28:29 +00:00
if key not in cfg . hdf5 :
raise RuntimeError ( f ' Key of Path ( { path } ) not in HDF5: { key } ' )
2024-08-03 03:25:49 +00:00
# I need to do some weird coersion to a normal dict because it'll bitch about Hdf5 objects not being pickleable in worker processes
metadata = { f ' { k } ' : f ' { v } ' for k , v in cfg . hdf5 [ key ] . attrs . items ( ) }
2024-07-18 21:16:14 +00:00
2023-10-11 00:18:24 +00:00
text = cfg . hdf5 [ key ] [ " text " ] [ : ]
resps = cfg . hdf5 [ key ] [ " audio " ] [ : , : ]
text = torch . from_numpy ( text ) . to ( self . text_dtype )
resps = torch . from_numpy ( resps ) . to ( torch . int16 )
2024-07-18 21:16:14 +00:00
lang = metadata [ " language " ] if " language " in metadata else None
tone = metadata [ " tone " ] if " tone " in metadata else None
2024-07-20 01:49:40 +00:00
text_string = metadata [ " text " ] if " text " in metadata else None
2024-09-21 18:08:01 +00:00
if cfg . dataset . retokenize_text and " phonemes " in metadata :
text = torch . tensor ( tokenize ( metadata [ " phonemes " ] ) ) . to ( self . text_dtype )
2023-08-02 21:53:35 +00:00
else :
2024-12-12 02:55:43 +00:00
resps , metadata = _load_artifact ( path , return_metadata = True )
2024-05-18 12:14:26 +00:00
text = torch . tensor ( tokenize ( metadata [ " phonemes " ] ) ) . to ( self . text_dtype )
2023-10-11 22:32:45 +00:00
2024-07-18 21:16:14 +00:00
lang = metadata [ " language " ] if " language " in metadata else None
tone = metadata [ " tone " ] if " tone " in metadata else None
2024-07-20 01:49:40 +00:00
text_string = metadata [ " text " ] if " text " in metadata else None
2024-07-18 21:16:14 +00:00
2024-09-19 00:36:03 +00:00
lang = self . get_language ( spkr_group ) if not lang else lang . lower ( )
2024-07-18 21:16:14 +00:00
if not tone :
tone = " neutral "
lang = torch . tensor ( [ self . lang_symmap [ lang ] ] ) . to ( torch . uint8 )
tone = torch . tensor ( [ self . tone_symmap [ tone ] ] ) . to ( torch . uint8 )
2024-07-18 21:48:41 +00:00
# a bool to easily experiment with two mindsets later
naive = cfg . experimental
2023-10-12 01:38:40 +00:00
2024-07-18 21:16:14 +00:00
# append additional prompts in an attempt to artifically increase lengths / offer new data
2024-10-17 22:06:48 +00:00
if cfg . dataset . resps_max_samples > 1 and random . random ( ) < cfg . dataset . resps_append_p :
2024-07-18 21:16:14 +00:00
ignore_paths = [ ]
2024-10-17 22:06:48 +00:00
for _ in range ( 1 , cfg . dataset . resps_max_samples ) :
2024-07-18 21:16:14 +00:00
path , txt , qnt = self . sample_utterance ( spkr_name , ignore = ignore_paths )
ignore_paths . append ( path )
# <s>[original text]</s><s>[new text]</s>
if naive :
text = torch . concat ( [ text , txt ] )
# <s>[original text] [new text]</s>
# removes the original text's </s>, includes a space, and remove the new text's <s>
else :
2023-10-11 22:32:45 +00:00
text = torch . concat ( [ text [ : - 1 ] , torch . tensor ( [ self . phone_symmap [ " " ] ] ) . to ( torch . int16 ) , txt [ 1 : ] ] )
2024-07-18 21:16:14 +00:00
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
2024-07-18 21:48:41 +00:00
resps = concat_audio ( resps , qnt , reencode = cfg . dataset . reencode_on_concat , device = cfg . dataset . reencode_device )
2024-07-18 21:16:14 +00:00
2024-07-18 22:16:32 +00:00
task = random . choice ( self . tasks )
if f ' < { task } > ' not in self . task_symmap :
raise Exception ( f ' Task not defined: { task } ' )
2023-10-09 18:01:40 +00:00
2024-07-19 04:25:32 +00:00
# Base TTS (<text><prompt> => <resp>)
2024-07-18 21:16:14 +00:00
if task == " tts " :
2024-09-17 04:10:29 +00:00
proms = self . sample_prompts ( spkr_name , reference = path )
2023-10-11 22:32:45 +00:00
2024-10-18 14:40:06 +00:00
if cfg . dataset . prompt_inject_noise :
2024-07-23 00:36:07 +00:00
# sample random noise
noise = self . sample_noise ( )
# extend the noise to fill the target audio
noise = repeat_extend_audio ( noise , proms . shape [ 0 ] )
# create the input prompt by merging the target audio with the noise
proms = merge_audio ( proms , noise , scale = [ 1 , cfg . dataset . noise_scale ] , device = cfg . dataset . reencode_device )
2024-07-19 04:25:32 +00:00
# VALL-E Continuous (<text><partial resp> => <remaining resp> )
# (this could just be sampled as <text a><text b><audio a> => <audio b>, but I need to experiment with it)
2024-07-18 21:16:14 +00:00
elif task == " tts-c " :
# trim a piece of the output response
if naive :
2024-10-18 14:40:06 +00:00
duration_lo , duration_hi = cfg . dataset . prompt_duration_range
trim_length = int ( random . uniform ( duration_lo , duration_hi ) * cfg . dataset . frames_per_second )
2023-09-02 21:29:53 +00:00
2023-09-01 22:19:34 +00:00
proms = resps [ : trim_length , : ]
resps = resps [ trim_length : , : ]
else :
2024-07-18 21:16:14 +00:00
path , txt , qnt = self . sample_utterance ( spkr_name )
# <s>[original text]</s><s>[new text]</s>
if naive :
text = torch . concat ( [ text , txt ] )
# <s>[original text] [new text]</s>
# removes the original text's </s>, includes a space, and remove the new text's <s>
else :
2024-07-19 04:25:32 +00:00
text = torch . concat ( [ text [ : - 1 ] , torch . tensor ( [ space_id ] ) . to ( torch . int16 ) , txt [ 1 : ] ] )
2024-07-18 21:16:14 +00:00
# set prompt as initial response
proms = resps
# set target as newly sampled response
resps = qnt
2024-07-19 04:25:32 +00:00
# inject task token
proms = [
proms ,
task ,
]
2024-09-06 04:21:18 +00:00
# Base STT (<resp> => <text>)
2024-09-06 01:43:20 +00:00
elif task == " stt " :
2024-09-06 04:21:18 +00:00
proms = [
task
]
2024-09-06 01:43:20 +00:00
2024-11-10 18:19:48 +00:00
# Duration prediction (<text><prompt> => len(<resp>))
elif task == " len " :
proms = self . sample_prompts ( spkr_name , reference = path )
2024-07-19 04:25:32 +00:00
# noise suppression (<text>? <resp+noise> => <resp>)
# speech removal (<text>?<resp+noise> => <noise>)
2023-08-19 03:22:13 +00:00
elif task == " ns " or task == " sr " :
# sample random noise
2023-08-18 00:07:59 +00:00
noise = self . sample_noise ( )
2023-08-19 03:22:13 +00:00
# extend the noise to fill the target audio
noise = repeat_extend_audio ( noise , resps . shape [ 0 ] )
# create the input prompt by merging the target audio with the noise
2024-07-18 22:16:32 +00:00
proms = merge_audio ( resps , noise , scale = [ 1 , cfg . dataset . noise_scale ] , device = cfg . dataset . reencode_device )
2024-07-19 04:25:32 +00:00
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
text = None
# inject task token
proms = [
task ,
proms
]
2023-08-19 03:22:13 +00:00
# set the target to just be the noise if <sr>
if task == " sr " :
resps = noise
2023-08-19 04:55:40 +00:00
2024-07-18 21:16:14 +00:00
2024-07-19 04:25:32 +00:00
# target speech extraction ( <text><prom><resp + other resp> => <resp> )
2023-08-18 19:47:48 +00:00
elif task == " tse " :
2024-07-19 04:25:32 +00:00
# sample a prompt
2024-09-17 04:10:29 +00:00
proms = self . sample_prompts ( spkr_name , reference = path )
2024-07-19 04:25:32 +00:00
# sample another speaker
_ , __ , other_resps = self . sample_utterance ( self . sample_speakers ( ignore = [ spkr_name ] ) )
2023-08-19 03:22:13 +00:00
# overlay the random speaker over the target audio
2024-07-19 04:25:32 +00:00
other_resps = merge_audio ( resps , other_resps , scale = [ 1 , random . uniform ( 0.5 , 0.75 ) ] , device = cfg . dataset . reencode_device )
# set the text prompt to empty to train without a guided text prompt
if random . random ( ) < 0.5 :
text = None
2023-08-19 04:55:40 +00:00
2024-07-18 21:16:14 +00:00
# stitch together the proms
proms = [
2024-07-19 04:25:32 +00:00
proms ,
2024-07-18 21:16:14 +00:00
task ,
2024-07-19 04:25:32 +00:00
other_resps ,
2024-07-18 21:16:14 +00:00
]
2023-08-19 04:55:40 +00:00
2023-08-19 03:22:13 +00:00
# clean speech editing
2023-08-19 06:16:46 +00:00
elif task == " cse " or task == " nse " :
2024-07-18 21:16:14 +00:00
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
# instead we'll just sample a bunch of utterances
2023-08-19 06:16:46 +00:00
2024-07-18 21:16:14 +00:00
samples = [ ]
for _ in range ( 4 ) :
sampled = self . sample_utterance ( spkr_name , ignore = [ s [ 0 ] for s in samples ] )
samples . append ( sampled )
2023-08-19 06:16:46 +00:00
2024-07-18 21:16:14 +00:00
pre_text , mid_text , post_text , edit_text = [ s [ 1 ] [ 1 : - 1 ] for s in samples ]
pre_prom , mid_prom , post_prom , edit_prom = [ s [ 2 ] for s in samples ]
2023-08-19 06:16:46 +00:00
# randomly drop out pre
if random . random ( ) < 0.125 :
pre_text = None
pre_prom = None
# randomly drop out post
2024-07-19 04:25:32 +00:00
elif random . random ( ) < 0.125 :
2023-08-19 06:16:46 +00:00
post_text = None
post_prom = None
# create new text
2024-07-19 04:25:32 +00:00
text = concat_audio (
2024-08-04 03:10:21 +00:00
torch . tensor ( [ bos_id ] ) . to ( dtype = self . text_dtype ) , # <s>
2024-07-19 04:25:32 +00:00
pre_text ,
2024-08-04 03:10:21 +00:00
None if pre_text is None else torch . tensor ( [ space_id ] ) . to ( dtype = self . text_dtype ) , # " "
2024-07-19 04:25:32 +00:00
edit_text ,
2024-08-04 03:10:21 +00:00
None if post_text is None else torch . tensor ( [ space_id ] ) . to ( dtype = self . text_dtype ) , # " "
2024-07-19 04:25:32 +00:00
post_text ,
2024-08-04 03:10:21 +00:00
torch . tensor ( [ eos_id ] ) . to ( dtype = self . text_dtype ) , # </s>
2024-07-19 04:25:32 +00:00
reencode = False ,
2023-08-19 06:16:46 +00:00
)
if task == " nse " :
# sample random noise
noise = self . sample_noise ( )
# it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent
# but it's noise, it's supposed to be random
2023-08-19 20:06:33 +00:00
def noise_proms ( p ) :
2023-08-19 06:16:46 +00:00
# ignore if we turned it off
2023-08-19 20:06:33 +00:00
if p is None :
2023-08-19 06:16:46 +00:00
return None
# extend the noise to fill the target audio
2023-08-19 20:06:33 +00:00
n = repeat_extend_audio ( noise , p . shape [ 0 ] )
2023-08-19 06:16:46 +00:00
# merge the noise over the utterance
2024-07-18 22:16:32 +00:00
return merge_audio ( p , n , scale = [ 1 , cfg . dataset . noise_scale ] , device = cfg . dataset . reencode_device )
2023-08-19 06:16:46 +00:00
# apply noise to all pieces
pre_prom = noise_proms ( pre_prom )
mid_prom = noise_proms ( mid_prom )
post_prom = noise_proms ( post_prom )
edit_prom = noise_proms ( edit_prom )
2024-07-18 21:16:14 +00:00
# create new prom
proms = [
pre_prom ,
2024-07-19 04:25:32 +00:00
" soe " ,
" mask " if task == " cse " else mid_prom ,
" eoe " ,
2024-07-18 21:16:14 +00:00
post_prom ,
]
2023-08-19 06:16:46 +00:00
# create new resp
2024-07-18 21:48:41 +00:00
resps = concat_audio (
2024-07-19 04:25:32 +00:00
pre_prom ,
edit_prom ,
post_prom ,
2024-07-18 21:48:41 +00:00
reencode = cfg . dataset . reencode_on_concat ,
device = cfg . dataset . reencode_device ,
2023-08-19 06:16:46 +00:00
)
2023-08-19 20:06:33 +00:00
else :
2023-09-02 17:23:40 +00:00
raise Exception ( f ' Undefined task: { task } ' )
2023-08-02 21:53:35 +00:00
2024-07-19 04:25:32 +00:00
if text is None :
text = torch . tensor ( [ bos_id , eos_id ] ) . to ( self . text_dtype )
2024-08-07 01:23:33 +00:00
# pad the target with silence
2024-10-17 22:06:48 +00:00
if random . random ( ) < cfg . dataset . resps_pad_silence_p :
2024-08-07 01:23:33 +00:00
resps = pad_codes_with_silence ( resps )
2023-08-02 21:53:35 +00:00
return dict (
index = index ,
2023-08-27 17:26:12 +00:00
path = Path ( path ) ,
2023-08-02 21:53:35 +00:00
spkr_name = spkr_name ,
spkr_id = spkr_id ,
2023-08-17 23:56:37 +00:00
task = task ,
2023-10-12 01:38:40 +00:00
lang = lang ,
2024-07-18 21:16:14 +00:00
tone = tone ,
2023-08-02 21:53:35 +00:00
text = text ,
proms = proms ,
resps = resps ,
2024-07-22 00:12:03 +00:00
metadata = metadata ,
2023-08-02 21:53:35 +00:00
)
def head_ ( self , n ) :
self . _head = n
def training_ ( self , value ) :
self . training = value
2024-11-13 04:30:09 +00:00
def index ( self ) :
2024-11-18 18:46:50 +00:00
return ( self . sampler . index ( ) if self . sampler is not None else - 1 ) / / self . batch_size
2024-11-23 15:45:23 +00:00
def batches ( self ) :
if isinstance ( self . sampler , BatchedOrderedSampler ) :
return len ( self . sampler )
return len ( self . sampler if self . sampler is not None else self ) / / self . batch_size
2024-11-13 04:30:09 +00:00
2023-08-02 21:53:35 +00:00
def __len__ ( self ) :
2023-10-22 14:01:47 +00:00
if self . sampler_type == " group " :
2023-10-17 00:30:38 +00:00
return min ( len ( self . spkr_groups ) , self . _head or len ( self . spkr_groups ) )
2023-10-22 14:01:47 +00:00
if self . sampler_type == " speaker " :
2023-08-17 00:39:21 +00:00
return min ( len ( self . spkrs ) , self . _head or len ( self . spkrs ) )
2023-08-02 21:53:35 +00:00
return min ( len ( self . paths ) , self . _head or len ( self . paths ) )
def collate_fn ( samples : list [ dict ] ) :
batch : dict [ str , Any ] = { k : [ s [ k ] for s in samples ] for k in samples [ 0 ] }
return batch
def _seed_worker ( worker_id ) :
worker_seed = torch . initial_seed ( ) % 2 * * 32
np . random . seed ( worker_seed )
random . seed ( worker_seed )
def _create_dataloader ( dataset , training ) :
2024-06-29 03:28:54 +00:00
kwargs = dict (
2024-09-21 17:19:34 +00:00
shuffle = not training ,
2024-06-29 03:28:54 +00:00
batch_size = cfg . hyperparameters . batch_size if training else cfg . evaluation . batch_size ,
2023-08-02 21:53:35 +00:00
drop_last = training ,
2024-07-27 20:36:05 +00:00
sampler = dataset . sampler if training else None ,
2024-06-29 03:28:54 +00:00
) if not isinstance ( dataset . sampler , BatchedOrderedSampler ) else dict (
batch_sampler = dataset . sampler ,
)
return DataLoader (
dataset = dataset ,
2023-08-02 21:53:35 +00:00
num_workers = cfg . dataset . workers ,
collate_fn = collate_fn ,
2023-08-24 22:05:56 +00:00
persistent_workers = cfg . dataset . workers > 1 ,
2024-06-30 16:36:46 +00:00
pin_memory = False ,
2023-08-02 21:53:35 +00:00
worker_init_fn = _seed_worker ,
2024-06-29 03:28:54 +00:00
* * kwargs ,
2023-08-02 21:53:35 +00:00
)
def create_datasets ( ) :
2023-08-27 00:53:23 +00:00
train_dataset = Dataset ( training = True )
val_dataset = Dataset ( phone_symmap = train_dataset . phone_symmap , training = False )
2023-08-02 21:53:35 +00:00
return train_dataset , val_dataset
2024-07-20 01:49:40 +00:00
def create_train_dataloader ( ) :
train_dataset = Dataset ( training = True )
train_dl = _create_dataloader ( train_dataset , training = True )
_logger . info ( str ( train_dataset . phone_symmap ) )
_logger . info ( str ( train_dataset . spkr_symmap ) )
_logger . info ( str ( train_dataset . spkr_group_symmap ) )
_logger . info ( f " #samples (train): { len ( train_dataset ) } . " )
_logger . info ( f " #duration (train): { str ( train_dataset . duration ) } . " )
2024-11-11 22:32:08 +00:00
# remove duration map (it gets bloated)
_durations_map = { }
2024-07-20 01:49:40 +00:00
return train_dl
def create_val_dataloader ( ) :
val_dataset = Dataset ( training = False )
val_dl = _create_dataloader ( val_dataset , training = False )
_logger . info ( str ( val_dataset . phone_symmap ) )
_logger . info ( str ( val_dataset . spkr_symmap ) )
_logger . info ( str ( val_dataset . spkr_group_symmap ) )
_logger . info ( f " #samples (val): { len ( val_dataset ) } . " )
_logger . info ( f " #duration (val): { str ( val_dataset . duration ) } . " )
2024-11-11 22:32:08 +00:00
# remove duration map (it gets bloated)
_durations_map = { }
2024-07-20 01:49:40 +00:00
return val_dl
2023-08-02 21:53:35 +00:00
2024-11-11 22:32:08 +00:00
# to-do, use the above two, then create the subtrain dataset
2023-08-02 21:53:35 +00:00
def create_train_val_dataloader ( ) :
train_dataset , val_dataset = create_datasets ( )
train_dl = _create_dataloader ( train_dataset , training = True )
val_dl = _create_dataloader ( val_dataset , training = False )
_logger . info ( str ( train_dataset . phone_symmap ) )
2024-11-18 18:46:50 +00:00
_logger . info ( f ' #speakers (train): { len ( train_dataset . spkr_symmap ) } ' )
_logger . info ( f ' #groups (train): { len ( train_dataset . spkr_group_symmap ) } ' )
2023-08-02 21:53:35 +00:00
_logger . info ( f " #samples (train): { len ( train_dataset ) } . " )
_logger . info ( f " #samples (val): { len ( val_dataset ) } . " )
_logger . info ( f " #duration (train): { str ( train_dataset . duration ) } . " )
_logger . info ( f " #duration (val): { str ( val_dataset . duration ) } . " )
2024-11-11 22:32:08 +00:00
# remove duration map (it gets bloated)
_durations_map = { }
2024-11-11 23:00:49 +00:00
return train_dl , val_dl
2023-08-02 21:53:35 +00:00
2024-09-18 21:43:57 +00:00
# parse metadata from an numpy file (.enc/.dac) and validate it
2024-09-09 14:57:32 +00:00
def process_artifact_metadata ( artifact ) :
metadata = { }
2024-09-18 21:43:57 +00:00
# text transcription (just in case)
2024-09-09 14:57:32 +00:00
if " text " in artifact [ " metadata " ] :
metadata [ " text " ] = artifact [ " metadata " ] [ " text " ]
2024-09-18 21:43:57 +00:00
# phonemization of text transcription (just in case)
2024-09-09 14:57:32 +00:00
if " phonemes " in artifact [ " metadata " ] :
metadata [ " phonemes " ] = artifact [ " metadata " ] [ " phonemes " ]
2024-09-18 21:43:57 +00:00
# language for sampling / input creation
2024-09-09 14:57:32 +00:00
if " language " in artifact [ " metadata " ] :
metadata [ " language " ] = artifact [ " metadata " ] [ " language " ]
2024-09-18 21:43:57 +00:00
# top-k similar utterances for this utternace
if " similar " in artifact [ " metadata " ] :
metadata [ " similar " ] = artifact [ " metadata " ] [ " similar " ]
# duration for use of culling / sorting dataset
if " duration " in artifact [ " metadata " ] :
2024-10-12 02:18:26 +00:00
metadata [ " duration " ] = float ( artifact [ " metadata " ] [ " duration " ] )
2024-09-18 21:43:57 +00:00
# derive duration from sample count / sample rate
elif " original_length " in artifact [ " metadata " ] and " sample_rate " in artifact [ " metadata " ] :
2024-09-09 14:57:32 +00:00
metadata [ " duration " ] = artifact [ " metadata " ] [ " original_length " ] / artifact [ " metadata " ] [ " sample_rate " ]
# rephonemize if required
if " phonemes " not in metadata and " text " in metadata :
metadata [ " phonemes " ] = encode_phns ( metadata [ " text " ] , language = metadata [ " language " ] if " language " in metadata [ " language " ] else " en " )
# clean up phonemes from espeak
# for example: Sonnenküste Update => zˈ ɔnənkˌystə (en)ˈ ʌpdeɪ t(de)
# to-do: regex replace /([a-z]{2})/ to ""
if " phonemes " in metadata :
metadata [ " phonemes " ] = metadata [ " phonemes " ] . replace ( " (en) " , " " )
2024-09-21 17:29:28 +00:00
if " language " in metadata :
metadata [ " phonemes " ] = metadata [ " phonemes " ] . replace ( f " ( { metadata [ ' language ' ] } ) " , " " )
metadata [ " phonemes " ] = re . sub ( r ' \ ([a-z] {2} \ ) ' , " " , metadata [ " phonemes " ] )
2024-09-09 14:57:32 +00:00
return metadata
2024-09-18 21:43:57 +00:00
# yucky, but I would like to have the LibriTTS-R utterances remapped to their LibriSpeech counterpart
# to-do: allow this to be adjusted without having to regenerate metadata / HDF5 by remapping name during dataloader creation
def remap_speaker_name ( name ) :
# commented out because I don't want the LibriSpeech portion of the dataset to get added
"""
if " LibriTTS-R " in speaker_name :
name = name . replace ( " LibriTTS-R " , " LibriVox " )
"""
return name
2023-08-27 00:53:23 +00:00
# parse dataset into better to sample metadata
2024-11-11 22:32:08 +00:00
def create_dataset_metadata ( skip_existing = False ) :
2024-04-29 04:03:09 +00:00
symmap = get_phone_symmap ( )
root = str ( cfg . data_dir )
metadata_root = str ( cfg . metadata_dir )
2024-04-29 03:28:29 +00:00
2024-04-29 04:03:09 +00:00
cfg . metadata_dir . mkdir ( parents = True , exist_ok = True )
2023-08-27 00:53:23 +00:00
2024-04-29 04:03:09 +00:00
def add ( dir , type = " training " , audios = True , texts = True ) :
name = str ( dir )
name = name . replace ( root , " " )
2024-09-18 21:43:57 +00:00
speaker_name = remap_speaker_name ( name )
2024-05-08 07:11:38 +00:00
metadata_path = Path ( f " { metadata_root } / { speaker_name } .json " )
metadata_path . parents [ 0 ] . mkdir ( parents = True , exist_ok = True )
2023-08-27 00:53:23 +00:00
2024-09-18 21:43:57 +00:00
metadata = json_read ( metadata_path , default = { } )
2023-08-27 00:53:23 +00:00
2024-04-29 04:03:09 +00:00
if not os . path . isdir ( f ' { root } / { name } / ' ) :
return
2024-08-19 01:51:14 +00:00
2024-04-29 04:03:09 +00:00
files = os . listdir ( f ' { root } / { name } / ' )
2023-10-17 00:30:38 +00:00
2024-04-29 04:03:09 +00:00
# grab IDs for every file
2024-12-12 02:55:43 +00:00
ids = { file . replace ( _get_artifact_extension ( ) , " " ) . replace ( _get_metadata_extension ( ) , " " ) for file in files }
2023-08-27 00:53:23 +00:00
2024-08-19 01:51:14 +00:00
wrote = False
2024-09-19 02:34:43 +00:00
for id in tqdm ( ids , desc = f " Processing { name } " , disable = True ) :
2024-04-29 04:03:09 +00:00
try :
2024-12-12 02:55:43 +00:00
quant_path = Path ( f ' { root } / { name } / { id } { _get_artifact_extension ( ) } ' )
2023-08-27 00:53:23 +00:00
2024-08-19 01:51:14 +00:00
if audios and not quant_path . exists ( ) :
2024-04-29 04:03:09 +00:00
continue
2023-08-27 00:53:23 +00:00
2024-05-08 07:11:38 +00:00
key = f ' { type } / { speaker_name } / { id } '
2024-04-29 04:03:09 +00:00
2024-05-08 07:11:38 +00:00
if skip_existing and id in metadata :
2024-04-29 04:03:09 +00:00
continue
2024-08-19 01:51:14 +00:00
wrote = True
2024-04-29 04:03:09 +00:00
if id not in metadata :
metadata [ id ] = { }
2024-05-16 04:04:19 +00:00
utterance_metadata = { }
2024-04-29 04:03:09 +00:00
if audios :
2024-09-09 14:57:32 +00:00
artifact = np . load ( quant_path , allow_pickle = True ) [ ( ) ]
qnt = torch . from_numpy ( artifact [ " codes " ] . astype ( int ) ) [ 0 ] . t ( ) . to ( dtype = torch . int16 )
utterance_metadata = process_artifact_metadata ( artifact )
2024-11-11 22:32:08 +00:00
# to-do: derive duration from codes if duration is malformed because this happened to me with LibriTTS-R
2024-11-12 18:49:53 +00:00
utterance_metadata [ " duration " ] = qnt . shape [ 0 ] / cfg . dataset . frames_per_second
2024-04-29 04:03:09 +00:00
2024-05-16 04:04:19 +00:00
for k , v in utterance_metadata . items ( ) :
metadata [ id ] [ k ] = v
2024-04-29 04:03:09 +00:00
except Exception as e :
2024-05-16 04:04:19 +00:00
tqdm . write ( f ' Error while processing { id } : { e } ' )
2024-04-29 04:03:09 +00:00
2024-08-19 01:51:14 +00:00
if wrote :
2024-09-18 21:43:57 +00:00
json_write ( metadata , metadata_path )
2024-04-29 04:03:09 +00:00
# training
for data_dir in tqdm ( sorted ( cfg . dataset . training ) , desc = " Processing Training " ) :
add ( data_dir , type = " training " )
# validation
for data_dir in tqdm ( sorted ( cfg . dataset . validation ) , desc = ' Processing Validation ' ) :
add ( data_dir , type = " validation " )
# noise
for data_dir in tqdm ( sorted ( cfg . dataset . noise ) , desc = ' Processing Noise ' ) :
add ( data_dir , type = " noise " , texts = False )
2023-08-27 00:53:23 +00:00
2023-08-19 14:50:07 +00:00
# parse yaml to create an hdf5 file
2023-08-27 00:53:23 +00:00
def create_dataset_hdf5 ( skip_existing = True ) :
2023-08-19 14:50:07 +00:00
cfg . dataset . use_hdf5 = True
cfg . load_hdf5 ( write = True )
2024-05-08 07:11:38 +00:00
hf = cfg . hdf5
2023-08-19 14:50:07 +00:00
2023-08-02 21:53:35 +00:00
symmap = get_phone_symmap ( )
2024-04-29 03:28:29 +00:00
root = str ( cfg . data_dir )
metadata_root = str ( cfg . metadata_dir )
2023-08-02 21:53:35 +00:00
2024-09-21 17:19:34 +00:00
def add ( dir , type = " training " , audios = True , texts = True , verbose = False ) :
2024-04-29 03:28:29 +00:00
name = str ( dir )
name = name . replace ( root , " " )
2024-09-18 21:43:57 +00:00
speaker_name = remap_speaker_name ( name )
2024-04-29 03:28:29 +00:00
2024-05-08 07:11:38 +00:00
metadata_path = Path ( f " { metadata_root } / { speaker_name } .json " )
metadata_path . parents [ 0 ] . mkdir ( parents = True , exist_ok = True )
2024-04-29 03:28:29 +00:00
2024-12-12 01:10:32 +00:00
try :
metadata = json_read ( metadata_path , default = { } )
except Exception as e :
print ( metadata_path , e )
return
2023-08-02 21:53:35 +00:00
if not os . path . isdir ( f ' { root } / { name } / ' ) :
return
2024-05-25 16:07:52 +00:00
2023-08-02 21:53:35 +00:00
files = os . listdir ( f ' { root } / { name } / ' )
# grab IDs for every file
2024-12-12 02:55:43 +00:00
ids = { file . replace ( _get_artifact_extension ( ) , " " ) . replace ( _get_metadata_extension ( ) , " " ) for file in files }
2024-04-19 02:24:06 +00:00
2024-05-25 16:07:52 +00:00
"""
# rephonemizes if you fuck up and use and old tokenizer...
for id , entry in tqdm ( metadata . items ( ) , desc = f " Processing { name } " ) :
key = f ' { type } / { speaker_name } / { id } '
if key not in hf :
continue
group = hf [ key ]
if " phonemes " not in entry :
continue
if " text " not in group :
continue
txt = entry [ " phonemes " ]
phn = " " . join ( txt )
phn = cfg . tokenizer . encode ( phn )
phn = np . array ( phn ) . astype ( np . uint8 )
del group [ " text " ]
group . create_dataset ( ' text ' , data = phn , compression = ' lzf ' )
"""
2024-09-21 17:19:34 +00:00
for id in tqdm ( ids , desc = f " Processing { name } " , disable = not verbose ) :
2023-09-12 20:54:41 +00:00
try :
2024-12-12 02:55:43 +00:00
quant_exists = os . path . exists ( f ' { root } / { name } / { id } { _get_artifact_extension ( ) } ' ) if audios else True
text_exists = os . path . exists ( f ' { root } / { name } / { id } { _get_metadata_extension ( ) } ' ) if texts else True
2023-08-02 21:53:35 +00:00
2024-05-18 12:14:26 +00:00
if not quant_exists :
2023-08-27 00:53:23 +00:00
continue
2023-08-02 21:53:35 +00:00
2024-05-08 07:11:38 +00:00
key = f ' { type } / { speaker_name } / { id } '
2023-09-12 20:54:41 +00:00
2024-04-29 03:28:29 +00:00
if skip_existing and key in hf :
continue
group = hf . create_group ( key ) if key not in hf else hf [ key ]
if id not in metadata :
metadata [ id ] = { }
2023-09-12 20:54:41 +00:00
2024-05-16 04:04:19 +00:00
utterance_metadata = { }
2023-09-12 20:54:41 +00:00
# audio
if audios :
2024-12-12 02:55:43 +00:00
artifact = np . load ( f ' { root } / { name } / { id } { _get_artifact_extension ( ) } ' , allow_pickle = True ) [ ( ) ]
2024-09-09 14:57:32 +00:00
qnt = torch . from_numpy ( artifact [ " codes " ] . astype ( int ) ) [ 0 ] . t ( ) . to ( dtype = torch . int16 )
utterance_metadata = process_artifact_metadata ( artifact )
2024-04-29 03:28:29 +00:00
if " audio " not in group :
2024-05-16 04:04:19 +00:00
group . create_dataset ( ' audio ' , data = qnt . numpy ( ) . astype ( np . int16 ) , compression = ' lzf ' )
2024-04-19 02:24:06 +00:00
2023-09-12 20:54:41 +00:00
# text
2024-09-18 21:43:57 +00:00
# this is a relic from when I did have the quantized audio and phoneme transcription separate
# to-do: ensure I can remove this block
2023-09-12 20:54:41 +00:00
if texts :
2024-05-16 04:04:19 +00:00
if not utterance_metadata and text_exists :
2024-12-12 02:55:43 +00:00
utterance_metadata = json_read ( f ' { root } / { name } / { id } { _get_metadata_extension ( ) } ' )
2024-04-21 19:49:18 +00:00
2024-05-16 04:04:19 +00:00
phn = " " . join ( utterance_metadata [ " phonemes " ] )
phn = cfg . tokenizer . encode ( phn )
2024-04-21 19:49:18 +00:00
phn = np . array ( phn ) . astype ( np . uint8 )
2024-04-18 18:32:41 +00:00
2024-04-29 03:28:29 +00:00
if " text " not in group :
group . create_dataset ( ' text ' , data = phn , compression = ' lzf ' )
2024-04-18 18:32:41 +00:00
2024-05-16 04:04:19 +00:00
for k , v in utterance_metadata . items ( ) :
group . attrs [ k ] = v
metadata [ id ] [ k ] = v
2024-04-29 03:28:29 +00:00
2023-09-12 20:54:41 +00:00
except Exception as e :
2024-05-16 04:04:19 +00:00
tqdm . write ( f ' Error while processing { id } : { e } ' )
2023-08-27 00:53:23 +00:00
2024-09-18 21:43:57 +00:00
json_write ( metadata , metadata_path )
2023-08-02 21:53:35 +00:00
# training
2024-05-08 07:11:38 +00:00
for data_dir in tqdm ( cfg . dataset . training , desc = " Processing Training " ) :
2023-08-02 21:53:35 +00:00
add ( data_dir , type = " training " )
# validation
2024-05-08 07:11:38 +00:00
for data_dir in tqdm ( cfg . dataset . validation , desc = ' Processing Validation ' ) :
2023-08-02 21:53:35 +00:00
add ( data_dir , type = " validation " )
2023-08-19 04:57:07 +00:00
# noise
2024-05-08 07:11:38 +00:00
for data_dir in tqdm ( cfg . dataset . noise , desc = ' Processing Noise ' ) :
2023-08-19 14:50:07 +00:00
add ( data_dir , type = " noise " , texts = False )
2023-08-19 04:57:07 +00:00
2023-08-02 21:53:35 +00:00
# write symmap
2023-08-27 00:53:23 +00:00
if " symmap " in hf :
del hf [ ' symmap ' ]
2023-08-02 21:53:35 +00:00
2024-09-18 21:43:57 +00:00
hf . create_dataset ( ' symmap ' , data = json_stringify ( symmap ) )
2023-08-02 21:53:35 +00:00
hf . close ( )
if __name__ == " __main__ " :
2023-08-17 20:04:45 +00:00
import argparse
parser = argparse . ArgumentParser ( " Save trained model to path. " )
2023-08-19 14:50:07 +00:00
parser . add_argument ( " --action " , type = str )
parser . add_argument ( " --tasks " , type = str )
2024-06-09 16:39:43 +00:00
args , unknown = parser . parse_known_args ( )
2023-08-17 20:04:45 +00:00
2023-08-19 14:50:07 +00:00
task = args . action
2023-08-19 04:55:40 +00:00
2024-09-21 17:19:34 +00:00
setup_logging ( )
2023-08-20 11:29:17 +00:00
cfg . dataset . workers = 1
2023-08-19 14:50:07 +00:00
if args . action == " hdf5 " :
2023-08-17 20:04:45 +00:00
create_dataset_hdf5 ( )
2024-05-12 18:02:15 +00:00
elif args . action == " list-dataset " :
dataset = [ ]
for group in os . listdir ( cfg . data_dir ) :
for name in os . listdir ( cfg . data_dir / group ) :
if len ( os . listdir ( cfg . data_dir / group / name ) ) == 0 :
continue
dataset . append ( f ' { group } / { name } ' )
2024-09-18 21:43:57 +00:00
_logger . info ( json_stringify ( dataset ) )
2023-08-27 00:53:23 +00:00
elif args . action == " metadata " :
create_dataset_metadata ( )
2023-08-19 14:50:07 +00:00
elif args . action == " sample " :
2024-11-15 04:17:47 +00:00
train_dl , val_dl = create_train_val_dataloader ( )
2023-08-19 04:55:40 +00:00
samples = {
" training " : [ next ( iter ( train_dl ) ) , next ( iter ( train_dl ) ) ] ,
2024-11-15 04:17:47 +00:00
#"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
" validation " : [ next ( iter ( val_dl ) ) , next ( iter ( val_dl ) ) ] ,
2023-08-19 04:55:40 +00:00
}
2024-05-12 12:30:59 +00:00
Path ( " ./data/sample-test/ " ) . mkdir ( parents = True , exist_ok = True )
for k , v in samples . items ( ) :
for i in range ( len ( v ) ) :
for j in tqdm ( range ( len ( v [ i ] [ ' proms ' ] ) ) , desc = " Decoding... " ) :
2024-08-09 15:51:36 +00:00
"""
2024-05-12 12:30:59 +00:00
"""
try :
decode_to_file ( v [ i ] [ ' proms ' ] [ j ] , f " ./data/sample-test/ { k } . { i } . { j } .proms.wav " , device = " cpu " )
except Exception as e :
2024-08-29 18:27:16 +00:00
_logger . info ( f " Error while decoding prom { k } . { i } . { j } .wav: { str ( e ) } " )
2024-05-12 12:30:59 +00:00
try :
decode_to_file ( v [ i ] [ ' resps ' ] [ j ] , f " ./data/sample-test/ { k } . { i } . { j } .resps.wav " , device = " cpu " )
except Exception as e :
2024-08-29 18:27:16 +00:00
_logger . info ( f " Error while decoding resp { k } . { i } . { j } .wav: { str ( e ) } " )
2024-11-15 04:17:47 +00:00
#v[i]['proms'][j] = v[i]['proms'][j].shape
#v[i]['resps'][j] = v[i]['resps'][j].shape
2024-05-12 12:30:59 +00:00
2023-08-19 04:55:40 +00:00
for k , v in samples . items ( ) :
for i in range ( len ( v ) ) :
2024-08-29 18:27:16 +00:00
_logger . info ( f ' { k } [ { i } ]: { v [ i ] } ' )
2024-07-22 05:30:40 +00:00
elif args . action == " validate " :
train_dl , subtrain_dl , val_dl = create_train_val_dataloader ( )
2024-09-19 00:36:03 +00:00
dataset = train_dl . dataset
2024-07-22 05:30:40 +00:00
2024-09-21 17:19:34 +00:00
missing = [ ]
symmap = get_phone_symmap ( )
2024-09-19 00:36:03 +00:00
2024-09-21 17:19:34 +00:00
for index in tqdm ( range ( len ( dataset ) ) , desc = " Processing dataset... " ) :
2024-09-19 00:36:03 +00:00
if dataset . sampler_type == " group " :
spkr_group = dataset . spkr_groups [ index ]
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
spkr_name = dataset . spkr_samplers [ spkr_group ] . sample ( )
spkr_id = dataset . spkr_symmap [ spkr_name ]
path = dataset . samplers [ spkr_name ] . sample ( )
elif dataset . sampler_type == " speaker " :
spkr_name = dataset . spkrs [ index ]
spkr_id = dataset . spkr_symmap [ spkr_name ]
path = dataset . samplers [ spkr_name ] . sample ( )
spkr_group = dataset . get_speaker_group ( path )
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
else :
path = dataset . paths [ index ]
spkr_name = dataset . get_speaker ( path )
spkr_id = dataset . spkr_symmap [ spkr_name ]
spkr_group = dataset . get_speaker_group ( path )
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
if cfg . dataset . use_hdf5 :
key = _get_hdf5_path ( path )
if key not in cfg . hdf5 :
continue
metadata = { f ' { k } ' : f ' { v } ' for k , v in cfg . hdf5 [ key ] . attrs . items ( ) }
else :
2024-12-12 02:55:43 +00:00
_ , metadata = _load_artifact ( path , return_metadata = True )
2024-09-21 17:19:34 +00:00
2024-09-19 00:36:03 +00:00
phonemes = metadata [ " phonemes " ]
2024-09-21 17:19:34 +00:00
for i , phone in enumerate ( phonemes ) :
if phone in symmap :
continue
if phone in missing :
continue
_logger . info ( f " { path } | { phonemes } [ { i } ] | { phone } " )
missing . append ( phone )
"""
text = tokenize ( phonemes ) [ 1 : - 1 ]
unk_token = tokenize ( " <unk> " ) [ 1 ]
if unk_token in text :
print ( unk_token , text , phonemes )
2024-09-19 00:36:03 +00:00
for i , token in enumerate ( text ) :
2024-09-21 17:19:34 +00:00
if token != unk_token :
2024-09-19 00:36:03 +00:00
continue
phone = phonemes [ i ]
2024-09-21 17:19:34 +00:00
if phone not in missing :
_logger . info ( f " { path } | { phonemes } [ { i } ] | { phone } " )
2024-07-22 05:30:40 +00:00
missing | = set ( [ phone ] )
2024-09-21 17:19:34 +00:00
"""
2024-07-22 05:30:40 +00:00
2024-08-29 18:27:16 +00:00
_logger . info ( f " Missing tokens: { missing } " )
2024-07-22 05:30:40 +00:00
2023-08-27 00:53:23 +00:00
2023-08-19 14:50:07 +00:00
elif args . action == " tasks " :
2023-08-19 04:55:40 +00:00
index = 0
2023-08-19 14:50:07 +00:00
cfg . dataset . tasks_list = args . tasks . split ( " , " )
train_dl , subtrain_dl , val_dl = create_train_val_dataloader ( )
batch = next ( iter ( train_dl ) )
2023-08-18 19:47:48 +00:00
2023-08-19 14:50:07 +00:00
for text , resps , proms , task in zip ( batch [ " text " ] , batch [ " resps " ] , batch [ " proms " ] , batch [ " task " ] ) :
if task not in cfg . dataset . tasks_list :
continue
2023-08-18 19:47:48 +00:00
2024-08-29 18:27:16 +00:00
_logger . info ( f ' { text } { task } { cfg . model . resp_levels } ' )
_logger . info ( f ' { proms . shape } { resps . shape } ' )
2023-08-28 16:02:45 +00:00
tokens = 0
tokens + = sum ( [ text . shape [ 0 ] for text in batch [ " text " ] ] )
tokens + = sum ( [ resps . shape [ 0 ] for resps in batch [ " resps " ] ] )
2024-08-29 18:27:16 +00:00
_logger . info ( f ' { tokens } ' )
2023-08-28 16:02:45 +00:00
2023-08-20 11:29:17 +00:00
decode_to_file ( proms , f " ./data/ { task } .proms.wav " , device = " cpu " )
decode_to_file ( resps , f " ./data/ { task } .resps.wav " , device = " cpu " )
2023-08-19 14:50:07 +00:00
break