diff options
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 390a1cd7..b0b16dbf 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2353,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute( { GGML_ASSERT(src0->type == GGML_TYPE_F32); - const int sf = dst->op_params[0]; + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; @@ -2376,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - [encoder setBytes:&sf length:sizeof(sf) atIndex:18]; + [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; + [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; + [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; + [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); |