summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu50
1 files changed, 47 insertions, 3 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 248cb2c4..5346b9e0 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -31,6 +31,9 @@
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
+#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
@@ -424,6 +427,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
+#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
+#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
+#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
+
#define MUL_MAT_SRC1_COL_STRIDE 128
#define MAX_STREAMS 8
@@ -6258,6 +6265,41 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
}
}
+void ggml_cuda_set_peer_access(const int n_tokens) {
+ static bool peer_access_enabled = false;
+
+ const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
+
+ if (peer_access_enabled == enable_peer_access) {
+ return;
+ }
+
+#ifdef NDEBUG
+ for (int id = 0; id < g_device_count; ++id) {
+ CUDA_CHECK(ggml_cuda_set_device(id));
+
+ for (int id_other = 0; id_other < g_device_count; ++id_other) {
+ if (id == id_other) {
+ continue;
+ }
+ if (id != g_main_device && id_other != g_main_device) {
+ continue;
+ }
+
+ int canAccessPeer;
+ CUDA_CHECK(cudaDeviceCanAccessPeer(&canAccessPeer, id, id_other));
+ if (enable_peer_access) {
+ CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
+ } else {
+ CUDA_CHECK(cudaDeviceDisablePeerAccess(id_other));
+ }
+ }
+ }
+#endif // NDEBUG
+
+ peer_access_enabled = enable_peer_access;
+}
+
static void ggml_cuda_op_mul_mat(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
const bool convert_src1_to_q8_1) {
@@ -6282,6 +6324,8 @@ static void ggml_cuda_op_mul_mat(
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
+ ggml_cuda_set_peer_access(ne11);
+
GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
@@ -7010,7 +7054,7 @@ void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
ggml_cuda_assign_buffers_impl(tensor, false, true, false);
}
-void ggml_cuda_set_main_device(int main_device) {
+void ggml_cuda_set_main_device(const int main_device) {
if (main_device >= g_device_count) {
fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
main_device, g_device_count, g_main_device);
@@ -7024,11 +7068,11 @@ void ggml_cuda_set_main_device(int main_device) {
}
}
-void ggml_cuda_set_mul_mat_q(bool mul_mat_q) {
+void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
g_mul_mat_q = mul_mat_q;
}
-void ggml_cuda_set_scratch_size(size_t scratch_size) {
+void ggml_cuda_set_scratch_size(const size_t scratch_size) {
g_scratch_size = scratch_size;
}