2021-08-05 11:57:04 +00:00
import os
import random
import torch
import torch . utils . data
2021-08-06 04:14:49 +00:00
import torchaudio
2021-08-05 11:57:04 +00:00
from tqdm import tqdm
2021-08-06 04:14:49 +00:00
from data . audio . wav_aug import WavAugmentor
2021-08-09 05:23:13 +00:00
from data . util import find_files_of_type , is_wav_file
2021-08-05 11:57:04 +00:00
from models . tacotron2 . taco_utils import load_wav_to_torch
2021-08-06 04:14:49 +00:00
from utils . util import opt_get
2021-08-05 11:57:04 +00:00
class WavfileDataset ( torch . utils . data . Dataset ) :
def __init__ ( self , opt ) :
2021-08-15 15:09:51 +00:00
path = opt [ ' path ' ]
cache_path = opt [ ' cache_path ' ] # Will fail when multiple paths specified, must be specified in this case.
if not isinstance ( path , list ) :
path = [ path ]
2021-08-05 11:57:04 +00:00
if os . path . exists ( cache_path ) :
self . audiopaths = torch . load ( cache_path )
else :
print ( " Building cache.. " )
2021-08-11 14:46:02 +00:00
self . audiopaths = [ ]
2021-08-15 15:09:51 +00:00
for p in path :
2021-08-11 14:46:02 +00:00
self . audiopaths . extend ( find_files_of_type ( ' img ' , p , qualifier = is_wav_file ) [ 0 ] )
2021-08-05 11:57:04 +00:00
torch . save ( self . audiopaths , cache_path )
2021-08-06 04:14:49 +00:00
# Parse options
self . sampling_rate = opt_get ( opt , [ ' sampling_rate ' ] , 24000 )
2021-08-16 02:53:26 +00:00
self . pad_to = opt_get ( opt , [ ' pad_to_seconds ' ] , None )
if self . pad_to is not None :
self . pad_to * = self . sampling_rate
2021-08-06 04:14:49 +00:00
2021-08-17 04:52:35 +00:00
self . augment = opt_get ( opt , [ ' do_augmentation ' ] , False )
2021-08-06 04:14:49 +00:00
if self . augment :
2021-08-17 04:52:35 +00:00
# The "window size" for the clips produced in seconds.
self . window = 2 * self . sampling_rate
2021-08-06 04:14:49 +00:00
self . augmentor = WavAugmentor ( )
2021-08-05 11:57:04 +00:00
def get_audio_for_index ( self , index ) :
audiopath = self . audiopaths [ index ]
2021-08-15 15:09:51 +00:00
audio , sampling_rate = load_wav_to_torch ( audiopath )
2021-08-05 11:57:04 +00:00
if sampling_rate != self . sampling_rate :
2021-08-15 15:09:51 +00:00
if sampling_rate < self . sampling_rate :
print ( f ' { audiopath } has a sample rate of { sampling_rate } which is lower than the requested sample rate of { self . sampling_rate } . This is not a good idea. ' )
audio = torch . nn . functional . interpolate ( audio . unsqueeze ( 0 ) . unsqueeze ( 1 ) , scale_factor = self . sampling_rate / sampling_rate , mode = ' nearest ' , recompute_scale_factor = False ) . squeeze ( )
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch . any ( audio > 2 ) or not torch . any ( audio < 0 ) :
print ( f " Error with { audiopath } . Max= { audio . max ( ) } min= { audio . min ( ) } " )
audio . clip_ ( - 1 , 1 )
audio = audio . unsqueeze ( 0 )
return audio , audiopath
2021-08-05 11:57:04 +00:00
def __getitem__ ( self , index ) :
2021-08-17 04:52:35 +00:00
success = False
# This "success" thing is a hack: This dataset is randomly failing for no apparent good reason and I don't know why.
# Symptoms are it complaining about being unable to read a nonsensical filename that is clearly corrupted. Memory corruption? I don't know..
while not success :
try :
# Split audio_norm into two tensors of equal size.
audio_norm , filename = self . get_audio_for_index ( index )
success = True
except :
print ( f " Failed to load { index } { self . audiopaths [ index ] } " )
2021-08-05 11:57:04 +00:00
2021-08-17 04:52:35 +00:00
if self . augment :
2021-08-06 05:12:59 +00:00
if audio_norm . shape [ 1 ] < self . window * 2 :
2021-08-05 11:57:04 +00:00
# Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
2021-08-17 04:52:35 +00:00
return self [ ( index + 1 ) % len ( self ) ]
2021-08-06 05:12:59 +00:00
j = random . randint ( 0 , audio_norm . shape [ 1 ] - self . window )
clip1 = audio_norm [ : , j : j + self . window ]
2021-08-06 04:14:49 +00:00
if self . augment :
clip1 = self . augmentor . augment ( clip1 , self . sampling_rate )
2021-08-06 05:12:59 +00:00
j = random . randint ( 0 , audio_norm . shape [ 1 ] - self . window )
clip2 = audio_norm [ : , j : j + self . window ]
2021-08-06 04:14:49 +00:00
if self . augment :
clip2 = self . augmentor . augment ( clip2 , self . sampling_rate )
2021-08-05 11:57:04 +00:00
2021-08-16 02:53:26 +00:00
# This is required when training to make sure all clips align.
if self . pad_to is not None :
if audio_norm . shape [ - 1 ] < = self . pad_to :
audio_norm = torch . nn . functional . pad ( audio_norm , ( 0 , self . pad_to - audio_norm . shape [ - 1 ] ) )
else :
#print(f"Warning! Truncating clip {filename} from {audio_norm.shape[-1]} to {self.pad_to}")
audio_norm = audio_norm [ : , : self . pad_to ]
2021-08-17 04:52:35 +00:00
output = {
2021-08-16 02:53:26 +00:00
' clip ' : audio_norm ,
2021-08-05 11:57:04 +00:00
' path ' : filename ,
}
2021-08-17 04:52:35 +00:00
if self . augment :
output . update ( {
' clip1 ' : clip1 [ 0 , : ] . unsqueeze ( 0 ) ,
' clip2 ' : clip2 [ 0 , : ] . unsqueeze ( 0 ) ,
} )
return output
2021-08-05 11:57:04 +00:00
def __len__ ( self ) :
return len ( self . audiopaths )
if __name__ == ' __main__ ' :
params = {
' mode ' : ' wavfile_clips ' ,
2021-08-16 02:53:26 +00:00
' path ' : [ ' E: \\ audio \\ books-split ' , ' E: \\ audio \\ LibriTTS \\ train-clean-360 ' , ' D: \\ data \\ audio \\ podcasts-split ' ] ,
' cache_path ' : ' E: \\ audio \\ clips-cache.pth ' ,
' sampling_rate ' : 22050 ,
' pad_to_seconds ' : 5 ,
2021-08-05 11:57:04 +00:00
' phase ' : ' train ' ,
' n_workers ' : 0 ,
' batch_size ' : 16 ,
2021-08-17 04:52:35 +00:00
' do_augmentation ' : False
2021-08-05 11:57:04 +00:00
}
from data import create_dataset , create_dataloader , util
2021-08-16 02:53:26 +00:00
ds = create_dataset ( params )
dl = create_dataloader ( ds , params )
2021-08-05 11:57:04 +00:00
i = 0
for b in tqdm ( dl ) :
2021-08-06 05:22:44 +00:00
for b_ in range ( 16 ) :
2021-08-16 02:53:26 +00:00
pass
#torchaudio.save(f'{i}_clip1_{b_}.wav', b['clip1'][b_], ds.sampling_rate)
#torchaudio.save(f'{i}_clip2_{b_}.wav', b['clip2'][b_], ds.sampling_rate)
#i += 1