summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m10
1 files changed, 8 insertions, 2 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 390a1cd7..b0b16dbf 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -2353,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute(
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
- const int sf = dst->op_params[0];
+ const float sf0 = (float)ne0/src0->ne[0];
+ const float sf1 = (float)ne1/src0->ne[1];
+ const float sf2 = (float)ne2/src0->ne[2];
+ const float sf3 = (float)ne3/src0->ne[3];
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
@@ -2376,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
+ [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
+ [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
+ [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
+ [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);