summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cann/acl_tensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cann/acl_tensor.h')
-rw-r--r--ggml/src/ggml-cann/acl_tensor.h36
1 files changed, 32 insertions, 4 deletions
diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h
index 7d0bf04e..4734a9cb 100644
--- a/ggml/src/ggml-cann/acl_tensor.h
+++ b/ggml/src/ggml-cann/acl_tensor.h
@@ -23,6 +23,9 @@
#ifndef CANN_ACL_TENSOR_H
#define CANN_ACL_TENSOR_H
+#include <algorithm>
+#include <cstring>
+
#include <aclnn/aclnn_base.h>
#include "common.h"
@@ -65,7 +68,8 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = null
size_t offset = 0);
/**
- * @brief Creates an ACL tensor from provided parameters.
+ * @brief Template for creating an ACL tensor from provided parameters. typename TYPE
+ * should be size_t or float.
*
* @details This function creates an ACL tensor using the provided data pointer,
* data type, dimensions, strides, format, offset, and additional parameters.
@@ -83,10 +87,34 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = null
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
* @return Pointer to the created ACL tensor.
*/
+template<typename TYPE>
aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
- size_t type_size, int64_t* ne, size_t* nb,
- int64_t dims, aclFormat format = ACL_FORMAT_ND,
- size_t offset = 0);
+ TYPE type_size, int64_t* ne, TYPE* nb,
+ int64_t dims,
+ aclFormat format = ACL_FORMAT_ND,
+ size_t offset = 0) {
+ int64_t tmp_ne[GGML_MAX_DIMS * 2];
+ int64_t tmp_stride[GGML_MAX_DIMS * 2];
+
+ memcpy(tmp_ne, ne, dims * sizeof(int64_t));
+ for (int i = 0; i < dims; i++) {
+ tmp_stride[i] = nb[i] / type_size;
+ }
+
+ std::reverse(tmp_ne, tmp_ne + dims);
+ std::reverse(tmp_stride, tmp_stride + dims);
+
+ int64_t acl_storage_len = 0;
+ for (int i = 0; i < dims; i++) {
+ acl_storage_len += (ne[i] - 1) * nb[i];
+ }
+
+ aclTensor* acl_tensor =
+ aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size,
+ format, &acl_storage_len, 1, data_ptr);
+
+ return acl_tensor;
+}
/**
* @brief Checks if tensors require broadcasting based on their shapes.