swap the order in which momentum and parameters are updated in ops.cu
This commit is contained in:
parent
c5582724d5
commit
8618bed001
25
csrc/ops.cu
25
csrc/ops.cu
|
@ -120,8 +120,6 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
case LION:
|
||||
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
|
@ -132,6 +130,18 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case LION:
|
||||
// in lion, the momentum update after the parameter update
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -164,7 +174,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
case LION:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
@ -172,6 +181,16 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case LION:
|
||||
// in lion, the momentum update happens after the parameter update
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user