summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEngininja2 <139037756+Engininja2@users.noreply.github.com>2024-05-18 02:05:17 -0600
committerGitHub <noreply@github.com>2024-05-18 10:05:17 +0200
commitd233b507cd19fcc2d8d8963ecc6a3eb7a33f2ecc (patch)
tree95e87a3f3ddce289dd28482989529e81946898ca
parent0f98acfac6cc561dc57586bfff778405e42b576b (diff)
cuda : add half2 __shfl_xor() for ROCm 5.5 (#7263)
-rw-r--r--ggml-cuda/common.cuh14
1 files changed, 14 insertions, 0 deletions
diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh
index b6f0bc36..784792ba 100644
--- a/ggml-cuda/common.cuh
+++ b/ggml-cuda/common.cuh
@@ -315,6 +315,20 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif
return c;
}
+
+#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
+// __shfl_xor() for half2 was added in ROCm 5.6
+static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
+ typedef union half2_b32 {
+ half2 val;
+ int b32;
+ } half2_b32_t;
+ half2_b32_t tmp;
+ tmp.val = var;
+ tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
+ return tmp.val;
+}
+#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
#endif // defined(GGML_USE_HIPBLAS)
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL