#version 450 #include "generic_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) buffer X {A_TYPE x[];}; layout (binding = 1) readonly buffer G {A_TYPE grad[];}; layout (binding = 2) buffer GM {A_TYPE gradm[];}; layout (binding = 3) buffer GV {A_TYPE gradv[];}; layout (binding = 4) readonly buffer P {float params[7];}; void main() { const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; if (i >= p.KX) { return; } const float alpha = params[0]; const float beta1 = params[1]; const float beta2 = params[2]; const float eps = params[3]; const float wd = params[4]; const float beta1h = params[5]; const float beta2h = params[6]; const float gi = grad[i]; const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1); const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2); gradm[i] = gmi; gradv[i] = gvi; const float mh = gmi*beta1h; const float vh = sqrt(gvi*beta2h) + eps; x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; }