bitsandbytes-rocm/bitsandbytes/__main__.py
2023-02-16 22:18:27 +01:00

78 lines
1.5 KiB
Python

import os
import sys
from warnings import warn
import torch
HEADER_WIDTH = 60
def print_header(
txt: str, width: int = HEADER_WIDTH, filler: str = "+"
) -> None:
txt = f" {txt} " if txt else ""
print(txt.center(width, filler))
def print_debug_info() -> None:
print(
"\nAbove we output some debug information. Please provide this info when "
f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n"
)
print_header("")
print_header("DEBUG INFORMATION")
print_header("")
print()
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
print_header("")
print_header("DEBUG INFO END")
print_header("")
print(
"""
Running a quick check that:
+ library is importable
+ CUDA function is callable
"""
)
try:
from bitsandbytes.optim import Adam
p = torch.nn.Parameter(torch.rand(10, 10).cuda())
a = torch.rand(10, 10).cuda()
p1 = p.data.sum().item()
adam = Adam([p])
out = a * p
loss = out.sum()
loss.backward()
adam.step()
p2 = p.data.sum().item()
assert p1 != p2
print("SUCCESS!")
print("Installation was successful!")
sys.exit(0)
except ImportError:
print()
warn(
f"WARNING: {__package__} is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!"
)
print_debug_info()
sys.exit(0)
except Exception as e:
print(e)
print_debug_info()
sys.exit(1)