diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 35 |
1 files changed, 21 insertions, 14 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4cd18a28..d82466e0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -15911,11 +15911,14 @@ static void ggml_compute_forward_get_rows_f16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + if (i01 >= 0 && i01 < ne01) { + ggml_fp16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } else { + memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float)); + } - ggml_fp16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } } @@ -15952,11 +15955,13 @@ static void ggml_compute_forward_get_rows_bf16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); - - ggml_bf16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + if (i01 >= 0 && i01 < ne01) { + ggml_bf16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } else { + memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float)); + } } } @@ -15993,11 +15998,13 @@ static void ggml_compute_forward_get_rows_f32( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); - - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), - (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + if (i01 >= 0 && i01 < ne01) { + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + } else { + memset((char *)dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float)); + } } } |