This commit is contained in:
mrq 2025-03-07 18:57:25 -06:00
parent dbd34b6430
commit 6cea840710
2 changed files with 9 additions and 17 deletions

View File

@ -563,7 +563,7 @@ class Evaluation:
@dataclass() @dataclass()
class DeepSpeed: class DeepSpeed:
zero_optimization_level: int = 0 # doesn't seem to work zero_optimization_level: int = 0
use_compression_training: bool = False # cope use_compression_training: bool = False # cope
compression_bits: int = 8 # cope compression_bits: int = 8 # cope
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
@ -697,24 +697,16 @@ class DeepSpeed:
} if self.use_compression_training else None, } if self.use_compression_training else None,
"zero_optimization": { "zero_optimization": {
"stage": self.zero_optimization_level, "stage": self.zero_optimization_level,
"allgather_partitions": True,
"contiguous_gradients": True, "contiguous_gradients": True,
"overlap_comm": True, "overlap_comm": True,
"reduce_scatter": True, "reduce_scatter": True,
"reduce_bucket_size": 5e8, #"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8, #"allgather_bucket_size": 5e8,
"sub_group_size": 5e8, #"sub_group_size": 5e8,
"round_robin_gradients": True, #"zero_quantized_weights": self.use_compression_training,
"offload_optimizer": { #"zero_hpz_partition_size": world_size(),
"device": "cpu", #"zero_quantized_gradients": self.use_compression_training,
"pin_memory": True
},
"offload_param": {
"device": "cpu",
"pin_memory": True
},
"zero_quantized_weights": self.use_compression_training,
"zero_hpz_partition_size": world_size(),
"zero_quantized_gradients": self.use_compression_training,
} if self.zero_optimization_level > 0 else None, } if self.zero_optimization_level > 0 else None,
"comms_logger": { "comms_logger": {
"enabled": False "enabled": False

View File

@ -1424,7 +1424,7 @@ class Base_V2(nn.Module):
hidden_states = output.hidden_states hidden_states = output.hidden_states
if self.use_streamlined_calc_loss: if self.use_streamlined_calc_loss:
logits = head( output.logits ) logits = self.audio_decoder( output.logits )
else: else:
logits = [ logit for logit in output.logits ] logits = [ logit for logit in output.logits ]
grouped_logits = {} grouped_logits = {}