diff --git a/javascript/ui.js b/javascript/ui.js index f94ed081..b1053201 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -218,10 +218,16 @@ function update_token_counter(button_id) { clearTimeout(token_timeout); token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); } + function submit_prompt(event, generate_button_id) { if (event.altKey && event.keyCode === 13) { event.preventDefault(); gradioApp().getElementById(generate_button_id).click(); return; } -} \ No newline at end of file +} + +function restart_reload(){ + document.body.innerHTML='

Reloading...

'; + setTimeout(function(){location.reload()},2000) +} diff --git a/modules/scripts.py b/modules/scripts.py index 7c3bd5e7..45230f9a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -162,6 +162,40 @@ class ScriptRunner: return processed + def reload_sources(self): + for si, script in list(enumerate(self.scripts)): + with open(script.filename, "r", encoding="utf8") as file: + args_from = script.args_from + args_to = script.args_to + filename = script.filename + text = file.read() + + from types import ModuleType + + compiled = compile(text, filename, 'exec') + module = ModuleType(script.filename) + exec(compiled, module.__dict__) + + for key, script_class in module.__dict__.items(): + if type(script_class) == type and issubclass(script_class, Script): + self.scripts[si] = script_class() + self.scripts[si].filename = filename + self.scripts[si].args_from = args_from + self.scripts[si].args_to = args_to scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + +def reload_script_body_only(): + scripts_txt2img.reload_sources() + scripts_img2img.reload_sources() + + +def reload_scripts(basedir): + global scripts_txt2img, scripts_img2img + + scripts_data.clear() + load_scripts(basedir) + + scripts_txt2img = ScriptRunner() + scripts_img2img = ScriptRunner() diff --git a/modules/ui.py b/modules/ui.py index c8f5bb84..78a15d83 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1145,6 +1145,31 @@ def create_ui(wrap_gradio_gpu_call): _js='function(){}' ) + with gr.Row(): + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') + + + def reload_scripts(): + modules.scripts.reload_script_body_only() + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[], + _js='function(){}' + ) + + def request_restart(): + settings_interface.gradio_ref.do_restart = True + + restart_gradio.click( + fn=request_restart, + inputs=[], + outputs=[], + _js='function(){restart_reload()}' + ) + if column is not None: column.__exit__() @@ -1170,7 +1195,9 @@ def create_ui(wrap_gradio_gpu_call): css += css_hide_progressbar with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - + + settings_interface.gradio_ref = demo + with gr.Tabs() as tabs: for interface, label, ifid in interfaces: with gr.TabItem(label, id=ifid): @@ -1350,12 +1377,12 @@ for filename in sorted(os.listdir(jsdir)): javascript += f"\n" -def template_response(*args, **kwargs): - res = gradio_routes_templates_response(*args, **kwargs) - res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res +if 'gradio_routes_templates_response' not in globals(): + def template_response(*args, **kwargs): + res = gradio_routes_templates_response(*args, **kwargs) + res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res - -gradio_routes_templates_response = gradio.routes.templates.TemplateResponse -gradio.routes.templates.TemplateResponse = template_response + gradio_routes_templates_response = gradio.routes.templates.TemplateResponse + gradio.routes.templates.TemplateResponse = template_response diff --git a/webui.py b/webui.py index dc72ceb8..63495697 100644 --- a/webui.py +++ b/webui.py @@ -1,4 +1,9 @@ import os +import threading +import time +import importlib +from modules import devices +from modules.paths import script_path import signal import threading @@ -82,16 +87,34 @@ def webui(): signal.signal(signal.SIGINT, sigint_handler) - demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + while 1: + + demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + + demo.launch( + share=cmd_opts.share, + server_name="0.0.0.0" if cmd_opts.listen else None, + server_port=cmd_opts.port, + debug=cmd_opts.gradio_debug, + auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, + inbrowser=cmd_opts.autolaunch, + prevent_thread_lock=True + ) + + while 1: + time.sleep(0.5) + if getattr(demo,'do_restart',False): + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + + print('Reloading Custom Scripts') + modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) + print('Reloading modules: modules.ui') + importlib.reload(modules.ui) + print('Restarting Gradio') - demo.launch( - share=cmd_opts.share, - server_name="0.0.0.0" if cmd_opts.listen else None, - server_port=cmd_opts.port, - debug=cmd_opts.gradio_debug, - auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, - inbrowser=cmd_opts.autolaunch, - ) if __name__ == "__main__":