summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/opt_step_adamw.comp
blob: e0214fe7645c217e00ec77558b87beb5e6407be0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#version 450

#include "generic_head.comp"
#include "types.comp"

#extension GL_EXT_control_flow_attributes : enable

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) buffer X {A_TYPE x[];};
layout (binding = 1) readonly buffer G {A_TYPE grad[];};
layout (binding = 2) buffer GM {A_TYPE gradm[];};
layout (binding = 3) buffer GV {A_TYPE gradv[];};
layout (binding = 4) readonly buffer P {float params[7];};

void main() {
    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;

    if (i >= p.KX) {
        return;
    }

    const float alpha  = params[0];
    const float beta1  = params[1];
    const float beta2  = params[2];
    const float eps    = params[3];
    const float wd     = params[4];
    const float beta1h = params[5];
    const float beta2h = params[6];

    const float gi = grad[i];
    const float gmi = gradm[i]*beta1 +    gi*(1.0f - beta1);
    const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);

    gradm[i] = gmi;
    gradv[i] = gvi;

    const float mh =      gmi*beta1h;
    const float vh = sqrt(gvi*beta2h) + eps;

    x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
}