forked from camenduru/ai-voice-cloning
Very, very, VERY, barebones integration with Bark (documentation soon)
This commit is contained in:
parent
faa8da12d7
commit
b6440091fb
456
src/utils.py
456
src/utils.py
|
@ -31,6 +31,7 @@ import music_tag
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -65,6 +66,7 @@ MIN_TRAINING_DURATION = 0.6
|
||||||
MAX_TRAINING_DURATION = 11.6097505669
|
MAX_TRAINING_DURATION = 11.6097505669
|
||||||
|
|
||||||
VALLE_ENABLED = False
|
VALLE_ENABLED = False
|
||||||
|
BARK_ENABLED = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vall_e.emb.qnt import encode as valle_quantize
|
from vall_e.emb.qnt import encode as valle_quantize
|
||||||
|
@ -76,11 +78,98 @@ try:
|
||||||
|
|
||||||
VALLE_ENABLED = True
|
VALLE_ENABLED = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if False: # args.tts_backend == "vall-e":
|
||||||
|
raise e
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if VALLE_ENABLED:
|
if VALLE_ENABLED:
|
||||||
TTSES.append('vall-e')
|
TTSES.append('vall-e')
|
||||||
|
|
||||||
|
try:
|
||||||
|
from bark.generation import SAMPLE_RATE as BARK_SAMPLE_RATE, ALLOWED_PROMPTS, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic, load_codec_model
|
||||||
|
from bark.api import generate_audio as bark_generate_audio
|
||||||
|
from encodec.utils import convert_audio
|
||||||
|
|
||||||
|
from scipy.io.wavfile import write as write_wav
|
||||||
|
|
||||||
|
BARK_ENABLED = True
|
||||||
|
except Exception as e:
|
||||||
|
if False: # args.tts_backend == "bark":
|
||||||
|
raise e
|
||||||
|
pass
|
||||||
|
|
||||||
|
if BARK_ENABLED:
|
||||||
|
TTSES.append('bark')
|
||||||
|
class Bark_TTS():
|
||||||
|
def __init__(self, small=False):
|
||||||
|
self.input_sample_rate = BARK_SAMPLE_RATE
|
||||||
|
self.output_sample_rate = args.output_sample_rate
|
||||||
|
|
||||||
|
preload_models(
|
||||||
|
text_use_gpu=True,
|
||||||
|
coarse_use_gpu=True,
|
||||||
|
fine_use_gpu=True,
|
||||||
|
codec_use_gpu=True,
|
||||||
|
|
||||||
|
text_use_small=small,
|
||||||
|
coarse_use_small=small,
|
||||||
|
fine_use_small=small,
|
||||||
|
|
||||||
|
force_reload=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_voice( self, voice, device='cuda' ):
|
||||||
|
transcription_json = f'./training/{voice}/whisper.json'
|
||||||
|
if not os.path.exists(transcription_json):
|
||||||
|
raise f"Transcription for voice not found: {voice}"
|
||||||
|
|
||||||
|
transcriptions = json.load(open(transcription_json, 'r', encoding="utf-8"))
|
||||||
|
candidates = []
|
||||||
|
for file in transcriptions:
|
||||||
|
result = transcriptions[file]
|
||||||
|
for segment in result['segments']:
|
||||||
|
entry = (
|
||||||
|
file.replace(".wav", f"_{pad(segment['id'], 4)}.wav"),
|
||||||
|
segment['end'] - segment['start'],
|
||||||
|
segment['text']
|
||||||
|
)
|
||||||
|
candidates.append(entry)
|
||||||
|
|
||||||
|
candidates.sort(key=lambda x: x[1])
|
||||||
|
candidate = random.choice(candidates)
|
||||||
|
audio_filepath = f'./training/{voice}/audio/{candidate[0]}'
|
||||||
|
text = candidate[-1]
|
||||||
|
|
||||||
|
print("Using as reference:", audio_filepath, text)
|
||||||
|
|
||||||
|
# Load and pre-process the audio waveform
|
||||||
|
model = load_codec_model(use_gpu=True)
|
||||||
|
wav, sr = torchaudio.load(audio_filepath)
|
||||||
|
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||||
|
wav = wav.unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
# Extract discrete codes from EnCodec
|
||||||
|
with torch.no_grad():
|
||||||
|
encoded_frames = model.encode(wav)
|
||||||
|
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze().cpu().numpy() # [n_q, T]
|
||||||
|
|
||||||
|
# get seconds of audio
|
||||||
|
seconds = wav.shape[-1] / model.sample_rate
|
||||||
|
# generate semantic tokens
|
||||||
|
semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)
|
||||||
|
|
||||||
|
output_path = './modules/bark/bark/assets/prompts/' + voice.replace("/", "_") + '.npz'
|
||||||
|
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
|
||||||
|
|
||||||
|
def inference( self, text, voice, text_temp=0.7, waveform_temp=0.7 ):
|
||||||
|
if not os.path.exists('./modules/bark/bark/assets/prompts/' + voice + '.npz'):
|
||||||
|
self.create_voice( voice )
|
||||||
|
voice = voice.replace("/", "_")
|
||||||
|
if voice not in ALLOWED_PROMPTS:
|
||||||
|
ALLOWED_PROMPTS.add( voice )
|
||||||
|
|
||||||
|
return (bark_generate_audio(text, history_prompt=voice, text_temp=text_temp, waveform_temp=waveform_temp), BARK_SAMPLE_RATE)
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
tts = None
|
tts = None
|
||||||
tts_loading = False
|
tts_loading = False
|
||||||
|
@ -96,6 +185,9 @@ training_state = None
|
||||||
|
|
||||||
current_voice = None
|
current_voice = None
|
||||||
|
|
||||||
|
def cleanup_voice_name( name ):
|
||||||
|
return name.split("/")[-1]
|
||||||
|
|
||||||
def resample( waveform, input_rate, output_rate=44100 ):
|
def resample( waveform, input_rate, output_rate=44100 ):
|
||||||
# mono-ize
|
# mono-ize
|
||||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
@ -121,6 +213,291 @@ def generate(**kwargs):
|
||||||
return generate_tortoise(**kwargs)
|
return generate_tortoise(**kwargs)
|
||||||
if args.tts_backend == "vall-e":
|
if args.tts_backend == "vall-e":
|
||||||
return generate_valle(**kwargs)
|
return generate_valle(**kwargs)
|
||||||
|
if args.tts_backend == "bark":
|
||||||
|
return generate_bark(**kwargs)
|
||||||
|
|
||||||
|
def generate_bark(**kwargs):
|
||||||
|
parameters = {}
|
||||||
|
parameters.update(kwargs)
|
||||||
|
|
||||||
|
voice = parameters['voice']
|
||||||
|
progress = parameters['progress'] if 'progress' in parameters else None
|
||||||
|
if parameters['seed'] == 0:
|
||||||
|
parameters['seed'] = None
|
||||||
|
|
||||||
|
usedSeed = parameters['seed']
|
||||||
|
|
||||||
|
global args
|
||||||
|
global tts
|
||||||
|
|
||||||
|
unload_whisper()
|
||||||
|
unload_voicefixer()
|
||||||
|
|
||||||
|
if not tts:
|
||||||
|
# should check if it's loading or unloaded, and load it if it's unloaded
|
||||||
|
if tts_loading:
|
||||||
|
raise Exception("TTS is still initializing...")
|
||||||
|
if progress is not None:
|
||||||
|
progress(0, "Initializing TTS...")
|
||||||
|
load_tts()
|
||||||
|
if hasattr(tts, "loading") and tts.loading:
|
||||||
|
raise Exception("TTS is still initializing...")
|
||||||
|
|
||||||
|
do_gc()
|
||||||
|
|
||||||
|
voice_samples = None
|
||||||
|
conditioning_latents = None
|
||||||
|
sample_voice = None
|
||||||
|
|
||||||
|
voice_cache = {}
|
||||||
|
|
||||||
|
def get_settings( override=None ):
|
||||||
|
settings = {
|
||||||
|
'voice': parameters['voice'],
|
||||||
|
'text_temp': float(parameters['temperature']),
|
||||||
|
'waveform_temp': float(parameters['temperature']),
|
||||||
|
}
|
||||||
|
|
||||||
|
# could be better to just do a ternary on everything above, but i am not a professional
|
||||||
|
selected_voice = voice
|
||||||
|
if override is not None:
|
||||||
|
if 'voice' in override:
|
||||||
|
selected_voice = override['voice']
|
||||||
|
|
||||||
|
for k in override:
|
||||||
|
if k not in settings:
|
||||||
|
continue
|
||||||
|
settings[k] = override[k]
|
||||||
|
|
||||||
|
return settings
|
||||||
|
|
||||||
|
if not parameters['delimiter']:
|
||||||
|
parameters['delimiter'] = "\n"
|
||||||
|
elif parameters['delimiter'] == "\\n":
|
||||||
|
parameters['delimiter'] = "\n"
|
||||||
|
|
||||||
|
if parameters['delimiter'] and parameters['delimiter'] != "" and parameters['delimiter'] in parameters['text']:
|
||||||
|
texts = parameters['text'].split(parameters['delimiter'])
|
||||||
|
else:
|
||||||
|
texts = split_and_recombine_text(parameters['text'])
|
||||||
|
|
||||||
|
full_start_time = time.time()
|
||||||
|
|
||||||
|
outdir = f"{args.results_folder}/{voice}/"
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
|
audio_cache = {}
|
||||||
|
|
||||||
|
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
idx_cache = {}
|
||||||
|
for i, file in enumerate(os.listdir(outdir)):
|
||||||
|
filename = os.path.basename(file)
|
||||||
|
extension = os.path.splitext(filename)[1]
|
||||||
|
if extension != ".json" and extension != ".wav":
|
||||||
|
continue
|
||||||
|
match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
|
||||||
|
if match and len(match) > 0:
|
||||||
|
key = int(match[0])
|
||||||
|
idx_cache[key] = True
|
||||||
|
|
||||||
|
if len(idx_cache) > 0:
|
||||||
|
keys = sorted(list(idx_cache.keys()))
|
||||||
|
idx = keys[-1] + 1
|
||||||
|
|
||||||
|
idx = pad(idx, 4)
|
||||||
|
|
||||||
|
def get_name(line=0, candidate=0, combined=False):
|
||||||
|
name = f"{idx}"
|
||||||
|
if combined:
|
||||||
|
name = f"{name}_combined"
|
||||||
|
elif len(texts) > 1:
|
||||||
|
name = f"{name}_{line}"
|
||||||
|
if parameters['candidates'] > 1:
|
||||||
|
name = f"{name}_{candidate}"
|
||||||
|
return name
|
||||||
|
|
||||||
|
def get_info( voice, settings = None, latents = True ):
|
||||||
|
info = {}
|
||||||
|
info.update(parameters)
|
||||||
|
|
||||||
|
info['time'] = time.time()-full_start_time
|
||||||
|
info['datetime'] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
info['progress'] = None
|
||||||
|
del info['progress']
|
||||||
|
|
||||||
|
if info['delimiter'] == "\n":
|
||||||
|
info['delimiter'] = "\\n"
|
||||||
|
|
||||||
|
if settings is not None:
|
||||||
|
for k in settings:
|
||||||
|
if k in info:
|
||||||
|
info[k] = settings[k]
|
||||||
|
return info
|
||||||
|
|
||||||
|
INFERENCING = True
|
||||||
|
for line, cut_text in enumerate(texts):
|
||||||
|
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
|
||||||
|
print(f"{progress.msg_prefix} Generating line: {cut_text}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# do setting editing
|
||||||
|
match = re.findall(r'^(\{.+\}) (.+?)$', cut_text)
|
||||||
|
override = None
|
||||||
|
if match and len(match) > 0:
|
||||||
|
match = match[0]
|
||||||
|
try:
|
||||||
|
override = json.loads(match[0])
|
||||||
|
cut_text = match[1].strip()
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Prompt settings editing requested, but received invalid JSON")
|
||||||
|
|
||||||
|
settings = get_settings( override=override )
|
||||||
|
|
||||||
|
gen = tts.inference(cut_text, **settings )
|
||||||
|
|
||||||
|
run_time = time.time()-start_time
|
||||||
|
print(f"Generating line took {run_time} seconds")
|
||||||
|
|
||||||
|
if not isinstance(gen, list):
|
||||||
|
gen = [gen]
|
||||||
|
|
||||||
|
for j, g in enumerate(gen):
|
||||||
|
wav, sr = g
|
||||||
|
name = get_name(line=line, candidate=j)
|
||||||
|
|
||||||
|
settings['text'] = cut_text
|
||||||
|
settings['time'] = run_time
|
||||||
|
settings['datetime'] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
# save here in case some error happens mid-batch
|
||||||
|
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
|
||||||
|
write_wav(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', sr, wav)
|
||||||
|
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||||
|
|
||||||
|
audio_cache[name] = {
|
||||||
|
'audio': wav,
|
||||||
|
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
|
||||||
|
}
|
||||||
|
|
||||||
|
del gen
|
||||||
|
do_gc()
|
||||||
|
INFERENCING = False
|
||||||
|
|
||||||
|
for k in audio_cache:
|
||||||
|
audio = audio_cache[k]['audio']
|
||||||
|
|
||||||
|
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
|
||||||
|
if volume_adjust is not None:
|
||||||
|
audio = volume_adjust(audio)
|
||||||
|
|
||||||
|
audio_cache[k]['audio'] = audio
|
||||||
|
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
|
||||||
|
|
||||||
|
output_voices = []
|
||||||
|
for candidate in range(parameters['candidates']):
|
||||||
|
if len(texts) > 1:
|
||||||
|
audio_clips = []
|
||||||
|
for line in range(len(texts)):
|
||||||
|
name = get_name(line=line, candidate=candidate)
|
||||||
|
audio = audio_cache[name]['audio']
|
||||||
|
audio_clips.append(audio)
|
||||||
|
|
||||||
|
name = get_name(candidate=candidate, combined=True)
|
||||||
|
audio = torch.cat(audio_clips, dim=-1)
|
||||||
|
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
|
||||||
|
|
||||||
|
audio = audio.squeeze(0).cpu()
|
||||||
|
audio_cache[name] = {
|
||||||
|
'audio': audio,
|
||||||
|
'settings': get_info(voice=voice),
|
||||||
|
'output': True
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
name = get_name(candidate=candidate)
|
||||||
|
audio_cache[name]['output'] = True
|
||||||
|
|
||||||
|
|
||||||
|
if args.voice_fixer:
|
||||||
|
if not voicefixer:
|
||||||
|
progress(0, "Loading voicefix...")
|
||||||
|
load_voicefixer()
|
||||||
|
|
||||||
|
try:
|
||||||
|
fixed_cache = {}
|
||||||
|
for name in progress.tqdm(audio_cache, desc="Running voicefix..."):
|
||||||
|
del audio_cache[name]['audio']
|
||||||
|
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
|
||||||
|
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
|
||||||
|
voicefixer.restore(
|
||||||
|
input=path,
|
||||||
|
output=fixed,
|
||||||
|
cuda=get_device_name() == "cuda" and args.voice_fixer_use_cuda,
|
||||||
|
#mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
fixed_cache[f'{name}_fixed'] = {
|
||||||
|
'settings': audio_cache[name]['settings'],
|
||||||
|
'output': True
|
||||||
|
}
|
||||||
|
audio_cache[name]['output'] = False
|
||||||
|
|
||||||
|
for name in fixed_cache:
|
||||||
|
audio_cache[name] = fixed_cache[name]
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print("\nFailed to run Voicefixer")
|
||||||
|
|
||||||
|
for name in audio_cache:
|
||||||
|
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||||
|
if args.prune_nonfinal_outputs:
|
||||||
|
audio_cache[name]['pruned'] = True
|
||||||
|
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||||
|
continue
|
||||||
|
|
||||||
|
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||||
|
|
||||||
|
if not args.embed_output_metadata:
|
||||||
|
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
|
||||||
|
|
||||||
|
if args.embed_output_metadata:
|
||||||
|
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
|
||||||
|
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
|
||||||
|
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
|
||||||
|
metadata.save()
|
||||||
|
|
||||||
|
if sample_voice is not None:
|
||||||
|
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
|
||||||
|
|
||||||
|
info = get_info(voice=voice, latents=False)
|
||||||
|
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
|
||||||
|
|
||||||
|
info['seed'] = usedSeed
|
||||||
|
if 'latents' in info:
|
||||||
|
del info['latents']
|
||||||
|
|
||||||
|
os.makedirs('./config/', exist_ok=True)
|
||||||
|
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(info, indent='\t') )
|
||||||
|
|
||||||
|
stats = [
|
||||||
|
[ parameters['seed'], "{:.3f}".format(info['time']) ]
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
sample_voice,
|
||||||
|
output_voices,
|
||||||
|
stats,
|
||||||
|
)
|
||||||
|
|
||||||
def generate_valle(**kwargs):
|
def generate_valle(**kwargs):
|
||||||
parameters = {}
|
parameters = {}
|
||||||
|
@ -289,9 +666,9 @@ def generate_valle(**kwargs):
|
||||||
settings['datetime'] = datetime.now().isoformat()
|
settings['datetime'] = datetime.now().isoformat()
|
||||||
|
|
||||||
# save here in case some error happens mid-batch
|
# save here in case some error happens mid-batch
|
||||||
#torchaudio.save(f'{outdir}/{voice}_{name}.wav', wav.cpu(), sr)
|
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
|
||||||
soundfile.write(f'{outdir}/{voice}_{name}.wav', wav.cpu()[0,0], sr)
|
soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
|
||||||
wav, sr = torchaudio.load(f'{outdir}/{voice}_{name}.wav')
|
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||||
|
|
||||||
audio_cache[name] = {
|
audio_cache[name] = {
|
||||||
'audio': wav,
|
'audio': wav,
|
||||||
|
@ -310,7 +687,7 @@ def generate_valle(**kwargs):
|
||||||
audio = volume_adjust(audio)
|
audio = volume_adjust(audio)
|
||||||
|
|
||||||
audio_cache[k]['audio'] = audio
|
audio_cache[k]['audio'] = audio
|
||||||
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
|
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
|
||||||
|
|
||||||
output_voices = []
|
output_voices = []
|
||||||
for candidate in range(parameters['candidates']):
|
for candidate in range(parameters['candidates']):
|
||||||
|
@ -323,7 +700,7 @@ def generate_valle(**kwargs):
|
||||||
|
|
||||||
name = get_name(candidate=candidate, combined=True)
|
name = get_name(candidate=candidate, combined=True)
|
||||||
audio = torch.cat(audio_clips, dim=-1)
|
audio = torch.cat(audio_clips, dim=-1)
|
||||||
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
|
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
|
||||||
|
|
||||||
audio = audio.squeeze(0).cpu()
|
audio = audio.squeeze(0).cpu()
|
||||||
audio_cache[name] = {
|
audio_cache[name] = {
|
||||||
|
@ -348,8 +725,8 @@ def generate_valle(**kwargs):
|
||||||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
path = f'{outdir}/{voice}_{name}.wav'
|
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
|
||||||
fixed = f'{outdir}/{voice}_{name}_fixed.wav'
|
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
|
||||||
voicefixer.restore(
|
voicefixer.restore(
|
||||||
input=path,
|
input=path,
|
||||||
output=fixed,
|
output=fixed,
|
||||||
|
@ -373,13 +750,13 @@ def generate_valle(**kwargs):
|
||||||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||||
if args.prune_nonfinal_outputs:
|
if args.prune_nonfinal_outputs:
|
||||||
audio_cache[name]['pruned'] = True
|
audio_cache[name]['pruned'] = True
|
||||||
os.remove(f'{outdir}/{voice}_{name}.wav')
|
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||||
|
|
||||||
if not args.embed_output_metadata:
|
if not args.embed_output_metadata:
|
||||||
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
|
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
|
||||||
|
|
||||||
if args.embed_output_metadata:
|
if args.embed_output_metadata:
|
||||||
|
@ -387,7 +764,7 @@ def generate_valle(**kwargs):
|
||||||
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
|
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
|
||||||
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
|
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
|
||||||
metadata.save()
|
metadata.save()
|
||||||
|
|
||||||
|
@ -415,8 +792,6 @@ def generate_valle(**kwargs):
|
||||||
stats,
|
stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_tortoise(**kwargs):
|
def generate_tortoise(**kwargs):
|
||||||
parameters = {}
|
parameters = {}
|
||||||
parameters.update(kwargs)
|
parameters.update(kwargs)
|
||||||
|
@ -698,7 +1073,7 @@ def generate_tortoise(**kwargs):
|
||||||
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
|
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
|
||||||
}
|
}
|
||||||
# save here in case some error happens mid-batch
|
# save here in case some error happens mid-batch
|
||||||
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate)
|
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, tts.output_sample_rate)
|
||||||
|
|
||||||
del gen
|
del gen
|
||||||
do_gc()
|
do_gc()
|
||||||
|
@ -712,7 +1087,7 @@ def generate_tortoise(**kwargs):
|
||||||
audio = volume_adjust(audio)
|
audio = volume_adjust(audio)
|
||||||
|
|
||||||
audio_cache[k]['audio'] = audio
|
audio_cache[k]['audio'] = audio
|
||||||
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
|
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{k}.wav', audio, args.output_sample_rate)
|
||||||
|
|
||||||
output_voices = []
|
output_voices = []
|
||||||
for candidate in range(parameters['candidates']):
|
for candidate in range(parameters['candidates']):
|
||||||
|
@ -725,7 +1100,7 @@ def generate_tortoise(**kwargs):
|
||||||
|
|
||||||
name = get_name(candidate=candidate, combined=True)
|
name = get_name(candidate=candidate, combined=True)
|
||||||
audio = torch.cat(audio_clips, dim=-1)
|
audio = torch.cat(audio_clips, dim=-1)
|
||||||
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
|
torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', audio, args.output_sample_rate)
|
||||||
|
|
||||||
audio = audio.squeeze(0).cpu()
|
audio = audio.squeeze(0).cpu()
|
||||||
audio_cache[name] = {
|
audio_cache[name] = {
|
||||||
|
@ -750,8 +1125,8 @@ def generate_tortoise(**kwargs):
|
||||||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
path = f'{outdir}/{voice}_{name}.wav'
|
path = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
|
||||||
fixed = f'{outdir}/{voice}_{name}_fixed.wav'
|
fixed = f'{outdir}/{cleanup_voice_name(voice)}_{name}_fixed.wav'
|
||||||
voicefixer.restore(
|
voicefixer.restore(
|
||||||
input=path,
|
input=path,
|
||||||
output=fixed,
|
output=fixed,
|
||||||
|
@ -775,13 +1150,13 @@ def generate_tortoise(**kwargs):
|
||||||
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
if 'output' not in audio_cache[name] or not audio_cache[name]['output']:
|
||||||
if args.prune_nonfinal_outputs:
|
if args.prune_nonfinal_outputs:
|
||||||
audio_cache[name]['pruned'] = True
|
audio_cache[name]['pruned'] = True
|
||||||
os.remove(f'{outdir}/{voice}_{name}.wav')
|
os.remove(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
output_voices.append(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||||
|
|
||||||
if not args.embed_output_metadata:
|
if not args.embed_output_metadata:
|
||||||
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
with open(f'{outdir}/{cleanup_voice_name(voice)}_{name}.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
|
f.write(json.dumps(audio_cache[name]['settings'], indent='\t') )
|
||||||
|
|
||||||
if args.embed_output_metadata:
|
if args.embed_output_metadata:
|
||||||
|
@ -789,7 +1164,7 @@ def generate_tortoise(**kwargs):
|
||||||
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
if 'pruned' in audio_cache[name] and audio_cache[name]['pruned']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
metadata = music_tag.load_file(f"{outdir}/{voice}_{name}.wav")
|
metadata = music_tag.load_file(f"{outdir}/{cleanup_voice_name(voice)}_{name}.wav")
|
||||||
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
|
metadata['lyrics'] = json.dumps(audio_cache[name]['settings'])
|
||||||
metadata.save()
|
metadata.save()
|
||||||
|
|
||||||
|
@ -1096,9 +1471,9 @@ class TrainingState():
|
||||||
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
|
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
|
||||||
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
|
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
|
||||||
|
|
||||||
# 'ar.loss.nll', 'nar.loss.nll',
|
'ar.loss.nll', 'nar.loss.nll',
|
||||||
# 'ar-half.loss.nll', 'nar-half.loss.nll',
|
'ar-half.loss.nll', 'nar-half.loss.nll',
|
||||||
# 'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
|
'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
|
||||||
]
|
]
|
||||||
|
|
||||||
keys['accuracies'] = [
|
keys['accuracies'] = [
|
||||||
|
@ -1464,14 +1839,14 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
|
||||||
return_code = training_state.process.wait()
|
return_code = training_state.process.wait()
|
||||||
training_state = None
|
training_state = None
|
||||||
|
|
||||||
def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
|
def update_training_dataplot(x_min=None, x_max=None, y_min=None, y_max=None, config_path=None):
|
||||||
global training_state
|
global training_state
|
||||||
losses = None
|
losses = None
|
||||||
lrs = None
|
lrs = None
|
||||||
grad_norms = None
|
grad_norms = None
|
||||||
|
|
||||||
x_lim = [ 0, x_lim ]
|
x_lim = [ x_min, x_max ]
|
||||||
y_lim = [ 0, y_lim ]
|
y_lim = [ y_min, y_max ]
|
||||||
|
|
||||||
if not training_state:
|
if not training_state:
|
||||||
if config_path:
|
if config_path:
|
||||||
|
@ -1490,23 +1865,23 @@ def update_training_dataplot(x_lim=None, y_lim=None, config_path=None):
|
||||||
losses = gr.LinePlot.update(
|
losses = gr.LinePlot.update(
|
||||||
value = pd.DataFrame(training_state.statistics['loss']),
|
value = pd.DataFrame(training_state.statistics['loss']),
|
||||||
x_lim=x_lim, y_lim=y_lim,
|
x_lim=x_lim, y_lim=y_lim,
|
||||||
x="epoch", y="value",
|
x="it", y="value", # x="epoch",
|
||||||
title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
title="Loss Metrics", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||||
width=500, height=350
|
width=500, height=350
|
||||||
)
|
)
|
||||||
if len(training_state.statistics['lr']) > 0:
|
if len(training_state.statistics['lr']) > 0:
|
||||||
lrs = gr.LinePlot.update(
|
lrs = gr.LinePlot.update(
|
||||||
value = pd.DataFrame(training_state.statistics['lr']),
|
value = pd.DataFrame(training_state.statistics['lr']),
|
||||||
x_lim=x_lim, y_lim=y_lim,
|
x_lim=x_lim,
|
||||||
x="epoch", y="value",
|
x="it", y="value", # x="epoch",
|
||||||
title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
title="Learning Rate", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||||
width=500, height=350
|
width=500, height=350
|
||||||
)
|
)
|
||||||
if len(training_state.statistics['grad_norm']) > 0:
|
if len(training_state.statistics['grad_norm']) > 0:
|
||||||
grad_norms = gr.LinePlot.update(
|
grad_norms = gr.LinePlot.update(
|
||||||
value = pd.DataFrame(training_state.statistics['grad_norm']),
|
value = pd.DataFrame(training_state.statistics['grad_norm']),
|
||||||
x_lim=x_lim, y_lim=y_lim,
|
x_lim=x_lim,
|
||||||
x="epoch", y="value",
|
x="it", y="value", # x="epoch",
|
||||||
title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
title="Gradient Normals", color="type", tooltip=['epoch', 'it', 'value', 'type'],
|
||||||
width=500, height=350
|
width=500, height=350
|
||||||
)
|
)
|
||||||
|
@ -1649,13 +2024,13 @@ def whisper_transcribe( file, language=None ):
|
||||||
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
||||||
if whisper_vad:
|
if whisper_vad:
|
||||||
# omits a considerable amount of the end
|
# omits a considerable amount of the end
|
||||||
"""
|
|
||||||
if args.whisper_batchsize > 1:
|
if args.whisper_batchsize > 1:
|
||||||
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
||||||
else:
|
else:
|
||||||
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
||||||
"""
|
"""
|
||||||
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
||||||
|
"""
|
||||||
else:
|
else:
|
||||||
result = whisper_model.transcribe(file)
|
result = whisper_model.transcribe(file)
|
||||||
|
|
||||||
|
@ -1717,7 +2092,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
os.makedirs(f'{indir}/audio/', exist_ok=True)
|
os.makedirs(f'{indir}/audio/', exist_ok=True)
|
||||||
|
|
||||||
TARGET_SAMPLE_RATE = 22050
|
TARGET_SAMPLE_RATE = 22050
|
||||||
if args.tts_backend == "vall-e":
|
if args.tts_backend != "tortoise":
|
||||||
TARGET_SAMPLE_RATE = 24000
|
TARGET_SAMPLE_RATE = 24000
|
||||||
if tts:
|
if tts:
|
||||||
TARGET_SAMPLE_RATE = tts.input_sample_rate
|
TARGET_SAMPLE_RATE = tts.input_sample_rate
|
||||||
|
@ -1735,7 +2110,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
try:
|
try:
|
||||||
result = whisper_transcribe(file, language=language)
|
result = whisper_transcribe(file, language=language)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Failed to transcribe:", file)
|
print("Failed to transcribe:", file, e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
results[basename] = result
|
results[basename] = result
|
||||||
|
@ -1802,7 +2177,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
|
||||||
results = json.load(open(infile, 'r', encoding="utf-8"))
|
results = json.load(open(infile, 'r', encoding="utf-8"))
|
||||||
|
|
||||||
TARGET_SAMPLE_RATE = 22050
|
TARGET_SAMPLE_RATE = 22050
|
||||||
if args.tts_backend == "vall-e":
|
if args.tts_backend != "tortoise":
|
||||||
TARGET_SAMPLE_RATE = 24000
|
TARGET_SAMPLE_RATE = 24000
|
||||||
if tts:
|
if tts:
|
||||||
TARGET_SAMPLE_RATE = tts.input_sample_rate
|
TARGET_SAMPLE_RATE = tts.input_sample_rate
|
||||||
|
@ -1934,8 +2309,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
lines = { 'training': [], 'validation': [] }
|
lines = { 'training': [], 'validation': [] }
|
||||||
segments = {}
|
segments = {}
|
||||||
|
|
||||||
# I'm not sure how the VALL-E implementation decides what's validation and what's not
|
if args.tts_backend != "tortoise":
|
||||||
if args.tts_backend == "vall-e":
|
|
||||||
text_length = 0
|
text_length = 0
|
||||||
audio_length = 0
|
audio_length = 0
|
||||||
|
|
||||||
|
@ -3008,6 +3382,10 @@ def load_tts( restart=False,
|
||||||
|
|
||||||
print(f"Loading VALL-E... (Config: {valle_model})")
|
print(f"Loading VALL-E... (Config: {valle_model})")
|
||||||
tts = VALLE_TTS(config=args.valle_model)
|
tts = VALLE_TTS(config=args.valle_model)
|
||||||
|
elif args.tts_backend == "bark":
|
||||||
|
|
||||||
|
print(f"Loading Bark...")
|
||||||
|
tts = Bark_TTS(small=args.low_vram)
|
||||||
|
|
||||||
print("Loaded TTS, ready for generation.")
|
print("Loaded TTS, ready for generation.")
|
||||||
tts_loading = False
|
tts_loading = False
|
||||||
|
|
40
src/webui.py
40
src/webui.py
|
@ -167,6 +167,10 @@ def reset_generate_settings_proxy():
|
||||||
return tuple(res)
|
return tuple(res)
|
||||||
|
|
||||||
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||||
|
if args.tts_backend == "bark":
|
||||||
|
global tts
|
||||||
|
tts.create_voice( voice )
|
||||||
|
return voice
|
||||||
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
|
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
|
||||||
return voice
|
return voice
|
||||||
|
|
||||||
|
@ -222,13 +226,13 @@ def prepare_all_datasets( language, validation_text_length, validation_audio_len
|
||||||
print("Processing:", voice)
|
print("Processing:", voice)
|
||||||
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
|
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
"""
|
|
||||||
|
|
||||||
if slice_audio:
|
if slice_audio:
|
||||||
for voice in voices:
|
for voice in voices:
|
||||||
print("Processing:", voice)
|
print("Processing:", voice)
|
||||||
message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
|
message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
"""
|
||||||
|
|
||||||
for voice in voices:
|
for voice in voices:
|
||||||
print("Processing:", voice)
|
print("Processing:", voice)
|
||||||
|
@ -400,12 +404,13 @@ def setup_gradio():
|
||||||
outputs=GENERATE_SETTINGS["mic_audio"],
|
outputs=GENERATE_SETTINGS["mic_audio"],
|
||||||
)
|
)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
preset = None
|
||||||
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise")
|
GENERATE_SETTINGS["candidates"] = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates", visible=args.tts_backend=="tortoise")
|
||||||
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed")
|
GENERATE_SETTINGS["seed"] = gr.Number(value=0, precision=0, label="Seed", visible=args.tts_backend!="tortoise")
|
||||||
|
|
||||||
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast" )
|
preset = gr.Radio( ["Ultra Fast", "Fast", "Standard", "High Quality"], label="Preset", type="value", value="Ultra Fast", visible=args.tts_backend=="tortoise" )
|
||||||
|
|
||||||
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples")
|
GENERATE_SETTINGS["num_autoregressive_samples"] = gr.Slider(value=16, minimum=2, maximum=512, step=1, label="Samples", visible=args.tts_backend!="bark")
|
||||||
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise")
|
GENERATE_SETTINGS["diffusion_iterations"] = gr.Slider(value=30, minimum=0, maximum=512, step=1, label="Iterations", visible=args.tts_backend=="tortoise")
|
||||||
|
|
||||||
GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
|
GENERATE_SETTINGS["temperature"] = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
|
||||||
|
@ -490,7 +495,7 @@ def setup_gradio():
|
||||||
merger_button = gr.Button(value="Run Merger")
|
merger_button = gr.Button(value="Run Merger")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
merger_output = gr.TextArea(label="Console Output", max_lines=8)
|
merger_output = gr.TextArea(label="Console Output", max_lines=8)
|
||||||
with gr.Tab("Training"):
|
with gr.Tab("Training", visible=args.tts_backend != "bark"):
|
||||||
with gr.Tab("Prepare Dataset"):
|
with gr.Tab("Prepare Dataset"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
@ -586,8 +591,10 @@ def setup_gradio():
|
||||||
keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
training_graph_x_lim = gr.Number(label="X Limit", precision=0, value=0)
|
training_graph_x_min = gr.Number(label="X Min", precision=0, value=0)
|
||||||
training_graph_y_lim = gr.Number(label="Y Limit", precision=0, value=0)
|
training_graph_x_max = gr.Number(label="X Max", precision=0, value=0)
|
||||||
|
training_graph_y_min = gr.Number(label="Y Min", precision=0, value=0)
|
||||||
|
training_graph_y_max = gr.Number(label="Y Max", precision=0, value=0)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_training_button = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
|
@ -597,7 +604,7 @@ def setup_gradio():
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
||||||
x="epoch",
|
x="it", # x="epoch",
|
||||||
y="value",
|
y="value",
|
||||||
title="Loss Metrics",
|
title="Loss Metrics",
|
||||||
color="type",
|
color="type",
|
||||||
|
@ -606,7 +613,7 @@ def setup_gradio():
|
||||||
height=350,
|
height=350,
|
||||||
)
|
)
|
||||||
training_lr_graph = gr.LinePlot(label="Training Metrics",
|
training_lr_graph = gr.LinePlot(label="Training Metrics",
|
||||||
x="epoch",
|
x="it", # x="epoch",
|
||||||
y="value",
|
y="value",
|
||||||
title="Learning Rate",
|
title="Learning Rate",
|
||||||
color="type",
|
color="type",
|
||||||
|
@ -615,7 +622,7 @@ def setup_gradio():
|
||||||
height=350,
|
height=350,
|
||||||
)
|
)
|
||||||
training_grad_norm_graph = gr.LinePlot(label="Training Metrics",
|
training_grad_norm_graph = gr.LinePlot(label="Training Metrics",
|
||||||
x="epoch",
|
x="it", # x="epoch",
|
||||||
y="value",
|
y="value",
|
||||||
title="Gradient Normals",
|
title="Gradient Normals",
|
||||||
color="type",
|
color="type",
|
||||||
|
@ -765,6 +772,7 @@ def setup_gradio():
|
||||||
inputs=show_experimental_settings,
|
inputs=show_experimental_settings,
|
||||||
outputs=experimental_column
|
outputs=experimental_column
|
||||||
)
|
)
|
||||||
|
if preset:
|
||||||
preset.change(fn=update_presets,
|
preset.change(fn=update_presets,
|
||||||
inputs=preset,
|
inputs=preset,
|
||||||
outputs=[
|
outputs=[
|
||||||
|
@ -860,8 +868,10 @@ def setup_gradio():
|
||||||
training_output.change(
|
training_output.change(
|
||||||
fn=update_training_dataplot,
|
fn=update_training_dataplot,
|
||||||
inputs=[
|
inputs=[
|
||||||
training_graph_x_lim,
|
training_graph_x_min,
|
||||||
training_graph_y_lim,
|
training_graph_x_max,
|
||||||
|
training_graph_y_min,
|
||||||
|
training_graph_y_max,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
training_loss_graph,
|
training_loss_graph,
|
||||||
|
@ -874,8 +884,10 @@ def setup_gradio():
|
||||||
view_losses.click(
|
view_losses.click(
|
||||||
fn=update_training_dataplot,
|
fn=update_training_dataplot,
|
||||||
inputs=[
|
inputs=[
|
||||||
training_graph_x_lim,
|
training_graph_x_min,
|
||||||
training_graph_y_lim,
|
training_graph_x_max,
|
||||||
|
training_graph_y_min,
|
||||||
|
training_graph_y_max,
|
||||||
training_configs,
|
training_configs,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
|
Loading…
Reference in New Issue
Block a user