cleanup
This commit is contained in:
parent
7c71f7239c
commit
cb273b8428
47
src/utils.py
47
src/utils.py
|
@ -1275,6 +1275,15 @@ def optimize_training_settings( **kwargs ):
|
|||
return v
|
||||
return 1
|
||||
|
||||
if settings['gpus'] > get_device_count():
|
||||
settings['gpus'] = get_device_count()
|
||||
messages.append(f"GPU count exceeds defacto GPU count, clamping to: {settings['gpus']}")
|
||||
|
||||
if settings['gpus'] <= 1:
|
||||
settings['gpus'] = 1
|
||||
else:
|
||||
messages.append(f"! EXPERIMENTAL ! Multi-GPU training is extremely particular, expect issues.")
|
||||
|
||||
# assuming you have equal GPUs
|
||||
vram = get_device_vram() * settings['gpus']
|
||||
batch_ratio = int(settings['batch_size'] / settings['gradient_accumulation_size'])
|
||||
|
@ -1303,14 +1312,14 @@ def optimize_training_settings( **kwargs ):
|
|||
messages.append("Resume path specified, but does not exist. Disabling...")
|
||||
|
||||
if settings['bitsandbytes']:
|
||||
messages.append("BitsAndBytes requested. Please note this is ! EXPERIMENTAL !")
|
||||
messages.append("! EXPERIMENTAL ! BitsAndBytes requested.")
|
||||
|
||||
if settings['half_p']:
|
||||
if settings['bitsandbytes']:
|
||||
settings['half_p'] = False
|
||||
messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
|
||||
else:
|
||||
messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !")
|
||||
messages.append("! EXPERIMENTAL ! Half Precision requested.")
|
||||
if not os.path.exists(get_halfp_model_path()):
|
||||
convert_to_halfp()
|
||||
|
||||
|
@ -1343,12 +1352,21 @@ def save_training_settings( **kwargs ):
|
|||
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
||||
messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
|
||||
|
||||
iterations_per_epoch = int(settings['iterations'] / settings['epochs'])
|
||||
iterations_per_epoch = settings['iterations'] / settings['epochs']
|
||||
|
||||
settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch)
|
||||
settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
|
||||
settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)
|
||||
|
||||
iterations_per_epoch = int(iterations_per_epoch)
|
||||
|
||||
if settings['print_rate'] < 1:
|
||||
settings['print_rate'] = 1
|
||||
if settings['save_rate'] < 1:
|
||||
settings['save_rate'] = 1
|
||||
if settings['validation_rate'] < 1:
|
||||
settings['validation_rate'] = 1
|
||||
|
||||
settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
|
||||
|
||||
settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
||||
|
@ -1809,9 +1827,7 @@ def save_args_settings():
|
|||
|
||||
# super kludgy )`;
|
||||
def import_generate_settings(file="./config/generate.json"):
|
||||
global GENERATE_SETTINGS_ARGS
|
||||
|
||||
defaults = {
|
||||
res = {
|
||||
'text': None,
|
||||
'delimiter': None,
|
||||
'emotion': None,
|
||||
|
@ -1836,19 +1852,9 @@ def import_generate_settings(file="./config/generate.json"):
|
|||
}
|
||||
|
||||
settings, _ = read_generate_settings(file, read_latents=False)
|
||||
|
||||
res = []
|
||||
if GENERATE_SETTINGS_ARGS is not None:
|
||||
for k in GENERATE_SETTINGS_ARGS:
|
||||
if k not in defaults:
|
||||
continue
|
||||
res.append(defaults[k] if not settings or k not in settings or not settings[k] is None else settings[k])
|
||||
else:
|
||||
for k in defaults:
|
||||
res.append(defaults[k] if not settings or k not in settings or not settings[k] is None else settings[k])
|
||||
|
||||
return tuple(res)
|
||||
|
||||
if settings is not None:
|
||||
res.update(settings)
|
||||
return res
|
||||
|
||||
def reset_generation_settings():
|
||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||
|
@ -1978,7 +1984,8 @@ def deduce_autoregressive_model(voice=None):
|
|||
if os.path.isdir(dir):
|
||||
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
|
||||
names = [ f'{dir}/{d}_gpt.pth' for d in counts ]
|
||||
return names[-1]
|
||||
if len(names) > 0:
|
||||
return names[-1]
|
||||
|
||||
if args.autoregressive_model != "auto":
|
||||
return args.autoregressive_model
|
||||
|
|
21
src/webui.py
21
src/webui.py
|
@ -27,6 +27,7 @@ GENERATE_SETTINGS = {}
|
|||
TRANSCRIBE_SETTINGS = {}
|
||||
EXEC_SETTINGS = {}
|
||||
TRAINING_SETTINGS = {}
|
||||
GENERATE_SETTINGS_ARGS = []
|
||||
|
||||
PRESETS = {
|
||||
'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False},
|
||||
|
@ -144,6 +145,18 @@ def history_view_results( voice ):
|
|||
gr.Dropdown.update(choices=sorted(files))
|
||||
)
|
||||
|
||||
def import_generate_settings_proxy( file=None ):
|
||||
global GENERATE_SETTINGS_ARGS
|
||||
settings = import_generate_settings( file )
|
||||
|
||||
res = []
|
||||
for k in GENERATE_SETTINGS_ARGS:
|
||||
res.append(settings[k] if k in settings else None)
|
||||
print(GENERATE_SETTINGS_ARGS)
|
||||
print(settings)
|
||||
print(res)
|
||||
return tuple(res)
|
||||
|
||||
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
|
||||
return voice
|
||||
|
@ -221,7 +234,10 @@ def import_training_settings_proxy( voice ):
|
|||
if k not in settings:
|
||||
continue
|
||||
output[k] = settings[k]
|
||||
|
||||
output = list(output.values())
|
||||
print(list(TRAINING_SETTINGS.keys()))
|
||||
print(output)
|
||||
messages.append(f"Imported training settings: {injson}")
|
||||
|
||||
return output[:-1] + ["\n".join(messages)]
|
||||
|
@ -273,6 +289,7 @@ def setup_gradio():
|
|||
autoregressive_models = get_autoregressive_models()
|
||||
dataset_list = get_dataset_list()
|
||||
|
||||
global GENERATE_SETTINGS_ARGS
|
||||
GENERATE_SETTINGS_ARGS = list(inspect.signature(generate_proxy).parameters.keys())[:-1]
|
||||
for i in range(len(GENERATE_SETTINGS_ARGS)):
|
||||
arg = GENERATE_SETTINGS_ARGS[i]
|
||||
|
@ -639,7 +656,7 @@ def setup_gradio():
|
|||
)
|
||||
|
||||
|
||||
copy_button.click(import_generate_settings,
|
||||
copy_button.click(import_generate_settings_proxy,
|
||||
inputs=audio_in, # JSON elements cannot be used as inputs
|
||||
outputs=generate_settings
|
||||
)
|
||||
|
@ -738,7 +755,7 @@ def setup_gradio():
|
|||
)
|
||||
|
||||
if os.path.isfile('./config/generate.json'):
|
||||
ui.load(import_generate_settings, inputs=None, outputs=generate_settings)
|
||||
ui.load(import_generate_settings_proxy, inputs=None, outputs=generate_settings)
|
||||
|
||||
if args.check_for_updates:
|
||||
ui.load(check_for_updates)
|
||||
|
|
Loading…
Reference in New Issue
Block a user