summaryrefslogtreecommitdiff
path: root/kompute-shaders/op_softmax.comp
diff options
context:
space:
mode:
Diffstat (limited to 'kompute-shaders/op_softmax.comp')
-rw-r--r--kompute-shaders/op_softmax.comp56
1 files changed, 56 insertions, 0 deletions
diff --git a/kompute-shaders/op_softmax.comp b/kompute-shaders/op_softmax.comp
new file mode 100644
index 00000000..7bc9176c
--- /dev/null
+++ b/kompute-shaders/op_softmax.comp
@@ -0,0 +1,56 @@
+// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
+
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x_id = 0) in;
+
+layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int ne01;
+ int ne02;
+ float scale;
+ int mask;
+} pcs;
+
+void main() {
+ if (gl_SubgroupInvocationID > 31)
+ return;
+
+ const uint i03 = gl_WorkGroupID.z;
+ const uint i02 = gl_WorkGroupID.y;
+ const uint i01 = gl_WorkGroupID.x;
+
+ const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
+ const uint psrc0 = extra_off + pcs.inAOff; // Based from inA
+ const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
+ const uint pdst = extra_off + pcs.outOff; // Based from out_
+
+ // parallel max
+ float localMax = uintBitsToFloat(0xFF800000);
+ for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
+ localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
+ }
+ float max_ = subgroupMax(localMax);
+
+ // parallel sum
+ float localSum = 0.0f;
+ for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
+ const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
+ localSum += exp_psrc0;
+ out_[pdst + i00] = exp_psrc0;
+ }
+
+ const float sum = subgroupAdd(localSum);
+ for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
+ out_[pdst + i00] /= sum;
+ }
+}