diff options
author | Brian <mofosyne@gmail.com> | 2024-05-14 23:10:39 +1000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-14 16:10:39 +0300 |
commit | 1265c670fd8e41e1947352c96c5179adda97fb2c (patch) | |
tree | 8ef6858e5701e5ff6e7a4e656e26589af5067963 /examples/llama.android/llama/src/main/cpp/llama-android.cpp | |
parent | 5e31828d3e35c76ecfee665bc23771a4bec1d130 (diff) |
Revert "move ndk code to a new library (#6951)" (#7282)
This reverts commit efc8f767c8c8c749a245dd96ad4e2f37c164b54c.
Diffstat (limited to 'examples/llama.android/llama/src/main/cpp/llama-android.cpp')
-rw-r--r-- | examples/llama.android/llama/src/main/cpp/llama-android.cpp | 443 |
1 files changed, 0 insertions, 443 deletions
diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp deleted file mode 100644 index 874158ef..00000000 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ /dev/null @@ -1,443 +0,0 @@ -#include <android/log.h> -#include <jni.h> -#include <iomanip> -#include <math.h> -#include <string> -#include <unistd.h> -#include "llama.h" -#include "common/common.h" - -// Write C++ code here. -// -// Do not forget to dynamically load the C++ library into your application. -// -// For instance, -// -// In MainActivity.java: -// static { -// System.loadLibrary("llama-android"); -// } -// -// Or, in MainActivity.kt: -// companion object { -// init { -// System.loadLibrary("llama-android") -// } -// } - -#define TAG "llama-android.cpp" -#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) -#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) - -jclass la_int_var; -jmethodID la_int_var_value; -jmethodID la_int_var_inc; - -std::string cached_token_chars; - -bool is_valid_utf8(const char * string) { - if (!string) { - return true; - } - - const unsigned char * bytes = (const unsigned char *)string; - int num; - - while (*bytes != 0x00) { - if ((*bytes & 0x80) == 0x00) { - // U+0000 to U+007F - num = 1; - } else if ((*bytes & 0xE0) == 0xC0) { - // U+0080 to U+07FF - num = 2; - } else if ((*bytes & 0xF0) == 0xE0) { - // U+0800 to U+FFFF - num = 3; - } else if ((*bytes & 0xF8) == 0xF0) { - // U+10000 to U+10FFFF - num = 4; - } else { - return false; - } - - bytes += 1; - for (int i = 1; i < num; ++i) { - if ((*bytes & 0xC0) != 0x80) { - return false; - } - bytes += 1; - } - } - - return true; -} - -static void log_callback(ggml_log_level level, const char * fmt, void * data) { - if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); - else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); - else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data); - else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data); -} - -extern "C" -JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) { - llama_model_params model_params = llama_model_default_params(); - - auto path_to_model = env->GetStringUTFChars(filename, 0); - LOGi("Loading model from %s", path_to_model); - - auto model = llama_load_model_from_file(path_to_model, model_params); - env->ReleaseStringUTFChars(filename, path_to_model); - - if (!model) { - LOGe("load_model() failed"); - env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed"); - return 0; - } - - return reinterpret_cast<jlong>(model); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { - llama_free_model(reinterpret_cast<llama_model *>(model)); -} - -extern "C" -JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) { - auto model = reinterpret_cast<llama_model *>(jmodel); - - if (!model) { - LOGe("new_context(): model cannot be null"); - env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null"); - return 0; - } - - int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2)); - LOGi("Using %d threads", n_threads); - - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.seed = 1234; - ctx_params.n_ctx = 2048; - ctx_params.n_threads = n_threads; - ctx_params.n_threads_batch = n_threads; - - llama_context * context = llama_new_context_with_model(model, ctx_params); - - if (!context) { - LOGe("llama_new_context_with_model() returned null)"); - env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), - "llama_new_context_with_model() returned null)"); - return 0; - } - - return reinterpret_cast<jlong>(context); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) { - llama_free(reinterpret_cast<llama_context *>(context)); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) { - llama_backend_free(); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) { - llama_log_set(log_callback, NULL); -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_bench_1model( - JNIEnv *env, - jobject, - jlong context_pointer, - jlong model_pointer, - jlong batch_pointer, - jint pp, - jint tg, - jint pl, - jint nr - ) { - auto pp_avg = 0.0; - auto tg_avg = 0.0; - auto pp_std = 0.0; - auto tg_std = 0.0; - - const auto context = reinterpret_cast<llama_context *>(context_pointer); - const auto model = reinterpret_cast<llama_model *>(model_pointer); - const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); - - const int n_ctx = llama_n_ctx(context); - - LOGi("n_ctx = %d", n_ctx); - - int i, j; - int nri; - for (nri = 0; nri < nr; nri++) { - LOGi("Benchmark prompt processing (pp)"); - - llama_batch_clear(*batch); - - const int n_tokens = pp; - for (i = 0; i < n_tokens; i++) { - llama_batch_add(*batch, 0, i, { 0 }, false); - } - - batch->logits[batch->n_tokens - 1] = true; - llama_kv_cache_clear(context); - - const auto t_pp_start = ggml_time_us(); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during prompt processing"); - } - const auto t_pp_end = ggml_time_us(); - - // bench text generation - - LOGi("Benchmark text generation (tg)"); - - llama_kv_cache_clear(context); - const auto t_tg_start = ggml_time_us(); - for (i = 0; i < tg; i++) { - - llama_batch_clear(*batch); - for (j = 0; j < pl; j++) { - llama_batch_add(*batch, 0, i, { j }, true); - } - - LOGi("llama_decode() text generation: %d", i); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during text generation"); - } - } - - const auto t_tg_end = ggml_time_us(); - - llama_kv_cache_clear(context); - - const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; - const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; - - const auto speed_pp = double(pp) / t_pp; - const auto speed_tg = double(pl * tg) / t_tg; - - pp_avg += speed_pp; - tg_avg += speed_tg; - - pp_std += speed_pp * speed_pp; - tg_std += speed_tg * speed_tg; - - LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); - } - - pp_avg /= double(nr); - tg_avg /= double(nr); - - if (nr > 1) { - pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1)); - tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1)); - } else { - pp_std = 0; - tg_std = 0; - } - - char model_desc[128]; - llama_model_desc(model, model_desc, sizeof(model_desc)); - - const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0; - const auto model_n_params = double(llama_model_n_params(model)) / 1e9; - - const auto backend = "(Android)"; // TODO: What should this be? - - std::stringstream result; - result << std::setprecision(2); - result << "| model | size | params | backend | test | t/s |\n"; - result << "| --- | --- | --- | --- | --- | --- |\n"; - result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; - result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; - - return env->NewStringUTF(result.str().c_str()); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { - llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer)); -} - -extern "C" -JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { - - // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. - - llama_batch *batch = new llama_batch { - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - 0, - 0, - 0, - }; - - if (embd) { - batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); - } else { - batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); - } - - batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); - batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); - for (int i = 0; i < n_tokens; ++i) { - batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); - } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); - - return reinterpret_cast<jlong>(batch); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) { - llama_backend_init(); -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) { - return env->NewStringUTF(llama_print_system_info()); -} - -extern "C" -JNIEXPORT jint JNICALL -Java_android_llama_cpp_LLamaAndroid_completion_1init( - JNIEnv *env, - jobject, - jlong context_pointer, - jlong batch_pointer, - jstring jtext, - jint n_len - ) { - - cached_token_chars.clear(); - - const auto text = env->GetStringUTFChars(jtext, 0); - const auto context = reinterpret_cast<llama_context *>(context_pointer); - const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); - - const auto tokens_list = llama_tokenize(context, text, 1); - - auto n_ctx = llama_n_ctx(context); - auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); - - LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); - - if (n_kv_req > n_ctx) { - LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); - } - - for (auto id : tokens_list) { - LOGi("%s", llama_token_to_piece(context, id).c_str()); - } - - llama_batch_clear(*batch); - - // evaluate the initial prompt - for (auto i = 0; i < tokens_list.size(); i++) { - llama_batch_add(*batch, tokens_list[i], i, { 0 }, false); - } - - // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; - - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() failed"); - } - - env->ReleaseStringUTFChars(jtext, text); - - return batch->n_tokens; -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_completion_1loop( - JNIEnv * env, - jobject, - jlong context_pointer, - jlong batch_pointer, - jint n_len, - jobject intvar_ncur -) { - const auto context = reinterpret_cast<llama_context *>(context_pointer); - const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); - const auto model = llama_get_model(context); - - if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); - if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); - if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); - - auto n_vocab = llama_n_vocab(model); - auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); - - std::vector<llama_token_data> candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - // sample the most likely token - const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); - - const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); - if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { - return env->NewStringUTF(""); - } - - auto new_token_chars = llama_token_to_piece(context, new_token_id); - cached_token_chars += new_token_chars; - - jstring new_token = nullptr; - if (is_valid_utf8(cached_token_chars.c_str())) { - new_token = env->NewStringUTF(cached_token_chars.c_str()); - LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id); - cached_token_chars.clear(); - } else { - new_token = env->NewStringUTF(""); - } - - llama_batch_clear(*batch); - llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true); - - env->CallVoidMethod(intvar_ncur, la_int_var_inc); - - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() returned null"); - } - - return new_token; -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_kv_cache_clear(reinterpret_cast<llama_context *>(context)); -} |