summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders/get_rows.comp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders/get_rows.comp')
-rw-r--r--ggml/src/vulkan-shaders/get_rows.comp17
1 files changed, 12 insertions, 5 deletions
diff --git a/ggml/src/vulkan-shaders/get_rows.comp b/ggml/src/vulkan-shaders/get_rows.comp
index e9ff22ef..ee6b86a1 100644
--- a/ggml/src/vulkan-shaders/get_rows.comp
+++ b/ggml/src/vulkan-shaders/get_rows.comp
@@ -3,6 +3,8 @@
#include "types.comp"
#include "generic_binary_head.comp"
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
void main() {
const uint i00 = gl_GlobalInvocationID.x;
const uint i10 = gl_GlobalInvocationID.y;
@@ -13,14 +15,19 @@ void main() {
return;
}
- const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
+ const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
- const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
- const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
+ const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
+ const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
+#if defined(DATA_A_BF16)
+ FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
+#else
+ FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
+#endif
#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
+ data_d[d_offset + i00] = D_TYPE(v);
#else
- data_d[d_offset + i00] = data_a[a_offset + i00];
+ data_d[d_offset + i00] = D_TYPE(v);
#endif
}