cleanup
This commit is contained in:
parent
7c71f7239c
commit
cb273b8428
47
src/utils.py
47
src/utils.py
|
@ -1274,6 +1274,15 @@ def optimize_training_settings( **kwargs ):
|
||||||
if vram > (k-1):
|
if vram > (k-1):
|
||||||
return v
|
return v
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
if settings['gpus'] > get_device_count():
|
||||||
|
settings['gpus'] = get_device_count()
|
||||||
|
messages.append(f"GPU count exceeds defacto GPU count, clamping to: {settings['gpus']}")
|
||||||
|
|
||||||
|
if settings['gpus'] <= 1:
|
||||||
|
settings['gpus'] = 1
|
||||||
|
else:
|
||||||
|
messages.append(f"! EXPERIMENTAL ! Multi-GPU training is extremely particular, expect issues.")
|
||||||
|
|
||||||
# assuming you have equal GPUs
|
# assuming you have equal GPUs
|
||||||
vram = get_device_vram() * settings['gpus']
|
vram = get_device_vram() * settings['gpus']
|
||||||
|
@ -1303,14 +1312,14 @@ def optimize_training_settings( **kwargs ):
|
||||||
messages.append("Resume path specified, but does not exist. Disabling...")
|
messages.append("Resume path specified, but does not exist. Disabling...")
|
||||||
|
|
||||||
if settings['bitsandbytes']:
|
if settings['bitsandbytes']:
|
||||||
messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !")
|
messages.append("! EXPERIMENTAL ! BitsAndBytes requested.")
|
||||||
|
|
||||||
if settings['half_p']:
|
if settings['half_p']:
|
||||||
if settings['bitsandbytes']:
|
if settings['bitsandbytes']:
|
||||||
settings['half_p'] = False
|
settings['half_p'] = False
|
||||||
messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
|
messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
|
||||||
else:
|
else:
|
||||||
messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !")
|
messages.append("! EXPERIMENTAL ! Half Precision requested.")
|
||||||
if not os.path.exists(get_halfp_model_path()):
|
if not os.path.exists(get_halfp_model_path()):
|
||||||
convert_to_halfp()
|
convert_to_halfp()
|
||||||
|
|
||||||
|
@ -1343,12 +1352,21 @@ def save_training_settings( **kwargs ):
|
||||||
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
||||||
messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
|
messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
|
||||||
|
|
||||||
iterations_per_epoch = int(settings['iterations'] / settings['epochs'])
|
iterations_per_epoch = settings['iterations'] / settings['epochs']
|
||||||
|
|
||||||
settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch)
|
settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch)
|
||||||
settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
|
settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
|
||||||
settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)
|
settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)
|
||||||
|
|
||||||
|
iterations_per_epoch = int(iterations_per_epoch)
|
||||||
|
|
||||||
|
if settings['print_rate'] < 1:
|
||||||
|
settings['print_rate'] = 1
|
||||||
|
if settings['save_rate'] < 1:
|
||||||
|
settings['save_rate'] = 1
|
||||||
|
if settings['validation_rate'] < 1:
|
||||||
|
settings['validation_rate'] = 1
|
||||||
|
|
||||||
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
|
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
|
||||||
|
|
||||||
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
||||||
|
@ -1809,9 +1827,7 @@ def save_args_settings():
|
||||||
|
|
||||||
# super kludgy )`;
|
# super kludgy )`;
|
||||||
def import_generate_settings(file="./config/generate.json"):
|
def import_generate_settings(file="./config/generate.json"):
|
||||||
global GENERATE_SETTINGS_ARGS
|
res = {
|
||||||
|
|
||||||
defaults = {
|
|
||||||
'text': None,
|
'text': None,
|
||||||
'delimiter': None,
|
'delimiter': None,
|
||||||
'emotion': None,
|
'emotion': None,
|
||||||
|
@ -1836,19 +1852,9 @@ def import_generate_settings(file="./config/generate.json"):
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, _ = read_generate_settings(file, read_latents=False)
|
settings, _ = read_generate_settings(file, read_latents=False)
|
||||||
|
if settings is not None:
|
||||||
res = []
|
res.update(settings)
|
||||||
if GENERATE_SETTINGS_ARGS is not None:
|
return res
|
||||||
for k in GENERATE_SETTINGS_ARGS:
|
|
||||||
if k not in defaults:
|
|
||||||
continue
|
|
||||||
res.append(defaults[k] if not settings or k not in settings or not settings[k] is None else settings[k])
|
|
||||||
else:
|
|
||||||
for k in defaults:
|
|
||||||
res.append(defaults[k] if not settings or k not in settings or not settings[k] is None else settings[k])
|
|
||||||
|
|
||||||
return tuple(res)
|
|
||||||
|
|
||||||
|
|
||||||
def reset_generation_settings():
|
def reset_generation_settings():
|
||||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||||
|
@ -1978,7 +1984,8 @@ def deduce_autoregressive_model(voice=None):
|
||||||
if os.path.isdir(dir):
|
if os.path.isdir(dir):
|
||||||
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
|
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
|
||||||
names = [ f'{dir}/{d}_gpt.pth' for d in counts ]
|
names = [ f'{dir}/{d}_gpt.pth' for d in counts ]
|
||||||
return names[-1]
|
if len(names) > 0:
|
||||||
|
return names[-1]
|
||||||
|
|
||||||
if args.autoregressive_model != "auto":
|
if args.autoregressive_model != "auto":
|
||||||
return args.autoregressive_model
|
return args.autoregressive_model
|
||||||
|
|
23
src/webui.py
23
src/webui.py
|
@ -27,6 +27,7 @@ GENERATE_SETTINGS = {}
|
||||||
TRANSCRIBE_SETTINGS = {}
|
TRANSCRIBE_SETTINGS = {}
|
||||||
EXEC_SETTINGS = {}
|
EXEC_SETTINGS = {}
|
||||||
TRAINING_SETTINGS = {}
|
TRAINING_SETTINGS = {}
|
||||||
|
GENERATE_SETTINGS_ARGS = []
|
||||||
|
|
||||||
PRESETS = {
|
PRESETS = {
|
||||||
'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False},
|
'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False},
|
||||||
|
@ -144,6 +145,18 @@ def history_view_results( voice ):
|
||||||
gr.Dropdown.update(choices=sorted(files))
|
gr.Dropdown.update(choices=sorted(files))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def import_generate_settings_proxy( file=None ):
|
||||||
|
global GENERATE_SETTINGS_ARGS
|
||||||
|
settings = import_generate_settings( file )
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for k in GENERATE_SETTINGS_ARGS:
|
||||||
|
res.append(settings[k] if k in settings else None)
|
||||||
|
print(GENERATE_SETTINGS_ARGS)
|
||||||
|
print(settings)
|
||||||
|
print(res)
|
||||||
|
return tuple(res)
|
||||||
|
|
||||||
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||||
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
|
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
|
||||||
return voice
|
return voice
|
||||||
|
@ -221,7 +234,10 @@ def import_training_settings_proxy( voice ):
|
||||||
if k not in settings:
|
if k not in settings:
|
||||||
continue
|
continue
|
||||||
output[k] = settings[k]
|
output[k] = settings[k]
|
||||||
|
|
||||||
output = list(output.values())
|
output = list(output.values())
|
||||||
|
print(list(TRAINING_SETTINGS.keys()))
|
||||||
|
print(output)
|
||||||
messages.append(f"Imported training settings: {injson}")
|
messages.append(f"Imported training settings: {injson}")
|
||||||
|
|
||||||
return output[:-1] + ["\n".join(messages)]
|
return output[:-1] + ["\n".join(messages)]
|
||||||
|
@ -250,7 +266,7 @@ def history_copy_settings( voice, file ):
|
||||||
def setup_gradio():
|
def setup_gradio():
|
||||||
global args
|
global args
|
||||||
global ui
|
global ui
|
||||||
|
|
||||||
if not args.share:
|
if not args.share:
|
||||||
def noop(function, return_value=None):
|
def noop(function, return_value=None):
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
|
@ -273,6 +289,7 @@ def setup_gradio():
|
||||||
autoregressive_models = get_autoregressive_models()
|
autoregressive_models = get_autoregressive_models()
|
||||||
dataset_list = get_dataset_list()
|
dataset_list = get_dataset_list()
|
||||||
|
|
||||||
|
global GENERATE_SETTINGS_ARGS
|
||||||
GENERATE_SETTINGS_ARGS = list(inspect.signature(generate_proxy).parameters.keys())[:-1]
|
GENERATE_SETTINGS_ARGS = list(inspect.signature(generate_proxy).parameters.keys())[:-1]
|
||||||
for i in range(len(GENERATE_SETTINGS_ARGS)):
|
for i in range(len(GENERATE_SETTINGS_ARGS)):
|
||||||
arg = GENERATE_SETTINGS_ARGS[i]
|
arg = GENERATE_SETTINGS_ARGS[i]
|
||||||
|
@ -639,7 +656,7 @@ def setup_gradio():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
copy_button.click(import_generate_settings,
|
copy_button.click(import_generate_settings_proxy,
|
||||||
inputs=audio_in, # JSON elements cannot be used as inputs
|
inputs=audio_in, # JSON elements cannot be used as inputs
|
||||||
outputs=generate_settings
|
outputs=generate_settings
|
||||||
)
|
)
|
||||||
|
@ -738,7 +755,7 @@ def setup_gradio():
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.isfile('./config/generate.json'):
|
if os.path.isfile('./config/generate.json'):
|
||||||
ui.load(import_generate_settings, inputs=None, outputs=generate_settings)
|
ui.load(import_generate_settings_proxy, inputs=None, outputs=generate_settings)
|
||||||
|
|
||||||
if args.check_for_updates:
|
if args.check_for_updates:
|
||||||
ui.load(check_for_updates)
|
ui.load(check_for_updates)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user