tweaks to try and get deepspeed quantized inferencing, validating bitsandbytes and deepspeed quantization, nothing seems to work
This commit is contained in:
parent
08bae355eb
commit
65f500083d
|
@ -322,6 +322,7 @@ class DeepSpeed:
|
|||
zero_optimization_level: int = 0
|
||||
use_compression_training: bool = False
|
||||
compression_bits: int = 8
|
||||
inferencing: bool = False
|
||||
|
||||
@cached_property
|
||||
def ds_cfg(self):
|
||||
|
@ -363,7 +364,7 @@ class DeepSpeed:
|
|||
"quantize_verbose": True,
|
||||
"quantization_type": "symmetric",
|
||||
"rounding": "nearest",
|
||||
"quantize_weight_in_forward": True,
|
||||
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
|
||||
"fp16_mixed_quantize":{
|
||||
"enabled": False,
|
||||
"quantize_change_ratio": 1
|
||||
|
@ -377,6 +378,35 @@ class DeepSpeed:
|
|||
"quantization_period": 0
|
||||
},
|
||||
"modules": [
|
||||
# "^.+?$"
|
||||
"blocks", # for transformer-based models
|
||||
"retnet", # for RetNets-based models
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"activation_quantization": {
|
||||
"shared_parameters":{
|
||||
"enabled": True,
|
||||
"quantizer_kernel": True,
|
||||
"schedule_offset": 0,
|
||||
"quantize_groups": 64,
|
||||
"quantize_verbose": True,
|
||||
"quantization_type": "symmetric",
|
||||
"rounding": "nearest",
|
||||
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
|
||||
"fp16_mixed_quantize":{
|
||||
"enabled": False,
|
||||
"quantize_change_ratio": 1
|
||||
}
|
||||
},
|
||||
"different_groups": {
|
||||
"aq1": {
|
||||
"params": {
|
||||
"bits": self.compression_bits,
|
||||
},
|
||||
"modules": [
|
||||
# "^.+?$"
|
||||
"blocks", # for transformer-based models
|
||||
"retnet", # for RetNets-based models
|
||||
]
|
||||
|
|
|
@ -382,7 +382,7 @@ class Dataset(_Dataset):
|
|||
resps = _load_quants(path)
|
||||
|
||||
spkr_group = self.get_speaker_group(path)
|
||||
lang = self.lang_symmap[ self.get_language(spkr_group) ]
|
||||
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
||||
|
||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||
if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
||||
|
|
|
@ -90,7 +90,7 @@ def load_engines():
|
|||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
# deepspeed inferencing
|
||||
if backend == "local" and inferencing and deepspeed_available: #and sys.platform.startswith("win"):
|
||||
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||
engine_class = _Engine
|
||||
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class TTS():
|
|||
|
||||
if amp is None:
|
||||
amp = cfg.inference.amp
|
||||
if dtype is None:
|
||||
if dtype is None or dtype == "auto":
|
||||
dtype = cfg.inference.weight_dtype
|
||||
if device is None:
|
||||
device = cfg.device
|
||||
|
@ -64,7 +64,7 @@ class TTS():
|
|||
|
||||
model.load_state_dict(state)
|
||||
|
||||
if deepspeed_available:
|
||||
if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing:
|
||||
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||
|
||||
return model
|
||||
|
@ -88,8 +88,9 @@ class TTS():
|
|||
else:
|
||||
self.load_models()
|
||||
|
||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
if self.dtype != torch.int8:
|
||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
|
||||
self.ar.eval()
|
||||
self.nar.eval()
|
||||
|
|
|
@ -23,7 +23,7 @@ def get_model(cfg):
|
|||
)
|
||||
model._cfg = cfg
|
||||
|
||||
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
@ -421,7 +421,7 @@ class Base(nn.Module):
|
|||
logits = [ logit[-1:] for logit in logits ]
|
||||
|
||||
devices = [ logit.device for logit in logits ]
|
||||
logits = [ logit.cpu() for logit in logits ]
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
||||
# perform repetition penalizing
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
|
|
@ -9,12 +9,13 @@ Linear = torch.nn.Linear
|
|||
|
||||
if cfg.bitsandbytes.enabled:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
|
||||
if cfg.bitsandbytes.linear:
|
||||
Linear = bnb.nn.Linear8bitLt
|
||||
|
||||
if cfg.bitsandbytes.embedding:
|
||||
Embedding = bnb.nn.StableEmbedding
|
||||
Embedding = bnb.nn.modules.Embedding
|
||||
"""
|
||||
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
||||
input,
|
||||
self.weight,
|
||||
|
@ -24,6 +25,7 @@ if cfg.bitsandbytes.enabled:
|
|||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)).to(self.weight.dtype) )
|
||||
"""
|
||||
|
||||
|
||||
if cfg.bitsandbytes.enabled:
|
||||
|
@ -62,11 +64,6 @@ def autocast_forward( func ):
|
|||
def wrapper( self, input, *args, **kwargs ):
|
||||
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
|
||||
return func( self, k, *args, **kwargs )
|
||||
"""
|
||||
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
|
||||
return func( self, input.to(torch.int32), *args, **kwargs )
|
||||
return func( self, input, *args, **kwargs )
|
||||
"""
|
||||
return wrapper
|
||||
Embedding.forward = autocast_forward(Embedding.forward)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user