swap the order in which momentum and parameters are updated in ops.cu
This commit is contained in:
parent
c5582724d5
commit
8618bed001
25
csrc/ops.cu
25
csrc/ops.cu
|
@ -120,8 +120,6 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||||
case MOMENTUM:
|
case MOMENTUM:
|
||||||
case RMSPROP:
|
case RMSPROP:
|
||||||
case ADAGRAD:
|
case ADAGRAD:
|
||||||
case LION:
|
|
||||||
|
|
||||||
if(max_unorm > 0.0f)
|
if(max_unorm > 0.0f)
|
||||||
{
|
{
|
||||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||||
|
@ -132,6 +130,18 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
break;
|
break;
|
||||||
|
case LION:
|
||||||
|
// in lion, the momentum update after the parameter update
|
||||||
|
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
|
||||||
|
if(max_unorm > 0.0f)
|
||||||
|
{
|
||||||
|
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||||
|
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,7 +174,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
||||||
case MOMENTUM:
|
case MOMENTUM:
|
||||||
case RMSPROP:
|
case RMSPROP:
|
||||||
case ADAGRAD:
|
case ADAGRAD:
|
||||||
case LION:
|
|
||||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
@ -172,6 +181,16 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
||||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
break;
|
break;
|
||||||
|
case LION:
|
||||||
|
// in lion, the momentum update happens after the parameter update
|
||||||
|
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
||||||
|
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
|
||||||
|
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||||
|
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user