2022-01-28 06:19:29 +00:00
import torch
import torchaudio
2022-02-04 05:18:21 +00:00
import numpy as np
from scipy . io . wavfile import read
2022-01-28 06:19:29 +00:00
def load_wav_to_torch ( full_path ) :
sampling_rate , data = read ( full_path )
if data . dtype == np . int32 :
norm_fix = 2 * * 31
elif data . dtype == np . int16 :
norm_fix = 2 * * 15
elif data . dtype == np . float16 or data . dtype == np . float32 :
norm_fix = 1.
else :
raise NotImplemented ( f " Provided data dtype not supported: { data . dtype } " )
return ( torch . FloatTensor ( data . astype ( np . float32 ) ) / norm_fix , sampling_rate )
def load_audio ( audiopath , sampling_rate ) :
if audiopath [ - 4 : ] == ' .wav ' :
audio , lsr = load_wav_to_torch ( audiopath )
elif audiopath [ - 4 : ] == ' .mp3 ' :
# https://github.com/neonbjb/pyfastmp3decoder - Definitely worth it.
from pyfastmp3decoder . mp3decoder import load_mp3
audio , lsr = load_mp3 ( audiopath , sampling_rate )
audio = torch . FloatTensor ( audio )
# Remove any channel data.
if len ( audio . shape ) > 1 :
if audio . shape [ 0 ] < 5 :
audio = audio [ 0 ]
else :
assert audio . shape [ 1 ] < 5
audio = audio [ : , 0 ]
if lsr != sampling_rate :
audio = torchaudio . functional . resample ( audio , lsr , sampling_rate )
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch . any ( audio > 2 ) or not torch . any ( audio < 0 ) :
print ( f " Error with { audiopath } . Max= { audio . max ( ) } min= { audio . min ( ) } " )
audio . clip_ ( - 1 , 1 )
return audio . unsqueeze ( 0 )