diff options
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 30 |
1 files changed, 26 insertions, 4 deletions
@@ -2354,6 +2354,10 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { } void ggml_free(struct ggml_context * ctx) { + if (ctx == NULL) { + return; + } + // make this function thread safe ggml_critical_section_start(); @@ -4362,6 +4366,23 @@ struct ggml_tensor * ggml_cpy( return ggml_cpy_impl(ctx, a, b); } +struct ggml_tensor * ggml_cast( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_type type) { + bool is_node = false; + + struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne); + ggml_format_name(result, "%s (copy)", a->name); + + result->op = GGML_OP_CPY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = result; + + return result; +} + // ggml_cont static struct ggml_tensor * ggml_cont_impl( @@ -14871,7 +14892,7 @@ size_t ggml_hash_find_or_insert(struct ggml_hash_set hash_set, struct ggml_tenso return i; } -static struct ggml_hash_set ggml_hash_set_new(size_t size) { +struct ggml_hash_set ggml_hash_set_new(size_t size) { size = ggml_hash_size(size); struct ggml_hash_set result; result.size = size; @@ -16620,7 +16641,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { return GGML_EXIT_SUCCESS; } -struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { +struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) { if (n_threads <= 0) { n_threads = GGML_DEFAULT_N_THREADS; } @@ -16682,14 +16703,15 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { } break; case GGML_OP_MUL_MAT_ID: { + cur = 0; const struct ggml_tensor * src0 = node->src[2]; const struct ggml_tensor * src1 = node->src[1]; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; if (src1->type != vec_dot_type) { - cur = ggml_row_size(vec_dot_type, ggml_nelements(src1)); + cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); } const int n_as = ggml_get_op_params_i32(node, 1); - cur = GGML_PAD(cur, sizeof(int64_t)); // align + cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows } break; |