summaryrefslogtreecommitdiff
path: root/ggml-opencl.cpp
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2024-01-12 20:07:38 +0100
committerGitHub <noreply@github.com>2024-01-12 20:07:38 +0100
commite7e4df031b9e29d4b55a4e0b0295187f6b213db1 (patch)
tree93211b7800be3c2c5f9eb1d55f3b7b3acdc56c9b /ggml-opencl.cpp
parent584d674be622fbf1578694ada6e62eebedbfd377 (diff)
llama : ggml-backend integration (#4766)
* llama : ggml-backend integration * ggml-backend : add names to buffers * fix unmap after loading * batched-bench : add tensor_split param * llama : check for null tensor_split * ggml-backend : increase GGML_MAX_BACKENDS * improve graph splitting, partial fix for --no-kv-offload * cuda : add ggml-backend split buffer support * cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available) * ggml : fix null backend dereference (#4807) * ggml : fix null backend dereference * ggml : also check ggml_backend_is_cpu * test-backend-ops : check buffer allocation failures * llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row) * ggml : fix mul_mat_id work size * llama : rewrite session kv load/set without graphs * minor * llama : only initialize used backends, free backends on context free * llama : abort ctx if cuda backend init fails * llama : rewrite lora with ggml-backend and compute on CPU ggml-ci * llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer * opencl : add ggml-backend buffer type * cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf) * llama : on Metal, by default offload the full model ggml-ci * metal : page align the data ptr (#4854) * Apply suggestions from code review Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * cuda : fix split buffer free * address review comments * llama-bench : add split-mode parameter * fix whitespace * opencl : fix double initialization * server : add --split-mode parameter * use async copy and compute to improve multi-gpu performance ggml-ci * use async memcpys to copy the graph outputs to the CPU * fix opencl * use a host buffer for the cpu compute buffer for faster copies to the gpu --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Diffstat (limited to 'ggml-opencl.cpp')
-rw-r--r--ggml-opencl.cpp335
1 files changed, 321 insertions, 14 deletions
diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp
index 496f9cdc..2bb93638 100644
--- a/ggml-opencl.cpp
+++ b/ggml-opencl.cpp
@@ -1,5 +1,6 @@
#include "ggml.h"
#include "ggml-opencl.h"
+#include "ggml-backend-impl.h"
#include <array>
#include <atomic>
@@ -10,7 +11,7 @@
#include <sstream>
#include <vector>
-#define CL_TARGET_OPENCL_VERSION 110
+#define CL_TARGET_OPENCL_VERSION 120
#include <clblast.h>
#if defined(_MSC_VER)
@@ -929,6 +930,12 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
}
void ggml_cl_init(void) {
+ static bool initialized = false;
+ if (initialized) {
+ return;
+ }
+ initialized = true;
+
cl_int err;
struct cl_device;
@@ -1483,8 +1490,8 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
} else {
d_X = ggml_cl_pool_malloc(sizeof(float) * x_ne, &x_size);
}
- cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
- cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
+ cl_mem d_Y = src1->backend == GGML_BACKEND_GPU ? (cl_mem) src1->extra : ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
+ cl_mem d_D = dst->backend == GGML_BACKEND_GPU ? (cl_mem) dst->extra : ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
size_t x_offset = 0;
@@ -1501,7 +1508,9 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
// copy src1 to device
- CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
+ if (src1->backend == GGML_BACKEND_CPU) {
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
+ }
CL_CHECK(clFinish(queue));
@@ -1522,8 +1531,10 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
}
// copy dst to host
- float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
- CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
+ if (dst->backend == GGML_BACKEND_CPU) {
+ float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
+ CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
+ }
}
}
}
@@ -1532,8 +1543,12 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
if (src0->backend != GGML_BACKEND_GPU) {
ggml_cl_pool_free(d_X, x_size);
}
- ggml_cl_pool_free(d_Y, y_size);
- ggml_cl_pool_free(d_D, d_size);
+ if (src1->backend != GGML_BACKEND_GPU) {
+ ggml_cl_pool_free(d_Y, y_size);
+ }
+ if (dst->backend != GGML_BACKEND_GPU) {
+ ggml_cl_pool_free(d_D, d_size);
+ }
}
static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
@@ -1598,6 +1613,8 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
}
+ // FIXME: convert on device
+
for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
// convert src1 to fp16
// TODO: use multiple threads
@@ -1643,11 +1660,13 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
}
// copy dst to host, then convert to float
- CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
-
- float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
-
- ggml_fp16_to_fp32_row(tmp, d, d_ne);
+ if (dst->backend == GGML_BACKEND_CPU) {
+ CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
+ float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
+ ggml_fp16_to_fp32_row(tmp, d, d_ne);
+ } else {
+ // FIXME: convert dst to fp32 on device
+ }
}
}
}
@@ -1801,7 +1820,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
}
-bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
const int64_t ne10 = src1->ne[0];
const int64_t ne0 = dst->ne[0];
@@ -1895,3 +1914,291 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
tensor->extra = dst;
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
}
+
+// ggml-backend
+
+// buffer
+
+struct ggml_backend_opencl_buffer_context {
+ ~ggml_backend_opencl_buffer_context() {
+ if (buffer) {
+ clReleaseMemObject(buffer);
+ }
+ for (auto * sub_buffer : sub_buffers) {
+ clReleaseMemObject(sub_buffer);
+ }
+ }
+
+ cl_mem buffer;
+ std::vector<cl_mem> sub_buffers;
+};
+
+static void * const cl_ptr_base = (void *)(uintptr_t) 0x1000;
+
+static const char * ggml_backend_opencl_buffer_get_name(ggml_backend_buffer_t buffer) {
+ return "OpenCL";
+
+ GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+ delete ctx;
+}
+
+static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
+ return cl_ptr_base;
+
+ GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+ if (tensor->view_src != NULL && tensor->view_offs == 0) {
+ tensor->extra = tensor->view_src->extra;
+ } else {
+ ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+ cl_buffer_region region = {(size_t)((char *)tensor->data - (char *)cl_ptr_base), ggml_nbytes(tensor)};
+ cl_int err;
+ cl_mem sub_buffer = clCreateSubBuffer(ctx->buffer, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
+ CL_CHECK(err);
+ ctx->sub_buffers.push_back(sub_buffer);
+ tensor->extra = sub_buffer;
+ }
+ tensor->backend = GGML_BACKEND_GPU;
+}
+
+static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ cl_mem tensor_buffer = (cl_mem) tensor->extra;
+ CL_CHECK(clEnqueueWriteBuffer(queue, tensor_buffer, true, offset, size, data, 0, NULL, NULL));
+ CL_CHECK(clFinish(queue));
+
+ GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ cl_mem tensor_buffer = (cl_mem) tensor->extra;
+ CL_CHECK(clEnqueueReadBuffer(queue, tensor_buffer, true, offset, size, data, 0, NULL, NULL));
+ CL_CHECK(clFinish(queue));
+
+ GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+ CL_CHECK(clEnqueueFillBuffer(queue, ctx->buffer, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL));
+ CL_CHECK(clFinish(queue));
+}
+
+static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) {
+ ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+ for (auto * sub_buffer : ctx->sub_buffers) {
+ clReleaseMemObject(sub_buffer);
+ }
+ ctx->sub_buffers.clear();
+}
+
+static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = {
+ /* .get_name = */ ggml_backend_opencl_buffer_get_name,
+ /* .free_buffer = */ ggml_backend_opencl_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_opencl_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_opencl_buffer_init_tensor,
+ /* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor,
+ /* .cpy_tensor = */ NULL,
+ /* .clear = */ ggml_backend_opencl_buffer_clear,
+ /* .reset = */ ggml_backend_opencl_buffer_reset,
+};
+
+// buffer type
+
+static const char * ggml_backend_opencl_buffer_type_name(ggml_backend_buffer_type_t buffer_type) {
+ return "OpenCL";
+
+ GGML_UNUSED(buffer_type);
+}
+
+static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) {
+ ggml_cl_init();
+
+ cl_int err;
+ cl_mem mem = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err);
+ if (err != CL_SUCCESS) {
+ fprintf(stderr, "%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0);
+ return nullptr;
+ }
+
+ ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context{mem, {}};
+
+ return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size);
+}
+
+static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
+ // FIXME: not thread safe, device may not be initialized yet
+ static cl_uint alignment = -1;
+ if (alignment == (cl_uint)-1) {
+ ggml_cl_init();
+ clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &alignment, NULL);
+ }
+ return alignment;
+
+ GGML_UNUSED(buffer_type);
+}
+
+static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buffer_type, ggml_backend_t backend) {
+ //return ggml_backend_is_opencl(backend); // opencl must be used through the cpu backend
+ return ggml_backend_is_cpu(backend);
+
+ GGML_UNUSED(buffer_type);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_opencl_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_opencl_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_opencl_buffer_type_get_alignment,
+ /* .get_alloc_size = */ NULL,
+ /* .supports_backend = */ ggml_backend_opencl_buffer_type_supports_backend,
+ /* .is_host = */ NULL,
+};
+
+
+ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() {
+ static ggml_backend_buffer_type buffer_type = {
+ /* .iface = */ ggml_backend_opencl_buffer_type_interface,
+ /* .context = */ nullptr,
+ };
+
+ return &buffer_type;
+}
+
+#if 0
+// host buffer type
+
+static const char * ggml_backend_opencl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
+ return "CL_Host";
+
+ GGML_UNUSED(buft);
+}
+
+static const char * ggml_backend_opencl_host_buffer_name(ggml_backend_buffer_t buffer) {
+ return "CL_Host";
+
+ GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_opencl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_cl_host_free(buffer->context);
+}
+
+static ggml_backend_buffer_t ggml_backend_opencl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ void * ptr = ggml_cl_host_malloc(size);
+
+ if (ptr == nullptr) {
+ // fallback to cpu buffer
+ return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+ }
+
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+ buffer->buft = buft;
+ buffer->iface.get_name = ggml_backend_opencl_host_buffer_name;
+ buffer->iface.free_buffer = ggml_backend_opencl_host_buffer_free_buffer;
+
+ return buffer;
+}
+
+ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type() {
+ static struct ggml_backend_buffer_type ggml_backend_opencl_buffer_type_host = {
+ /* .iface = */ {
+ /* .get_name = */ ggml_backend_opencl_host_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_opencl_host_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
+ /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
+ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
+ },
+ /* .context = */ nullptr,
+ };
+
+ return &ggml_backend_opencl_buffer_type_host;
+}
+
+// backend
+
+static const char * ggml_backend_opencl_name(ggml_backend_t backend) {
+ return "OpenCL";
+
+ GGML_UNUSED(backend);
+}
+
+static void ggml_backend_opencl_free(ggml_backend_t backend) {
+ GGML_UNUSED(backend);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_opencl_get_default_buffer_type(ggml_backend_t backend) {
+ return ggml_backend_opencl_buffer_type();
+
+ GGML_UNUSED(backend);
+}
+
+static bool ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
+ for (int i = 0; i < graph->n_nodes; ++i) {
+ ggml_tensor * node = graph->nodes[i];
+ switch (node->op) {
+ case GGML_OP_MUL_MAT:
+ ggml_cl_mul_mat(node->src[0], node->src[1], node, nullptr, 0);
+ break;
+ case GGML_OP_MUL:
+ ggml_cl_mul(node->src[0], node->src[1], node);
+ break;
+ default:
+ GGML_ASSERT(false);
+ }
+ }
+
+ return true;
+
+ GGML_UNUSED(backend);
+}
+
+static bool ggml_backend_opencl_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
+ switch (op->op) {
+ case GGML_OP_MUL_MAT:
+ return ggml_cl_can_mul_mat(op->src[0], op->src[1], op);
+ case GGML_OP_MUL:
+ // return ggml_can_repeat_rows(op->src[1], op->src[0]);
+ return true;
+ default:
+ return false;
+ }
+
+ GGML_UNUSED(backend);
+}
+
+static ggml_backend_i opencl_backend_i = {
+ /* .get_name = */ ggml_backend_opencl_name,
+ /* .free = */ ggml_backend_opencl_free,
+ /* .get_default_buffer_type = */ ggml_backend_opencl_get_default_buffer_type,
+ /* .set_tensor_async = */ NULL,
+ /* .get_tensor_async = */ NULL,
+ /* .cpy_tensor_from_async = */ NULL,
+ /* .cpy_tensor_to_async = */ NULL,
+ /* .synchronize = */ NULL,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_opencl_graph_compute,
+ /* .supports_op = */ ggml_backend_opencl_supports_op,
+};
+
+ggml_backend_t ggml_backend_opencl_init() {
+ ggml_backend_t backend = new ggml_backend {
+ /* .interface = */ opencl_backend_i,
+ /* .context = */ nullptr
+ };
+
+ return backend;
+}
+
+bool ggml_backend_is_opencl(ggml_backend_t backend) {
+ return backend && backend->iface.get_name == ggml_backend_opencl_name;
+}
+#endif