oops
This commit is contained in:
parent
dbd34b6430
commit
6cea840710
|
@ -563,7 +563,7 @@ class Evaluation:
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class DeepSpeed:
|
class DeepSpeed:
|
||||||
zero_optimization_level: int = 0 # doesn't seem to work
|
zero_optimization_level: int = 0
|
||||||
use_compression_training: bool = False # cope
|
use_compression_training: bool = False # cope
|
||||||
compression_bits: int = 8 # cope
|
compression_bits: int = 8 # cope
|
||||||
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
||||||
|
@ -697,24 +697,16 @@ class DeepSpeed:
|
||||||
} if self.use_compression_training else None,
|
} if self.use_compression_training else None,
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": self.zero_optimization_level,
|
"stage": self.zero_optimization_level,
|
||||||
|
"allgather_partitions": True,
|
||||||
"contiguous_gradients": True,
|
"contiguous_gradients": True,
|
||||||
"overlap_comm": True,
|
"overlap_comm": True,
|
||||||
"reduce_scatter": True,
|
"reduce_scatter": True,
|
||||||
"reduce_bucket_size": 5e8,
|
#"reduce_bucket_size": 5e8,
|
||||||
"allgather_bucket_size": 5e8,
|
#"allgather_bucket_size": 5e8,
|
||||||
"sub_group_size": 5e8,
|
#"sub_group_size": 5e8,
|
||||||
"round_robin_gradients": True,
|
#"zero_quantized_weights": self.use_compression_training,
|
||||||
"offload_optimizer": {
|
#"zero_hpz_partition_size": world_size(),
|
||||||
"device": "cpu",
|
#"zero_quantized_gradients": self.use_compression_training,
|
||||||
"pin_memory": True
|
|
||||||
},
|
|
||||||
"offload_param": {
|
|
||||||
"device": "cpu",
|
|
||||||
"pin_memory": True
|
|
||||||
},
|
|
||||||
"zero_quantized_weights": self.use_compression_training,
|
|
||||||
"zero_hpz_partition_size": world_size(),
|
|
||||||
"zero_quantized_gradients": self.use_compression_training,
|
|
||||||
} if self.zero_optimization_level > 0 else None,
|
} if self.zero_optimization_level > 0 else None,
|
||||||
"comms_logger": {
|
"comms_logger": {
|
||||||
"enabled": False
|
"enabled": False
|
||||||
|
|
|
@ -1424,7 +1424,7 @@ class Base_V2(nn.Module):
|
||||||
hidden_states = output.hidden_states
|
hidden_states = output.hidden_states
|
||||||
|
|
||||||
if self.use_streamlined_calc_loss:
|
if self.use_streamlined_calc_loss:
|
||||||
logits = head( output.logits )
|
logits = self.audio_decoder( output.logits )
|
||||||
else:
|
else:
|
||||||
logits = [ logit for logit in output.logits ]
|
logits = [ logit for logit in output.logits ]
|
||||||
grouped_logits = {}
|
grouped_logits = {}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user