2021-09-14 23:43:16 +00:00
import os
import pathlib
import random
2021-10-21 22:45:19 +00:00
import sys
2021-10-24 15:09:34 +00:00
from warnings import warn
2021-09-14 23:43:16 +00:00
import torch
import torch . utils . data
import torch . nn . functional as F
import torchaudio
from audio2numpy import open_audio
from tqdm import tqdm
from data . audio . wav_aug import WavAugmentor
2021-10-21 22:45:19 +00:00
from data . util import find_files_of_type , is_wav_file , is_audio_file , load_paths_from_cache
2021-09-14 23:43:16 +00:00
from models . tacotron2 . taco_utils import load_wav_to_torch
from utils . util import opt_get
def load_audio ( audiopath , sampling_rate ) :
2021-09-15 00:29:17 +00:00
if audiopath [ - 4 : ] == ' .wav ' :
2021-09-14 23:43:16 +00:00
audio , lsr = load_wav_to_torch ( audiopath )
2021-10-29 04:33:12 +00:00
elif audiopath [ - 4 : ] == ' .mp3 ' :
# https://github.com/neonbjb/pyfastmp3decoder - Definitely worth it.
from pyfastmp3decoder . mp3decoder import load_mp3
audio , lsr = load_mp3 ( audiopath , sampling_rate )
audio = torch . FloatTensor ( audio )
2021-09-14 23:43:16 +00:00
else :
audio , lsr = open_audio ( audiopath )
audio = torch . FloatTensor ( audio )
# Remove any channel data.
if len ( audio . shape ) > 1 :
if audio . shape [ 0 ] < 5 :
audio = audio [ 0 ]
else :
assert audio . shape [ 1 ] < 5
audio = audio [ : , 0 ]
if lsr != sampling_rate :
2021-10-24 15:09:34 +00:00
#if lsr < sampling_rate:
# warn(f'{audiopath} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {sampling_rate}. This is not a good idea.')
2021-09-14 23:43:16 +00:00
audio = torch . nn . functional . interpolate ( audio . unsqueeze ( 0 ) . unsqueeze ( 1 ) , scale_factor = sampling_rate / lsr , mode = ' nearest ' , recompute_scale_factor = False ) . squeeze ( )
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch . any ( audio > 2 ) or not torch . any ( audio < 0 ) :
print ( f " Error with { audiopath } . Max= { audio . max ( ) } min= { audio . min ( ) } " )
audio . clip_ ( - 1 , 1 )
return audio . unsqueeze ( 0 )
class UnsupervisedAudioDataset ( torch . utils . data . Dataset ) :
def __init__ ( self , opt ) :
path = opt [ ' path ' ]
cache_path = opt [ ' cache_path ' ] # Will fail when multiple paths specified, must be specified in this case.
2021-10-21 22:45:19 +00:00
self . audiopaths = load_paths_from_cache ( path , cache_path )
2021-09-14 23:43:16 +00:00
# Parse options
self . sampling_rate = opt_get ( opt , [ ' sampling_rate ' ] , 22050 )
self . pad_to = opt_get ( opt , [ ' pad_to_seconds ' ] , None )
if self . pad_to is not None :
self . pad_to * = self . sampling_rate
self . pad_to = opt_get ( opt , [ ' pad_to_samples ' ] , self . pad_to )
2021-10-26 16:42:23 +00:00
# "Resampled clip" is audio data pulled from the basis of "clip" but with randomly different bounds. There are no
# guarantees that "clip_resampled" is different from "clip": in fact, if "clip" is less than pad_to_seconds/samples,
self . should_resample_clip = opt_get ( opt , [ ' resample_clip ' ] , False )
# "Extra samples" are other audio clips pulled from wav files in the same directory as the 'clip' wav file.
2021-09-14 23:43:16 +00:00
self . extra_samples = opt_get ( opt , [ ' extra_samples ' ] , 0 )
self . extra_sample_len = opt_get ( opt , [ ' extra_sample_length ' ] , 2 )
self . extra_sample_len * = self . sampling_rate
def get_audio_for_index ( self , index ) :
audiopath = self . audiopaths [ index ]
audio = load_audio ( audiopath , self . sampling_rate )
return audio , audiopath
def get_related_audio_for_index ( self , index ) :
if self . extra_samples < = 0 :
2021-09-17 04:43:10 +00:00
return None , 0
2021-09-14 23:43:16 +00:00
audiopath = self . audiopaths [ index ]
related_files = find_files_of_type ( ' img ' , os . path . dirname ( audiopath ) , qualifier = is_audio_file ) [ 0 ]
assert audiopath in related_files
assert len ( related_files ) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
2021-09-15 00:29:17 +00:00
if len ( related_files ) == 0 :
print ( f " No related files for { audiopath } " )
2021-09-14 23:43:16 +00:00
related_files . remove ( audiopath )
related_clips = [ ]
random . shuffle ( related_clips )
2021-09-15 02:45:05 +00:00
i = 0
for related_file in related_files :
2021-09-14 23:43:16 +00:00
rel_clip = load_audio ( related_file , self . sampling_rate )
gap = rel_clip . shape [ - 1 ] - self . extra_sample_len
if gap < 0 :
rel_clip = F . pad ( rel_clip , pad = ( 0 , abs ( gap ) ) )
elif gap > 0 :
rand_start = random . randint ( 0 , gap )
rel_clip = rel_clip [ : , rand_start : rand_start + self . extra_sample_len ]
related_clips . append ( rel_clip )
2021-09-15 02:45:05 +00:00
i + = 1
if i > = self . extra_samples :
2021-09-14 23:43:16 +00:00
break
2021-09-15 02:45:05 +00:00
actual_extra_samples = i
while i < self . extra_samples :
2021-09-14 23:43:16 +00:00
related_clips . append ( torch . zeros ( 1 , self . extra_sample_len ) )
2021-09-15 02:45:05 +00:00
i + = 1
2021-09-14 23:43:16 +00:00
return torch . stack ( related_clips , dim = 0 ) , actual_extra_samples
def __getitem__ ( self , index ) :
2021-09-17 21:25:57 +00:00
try :
# Split audio_norm into two tensors of equal size.
audio_norm , filename = self . get_audio_for_index ( index )
alt_files , actual_samples = self . get_related_audio_for_index ( index )
except :
2021-10-21 22:45:19 +00:00
print ( f " Error loading audio for file { self . audiopaths [ index ] } { sys . exc_info ( ) } " )
2021-09-17 21:25:57 +00:00
return self [ index + 1 ]
2021-09-14 23:43:16 +00:00
2021-10-26 16:42:23 +00:00
# When generating resampled clips, skew is a bias that tries to spread them out from each other, reducing their
# influence on one another.
skew = [ - 1 , 1 ] if self . should_resample_clip else [ 0 ]
# To increase variability, which skew is applied to the clip and resampled_clip is randomized.
random . shuffle ( skew )
clips = [ ]
for sk in skew :
if self . pad_to is not None :
if audio_norm . shape [ - 1 ] < = self . pad_to :
clips . append ( torch . nn . functional . pad ( audio_norm , ( 0 , self . pad_to - audio_norm . shape [ - 1 ] ) ) )
else :
gap = audio_norm . shape [ - 1 ] - self . pad_to
start = min ( max ( random . randint ( 0 , gap - 1 ) + sk * gap / / 2 , 0 ) , gap - 1 )
clips . append ( audio_norm [ : , start : start + self . pad_to ] )
2021-10-27 19:09:46 +00:00
else :
clips . append ( audio_norm )
2021-09-14 23:43:16 +00:00
output = {
2021-10-26 16:42:23 +00:00
' clip ' : clips [ 0 ] ,
2021-09-14 23:43:16 +00:00
' path ' : filename ,
}
2021-10-26 16:42:23 +00:00
if self . should_resample_clip :
output [ ' resampled_clip ' ] = clips [ 1 ]
2021-09-17 04:43:10 +00:00
if self . extra_samples > 0 :
output [ ' alt_clips ' ] = alt_files
output [ ' num_alt_clips ' ] = actual_samples
2021-09-14 23:43:16 +00:00
return output
def __len__ ( self ) :
return len ( self . audiopaths )
if __name__ == ' __main__ ' :
params = {
' mode ' : ' unsupervised_audio ' ,
2021-10-29 04:33:12 +00:00
' path ' : [ ' \\ \\ 192.168.5.3 \\ rtx3080_audio \\ split \\ cleaned \\ books0 ' ] ,
' cache_path ' : ' E: \\ audio \\ remote-cache3.pth ' ,
2021-09-14 23:43:16 +00:00
' sampling_rate ' : 22050 ,
2021-10-17 23:32:46 +00:00
' pad_to_samples ' : 40960 ,
2021-09-14 23:43:16 +00:00
' phase ' : ' train ' ,
2021-10-17 23:32:46 +00:00
' n_workers ' : 1 ,
2021-09-14 23:43:16 +00:00
' batch_size ' : 16 ,
' extra_samples ' : 4 ,
2021-10-26 16:42:23 +00:00
' resample_clip ' : True ,
2021-09-14 23:43:16 +00:00
}
from data import create_dataset , create_dataloader , util
ds = create_dataset ( params )
dl = create_dataloader ( ds , params )
i = 0
for b in tqdm ( dl ) :
2021-10-17 23:32:46 +00:00
for b_ in range ( b [ ' clip ' ] . shape [ 0 ] ) :
#pass
torchaudio . save ( f ' { i } _clip_ { b_ } .wav ' , b [ ' clip ' ] [ b_ ] , ds . sampling_rate )
2021-10-26 16:42:23 +00:00
torchaudio . save ( f ' { i } _resampled_clip_ { b_ } .wav ' , b [ ' resampled_clip ' ] [ b_ ] , ds . sampling_rate )
2021-10-17 23:32:46 +00:00
i + = 1