diff --git a/webui.sh b/webui.sh index aa4f875c..ff410e15 100755 --- a/webui.sh +++ b/webui.sh @@ -109,6 +109,10 @@ if echo "$gpu_info" | grep -q "Navi" then export HSA_OVERRIDE_GFX_VERSION=10.3.0 fi +if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] +then + export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" +fi for preq in "${GIT}" "${python_cmd}" do @@ -170,10 +174,6 @@ then else printf "\n%s\n" "${delimiter}" printf "Launching launch.py..." - printf "\n%s\n" "${delimiter}" - if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] - then - export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" - fi + printf "\n%s\n" "${delimiter}" exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" fi