diff --git a/vall_e/config.py b/vall_e/config.py index 643ed6d..dcd4c9c 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -322,6 +322,7 @@ class DeepSpeed: zero_optimization_level: int = 0 use_compression_training: bool = False compression_bits: int = 8 + inferencing: bool = False @cached_property def ds_cfg(self): @@ -363,7 +364,7 @@ class DeepSpeed: "quantize_verbose": True, "quantization_type": "symmetric", "rounding": "nearest", - "quantize_weight_in_forward": True, + "quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16 "fp16_mixed_quantize":{ "enabled": False, "quantize_change_ratio": 1 @@ -377,6 +378,35 @@ class DeepSpeed: "quantization_period": 0 }, "modules": [ + # "^.+?$" + "blocks", # for transformer-based models + "retnet", # for RetNets-based models + ] + } + } + }, + "activation_quantization": { + "shared_parameters":{ + "enabled": True, + "quantizer_kernel": True, + "schedule_offset": 0, + "quantize_groups": 64, + "quantize_verbose": True, + "quantization_type": "symmetric", + "rounding": "nearest", + "quantize_weight_in_forward": cfg.trainer.weight_dtype.lower() != "float16", # MoQ (quantize in optimization step) weight quantization is only supported for FP16 + "fp16_mixed_quantize":{ + "enabled": False, + "quantize_change_ratio": 1 + } + }, + "different_groups": { + "aq1": { + "params": { + "bits": self.compression_bits, + }, + "modules": [ + # "^.+?$" "blocks", # for transformer-based models "retnet", # for RetNets-based models ] diff --git a/vall_e/data.py b/vall_e/data.py index 4c5a0aa..d780b23 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -382,7 +382,7 @@ class Dataset(_Dataset): resps = _load_quants(path) spkr_group = self.get_speaker_group(path) - lang = self.lang_symmap[ self.get_language(spkr_group) ] + lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8) # append additional prompts in an attempt to artifically increase lengths / offer new data if cfg.experimental and cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append: diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 30a6de8..faf0e32 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -90,7 +90,7 @@ def load_engines(): model.load_state_dict(state, strict=cfg.trainer.strict_loading) # deepspeed inferencing - if backend == "local" and inferencing and deepspeed_available: #and sys.platform.startswith("win"): + if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"): engine_class = _Engine model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module diff --git a/vall_e/inference.py b/vall_e/inference.py index e39d13e..12e5672 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -36,7 +36,7 @@ class TTS(): if amp is None: amp = cfg.inference.amp - if dtype is None: + if dtype is None or dtype == "auto": dtype = cfg.inference.weight_dtype if device is None: device = cfg.device @@ -64,7 +64,7 @@ class TTS(): model.load_state_dict(state) - if deepspeed_available: + if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing: model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module return model @@ -88,8 +88,9 @@ class TTS(): else: self.load_models() - self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) - self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) + if self.dtype != torch.int8: + self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) + self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.ar.eval() self.nar.eval() diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index e9728ec..38071e1 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -23,7 +23,7 @@ def get_model(cfg): ) model._cfg = cfg - print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") return model diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c90ddf9..c385ad0 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -421,7 +421,7 @@ class Base(nn.Module): logits = [ logit[-1:] for logit in logits ] devices = [ logit.device for logit in logits ] - logits = [ logit.cpu() for logit in logits ] + logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] # perform repetition penalizing logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index dc16236..62ac50e 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -9,12 +9,13 @@ Linear = torch.nn.Linear if cfg.bitsandbytes.enabled: import bitsandbytes as bnb - + if cfg.bitsandbytes.linear: Linear = bnb.nn.Linear8bitLt if cfg.bitsandbytes.embedding: - Embedding = bnb.nn.StableEmbedding + Embedding = bnb.nn.modules.Embedding + """ Embedding.forward = lambda self, input: ( self.norm(F.embedding( input, self.weight, @@ -24,6 +25,7 @@ if cfg.bitsandbytes.enabled: self.scale_grad_by_freq, self.sparse, )).to(self.weight.dtype) ) + """ if cfg.bitsandbytes.enabled: @@ -62,11 +64,6 @@ def autocast_forward( func ): def wrapper( self, input, *args, **kwargs ): with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k: return func( self, k, *args, **kwargs ) - """ - if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8: - return func( self, input.to(torch.int32), *args, **kwargs ) - return func( self, input, *args, **kwargs ) - """ return wrapper Embedding.forward = autocast_forward(Embedding.forward)