diff options
author | Elton Kola <eltonkola@gmail.com> | 2024-05-14 03:30:30 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-14 17:30:30 +1000 |
commit | efc8f767c8c8c749a245dd96ad4e2f37c164b54c (patch) | |
tree | 4c2910c9f7b3fcd27fbc1d04dd5188be962a8a5d /examples/llama.android/llama/src/main/cpp | |
parent | e0f556186b6e1f2b7032a1479edf5e89e2b1bd86 (diff) |
move ndk code to a new library (#6951)
Diffstat (limited to 'examples/llama.android/llama/src/main/cpp')
-rw-r--r-- | examples/llama.android/llama/src/main/cpp/CMakeLists.txt | 49 | ||||
-rw-r--r-- | examples/llama.android/llama/src/main/cpp/llama-android.cpp | 443 |
2 files changed, 492 insertions, 0 deletions
diff --git a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt b/examples/llama.android/llama/src/main/cpp/CMakeLists.txt new file mode 100644 index 00000000..42ebaad4 --- /dev/null +++ b/examples/llama.android/llama/src/main/cpp/CMakeLists.txt @@ -0,0 +1,49 @@ +# For more information about using CMake with Android Studio, read the +# documentation: https://d.android.com/studio/projects/add-native-code.html. +# For more examples on how to use CMake, see https://github.com/android/ndk-samples. + +# Sets the minimum CMake version required for this project. +cmake_minimum_required(VERSION 3.22.1) + +# Declares the project name. The project name can be accessed via ${ PROJECT_NAME}, +# Since this is the top level CMakeLists.txt, the project name is also accessible +# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level +# build script scope). +project("llama-android") + +include(FetchContent) +FetchContent_Declare( + llama + GIT_REPOSITORY https://github.com/ggerganov/llama.cpp + GIT_TAG master +) + +# Also provides "common" +FetchContent_MakeAvailable(llama) + +# Creates and names a library, sets it as either STATIC +# or SHARED, and provides the relative paths to its source code. +# You can define multiple libraries, and CMake builds them for you. +# Gradle automatically packages shared libraries with your APK. +# +# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define +# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME} +# is preferred for the same purpose. +# +# In order to load a library into your app from Java/Kotlin, you must call +# System.loadLibrary() and pass the name of the library defined here; +# for GameActivity/NativeActivity derived applications, the same library name must be +# used in the AndroidManifest.xml file. +add_library(${CMAKE_PROJECT_NAME} SHARED + # List C/C++ source files with relative paths to this CMakeLists.txt. + llama-android.cpp) + +# Specifies libraries CMake should link to your target library. You +# can link libraries from various origins, such as libraries defined in this +# build script, prebuilt third-party libraries, or Android system libraries. +target_link_libraries(${CMAKE_PROJECT_NAME} + # List libraries link to the target library + llama + common + android + log) diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp new file mode 100644 index 00000000..874158ef --- /dev/null +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -0,0 +1,443 @@ +#include <android/log.h> +#include <jni.h> +#include <iomanip> +#include <math.h> +#include <string> +#include <unistd.h> +#include "llama.h" +#include "common/common.h" + +// Write C++ code here. +// +// Do not forget to dynamically load the C++ library into your application. +// +// For instance, +// +// In MainActivity.java: +// static { +// System.loadLibrary("llama-android"); +// } +// +// Or, in MainActivity.kt: +// companion object { +// init { +// System.loadLibrary("llama-android") +// } +// } + +#define TAG "llama-android.cpp" +#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) +#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) + +jclass la_int_var; +jmethodID la_int_var_value; +jmethodID la_int_var_inc; + +std::string cached_token_chars; + +bool is_valid_utf8(const char * string) { + if (!string) { + return true; + } + + const unsigned char * bytes = (const unsigned char *)string; + int num; + + while (*bytes != 0x00) { + if ((*bytes & 0x80) == 0x00) { + // U+0000 to U+007F + num = 1; + } else if ((*bytes & 0xE0) == 0xC0) { + // U+0080 to U+07FF + num = 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // U+0800 to U+FFFF + num = 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // U+10000 to U+10FFFF + num = 4; + } else { + return false; + } + + bytes += 1; + for (int i = 1; i < num; ++i) { + if ((*bytes & 0xC0) != 0x80) { + return false; + } + bytes += 1; + } + } + + return true; +} + +static void log_callback(ggml_log_level level, const char * fmt, void * data) { + if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); + else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); + else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data); + else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data); +} + +extern "C" +JNIEXPORT jlong JNICALL +Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) { + llama_model_params model_params = llama_model_default_params(); + + auto path_to_model = env->GetStringUTFChars(filename, 0); + LOGi("Loading model from %s", path_to_model); + + auto model = llama_load_model_from_file(path_to_model, model_params); + env->ReleaseStringUTFChars(filename, path_to_model); + + if (!model) { + LOGe("load_model() failed"); + env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed"); + return 0; + } + + return reinterpret_cast<jlong>(model); +} + +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { + llama_free_model(reinterpret_cast<llama_model *>(model)); +} + +extern "C" +JNIEXPORT jlong JNICALL +Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) { + auto model = reinterpret_cast<llama_model *>(jmodel); + + if (!model) { + LOGe("new_context(): model cannot be null"); + env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null"); + return 0; + } + + int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2)); + LOGi("Using %d threads", n_threads); + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.seed = 1234; + ctx_params.n_ctx = 2048; + ctx_params.n_threads = n_threads; + ctx_params.n_threads_batch = n_threads; + + llama_context * context = llama_new_context_with_model(model, ctx_params); + + if (!context) { + LOGe("llama_new_context_with_model() returned null)"); + env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), + "llama_new_context_with_model() returned null)"); + return 0; + } + + return reinterpret_cast<jlong>(context); +} + +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) { + llama_free(reinterpret_cast<llama_context *>(context)); +} + +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) { + llama_backend_free(); +} + +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) { + llama_log_set(log_callback, NULL); +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_android_llama_cpp_LLamaAndroid_bench_1model( + JNIEnv *env, + jobject, + jlong context_pointer, + jlong model_pointer, + jlong batch_pointer, + jint pp, + jint tg, + jint pl, + jint nr + ) { + auto pp_avg = 0.0; + auto tg_avg = 0.0; + auto pp_std = 0.0; + auto tg_std = 0.0; + + const auto context = reinterpret_cast<llama_context *>(context_pointer); + const auto model = reinterpret_cast<llama_model *>(model_pointer); + const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); + + const int n_ctx = llama_n_ctx(context); + + LOGi("n_ctx = %d", n_ctx); + + int i, j; + int nri; + for (nri = 0; nri < nr; nri++) { + LOGi("Benchmark prompt processing (pp)"); + + llama_batch_clear(*batch); + + const int n_tokens = pp; + for (i = 0; i < n_tokens; i++) { + llama_batch_add(*batch, 0, i, { 0 }, false); + } + + batch->logits[batch->n_tokens - 1] = true; + llama_kv_cache_clear(context); + + const auto t_pp_start = ggml_time_us(); + if (llama_decode(context, *batch) != 0) { + LOGi("llama_decode() failed during prompt processing"); + } + const auto t_pp_end = ggml_time_us(); + + // bench text generation + + LOGi("Benchmark text generation (tg)"); + + llama_kv_cache_clear(context); + const auto t_tg_start = ggml_time_us(); + for (i = 0; i < tg; i++) { + + llama_batch_clear(*batch); + for (j = 0; j < pl; j++) { + llama_batch_add(*batch, 0, i, { j }, true); + } + + LOGi("llama_decode() text generation: %d", i); + if (llama_decode(context, *batch) != 0) { + LOGi("llama_decode() failed during text generation"); + } + } + + const auto t_tg_end = ggml_time_us(); + + llama_kv_cache_clear(context); + + const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; + const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; + + const auto speed_pp = double(pp) / t_pp; + const auto speed_tg = double(pl * tg) / t_tg; + + pp_avg += speed_pp; + tg_avg += speed_tg; + + pp_std += speed_pp * speed_pp; + tg_std += speed_tg * speed_tg; + + LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); + } + + pp_avg /= double(nr); + tg_avg /= double(nr); + + if (nr > 1) { + pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1)); + tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1)); + } else { + pp_std = 0; + tg_std = 0; + } + + char model_desc[128]; + llama_model_desc(model, model_desc, sizeof(model_desc)); + + const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0; + const auto model_n_params = double(llama_model_n_params(model)) / 1e9; + + const auto backend = "(Android)"; // TODO: What should this be? + + std::stringstream result; + result << std::setprecision(2); + result << "| model | size | params | backend | test | t/s |\n"; + result << "| --- | --- | --- | --- | --- | --- |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; + + return env->NewStringUTF(result.str().c_str()); +} + +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { + llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer)); +} + +extern "C" +JNIEXPORT jlong JNICALL +Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { + + // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. + + llama_batch *batch = new llama_batch { + 0, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + 0, + 0, + 0, + }; + + if (embd) { + batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); + } else { + batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); + } + + batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); + for (int i = 0; i < n_tokens; ++i) { + batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return reinterpret_cast<jlong>(batch); +} + +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) { + llama_backend_init(); +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) { + return env->NewStringUTF(llama_print_system_info()); +} + +extern "C" +JNIEXPORT jint JNICALL +Java_android_llama_cpp_LLamaAndroid_completion_1init( + JNIEnv *env, + jobject, + jlong context_pointer, + jlong batch_pointer, + jstring jtext, + jint n_len + ) { + + cached_token_chars.clear(); + + const auto text = env->GetStringUTFChars(jtext, 0); + const auto context = reinterpret_cast<llama_context *>(context_pointer); + const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); + + const auto tokens_list = llama_tokenize(context, text, 1); + + auto n_ctx = llama_n_ctx(context); + auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); + + LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); + + if (n_kv_req > n_ctx) { + LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); + } + + for (auto id : tokens_list) { + LOGi("%s", llama_token_to_piece(context, id).c_str()); + } + + llama_batch_clear(*batch); + + // evaluate the initial prompt + for (auto i = 0; i < tokens_list.size(); i++) { + llama_batch_add(*batch, tokens_list[i], i, { 0 }, false); + } + + // llama_decode will output logits only for the last token of the prompt + batch->logits[batch->n_tokens - 1] = true; + + if (llama_decode(context, *batch) != 0) { + LOGe("llama_decode() failed"); + } + + env->ReleaseStringUTFChars(jtext, text); + + return batch->n_tokens; +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_android_llama_cpp_LLamaAndroid_completion_1loop( + JNIEnv * env, + jobject, + jlong context_pointer, + jlong batch_pointer, + jint n_len, + jobject intvar_ncur +) { + const auto context = reinterpret_cast<llama_context *>(context_pointer); + const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); + const auto model = llama_get_model(context); + + if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); + if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); + if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); + + auto n_vocab = llama_n_vocab(model); + auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); + + std::vector<llama_token_data> candidates; + candidates.reserve(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // sample the most likely token + const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); + + const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); + if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { + return env->NewStringUTF(""); + } + + auto new_token_chars = llama_token_to_piece(context, new_token_id); + cached_token_chars += new_token_chars; + + jstring new_token = nullptr; + if (is_valid_utf8(cached_token_chars.c_str())) { + new_token = env->NewStringUTF(cached_token_chars.c_str()); + LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id); + cached_token_chars.clear(); + } else { + new_token = env->NewStringUTF(""); + } + + llama_batch_clear(*batch); + llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true); + + env->CallVoidMethod(intvar_ncur, la_int_var_inc); + + if (llama_decode(context, *batch) != 0) { + LOGe("llama_decode() returned null"); + } + + return new_token; +} + +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { + llama_kv_cache_clear(reinterpret_cast<llama_context *>(context)); +} |