tweaks to try and get deepspeed quantized inferencing, validating bitsandbytes and deepspeed quantization, nothing seems to work
This commit is contained in:
parent
08bae355eb
commit
65f500083d
|
@ -322,6 +322,7 @@ class DeepSpeed:
|
||||||
zero_optimization_level: int = 0
|
zero_optimization_level: int = 0
|
||||||
use_compression_training: bool = False
|
use_compression_training: bool = False
|
||||||
compression_bits: int = 8
|
compression_bits: int = 8
|
||||||
|
inferencing: bool = False
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def ds_cfg(self):
|
def ds_cfg(self):
|
||||||
|
@ -363,7 +364,7 @@ class DeepSpeed:
|
||||||
"quantize_verbose": True,
|
"quantize_verbose": True,
|
||||||
"quantization_type": "symmetric",
|
"quantization_type": "symmetric",
|
||||||
"rounding": "nearest",
|
"rounding": "nearest",
|
||||||
"quantize_weight_in_forward": True,
|
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
|
||||||
"fp16_mixed_quantize":{
|
"fp16_mixed_quantize":{
|
||||||
"enabled": False,
|
"enabled": False,
|
||||||
"quantize_change_ratio": 1
|
"quantize_change_ratio": 1
|
||||||
|
@ -377,6 +378,35 @@ class DeepSpeed:
|
||||||
"quantization_period": 0
|
"quantization_period": 0
|
||||||
},
|
},
|
||||||
"modules": [
|
"modules": [
|
||||||
|
# "^.+?$"
|
||||||
|
"blocks", # for transformer-based models
|
||||||
|
"retnet", # for RetNets-based models
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"activation_quantization": {
|
||||||
|
"shared_parameters":{
|
||||||
|
"enabled": True,
|
||||||
|
"quantizer_kernel": True,
|
||||||
|
"schedule_offset": 0,
|
||||||
|
"quantize_groups": 64,
|
||||||
|
"quantize_verbose": True,
|
||||||
|
"quantization_type": "symmetric",
|
||||||
|
"rounding": "nearest",
|
||||||
|
"quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16
|
||||||
|
"fp16_mixed_quantize":{
|
||||||
|
"enabled": False,
|
||||||
|
"quantize_change_ratio": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"different_groups": {
|
||||||
|
"aq1": {
|
||||||
|
"params": {
|
||||||
|
"bits": self.compression_bits,
|
||||||
|
},
|
||||||
|
"modules": [
|
||||||
|
# "^.+?$"
|
||||||
"blocks", # for transformer-based models
|
"blocks", # for transformer-based models
|
||||||
"retnet", # for RetNets-based models
|
"retnet", # for RetNets-based models
|
||||||
]
|
]
|
||||||
|
|
|
@ -382,7 +382,7 @@ class Dataset(_Dataset):
|
||||||
resps = _load_quants(path)
|
resps = _load_quants(path)
|
||||||
|
|
||||||
spkr_group = self.get_speaker_group(path)
|
spkr_group = self.get_speaker_group(path)
|
||||||
lang = self.lang_symmap[ self.get_language(spkr_group) ]
|
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
||||||
|
|
||||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||||
if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
||||||
|
|
|
@ -90,7 +90,7 @@ def load_engines():
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
# deepspeed inferencing
|
# deepspeed inferencing
|
||||||
if backend == "local" and inferencing and deepspeed_available: #and sys.platform.startswith("win"):
|
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||||
engine_class = _Engine
|
engine_class = _Engine
|
||||||
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ class TTS():
|
||||||
|
|
||||||
if amp is None:
|
if amp is None:
|
||||||
amp = cfg.inference.amp
|
amp = cfg.inference.amp
|
||||||
if dtype is None:
|
if dtype is None or dtype == "auto":
|
||||||
dtype = cfg.inference.weight_dtype
|
dtype = cfg.inference.weight_dtype
|
||||||
if device is None:
|
if device is None:
|
||||||
device = cfg.device
|
device = cfg.device
|
||||||
|
@ -64,7 +64,7 @@ class TTS():
|
||||||
|
|
||||||
model.load_state_dict(state)
|
model.load_state_dict(state)
|
||||||
|
|
||||||
if deepspeed_available:
|
if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing:
|
||||||
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -88,6 +88,7 @@ class TTS():
|
||||||
else:
|
else:
|
||||||
self.load_models()
|
self.load_models()
|
||||||
|
|
||||||
|
if self.dtype != torch.int8:
|
||||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||||
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ def get_model(cfg):
|
||||||
)
|
)
|
||||||
model._cfg = cfg
|
model._cfg = cfg
|
||||||
|
|
||||||
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -421,7 +421,7 @@ class Base(nn.Module):
|
||||||
logits = [ logit[-1:] for logit in logits ]
|
logits = [ logit[-1:] for logit in logits ]
|
||||||
|
|
||||||
devices = [ logit.device for logit in logits ]
|
devices = [ logit.device for logit in logits ]
|
||||||
logits = [ logit.cpu() for logit in logits ]
|
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||||
|
|
||||||
# perform repetition penalizing
|
# perform repetition penalizing
|
||||||
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||||
|
|
|
@ -14,7 +14,8 @@ if cfg.bitsandbytes.enabled:
|
||||||
Linear = bnb.nn.Linear8bitLt
|
Linear = bnb.nn.Linear8bitLt
|
||||||
|
|
||||||
if cfg.bitsandbytes.embedding:
|
if cfg.bitsandbytes.embedding:
|
||||||
Embedding = bnb.nn.StableEmbedding
|
Embedding = bnb.nn.modules.Embedding
|
||||||
|
"""
|
||||||
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
||||||
input,
|
input,
|
||||||
self.weight,
|
self.weight,
|
||||||
|
@ -24,6 +25,7 @@ if cfg.bitsandbytes.enabled:
|
||||||
self.scale_grad_by_freq,
|
self.scale_grad_by_freq,
|
||||||
self.sparse,
|
self.sparse,
|
||||||
)).to(self.weight.dtype) )
|
)).to(self.weight.dtype) )
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
if cfg.bitsandbytes.enabled:
|
if cfg.bitsandbytes.enabled:
|
||||||
|
@ -62,11 +64,6 @@ def autocast_forward( func ):
|
||||||
def wrapper( self, input, *args, **kwargs ):
|
def wrapper( self, input, *args, **kwargs ):
|
||||||
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
|
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
|
||||||
return func( self, k, *args, **kwargs )
|
return func( self, k, *args, **kwargs )
|
||||||
"""
|
|
||||||
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
|
|
||||||
return func( self, input.to(torch.int32), *args, **kwargs )
|
|
||||||
return func( self, input, *args, **kwargs )
|
|
||||||
"""
|
|
||||||
return wrapper
|
return wrapper
|
||||||
Embedding.forward = autocast_forward(Embedding.forward)
|
Embedding.forward = autocast_forward(Embedding.forward)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user