From 6cea8407108a396a955a140ed20d0a9406e8df8a Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 7 Mar 2025 18:57:25 -0600 Subject: [PATCH] oops --- vall_e/config.py | 24 ++++++++---------------- vall_e/models/base_v2.py | 2 +- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 34379b4..4f403e5 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -563,7 +563,7 @@ class Evaluation: @dataclass() class DeepSpeed: - zero_optimization_level: int = 0 # doesn't seem to work + zero_optimization_level: int = 0 use_compression_training: bool = False # cope compression_bits: int = 8 # cope inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead @@ -697,24 +697,16 @@ class DeepSpeed: } if self.use_compression_training else None, "zero_optimization": { "stage": self.zero_optimization_level, + "allgather_partitions": True, "contiguous_gradients": True, "overlap_comm": True, "reduce_scatter": True, - "reduce_bucket_size": 5e8, - "allgather_bucket_size": 5e8, - "sub_group_size": 5e8, - "round_robin_gradients": True, - "offload_optimizer": { - "device": "cpu", - "pin_memory": True - }, - "offload_param": { - "device": "cpu", - "pin_memory": True - }, - "zero_quantized_weights": self.use_compression_training, - "zero_hpz_partition_size": world_size(), - "zero_quantized_gradients": self.use_compression_training, + #"reduce_bucket_size": 5e8, + #"allgather_bucket_size": 5e8, + #"sub_group_size": 5e8, + #"zero_quantized_weights": self.use_compression_training, + #"zero_hpz_partition_size": world_size(), + #"zero_quantized_gradients": self.use_compression_training, } if self.zero_optimization_level > 0 else None, "comms_logger": { "enabled": False diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 1181590..edc5548 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1424,7 +1424,7 @@ class Base_V2(nn.Module): hidden_states = output.hidden_states if self.use_streamlined_calc_loss: - logits = head( output.logits ) + logits = self.audio_decoder( output.logits ) else: logits = [ logit for logit in output.logits ] grouped_logits = {}