summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-sycl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-sycl.cpp')
-rw-r--r--ggml/src/ggml-sycl.cpp32
1 files changed, 24 insertions, 8 deletions
diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp
index 36518ff9..d8eb86c2 100644
--- a/ggml/src/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl.cpp
@@ -1723,7 +1723,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
});
});
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
}
@@ -2075,8 +2075,8 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
// GGML_SYCL_DEBUG("current device index %d\n", id);
src_ptr = (char *) extra->data_device[id];
} else {
- // GGML_SYCL_DEBUG("GGML_ASSERT(false)\n");
- GGML_ASSERT(false);
+ // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
+ GGML_ABORT("fatal error");
}
char * dst_ptr = (char *) dst;
@@ -2163,7 +2163,7 @@ static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_te
default:
// TODO: k-quants
fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
break;
}
}
@@ -2192,7 +2192,7 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
}
@@ -2476,7 +2476,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
case GGML_TYPE_Q6_K:
return 64;
default:
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
}
@@ -3101,7 +3101,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
} else {
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
if (convert_src1_to_q8_1 && !src1_is_contiguous) {
@@ -3896,7 +3896,7 @@ static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
(void) dst;
@@ -3981,6 +3981,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
ggml_sycl_func_t func;
switch (tensor->op) {
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ func = ggml_sycl_op_conv_transpose_1d;
+ break;
case GGML_OP_REPEAT:
func = ggml_sycl_repeat;
break;
@@ -4105,6 +4108,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_ARGSORT:
func = ggml_sycl_argsort;
break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ func = ggml_sycl_op_timestep_embedding;
+ break;
default:
return false;
}
@@ -5090,6 +5096,15 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back
GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
switch (op->op) {
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ ggml_type src1_type = op->src[1]->type;
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ return false;
+ } break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU:
@@ -5213,6 +5228,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_LEAKY_RELU:
+ case GGML_OP_TIMESTEP_EMBEDDING:
return true;
default:
return false;