tweaks to try and get deepspeed quantized inferencing, validating bitsandbytes and deepspeed quantization, nothing seems to work

This commit is contained in:
mrq 2023-10-12 22:21:43 -05:00
parent 08bae355eb
commit 65f500083d
7 changed files with 44 additions and 16 deletions

View File

@ -322,6 +322,7 @@ class DeepSpeed:
zero_optimization_level: int = 0 zero_optimization_level: int = 0
use_compression_training: bool = False use_compression_training: bool = False
compression_bits: int = 8 compression_bits: int = 8
inferencing: bool = False
@cached_property @cached_property
def ds_cfg(self): def ds_cfg(self):
@ -363,7 +364,7 @@ class DeepSpeed:
"quantize_verbose": True, "quantize_verbose": True,
"quantization_type": "symmetric", "quantization_type": "symmetric",
"rounding": "nearest", "rounding": "nearest",
"quantize_weight_in_forward": True, "quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
"fp16_mixed_quantize":{ "fp16_mixed_quantize":{
"enabled": False, "enabled": False,
"quantize_change_ratio": 1 "quantize_change_ratio": 1
@ -377,6 +378,35 @@ class DeepSpeed:
"quantization_period": 0 "quantization_period": 0
}, },
"modules": [ "modules": [
# "^.+?$"
"blocks", # for transformer-based models
"retnet", # for RetNets-based models
]
}
}
},
"activation_quantization": {
"shared_parameters":{
"enabled": True,
"quantizer_kernel": True,
"schedule_offset": 0,
"quantize_groups": 64,
"quantize_verbose": True,
"quantization_type": "symmetric",
"rounding": "nearest",
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
"fp16_mixed_quantize":{
"enabled": False,
"quantize_change_ratio": 1
}
},
"different_groups": {
"aq1": {
"params": {
"bits": self.compression_bits,
},
"modules": [
# "^.+?$"
"blocks", # for transformer-based models "blocks", # for transformer-based models
"retnet", # for RetNets-based models "retnet", # for RetNets-based models
] ]

View File

@ -382,7 +382,7 @@ class Dataset(_Dataset):
resps = _load_quants(path) resps = _load_quants(path)
spkr_group = self.get_speaker_group(path) spkr_group = self.get_speaker_group(path)
lang = self.lang_symmap[ self.get_language(spkr_group) ] lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
# append additional prompts in an attempt to artifically increase lengths / offer new data # append additional prompts in an attempt to artifically increase lengths / offer new data
if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append: if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:

View File

@ -90,7 +90,7 @@ def load_engines():
model.load_state_dict(state, strict=cfg.trainer.strict_loading) model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# deepspeed inferencing # deepspeed inferencing
if backend == "local" and inferencing and deepspeed_available: #and sys.platform.startswith("win"): if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
engine_class = _Engine engine_class = _Engine
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module

View File

@ -36,7 +36,7 @@ class TTS():
if amp is None: if amp is None:
amp = cfg.inference.amp amp = cfg.inference.amp
if dtype is None: if dtype is None or dtype == "auto":
dtype = cfg.inference.weight_dtype dtype = cfg.inference.weight_dtype
if device is None: if device is None:
device = cfg.device device = cfg.device
@ -64,7 +64,7 @@ class TTS():
model.load_state_dict(state) model.load_state_dict(state)
if deepspeed_available: if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing:
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
return model return model
@ -88,6 +88,7 @@ class TTS():
else: else:
self.load_models() self.load_models()
if self.dtype != torch.int8:
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)

View File

@ -23,7 +23,7 @@ def get_model(cfg):
) )
model._cfg = cfg model._cfg = cfg
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
return model return model

View File

@ -421,7 +421,7 @@ class Base(nn.Module):
logits = [ logit[-1:] for logit in logits ] logits = [ logit[-1:] for logit in logits ]
devices = [ logit.device for logit in logits ] devices = [ logit.device for logit in logits ]
logits = [ logit.cpu() for logit in logits ] logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
# perform repetition penalizing # perform repetition penalizing
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]

View File

@ -14,7 +14,8 @@ if cfg.bitsandbytes.enabled:
Linear = bnb.nn.Linear8bitLt Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes.embedding: if cfg.bitsandbytes.embedding:
Embedding = bnb.nn.StableEmbedding Embedding = bnb.nn.modules.Embedding
"""
Embedding.forward = lambda self, input: ( self.norm(F.embedding( Embedding.forward = lambda self, input: ( self.norm(F.embedding(
input, input,
self.weight, self.weight,
@ -24,6 +25,7 @@ if cfg.bitsandbytes.enabled:
self.scale_grad_by_freq, self.scale_grad_by_freq,
self.sparse, self.sparse,
)).to(self.weight.dtype) ) )).to(self.weight.dtype) )
"""
if cfg.bitsandbytes.enabled: if cfg.bitsandbytes.enabled:
@ -62,11 +64,6 @@ def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ): def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k: with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
return func( self, k, *args, **kwargs ) return func( self, k, *args, **kwargs )
"""
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
return func( self, input.to(torch.int32), *args, **kwargs )
return func( self, input, *args, **kwargs )
"""
return wrapper return wrapper
Embedding.forward = autocast_forward(Embedding.forward) Embedding.forward = autocast_forward(Embedding.forward)