From 2cd7b7268887be6035e0154a52869a5649b0b5b1 Mon Sep 17 00:00:00 2001
From: NtTestAlert <nttestalert@protonmail.com>
Date: Sat, 1 Apr 2023 15:08:31 +0200
Subject: [PATCH] feat: support .flac voice files

---
 tortoise/utils/audio.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/tortoise/utils/audio.py b/tortoise/utils/audio.py
index 74060d4..e3885e5 100755
--- a/tortoise/utils/audio.py
+++ b/tortoise/utils/audio.py
@@ -2,6 +2,7 @@ import os
 from glob import glob
 
 import librosa
+import soundfile as sf
 import torch
 import torchaudio
 import numpy as np
@@ -24,6 +25,9 @@ def load_audio(audiopath, sampling_rate):
     elif audiopath[-4:] == '.mp3':
         audio, lsr = librosa.load(audiopath, sr=sampling_rate)
         audio = torch.FloatTensor(audio)
+    elif audiopath[-5:] == '.flac':
+        audio, lsr = sf.read(audiopath)
+        audio = torch.FloatTensor(audio)
     else:
         assert False, f"Unsupported audio format provided: {audiopath[-4:]}"
 
@@ -85,7 +89,7 @@ def get_voices(extra_voice_dirs=[], load_latents=True):
         for sub in subs:
             subj = os.path.join(d, sub)
             if os.path.isdir(subj):
-                voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3'))
+                voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.flac'))
                 if load_latents:
                     voices[sub] = voices[sub] + list(glob(f'{subj}/*.pth'))
     return voices