From 7d43c585dc174bb586775c22c15e5db9242b5b4b Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 3 Mar 2024 20:23:52 +0800 Subject: add some new ops, fix some operators and add batch operations to certain operators. (ggml/747) * cuda: fix group_norm * cuda: add batch inference support for ggml_pad/ggml_upscale * add ggml_arrange * add ggml_timestep_embedding * update ggml_arange/ggml_timestep_embedding tests * cuda: fix im2col * add ggml_arange/ggml_timestep_embbeding support for metal backend * fix some bugs * fix some bugs * Update ggml.h Co-authored-by: Georgi Gerganov * Update ggml-cuda.cu Co-authored-by: Georgi Gerganov * Update ggml-metal.m Co-authored-by: Georgi Gerganov * Update ggml-metal.m Co-authored-by: Georgi Gerganov * Update ggml-metal.metal Co-authored-by: Georgi Gerganov * modify according to the review comments * ggml : fix compile warnings + code style * ggml : normalize compute_forward calls + fix seg fault in debug * minor --------- Co-authored-by: Georgi Gerganov Co-authored-by: slaren --- ggml.h | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'ggml.h') diff --git a/ggml.h b/ggml.h index 0a6d3c05..98cfc7bf 100644 --- a/ggml.h +++ b/ggml.h @@ -454,6 +454,8 @@ extern "C" { GGML_OP_POOL_2D, GGML_OP_UPSCALE, // nearest interpolate GGML_OP_PAD, + GGML_OP_ARANGE, + GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, @@ -1661,6 +1663,15 @@ extern "C" { int p2, int p3); + // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 + // timesteps: [N,] + // return: [N, dim] + GGML_API struct ggml_tensor * ggml_timestep_embedding( + struct ggml_context * ctx, + struct ggml_tensor * timesteps, + int dim, + int max_period); + // sort rows enum ggml_sort_order { GGML_SORT_ORDER_ASC, @@ -1672,6 +1683,12 @@ extern "C" { struct ggml_tensor * a, enum ggml_sort_order order); + GGML_API struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step); + // top k elements per row GGML_API struct ggml_tensor * ggml_top_k( struct ggml_context * ctx, -- cgit v1.2.3