diff --git a/.github/workflows/native_s3_llama.yml b/.github/workflows/native_s3_llama.yml deleted file mode 100644 index 2900a591179..00000000000 --- a/.github/workflows/native_s3_llama.yml +++ /dev/null @@ -1,204 +0,0 @@ -name: Native S3 llama.cpp - -on: - workflow_dispatch: - -jobs: - build-llamacpp-jni-osx: - runs-on: macos-13 - steps: - - uses: actions/checkout@v4 - - name: Set up JDK 17 - uses: actions/setup-java@v4 - with: - distribution: 'corretto' - java-version: 17 - - uses: actions/cache@v4 - with: - path: ~/.gradle/caches - key: ${{ runner.os }}-gradle-${{ hashFiles('*/build.gradle.kts', 'engines/**/build.gradle.kts', 'extensions/**/build.gradle.kts') }} - restore-keys: | - ${{ runner.os }}-gradle- - - name: Release JNI prep - run: | - ./gradlew :engines:llama:compileJNI - ./gradlew -Pjni :engines:llama:test -Dnightly=true - - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v4 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: us-east-2 - - name: Copy files to S3 with the AWS CLI - run: | - LLAMACPP_VERSION="$(awk -F '=' '/llamacpp/ {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)" - aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ - aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" - - build-llamacpp-jni-linux: - runs-on: ubuntu-latest - container: centos:centos7 - steps: - - name: Install Environment - run: | - yum -y update - yum -y install centos-release-scl-rh epel-release perl-core - yum -y install devtoolset-7 git patch cmake3 libstdc++-static - ln -s /usr/bin/cmake3 /usr/bin/cmake - pip3 install awscli --upgrade - - uses: actions/checkout@v3 - - name: Set up JDK 17 - uses: actions/setup-java@v3 - with: - distribution: 'corretto' - java-version: 17 - - name: Release JNI prep - run: | - export PATH=$PATH:/opt/rh/devtoolset-7/root/usr/bin - ./gradlew :engines:llama:compileJNI - ./gradlew -Pjni :engines:llama:test -Dnightly=true - - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v2 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: us-east-2 - - name: Copy files to S3 with the AWS CLI - run: | - LLAMACPP_VERSION="$(awk -F '=' '/llamacpp/ {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)" - aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ - aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" - - build-llamacpp-jni-windows: - runs-on: windows-latest - steps: - - name: Install Environment - run: | - choco install -y mingw - - uses: actions/checkout@v4 - - name: Set up JDK 17 - uses: actions/setup-java@v4 - with: - distribution: 'corretto' - java-version: 17 - - uses: actions/cache@v4 - with: - path: ~/.gradle/caches - key: ${{ runner.os }}-gradle-${{ hashFiles('*/build.gradle.kts', 'engines/**/build.gradle.kts', 'extensions/**/build.gradle.kts') }} - restore-keys: | - ${{ runner.os }}-gradle- - - name: Release CPU JNI - shell: cmd - run: | - call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" amd64 - gradlew :engines:llama:compileJNI - gradlew -Pjni :engines:llama:test -Dnightly=true - - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v4 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: us-east-2 - - name: Copy files to S3 with the AWS CLI - shell: bash - run: | - LLAMACPP_VERSION="$(awk -F '=' '/llamacpp/ {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)" - aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ - aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" - - build-llamacpp-jni-arm64-osx: - if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} - runs-on: macos-latest-xlarge - steps: - - uses: actions/checkout@v4 - - name: Set up JDK 17 - uses: actions/setup-java@v4 - with: - java-version: 17 - distribution: corretto - architecture: aarch64 - - uses: actions/cache@v4 - with: - path: ~/.gradle/caches - key: ${{ runner.os }}-gradle-${{ hashFiles('*/build.gradle.kts', 'engines/**/build.gradle.kts', 'extensions/**/build.gradle.kts') }} - restore-keys: | - ${{ runner.os }}-gradle- - - name: Release JNI prep - run: | - ./gradlew :engines:llama:compileJNI - ./gradlew -Pjni :engines:llama:test -Dnightly=true - - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v4 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: us-east-2 - - name: Copy files to S3 with the AWS CLI - run: | - LLAMACPP_VERSION="$(awk -F '=' '/llamacpp/ {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)" - aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ - aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" - - create-aarch64-runner: - if: github.repository == 'deepjavalibrary/djl' - runs-on: [ self-hosted, scheduler ] - steps: - - name: Create new Graviton instance - id: create_aarch64 - run: | - cd /home/ubuntu/djl_benchmark_script/scripts - token=$( curl -X POST -H "Authorization: token ${{ secrets.ACTION_RUNNER_PERSONAL_TOKEN }}" \ - https://api.github.com/repos/deepjavalibrary/djl/actions/runners/registration-token \ - --fail \ - | jq '.token' | tr -d '"' ) - ./start_instance.sh action_graviton $token djl - outputs: - aarch64_instance_id: ${{ steps.create_aarch64.outputs.action_graviton_instance_id }} - - build-llamacpp-jni-aarch64: - if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} - runs-on: [ self-hosted, aarch64 ] - timeout-minutes: 30 - needs: create-aarch64-runner - container: amazonlinux:2 - steps: - - name: Install Environment - run: | - yum -y update - yum -y groupinstall "Development Tools" - yum -y install patch perl-IPC-Cmd cmake3 - ln -s /usr/bin/cmake3 /usr/bin/cmake - pip3 install awscli --upgrade - - uses: actions/checkout@v3 - - name: Set up JDK 17 - uses: actions/setup-java@v3 - with: - java-version: 17 - distribution: corretto - architecture: aarch64 - - name: Release JNI prep - run: | - ./gradlew :engines:llama:compileJNI - ./gradlew -Pjni :engines:llama:test -Dnightly=true - - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v2 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: us-east-2 - - name: Copy files to S3 with the AWS CLI - run: | - LLAMACPP_VERSION="$(awk -F '=' '/llamacpp/ {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)" - aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ - aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" - - stop-runners: - if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} - runs-on: [ self-hosted, scheduler ] - needs: [ create-aarch64-runner, build-llamacpp-jni-aarch64 ] - steps: - - name: Stop all instances - run: | - cd /home/ubuntu/djl_benchmark_script/scripts - instance_id=${{ needs.create-aarch64-runner.outputs.aarch64_instance_id }} - ./stop_instance.sh $instance_id diff --git a/bom/build.gradle.kts b/bom/build.gradle.kts index 23f996462ee..72bbfb353f4 100644 --- a/bom/build.gradle.kts +++ b/bom/build.gradle.kts @@ -24,7 +24,6 @@ dependencies { api("ai.djl.fasttext:fasttext-engine:${version}") api("ai.djl.hadoop:hadoop:${version}") api("ai.djl.huggingface:tokenizers:${version}") - api("ai.djl.llama:llama:${version}") api("ai.djl.ml.lightgbm:lightgbm:${version}") api("ai.djl.ml.xgboost:xgboost-gpu:${version}") api("ai.djl.ml.xgboost:xgboost:${version}") diff --git a/engines/llama/.gitignore b/engines/llama/.gitignore deleted file mode 100644 index 3428b3b2f53..00000000000 --- a/engines/llama/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -jnilib/ -llama.cpp/ -models/ diff --git a/engines/llama/CMakeLists.txt b/engines/llama/CMakeLists.txt deleted file mode 100644 index d1fc8131db8..00000000000 --- a/engines/llama/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -cmake_minimum_required(VERSION 3.12 FATAL_ERROR) - -project(djl_llama CXX) - -set(CMAKE_POSITION_INDEPENDENT_CODE ON) -set(BUILD_SHARED_LIBS ON) - -set(JAVA_AWT_LIBRARY NotNeeded) -set(JAVA_AWT_INCLUDE_PATH NotNeeded) -find_package(JNI REQUIRED) - -add_subdirectory(llama.cpp) -include(build-args.cmake) -add_library(djl_llama SHARED src/main/native/ai_djl_llama.cpp) - -target_include_directories(djl_llama PRIVATE - ${JNI_INCLUDE_DIRS} - src/main/native - llama.cpp - llama.cpp/common - build/include) -target_link_libraries(djl_llama PRIVATE common llama ${LLAMA_EXTRA_LIBS}) -target_compile_features(djl_llama PRIVATE cxx_std_11) diff --git a/engines/llama/build-args.cmake b/engines/llama/build-args.cmake deleted file mode 100644 index dee0db659cd..00000000000 --- a/engines/llama/build-args.cmake +++ /dev/null @@ -1,639 +0,0 @@ -if (APPLE) - set(LLAMA_METAL_DEFAULT ON) -else() - set(LLAMA_METAL_DEFAULT OFF) -endif() - -# general -option(LLAMA_NATIVE "llama: enable -march=native flag" ON) - -# instruction set specific -if (LLAMA_NATIVE) - set(INS_ENB OFF) -else() - set(INS_ENB ON) -endif() - -option(LLAMA_AVX "llama: enable AVX" ${INS_ENB}) -option(LLAMA_AVX2 "llama: enable AVX2" ${INS_ENB}) -option(LLAMA_AVX512 "llama: enable AVX512" OFF) -option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) -option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) -option(LLAMA_FMA "llama: enable FMA" ${INS_ENB}) -# in MSVC F16C is implied with AVX2/AVX512 -if (NOT MSVC) - option(LLAMA_F16C "llama: enable F16C" ${INS_ENB}) -endif() - -# 3rd party libs -option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) -option(LLAMA_BLAS "llama: use BLAS" OFF) -set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") -option(LLAMA_CUBLAS "llama: use CUDA" OFF) -#option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF) -option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) -option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF) -set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") -set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") -option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) -set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") -set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING - "llama: max. batch size for using peer access") -option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) -option(LLAMA_CLBLAST "llama: use CLBlast" OFF) -option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) -option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) -option(LLAMA_MPI "llama: use MPI" OFF) -option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) - - -# -# Compile flags -# - -set(CMAKE_CXX_STANDARD 11) -set(CMAKE_CXX_STANDARD_REQUIRED true) -set(CMAKE_C_STANDARD 11) -set(CMAKE_C_STANDARD_REQUIRED true) -set(THREADS_PREFER_PTHREAD_FLAG ON) -find_package(Threads REQUIRED) -include(CheckCXXCompilerFlag) - -# enable libstdc++ assertions for debug builds -if (CMAKE_SYSTEM_NAME MATCHES "Linux") - add_compile_definitions($<$:_GLIBCXX_ASSERTIONS>) -endif() - -if (NOT MSVC) - if (LLAMA_SANITIZE_THREAD) - add_compile_options(-fsanitize=thread) - link_libraries(-fsanitize=thread) - endif() - - if (LLAMA_SANITIZE_ADDRESS) - add_compile_options(-fsanitize=address -fno-omit-frame-pointer) - link_libraries(-fsanitize=address) - endif() - - if (LLAMA_SANITIZE_UNDEFINED) - add_compile_options(-fsanitize=undefined) - link_libraries(-fsanitize=undefined) - endif() -endif() - -if (APPLE AND LLAMA_ACCELERATE) - find_library(ACCELERATE_FRAMEWORK Accelerate) - if (ACCELERATE_FRAMEWORK) - message(STATUS "Accelerate framework found") - - add_compile_definitions(GGML_USE_ACCELERATE) - add_compile_definitions(ACCELERATE_NEW_LAPACK) - add_compile_definitions(ACCELERATE_LAPACK_ILP64) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) - else() - message(WARNING "Accelerate framework not found") - endif() -endif() - -if (LLAMA_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) - - message(STATUS "Metal framework found") - set(GGML_HEADERS_METAL ggml-metal.h) - set(GGML_SOURCES_METAL ggml-metal.m) - - add_compile_definitions(GGML_USE_METAL) - if (LLAMA_METAL_NDEBUG) - add_compile_definitions(GGML_METAL_NDEBUG) - endif() - - # get full path to the file - #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") - - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} - ${FOUNDATION_LIBRARY} - ${METAL_FRAMEWORK} - ${METALKIT_FRAMEWORK} - ) -endif() -if (LLAMA_BLAS) - if (LLAMA_STATIC) - set(BLA_STATIC ON) - endif() - if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) - set(BLA_SIZEOF_INTEGER 8) - endif() - - set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) - find_package(BLAS) - - if (BLAS_FOUND) - message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") - - if ("${BLAS_INCLUDE_DIRS}" STREQUAL "") - # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. - # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 - find_package(PkgConfig REQUIRED) - if (${LLAMA_BLAS_VENDOR} MATCHES "Generic") - pkg_check_modules(DepBLAS REQUIRED blas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS") - pkg_check_modules(DepBLAS REQUIRED openblas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME") - pkg_check_modules(DepBLAS REQUIRED blis) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS") - pkg_check_modules(DepBLAS REQUIRED blas-atlas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS") - pkg_check_modules(DepBLAS REQUIRED flexiblas_api) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel") - # all Intel* libraries share the same include path - pkg_check_modules(DepBLAS REQUIRED mkl-sdl) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC") - # this doesn't provide pkg-config - # suggest to assign BLAS_INCLUDE_DIRS on your own - if ("${NVHPC_VERSION}" STREQUAL "") - message(WARNING "Better to set NVHPC_VERSION") - else() - set(DepBLAS_FOUND ON) - set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") - endif() - endif() - if (DepBLAS_FOUND) - set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) - else() - message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" - " detected by pkgconfig, trying to find cblas.h from possible paths...") - find_path(BLAS_INCLUDE_DIRS - NAMES cblas.h - HINTS - /usr/include - /usr/local/include - /usr/include/openblas - /opt/homebrew/opt/openblas/include - /usr/local/opt/openblas/include - /usr/include/x86_64-linux-gnu/openblas/include - ) - endif() - endif() - - message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") - add_compile_options(${BLAS_LINKER_FLAGS}) - add_compile_definitions(GGML_USE_OPENBLAS) - if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel")) - add_compile_definitions(GGML_BLAS_USE_MKL) - endif() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) - set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) - - else() - message(WARNING "BLAS not found, please refer to " - "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" - " to set correct LLAMA_BLAS_VENDOR") - endif() -endif() - -if (LLAMA_QKK_64) - add_compile_definitions(GGML_QKK_64) -endif() - -if (LLAMA_CUBLAS) - cmake_minimum_required(VERSION 3.17) - - find_package(CUDAToolkit) - if (CUDAToolkit_FOUND) - message(STATUS "cuBLAS found") - - enable_language(CUDA) - - set(GGML_HEADERS_CUDA ggml-cuda.h) - set(GGML_SOURCES_CUDA ggml-cuda.cu) - - add_compile_definitions(GGML_USE_CUBLAS) -# if (LLAMA_CUDA_CUBLAS) -# add_compile_definitions(GGML_CUDA_CUBLAS) -# endif() - if (LLAMA_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - if (LLAMA_CUDA_FORCE_MMQ) - add_compile_definitions(GGML_CUDA_FORCE_MMQ) - endif() - add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) - if (DEFINED LLAMA_CUDA_DMMV_Y) - add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_DMMV_Y}) # for backwards compatibility - endif() - if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) - add_compile_definitions(GGML_CUDA_F16) - endif() - add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) - add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE}) - - if (LLAMA_STATIC) - if (WIN32) - # As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt) - else () - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) - endif() - else() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) - endif() - - if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - # 52 == lowest CUDA 12 standard - # 60 == f16 CUDA intrinsics - # 61 == integer CUDA intrinsics - # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster - if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics - else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics - #set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work - endif() - endif() - message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") - - else() - message(WARNING "cuBLAS not found") - endif() -endif() - -if (LLAMA_MPI) - cmake_minimum_required(VERSION 3.10) - find_package(MPI) - if (MPI_C_FOUND) - message(STATUS "MPI found") - set(GGML_HEADERS_MPI ggml-mpi.h) - set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) - add_compile_definitions(GGML_USE_MPI) - add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) - if (NOT MSVC) - add_compile_options(-Wno-cast-qual) - endif() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) - set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) - # Even if you're only using the C header, C++ programs may bring in MPI - # C++ functions, so more linkage is needed - if (MPI_CXX_FOUND) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_CXX_LIBRARIES}) - endif() - else() - message(WARNING "MPI not found") - endif() -endif() - -if (LLAMA_CLBLAST) - find_package(CLBlast) - if (CLBlast_FOUND) - message(STATUS "CLBlast found") - - set(GGML_HEADERS_OPENCL ggml-opencl.h) - set(GGML_SOURCES_OPENCL ggml-opencl.cpp) - - add_compile_definitions(GGML_USE_CLBLAST) - - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) - else() - message(WARNING "CLBlast not found") - endif() -endif() - -if (LLAMA_HIPBLAS) - list(APPEND CMAKE_PREFIX_PATH /opt/rocm) - - if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") - message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang") - endif() - if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") - message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") - endif() - - find_package(hip) - find_package(hipblas) - find_package(rocblas) - - if (${hipblas_FOUND} AND ${hip_FOUND}) - message(STATUS "HIP and hipBLAS found") - add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) - add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) - if (BUILD_SHARED_LIBS) - set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON) - endif() - if (LLAMA_CUDA_FORCE_DMMV) - target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV) - endif() - if (LLAMA_CUDA_FORCE_MMQ) - target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_MMQ) - endif() - target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) - target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) - target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) - set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) - target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) - - if (LLAMA_STATIC) - message(FATAL_ERROR "Static linking not supported for HIP/ROCm") - endif() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm) - else() - message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") - endif() -endif() - -function(get_flags CCID CCVER) - set(C_FLAGS "") - set(CXX_FLAGS "") - - if (CCID MATCHES "Clang") - set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return) - set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi) - - if ( - (CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR - (CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0) - ) - set(C_FLAGS ${C_FLAGS} -Wdouble-promotion) - endif() - elseif (CCID STREQUAL "GNU") - set(C_FLAGS -Wdouble-promotion) - set(CXX_FLAGS -Wno-array-bounds) - - if (CCVER VERSION_GREATER_EQUAL 7.1.0) - set(CXX_FLAGS ${CXX_FLAGS} -Wno-format-truncation) - endif() - if (CCVER VERSION_GREATER_EQUAL 8.1.0) - set(CXX_FLAGS ${CXX_FLAGS} -Wextra-semi) - endif() - endif() - - set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE) - set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE) -endfunction() - -if (LLAMA_ALL_WARNINGS) - if (NOT MSVC) - set(WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function) - set(C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes - -Werror=implicit-int -Werror=implicit-function-declaration) - set(CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn) - - set(C_FLAGS ${WARNING_FLAGS} ${C_FLAGS}) - set(CXX_FLAGS ${WARNING_FLAGS} ${CXX_FLAGS}) - - get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}) - - add_compile_options("$<$:${C_FLAGS};${GF_C_FLAGS}>" - "$<$:${CXX_FLAGS};${GF_CXX_FLAGS}>") - else() - # todo : msvc - set(C_FLAGS "") - set(CXX_FLAGS "") - endif() -endif() - -if (LLAMA_CUBLAS) - set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math) - if (NOT MSVC) - set(CUDA_FLAGS ${CUDA_FLAGS} -Wno-pedantic) - endif() - - if (LLAMA_ALL_WARNINGS AND NOT MSVC) - set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c) - if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "") - set(NVCC_CMD ${NVCC_CMD} -ccbin ${CMAKE_CUDA_HOST_COMPILER}) - endif() - - execute_process( - COMMAND ${NVCC_CMD} -Xcompiler --version - OUTPUT_VARIABLE CUDA_CCFULLVER - ERROR_QUIET - ) - - if (NOT CUDA_CCFULLVER MATCHES clang) - set(CUDA_CCID "GNU") - execute_process( - COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion" - OUTPUT_VARIABLE CUDA_CCVER - ERROR_QUIET - ) - else() - if (CUDA_CCFULLVER MATCHES Apple) - set(CUDA_CCID "AppleClang") - else() - set(CUDA_CCID "Clang") - endif() - string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER}) - endif() - - message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}") - - get_flags(${CUDA_CCID} ${CUDA_CCVER}) - list(JOIN GF_CXX_FLAGS " " CUDA_CXX_FLAGS) # pass host compiler flags as a single argument - if (NOT CUDA_CXX_FLAGS STREQUAL "") - set(CUDA_FLAGS ${CUDA_FLAGS} -Xcompiler ${CUDA_CXX_FLAGS}) - endif() - endif() - - add_compile_options("$<$:${CUDA_FLAGS}>") -endif() - -if (WIN32) - add_compile_definitions(_CRT_SECURE_NO_WARNINGS) - - if (BUILD_SHARED_LIBS) - set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) - endif() -endif() - -if (LLAMA_LTO) - include(CheckIPOSupported) - check_ipo_supported(RESULT result OUTPUT output) - if (result) - set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) - else() - message(WARNING "IPO is not supported: ${output}") - endif() -endif() - -# this version of Apple ld64 is buggy -execute_process( - COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v - ERROR_VARIABLE output - OUTPUT_QUIET -) -if (output MATCHES "dyld-1015\.7") - add_compile_definitions(HAVE_BUGGY_APPLE_LINKER) -endif() - -# Architecture specific -# TODO: probably these flags need to be tweaked on some architectures -# feel free to update the Makefile for your architecture and send a pull request or issue -message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") -if (MSVC) - string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR) - message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}") -else () - set(CMAKE_GENERATOR_PLATFORM_LWR "") -endif () - -if (NOT MSVC) - if (LLAMA_STATIC) - add_link_options(-static) - if (MINGW) - add_link_options(-static-libgcc -static-libstdc++) - endif() - endif() - if (LLAMA_GPROF) - add_compile_options(-pg) - endif() -endif() - -if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") OR ("${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "arm64")) - message(STATUS "ARM detected") - if (MSVC) - add_compile_definitions(__ARM_NEON) - add_compile_definitions(__ARM_FEATURE_FMA) - add_compile_definitions(__ARM_FEATURE_DOTPROD) - # add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) # MSVC doesn't support vdupq_n_f16, vld1q_f16, vst1q_f16 - add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead - else() - check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E) - if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "") - add_compile_options(-mfp16-format=ieee) - endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") - # Raspberry Pi 1, Zero - add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access) - endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") - # Raspberry Pi 2 - add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations) - endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") - # Raspberry Pi 3, 4, Zero 2 (32-bit) - add_compile_options(-mno-unaligned-access) - endif() - endif() -elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "^(x86_64|i686|amd64|x64)$" ) - message(STATUS "x86 detected") - if (MSVC) - # instruction set detection for MSVC only - if (LLAMA_NATIVE) - include(${llama.cpp_SOURCE_DIR}/cmake/FindSIMD.cmake) - endif () - if (LLAMA_AVX512) - add_compile_options($<$:/arch:AVX512>) - add_compile_options($<$:/arch:AVX512>) - # MSVC has no compile-time flags enabling specific - # AVX512 extensions, neither it defines the - # macros corresponding to the extensions. - # Do it manually. - if (LLAMA_AVX512_VBMI) - add_compile_definitions($<$:__AVX512VBMI__>) - add_compile_definitions($<$:__AVX512VBMI__>) - endif() - if (LLAMA_AVX512_VNNI) - add_compile_definitions($<$:__AVX512VNNI__>) - add_compile_definitions($<$:__AVX512VNNI__>) - endif() - elseif (LLAMA_AVX2) - add_compile_options($<$:/arch:AVX2>) - add_compile_options($<$:/arch:AVX2>) - elseif (LLAMA_AVX) - add_compile_options($<$:/arch:AVX>) - add_compile_options($<$:/arch:AVX>) - endif() - else() - if (LLAMA_NATIVE) - add_compile_options(-march=native) - endif() - if (LLAMA_F16C) - add_compile_options(-mf16c) - endif() - if (LLAMA_FMA) - add_compile_options(-mfma) - endif() - if (LLAMA_AVX) - add_compile_options(-mavx) - endif() - if (LLAMA_AVX2) - add_compile_options(-mavx2) - endif() - if (LLAMA_AVX512) - add_compile_options(-mavx512f) - add_compile_options(-mavx512bw) - endif() - if (LLAMA_AVX512_VBMI) - add_compile_options(-mavx512vbmi) - endif() - if (LLAMA_AVX512_VNNI) - add_compile_options(-mavx512vnni) - endif() - endif() -elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") - message(STATUS "PowerPC detected") - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") - add_compile_options(-mcpu=powerpc64le) - else() - add_compile_options(-mcpu=native -mtune=native) - #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) - endif() -else() - message(STATUS "Unknown architecture") -endif() - -if (MINGW) - # Target Windows 8 for PrefetchVirtualMemory - add_compile_definitions(_WIN32_WINNT=0x602) -endif() - -# -# POSIX conformance -# - -# clock_gettime came in POSIX.1b (1993) -# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional -# posix_memalign came in POSIX.1-2001 / SUSv3 -# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985) -add_compile_definitions(_XOPEN_SOURCE=600) - -# Somehow in OpenBSD whenever POSIX conformance is specified -# some string functions rely on locale_t availability, -# which was introduced in POSIX.1-2008, forcing us to go higher -if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") - remove_definitions(-D_XOPEN_SOURCE=600) - add_compile_definitions(_XOPEN_SOURCE=700) -endif() - -# Data types, macros and functions related to controlling CPU affinity and -# some memory allocation are available on Linux through GNU extensions in libc -if (CMAKE_SYSTEM_NAME MATCHES "Linux") - add_compile_definitions(_GNU_SOURCE) -endif() - -# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1, -# and on macOS its availability depends on enabling Darwin extensions -# similarly on DragonFly, enabling BSD extensions is necessary -if ( - CMAKE_SYSTEM_NAME MATCHES "Darwin" OR - CMAKE_SYSTEM_NAME MATCHES "iOS" OR - CMAKE_SYSTEM_NAME MATCHES "tvOS" OR - CMAKE_SYSTEM_NAME MATCHES "DragonFly" -) - add_compile_definitions(_DARWIN_C_SOURCE) -endif() - -# alloca is a non-standard interface that is not visible on BSDs when -# POSIX conformance is specified, but not all of them provide a clean way -# to enable it in such cases -if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD") - add_compile_definitions(__BSD_VISIBLE) -endif() -if (CMAKE_SYSTEM_NAME MATCHES "NetBSD") - add_compile_definitions(_NETBSD_SOURCE) -endif() -if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") - add_compile_definitions(_BSD_SOURCE) -endif() diff --git a/engines/llama/build.cmd b/engines/llama/build.cmd deleted file mode 100644 index 83ccf65198c..00000000000 --- a/engines/llama/build.cmd +++ /dev/null @@ -1,23 +0,0 @@ -@rem https://chocolatey.org/docs/installation#install-with-cmdexe -@rem to install rust java etc.. -@rem choco install jdk17 -y - -set VERSION="%1" - -if exist "llama.cpp" ( - echo Found "llama.cpp" -) else ( - git clone https://github.com/ggerganov/llama.cpp.git -b %VERSION% -) - -if exist build rd /q /s build -md build\classes -cd build -javac -classpath "%2" -sourcepath ..\src\main\java\ ..\src\main\java\ai\djl\llama\jni\LlamaLibrary.java -h include -d classes -cmake .. -cmake --build . --config Release - -@rem for nightly ci -md jnilib\win-x86_64 -copy Release\djl_llama.dll jnilib\win-x86_64\ -copy bin\Release\llama.dll jnilib\win-x86_64\ diff --git a/engines/llama/build.gradle.kts b/engines/llama/build.gradle.kts deleted file mode 100644 index 68ada8e8dd6..00000000000 --- a/engines/llama/build.gradle.kts +++ /dev/null @@ -1,114 +0,0 @@ -plugins { - ai.djl.javaProject - ai.djl.cppFormatter - ai.djl.publish -} - -group = "ai.djl.llama" - -dependencies { - api(project(":api")) - - testImplementation(project(":testing")) - testImplementation(libs.slf4j.simple) -} - -tasks { - compileJava { dependsOn(processResources) } - - processResources { - val path = "${project.projectDir}/build/resources/main" - inputs.properties(mapOf("djl_version" to libs.versions.djl.get(), "llamacpp_version" to libs.versions.llamacpp.get())) - outputs.dir("$path/native/lib") - doLast { - val llamacpp = libs.versions.llamacpp.get() - val djl = libs.versions.djl.get() - var url = "https://publish.djl.ai/llama/$llamacpp/jnilib/$djl" - val files = listOf( - "linux-x86_64/libdjl_llama.so", - "linux-x86_64/libllama.so", - "linux-aarch64/libdjl_llama.so", - "linux-aarch64/libllama.so", - "osx-x86_64/libdjl_llama.dylib", - "osx-x86_64/libllama.dylib", - "osx-x86_64/ggml-metal.metal", - "osx-aarch64/libdjl_llama.dylib", - "osx-aarch64/libllama.dylib", - "osx-aarch64/ggml-metal.metal", - "win-x86_64/djl_llama.dll", - "win-x86_64/llama.dll" - ) - val jnilibDir = project.projectDir / "jnilib/$djl" - files.forEach { - val file = jnilibDir / it - if (file.exists()) - project.logger.lifecycle("prebuilt or cached file found for $it") - else if (!project.hasProperty("jni")) { - project.logger.lifecycle("Downloading $url/$it") - file.parentFile.mkdirs() - "$url/$it".url into file - } - } - copy { - from(jnilibDir) - into("$path/native/lib") - } - - // write properties - val propFile = file("$path/native/lib/llama.properties") - propFile.text = "version=$llamacpp-$version\n" - - url = "https://mlrepo.djl.ai/model/nlp/text_generation/ai/djl/huggingface/gguf/models.json.gz" - val prefix = File("$path/nlp/text_generation") - val file = prefix / "ai.djl.huggingface.gguf.json" - if (file.exists()) - project.logger.lifecycle("gguf index file already exists") - else { - project.logger.lifecycle("Downloading gguf index file") - file.parentFile.mkdirs() - url.url gzipInto file - } - } - } - - publishing { - publications { - named("maven") { - pom { - name = "DJL NLP utilities for Llama.cpp" - description = "Deep Java Library (DJL) NLP utilities for llama.cpp" - url = "http://www.djl.ai/engines/${project.name}" - } - } - } - } - - register("compileJNI") { - doFirst { - val cp = configurations.runtimeClasspath.get().resolve().joinToString(":") - if ("mac" in os || "linux" in os) { - val arch = if (arch == "amd64") "x86_64" else arch - exec { - commandLine("bash", "build.sh", libs.versions.llamacpp.get(), arch, cp) - } - } else - exec { - commandLine("${project.projectDir}/build.cmd", libs.versions.llamacpp.get(), cp) - } - - // for ci to upload to S3 - val ciDir = project.projectDir / "jnilib/${libs.versions.djl.get()}/" - copy { - from(project.projectDir / "build/jnilib") - into(ciDir) - } - delete("$home/.djl.ai/llama") - } - } - - clean { - doFirst { - delete("$home/.djl.ai/llama") - } - } -} \ No newline at end of file diff --git a/engines/llama/build.sh b/engines/llama/build.sh deleted file mode 100755 index 1b6e7d4e1fa..00000000000 --- a/engines/llama/build.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env bash - -set -e -WORK_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -NUM_PROC=1 -if [[ -n $(command -v nproc) ]]; then - NUM_PROC=$(nproc) -elif [[ -n $(command -v sysctl) ]]; then - NUM_PROC=$(sysctl -n hw.ncpu) -fi -PLATFORM=$(uname | tr '[:upper:]' '[:lower:]') - -VERSION=$1 -ARCH=$2 -CLASSPATH=$3 - -pushd $WORK_DIR -if [ ! -d "llama.cpp" ]; then - git clone https://github.com/ggerganov/llama.cpp.git -b $VERSION -fi - -if [ ! -d "build" ]; then - mkdir build -fi -cd build - -rm -rf classes -mkdir classes -javac -classpath $CLASSPATH -sourcepath ../src/main/java/:../../../api/src/main/java ../src/main/java/ai/djl/llama/jni/LlamaLibrary.java -h include -d classes -cmake .. -cmake --build . --config Release -- -j "${NUM_PROC}" - -popd - -# for nightly ci -if [[ $PLATFORM == 'darwin' ]]; then - mkdir -p build/jnilib/osx-$ARCH - cp -f build/libdjl_llama.dylib build/jnilib/osx-$ARCH/ - cp -f build/llama.cpp/libllama.dylib build/jnilib/osx-$ARCH/ - cp -f llama.cpp/ggml-metal.metal build/jnilib/osx-$ARCH/ -elif [[ $PLATFORM == 'linux' ]]; then - mkdir -p build/jnilib/linux-$ARCH - cp -f build/libdjl_llama.so build/jnilib/linux-$ARCH/ - cp -f build/llama.cpp/libllama.so build/jnilib/linux-$ARCH/ -fi diff --git a/engines/llama/gradlew b/engines/llama/gradlew deleted file mode 120000 index 343e0d2caa4..00000000000 --- a/engines/llama/gradlew +++ /dev/null @@ -1 +0,0 @@ -../../gradlew \ No newline at end of file diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java deleted file mode 100644 index 75fdf5a5d8c..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ - -package ai.djl.llama.engine; - -import ai.djl.Device; -import ai.djl.Model; -import ai.djl.engine.Engine; -import ai.djl.engine.EngineException; -import ai.djl.llama.jni.LibUtils; -import ai.djl.ndarray.NDManager; -import ai.djl.util.Platform; -import ai.djl.util.passthrough.PassthroughNDManager; - -/** The {@code LlamaEngine} is an implementation of the {@link Engine} based on the llama.cpp. */ -public final class LlamaEngine extends Engine { - - public static final String ENGINE_NAME = "Llama"; - static final int RANK = 10; - - private Engine alternativeEngine; - private boolean initialized; - - private LlamaEngine() { - try { - LibUtils.loadLibrary(); - } catch (EngineException e) { // NOPMD - throw e; - } catch (Throwable t) { - throw new EngineException("Failed to load llama.cpp native library", t); - } - } - - static Engine newInstance() { - return new LlamaEngine(); - } - - /** {@inheritDoc} */ - @Override - public Engine getAlternativeEngine() { - if (!initialized && !Boolean.getBoolean("ai.djl.llama.disable_alternative")) { - Engine engine = Engine.getInstance(); - if (engine.getRank() < getRank()) { - // alternativeEngine should not have the same rank as Llama - alternativeEngine = engine; - } - initialized = true; - } - return alternativeEngine; - } - - /** {@inheritDoc} */ - @Override - public String getEngineName() { - return ENGINE_NAME; - } - - /** {@inheritDoc} */ - @Override - public int getRank() { - return RANK; - } - - /** {@inheritDoc} */ - @Override - public String getVersion() { - Platform platform = Platform.detectPlatform("llama"); - return platform.getVersion(); - } - - /** {@inheritDoc} */ - @Override - public boolean hasCapability(String capability) { - return false; - } - - /** {@inheritDoc} */ - @Override - public Model newModel(String name, Device device) { - return new LlamaModel(name, newBaseManager(device)); - } - - /** {@inheritDoc} */ - @Override - public NDManager newBaseManager() { - return newBaseManager(null); - } - - /** {@inheritDoc} */ - @Override - public NDManager newBaseManager(Device device) { - return PassthroughNDManager.INSTANCE; - } - - /** {@inheritDoc} */ - @Override - public String toString() { - return getEngineName() + ':' + getVersion() + ", " + getEngineName() + ':' + getVersion(); - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java deleted file mode 100644 index ca5cc646498..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.engine; - -import ai.djl.engine.Engine; -import ai.djl.engine.EngineProvider; - -/** {@code LlamaEngineProvider} is the Llama implementation of {@link EngineProvider}. */ -public class LlamaEngineProvider implements EngineProvider { - - /** {@inheritDoc} */ - @Override - public String getEngineName() { - return LlamaEngine.ENGINE_NAME; - } - - /** {@inheritDoc} */ - @Override - public int getEngineRank() { - return LlamaEngine.RANK; - } - - /** {@inheritDoc} */ - @Override - public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = LlamaEngine.newInstance(); - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java deleted file mode 100644 index 4b4d332fc9f..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java +++ /dev/null @@ -1,430 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.engine; - -import ai.djl.llama.jni.InputParameters; - -import com.google.gson.annotations.SerializedName; - -import java.util.Map; - -/** A class hold input data for Llama model. */ -public class LlamaInput { - - private String inputs; - private String prefix; - private String suffix; - private Parameters parameters; - - /** - * Returns the input prompt. - * - * @return the input prompt - */ - public String getInputs() { - return inputs; - } - - /** - * Sets the input prompt. - * - * @param inputs the input prompt - */ - public void setInputs(String inputs) { - this.inputs = inputs; - } - - /** - * Returns the prompt prefix. - * - * @return the prompt prefix - */ - public String getPrefix() { - return prefix; - } - - /** - * Sets the prompt prefix. - * - * @param prefix the prompt prefix - */ - public void setPrefix(String prefix) { - this.prefix = prefix; - } - - /** - * Returns the prompt suffix. - * - * @return the prompt suffix - */ - public String getSuffix() { - return suffix; - } - - /** - * Sets the prompt suffix. - * - * @param suffix the prompt suffix - */ - public void setSuffix(String suffix) { - this.suffix = suffix; - } - - /** - * Returns the input parameters. - * - * @return the input parameters - */ - public Parameters getParameters() { - if (parameters == null) { - parameters = new Parameters(); - } - return parameters; - } - - /** - * Sets the input parameters. - * - * @param parameters the input parameters - */ - public void setParameters(Parameters parameters) { - this.parameters = parameters; - } - - /** The input parameters class. */ - public static final class Parameters { - - @SerializedName("max_new_tokens") - private int nPredict; - - @SerializedName("number_keep") - private int nKeep; - - @SerializedName("number_probabilities") - private int nProbs; - - @SerializedName("top_k") - private int topK; - - @SerializedName("top_p") - private float topP; - - @SerializedName("tfs_z") - private float tfsZ; - - @SerializedName("typical_p") - private float typicalP; - - @SerializedName("temperature") - private float temperature; - - @SerializedName("repeat_penalty") - private float repeatPenalty; - - @SerializedName("repeat_last_n") - private int repeatLastN; - - @SerializedName("frequency_penalty") - private float frequencyPenalty; - - @SerializedName("presence_penalty") - private float presencePenalty; - - @SerializedName("penalize_nl") - private boolean penalizeNl; - - @SerializedName("ignore_eos") - private boolean ignoreEos; - - @SerializedName("mirostat") - private int mirostat; - - @SerializedName("mirostat_tau") - private float mirostatTau; - - @SerializedName("mirostat_eta") - private float mirostatEta; - - @SerializedName("number_beams") - private int nBeams; - - @SerializedName("seed") - private int seed; - - @SerializedName("logit_bias") - private Map logitBias; - - @SerializedName("grammar") - private String grammar; - - @SerializedName("anti_prompt") - private String[] antiPrompt; - - /** - * Sets the max new tokens. - * - * @param maxNewTokens the max new tokens - */ - public void setMaxNewTokens(int maxNewTokens) { - this.nPredict = maxNewTokens; - } - - /** - * Sets the number of keep. - * - * @param nKeep the number of keep - */ - public void setNumberKeep(int nKeep) { - this.nKeep = nKeep; - } - - /** - * Sets the number of probabilities. - * - * @param nProbs the number of probabilities - */ - public void setNumberProbabilities(int nProbs) { - this.nProbs = nProbs; - } - - /** - * Sets the top K. - * - * @param topK the top K - */ - public void setTopK(int topK) { - this.topK = topK; - } - - /** - * Sets the top P. - * - * @param topP the top P - */ - public void setTopP(float topP) { - this.topP = topP; - } - - /** - * Sets the tfs Z. - * - * @param tfsZ the tfs Z - */ - public void setTfsZ(float tfsZ) { - this.tfsZ = tfsZ; - } - - /** - * Sets the typical P. - * - * @param typicalP the typical P - */ - public void setTypicalP(float typicalP) { - this.typicalP = typicalP; - } - - /** - * Sets the temperature. - * - * @param temperature the temperature - */ - public void setTemperature(float temperature) { - this.temperature = temperature; - } - - /** - * Sets the repeat penalty. - * - * @param repeatPenalty the repeat penalty - */ - public void setRepeatPenalty(float repeatPenalty) { - this.repeatPenalty = repeatPenalty; - } - - /** - * Sets the repeat last N. - * - * @param repeatLastN the repeat last N - */ - public void setRepeatLastN(int repeatLastN) { - this.repeatLastN = repeatLastN; - } - - /** - * Sets the frequency penalty. - * - * @param frequencyPenalty the frequency penalty - */ - public void setFrequencyPenalty(float frequencyPenalty) { - this.frequencyPenalty = frequencyPenalty; - } - - /** - * Sets the presence penalty. - * - * @param presencePenalty the presence penalty - */ - public void setPresencePenalty(float presencePenalty) { - this.presencePenalty = presencePenalty; - } - - /** - * Sets the penalize nl. - * - * @param penalizeNl the penalize nl - */ - public void setPenalizeNl(boolean penalizeNl) { - this.penalizeNl = penalizeNl; - } - - /** - * Sets if ignore EOS. - * - * @param ignoreEos if ignore EOS - */ - public void setIgnoreEos(boolean ignoreEos) { - this.ignoreEos = ignoreEos; - } - - /** - * Sets the mirostat. - * - * @param mirostat the mirostat - */ - public void setMirostat(int mirostat) { - this.mirostat = mirostat; - } - - /** - * Sets the mirostat TAU. - * - * @param mirostatTau the mirostat TAU - */ - public void setMirostatTau(float mirostatTau) { - this.mirostatTau = mirostatTau; - } - - /** - * Sets the mirostat ETA. - * - * @param mirostatEta the mirostat ETA - */ - public void setMirostatEta(float mirostatEta) { - this.mirostatEta = mirostatEta; - } - - /** - * Sets the number of beams. - * - * @param nBeams the number of beams - */ - public void setNumberBeams(int nBeams) { - this.nBeams = nBeams; - } - - /** - * Sets the seed. - * - * @param seed the seed - */ - public void setSeed(int seed) { - this.seed = seed; - } - - /** - * Sets the logit bias. - * - * @param logitBias the logit bias - */ - public void setLogitBias(Map logitBias) { - this.logitBias = logitBias; - } - - /** - * Sets the grammar template. - * - * @param grammar the grammar template - */ - public void setGrammar(String grammar) { - this.grammar = grammar; - } - - /** - * Sets the anti prompt. - * - * @param antiPrompt the anti prompt - */ - public void setAntiPrompt(String[] antiPrompt) { - this.antiPrompt = antiPrompt; - } - - /** - * Returns the {@link InputParameters} object. - * - * @return the {@link InputParameters} object - */ - public InputParameters toInputParameters() { - setDefaultValue(); - return new InputParameters( - nPredict, - nKeep, - nProbs, - topK, - topP, - tfsZ, - typicalP, - temperature, - repeatPenalty, - repeatLastN, - frequencyPenalty, - presencePenalty, - penalizeNl, - ignoreEos, - mirostat, - mirostatTau, - mirostatEta, - nBeams, - seed, - logitBias, - grammar, - antiPrompt); - } - - private void setDefaultValue() { - if (nPredict == 0) { - nPredict = -1; - } - if (topK == 0) { - topK = 40; - } - if (topP == 0) { - topP = 0.95f; - } - if (tfsZ == 0) { - tfsZ = 1f; - } - if (typicalP == 0) { - typicalP = 1f; - } - if (temperature == 0) { - temperature = 0.8f; - } - if (repeatPenalty == 0) { - repeatPenalty = 1.10f; - } - if (repeatLastN == 0) { - repeatLastN = 64; - } - } - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java deleted file mode 100644 index 0ff3c6d70c0..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.engine; - -import ai.djl.BaseModel; -import ai.djl.Model; -import ai.djl.llama.jni.LlamaLibrary; -import ai.djl.llama.jni.ModelParameters; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.types.DataType; -import ai.djl.nn.Blocks; - -import java.io.FileNotFoundException; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Map; - -/** {@code LlamaModel} is the llama.cpp implementation of {@link Model}. */ -public class LlamaModel extends BaseModel { - - private long handle = -1; - - /** - * Constructs a new Model on a given device. - * - * @param name the model name - * @param manager the {@link NDManager} to holds the NDArray - */ - LlamaModel(String name, NDManager manager) { - super(name); - this.manager = manager; - this.manager.setName("llamaModel"); - dataType = DataType.FLOAT32; - } - - /** {@inheritDoc} */ - @Override - public void load(Path modelPath, String prefix, Map options) throws IOException { - setModelDir(modelPath); - wasLoaded = true; - if (block != null) { - throw new UnsupportedOperationException("Llama does not support dynamic blocks"); - } - - if (prefix == null) { - prefix = modelName; - } - - // search for .onnx file with prefix, folder name or "model.onnx" - Path modelFile = findModelFile(prefix, modelDir.toFile().getName(), "model.gguf"); - if (modelFile == null) { - throw new FileNotFoundException(".gguf file not found in: " + modelPath); - } - - ModelParameters param = new ModelParameters(options); - handle = LlamaLibrary.loadModel(modelFile.toString(), param); - block = Blocks.identityBlock(); - } - - long getHandle() { - return handle; - } - - private Path findModelFile(String... prefixes) { - if (Files.isRegularFile(modelDir)) { - Path file = modelDir; - modelDir = modelDir.getParent(); - String fileName = file.toFile().getName(); - if (fileName.endsWith(".gguf")) { - modelName = fileName.substring(0, fileName.length() - 5); - } else { - modelName = fileName; - } - return file; - } - for (String prefix : prefixes) { - Path modelFile = modelDir.resolve(prefix); - if (Files.isRegularFile(modelFile)) { - return modelFile; - } - if (!prefix.endsWith(".gguf")) { - modelFile = modelDir.resolve(prefix + ".gguf"); - if (Files.isRegularFile(modelFile)) { - return modelFile; - } - } - } - return null; - } - - /** {@inheritDoc} */ - @Override - public void close() { - if (handle == -1) { - return; - } - LlamaLibrary.delete(handle); - handle = -1; - super.close(); - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java deleted file mode 100644 index c8d3692b160..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.engine; - -import ai.djl.inference.streaming.IteratorBytesSupplier; -import ai.djl.llama.jni.InputParameters; -import ai.djl.llama.jni.LlamaLibrary; -import ai.djl.llama.jni.Token; -import ai.djl.llama.jni.TokenIterator; -import ai.djl.modality.Input; -import ai.djl.modality.Output; -import ai.djl.ndarray.BytesSupplier; -import ai.djl.ndarray.NDList; -import ai.djl.translate.NoBatchifyTranslator; -import ai.djl.translate.TranslatorContext; -import ai.djl.util.JsonUtils; - -import java.util.Iterator; - -/** Built-in {@code Translator} that provides preprocessing and postprocessing for llama.cpp. */ -public class LlamaTranslator implements NoBatchifyTranslator { - - private long handle; - - /** {@inheritDoc} */ - @Override - public void prepare(TranslatorContext ctx) { - LlamaModel model = (LlamaModel) ctx.getModel(); - handle = model.getHandle(); - } - - /** {@inheritDoc} */ - @Override - public NDList processInput(TranslatorContext ctx, I input) { - if (input instanceof String) { - ctx.setAttachment("out", generate((String) input)); - } else if (input instanceof LlamaInput) { - ctx.setAttachment("out", generate((LlamaInput) input)); - } else if (input instanceof Input) { - String prompt = ((Input) input).getData().getAsString(); - TokenIterator it = generate(prompt); - Output output = new Output(); - output.add(new IteratorBytesSupplier(new OutputIterator(it))); - ctx.setAttachment("out", output); - } - return new NDList(); - } - - /** {@inheritDoc} */ - @Override - @SuppressWarnings("unchecked") - public O processOutput(TranslatorContext ctx, NDList list) { - return (O) ctx.getAttachment("out"); - } - - private TokenIterator generate(String input) { - LlamaInput in = JsonUtils.GSON.fromJson(input, LlamaInput.class); - return generate(in); - } - - private TokenIterator generate(LlamaInput in) { - InputParameters param = in.getParameters().toInputParameters(); - String prefix = in.getPrefix(); - String suffix = in.getSuffix(); - String inputs = in.getInputs(); - if (prefix != null && suffix != null) { - LlamaLibrary.infill(handle, prefix, prefix, param); - } else if (inputs != null && !inputs.isEmpty()) { - LlamaLibrary.generate(handle, inputs, param); - } else { - throw new IllegalArgumentException("Unsupported input format"); - } - return new TokenIterator(handle); - } - - private static final class OutputIterator implements Iterator { - - private TokenIterator it; - - public OutputIterator(TokenIterator it) { - this.it = it; - } - - /** {@inheritDoc} */ - @Override - public boolean hasNext() { - return it.hasNext(); - } - - /** {@inheritDoc} */ - @Override - public BytesSupplier next() { - Token token = it.next(); - return BytesSupplier.wrap(JsonUtils.GSON.toJson(token) + "\n"); - } - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java deleted file mode 100644 index 089b5055b51..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.engine; - -import ai.djl.Model; -import ai.djl.llama.jni.TokenIterator; -import ai.djl.modality.Input; -import ai.djl.modality.Output; -import ai.djl.translate.Translator; -import ai.djl.translate.TranslatorFactory; -import ai.djl.util.Pair; - -import java.io.Serializable; -import java.lang.reflect.Type; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -/** A {@link TranslatorFactory} that creates a {@link LlamaTranslator} instance. */ -public class LlamaTranslatorFactory implements TranslatorFactory, Serializable { - - private static final long serialVersionUID = 1L; - - private static final Set> SUPPORTED_TYPES = new HashSet<>(); - - static { - SUPPORTED_TYPES.add(new Pair<>(String.class, TokenIterator.class)); - SUPPORTED_TYPES.add(new Pair<>(LlamaInput.class, TokenIterator.class)); - SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); - } - - /** {@inheritDoc} */ - @Override - public Set> getSupportedTypes() { - return SUPPORTED_TYPES; - } - - /** {@inheritDoc} */ - @Override - public boolean isSupported(Class input, Class output) { - return true; - } - - /** {@inheritDoc} */ - @Override - public Translator newInstance( - Class input, Class output, Model model, Map arguments) { - return new LlamaTranslator<>(); - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java b/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java deleted file mode 100644 index 226e7a6ddb8..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ - -/** Contains classes to interface with the underlying Llama Engine. */ -package ai.djl.llama.engine; diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java b/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java deleted file mode 100644 index d13abc5ef90..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java +++ /dev/null @@ -1,314 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.jni; - -import java.util.Map; - -/** A class holds input parameters. */ -@SuppressWarnings({"PMD.UnusedPrivateField", "PMD.UnusedAssignment"}) -public class InputParameters { - - private int nPredict; - private int nKeep; - private int nProbs; - private int topK; - private float topP; - private float tfsZ; - private float typicalP; - private float temperature; - private float repeatPenalty; - private int repeatLastN; - private float frequencyPenalty; - private float presencePenalty; - private boolean penalizeNl; - private boolean ignoreEos; - private int mirostat; - private float mirostatTau; - private float mirostatEta; - private int nBeams; - private int seed; - private Map logitBias; - private String grammar; - private String[] antiPrompt; - - /** - * Constructs new {@code InputParameters} instance. - * - * @param nPredict the max new tokens - * @param nKeep the number of keep - * @param nProbs the number of probabilities - * @param topK the top K - * @param topP the top P - * @param tfsZ the tfs Z - * @param typicalP the typical P - * @param temperature the temperature - * @param repeatPenalty the repeat penalty - * @param repeatLastN the repeat last N - * @param frequencyPenalty the frequency penalty - * @param presencePenalty the presence penalty - * @param penalizeNl the penalize nl - * @param ignoreEos the ignore EOS - * @param mirostat the mirostat - * @param mirostatTau the mirostat TAU - * @param mirostatEta the mirostat ETA - * @param nBeams the number of beams - * @param seed the seed - * @param logitBias the logit bias - * @param grammar the grammar - * @param antiPrompt the anti prompt - */ - public InputParameters( - int nPredict, - int nKeep, - int nProbs, - int topK, - float topP, - float tfsZ, - float typicalP, - float temperature, - float repeatPenalty, - int repeatLastN, - float frequencyPenalty, - float presencePenalty, - boolean penalizeNl, - boolean ignoreEos, - int mirostat, - float mirostatTau, - float mirostatEta, - int nBeams, - int seed, - Map logitBias, - String grammar, - String[] antiPrompt) { - this.nPredict = nPredict; - this.nKeep = nKeep; - this.nProbs = nProbs; - this.topK = topK; - this.topP = topP; - this.tfsZ = tfsZ; - this.typicalP = typicalP; - this.temperature = temperature; - this.repeatPenalty = repeatPenalty; - this.repeatLastN = repeatLastN; - this.frequencyPenalty = frequencyPenalty; - this.presencePenalty = presencePenalty; - this.penalizeNl = penalizeNl; - this.ignoreEos = ignoreEos; - this.mirostat = mirostat; - this.mirostatTau = mirostatTau; - this.mirostatEta = mirostatEta; - this.nBeams = nBeams; - this.seed = seed; - this.logitBias = logitBias; - this.grammar = grammar; - this.antiPrompt = antiPrompt; - } - - /** - * Returns the max new tokens. - * - * @return the max new tokens - */ - public int getMaxNewTokens() { - return nPredict; - } - - /** - * Returns the number of keep. - * - * @return the number of keep - */ - public int getNumberKeep() { - return nKeep; - } - - /** - * Returns the number of probabilities. - * - * @return the number of probabilities - */ - public int getNumberProbabilities() { - return nProbs; - } - - /** - * Returns the top K. - * - * @return the top K - */ - public int getTopK() { - return topK; - } - - /** - * Return the top P. - * - * @return the top P - */ - public float getTopP() { - return topP; - } - - /** - * Return the TfsZ. - * - * @return the TfsZ - */ - public float getTfsZ() { - return tfsZ; - } - - /** - * Return the typical P. - * - * @return the typical P - */ - public float getTypicalP() { - return typicalP; - } - - /** - * Return the temperature. - * - * @return the temperature - */ - public float getTemperature() { - return temperature; - } - - /** - * Return the repeat penalty. - * - * @return the repeat penalty - */ - public float getRepeatPenalty() { - return repeatPenalty; - } - - /** - * Return the repeat last N. - * - * @return the repeat last N - */ - public int getRepeatLastN() { - return repeatLastN; - } - - /** - * Return the frequency penalty. - * - * @return the frequency penalty - */ - public float getFrequencyPenalty() { - return frequencyPenalty; - } - - /** - * Return the presence penalty. - * - * @return the presence penalty - */ - public float getPresencePenalty() { - return presencePenalty; - } - - /** - * Return the penalize NL. - * - * @return the penalize NL - */ - public boolean isPenalizeNl() { - return penalizeNl; - } - - /** - * Returns {@code true} if ignore EOS. - * - * @return {@code true} if ignore EOS - */ - public boolean isIgnoreEos() { - return ignoreEos; - } - - /** - * Returns the mirostat. - * - * @return the mirostat - */ - public int getMirostat() { - return mirostat; - } - - /** - * Returns the mirostat TAU. - * - * @return the mirostat TAU - */ - public float getMirostatTau() { - return mirostatTau; - } - - /** - * Returns the mirostat ETA. - * - * @return the mirostat ETA - */ - public float getMirostatEta() { - return mirostatEta; - } - - /** - * Returns the number of beams. - * - * @return the number of beams - */ - public int getNumberBeams() { - return nBeams; - } - - /** - * Returns the seed. - * - * @return the seed - */ - public int getSeed() { - return seed; - } - - /** - * Returns the logit bias. - * - * @return the logit bias - */ - public Map getLogitBias() { - return logitBias; - } - - /** - * Returns the grammar template. - * - * @return the grammar template - */ - public String getGrammar() { - return grammar; - } - - /** - * Returns the anti-prompt. - * - * @return the anti-prompt - */ - public String[] getAntiPrompt() { - return antiPrompt; - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java b/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java deleted file mode 100644 index d51a4fe2e5e..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.jni; - -import ai.djl.util.ClassLoaderUtils; -import ai.djl.util.Platform; -import ai.djl.util.Utils; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.StandardCopyOption; -import java.util.ArrayList; -import java.util.List; - -/** Utilities for finding the llama.cpp native binary on the System. */ -public final class LibUtils { - - private static final Logger logger = LoggerFactory.getLogger(LibUtils.class); - - private static final String LIB_NAME = System.mapLibraryName("djl_llama"); - private static final String LLAMA_NAME = System.mapLibraryName("llama"); - - private LibUtils() {} - - /** Loads llama.cpp native library. */ - public static void loadLibrary() { - List libs = new ArrayList<>(3); - libs.add(LLAMA_NAME); - libs.add(LIB_NAME); - if (System.getProperty("os.name").startsWith("Mac")) { - libs.add("ggml-metal.metal"); - } - Path dir = copyJniLibraryFromClasspath(libs.toArray(new String[0])); - logger.debug("Loading llama.cpp library from: {}", dir); - - for (int i = 0; i < 2; ++i) { - String lib = libs.get(i); - String path = dir.resolve(lib).toString(); - logger.debug("Loading native library: {}", path); - String nativeHelper = System.getProperty("ai.djl.llama.native_helper"); - if (nativeHelper != null && !nativeHelper.isEmpty()) { - ClassLoaderUtils.nativeLoad(nativeHelper, path); - } else { - System.load(path); // NOPMD - } - } - } - - private static Path copyJniLibraryFromClasspath(String... libs) { - Path cacheDir = Utils.getEngineCacheDir("llama"); - Platform platform = Platform.detectPlatform("llama"); - String classifier = platform.getClassifier(); - String version = platform.getVersion(); - Path dir = cacheDir.resolve(version + '-' + classifier); - Path path = dir.resolve(LIB_NAME); - logger.debug("Using cache dir: {}", dir); - if (Files.exists(path)) { - return dir.toAbsolutePath(); - } - - Path tmp = null; - try { - Files.createDirectories(cacheDir); - tmp = Files.createTempDirectory(cacheDir, "tmp"); - - for (String libName : libs) { - String libPath = "native/lib/" + classifier + "/" + libName; - logger.info("Extracting {} to cache ...", libPath); - try (InputStream is = ClassLoaderUtils.getResourceAsStream(libPath)) { - Path target = tmp.resolve(libName); - Files.copy(is, target, StandardCopyOption.REPLACE_EXISTING); - } - } - Utils.moveQuietly(tmp, dir); - return dir.toAbsolutePath(); - } catch (IOException e) { - throw new IllegalStateException("Cannot copy jni files", e); - } finally { - if (tmp != null) { - Utils.deleteQuietly(tmp); - } - } - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java b/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java deleted file mode 100644 index 5d40fa29830..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.jni; - -/** Native library for llama.cpp. */ -@SuppressWarnings("MissingJavadocMethod") -public final class LlamaLibrary { - - private LlamaLibrary() {} - - public static native long loadModel(String filePath, ModelParameters param); - - public static native void generate(long handle, String prompt, InputParameters param); - - public static native void infill( - long handle, String prefix, String suffix, InputParameters param); - - public static native Token getNext(long handle, long count, long pos); - - public static native float[] embed(long handle, String prompt); - - public static native int[] encode(long handle, String prompt); - - public static native byte[] decodeBytes(long handle, int[] tokens); - - public static native void delete(long handle); -} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java b/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java deleted file mode 100644 index e3e440474a8..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.jni; - -import java.util.Map; - -/** A class holds llama.cpp model loading parameters. */ -@SuppressWarnings("PMD.SingularField") -public final class ModelParameters { - - private int nThreads; - private int nCtx; - private int nBatch; - private int nGpuLayers; - private int mainGpu; - private float ropeFreqBase; - private float ropeFreqScale; - private boolean mulMatQ; - private boolean f16Kv; - private boolean logitsAll; - private boolean vocabOnly; - private boolean useMmap; - private boolean useMlock; - private boolean embedding; - private boolean memoryF16; - private boolean memTest; - private boolean numa; - private boolean verbosePrompt; - private float[] tensorSplit; - private String loraAdapter; - private String loraBase; - - /** - * Constructs a new {@code ModelParameters} instance. - * - * @param options the model loading options - */ - public ModelParameters(Map options) { - nThreads = intValue(options, "number_threads", Runtime.getRuntime().availableProcessors()); - nCtx = intValue(options, "max_context_length", 512); - nBatch = intValue(options, "max_rolling_batch", 512); - nGpuLayers = intValue(options, "number_gpu_layers", -1); - mainGpu = intValue(options, "tensor_parallel_degree", 0); - ropeFreqBase = floatValue(options, "rope_freq_base"); - ropeFreqScale = floatValue(options, "ropeFreqScale"); - f16Kv = booleanValue(options, "f16_kv"); - mulMatQ = booleanValue(options, "mulmat_q", true); - logitsAll = booleanValue(options, "logits_all"); - vocabOnly = booleanValue(options, "vocab_only"); - useMmap = booleanValue(options, "use_mmap", true); - useMlock = booleanValue(options, "use_mlock"); - embedding = booleanValue(options, "embedding"); - memoryF16 = booleanValue(options, "memory_f16", true); - memTest = booleanValue(options, "mem_test"); - numa = booleanValue(options, "numa"); - verbosePrompt = booleanValue(options, "verbose_prompt"); - String val = stringValue(options, "tensor_split"); - if (val != null && !val.isEmpty()) { - String[] tokens = val.split(","); - tensorSplit = new float[tokens.length]; - for (int i = 0; i < tokens.length; ++i) { - tensorSplit[i] = Float.parseFloat(tokens[i].trim()); - } - } - loraAdapter = stringValue(options, "lora_adapter"); - loraBase = stringValue(options, "loraBase"); - } - - private static int intValue(Map arguments, String key, int def) { - Object value = arguments.get(key); - if (value == null) { - return def; - } - return (int) Double.parseDouble(value.toString()); - } - - private static float floatValue(Map arguments, String key) { - Object value = arguments.get(key); - if (value == null) { - return 0f; - } - return (float) Double.parseDouble(value.toString()); - } - - private static boolean booleanValue(Map arguments, String key) { - return booleanValue(arguments, key, false); - } - - private static boolean booleanValue(Map arguments, String key, boolean def) { - Object value = arguments.get(key); - if (value == null) { - return def; - } - return Boolean.parseBoolean(value.toString()); - } - - private static String stringValue(Map arguments, String key) { - Object value = arguments.get(key); - if (value == null) { - return null; - } - return value.toString(); - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/Token.java b/engines/llama/src/main/java/ai/djl/llama/jni/Token.java deleted file mode 100644 index b8d74306b56..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/jni/Token.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.jni; - -import ai.djl.util.JsonUtils; - -import java.nio.charset.StandardCharsets; -import java.util.Map; - -/** The output token class. */ -public final class Token { - - private int token; - private String text; - private Map probabilities; - transient long count; - transient long pos; - transient boolean hasNext; - - /** - * Constructs a new {@code Token} instance. - * - * @param token the token id - * @param generated the token text - * @param probabilities the token probabilities - * @param count the generated token count - * @param pos the token index - * @param hasNext has more tokens - */ - public Token( - int token, - byte[] generated, - Map probabilities, - long count, - long pos, - boolean hasNext) { - this.token = token; - this.text = new String(generated, StandardCharsets.UTF_8); - this.probabilities = probabilities; - this.count = count; - this.pos = pos; - this.hasNext = hasNext; - } - - /** - * Returns the token id. - * - * @return the token id - */ - public int getToken() { - return token; - } - - /** - * Returns the token text. - * - * @return the token text - */ - public String getText() { - return text; - } - - /** - * Returns the token probabilities. - * - * @return the token probabilities - */ - public Map getProbabilities() { - return probabilities; - } - - /** {@inheritDoc} */ - @Override - public String toString() { - return JsonUtils.GSON.toJson(this) + '\n'; - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java b/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java deleted file mode 100644 index cab6575d8f7..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.jni; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Iterator; -import java.util.NoSuchElementException; -import java.util.concurrent.atomic.AtomicBoolean; - -/** A iterator class holds generated tokens. */ -public class TokenIterator implements Iterator { - - private static final Logger logger = LoggerFactory.getLogger(TokenIterator.class); - - private static AtomicBoolean active = new AtomicBoolean(); - - private long handle; - private long count; - private long pos; - private boolean hasNext; - - /** - * Constructs a new {@code TokenIterator} instance. - * - * @param handle the llama.cpp handle - */ - public TokenIterator(long handle) { - this.handle = handle; - hasNext = true; - if (!active.compareAndSet(false, true)) { - active.set(true); - logger.warn("Previous inference has been reset"); - } - } - - /** {@inheritDoc} */ - @Override - public boolean hasNext() { - return hasNext; - } - - /** {@inheritDoc} */ - @Override - public Token next() { - if (!hasNext) { - throw new NoSuchElementException(); - } - Token token = LlamaLibrary.getNext(handle, count, pos); - count = token.count; - pos = token.pos; - hasNext = token.hasNext; - if (!hasNext) { - active.set(false); - } - return token; - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java b/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java deleted file mode 100644 index 6f429aceda2..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java +++ /dev/null @@ -1,14 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -/** Contains classes to interface with the native llama.cpp code. */ -package ai.djl.llama.jni; diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java deleted file mode 100644 index 91b6e55050a..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.zoo; - -import ai.djl.Application; -import ai.djl.repository.Repository; -import ai.djl.repository.zoo.ModelLoader; -import ai.djl.repository.zoo.ModelZoo; -import ai.djl.util.ClassLoaderUtils; -import ai.djl.util.JsonUtils; -import ai.djl.util.Utils; - -import com.google.gson.reflect.TypeToken; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.io.InputStream; -import java.io.Reader; -import java.io.Writer; -import java.lang.reflect.Type; -import java.net.URI; -import java.net.URL; -import java.nio.file.Files; -import java.nio.file.Path; -import java.time.Duration; -import java.util.Collection; -import java.util.Collections; -import java.util.Map; -import java.util.Set; -import java.util.zip.GZIPInputStream; - -/** LlamaModelZoo is a repository that contains llama.cpp models. */ -public class LlamaModelZoo extends ModelZoo { - - private static final Logger logger = LoggerFactory.getLogger(LlamaModelZoo.class); - - private static final String REPO = "https://mlrepo.djl.ai/"; - private static final Repository REPOSITORY = Repository.newInstance("gguf", REPO); - private static final String GROUP_ID = "ai.djl.huggingface.gguf"; - - private static final long ONE_DAY = Duration.ofDays(1).toMillis(); - - private volatile boolean initialized; // NOPMD - - LlamaModelZoo() {} - - /** {@inheritDoc} */ - @Override - public String getGroupId() { - return GROUP_ID; - } - - /** {@inheritDoc} */ - @Override - public Set getSupportedEngines() { - return Collections.singleton("Llama"); - } - - /** {@inheritDoc} */ - @Override - public Collection getModelLoaders() { - init(); - return super.getModelLoaders(); - } - - /** {@inheritDoc} */ - @Override - public ModelLoader getModelLoader(String name) { - init(); - return super.getModelLoader(name); - } - - private void init() { - if (!initialized) { - synchronized (LlamaModelZoo.class) { - if (!initialized) { - Application app = Application.NLP.TEXT_GENERATION; - Map map = listModels(app); - for (Map.Entry entry : map.entrySet()) { - String artifactId = entry.getKey(); - Map gguf = entry.getValue().getGguf(); - if (gguf != null) { - for (String key : gguf.keySet()) { - addModel(REPOSITORY.model(app, GROUP_ID, artifactId, "0.0.1", key)); - } - } - } - initialized = true; - } - } - } - } - - private Map listModels(Application app) { - try { - String path = "model/" + app.getPath() + "/ai/djl/huggingface/gguf/"; - Path dir = Utils.getCacheDir().resolve("cache/repo/" + path); - if (Files.notExists(dir)) { - Files.createDirectories(dir); - } else if (!Files.isDirectory(dir)) { - logger.warn("Failed initialize cache directory: {}", dir); - return Collections.emptyMap(); - } - Type type = new TypeToken>() {}.getType(); - - Path file = dir.resolve("models.json"); - if (Files.exists(file)) { - long lastModified = Files.getLastModifiedTime(file).toMillis(); - if (Utils.isOfflineMode() || System.currentTimeMillis() - lastModified < ONE_DAY) { - try (Reader reader = Files.newBufferedReader(file)) { - return JsonUtils.GSON.fromJson(reader, type); - } - } - } - - URL url = URI.create(REPO).resolve(path + "models.json.gz").toURL(); - Path tmp = Files.createTempFile(dir, "models", ".tmp"); - try (GZIPInputStream gis = new GZIPInputStream(Utils.openUrl(url))) { - String json = Utils.toString(gis); - try (Writer writer = Files.newBufferedWriter(tmp)) { - writer.write(json); - } - Utils.moveQuietly(tmp, file); - return JsonUtils.GSON.fromJson(json, type); - } catch (IOException e) { - logger.warn("Failed to download Huggingface gguf index: {}", app); - if (Files.exists(file)) { - try (Reader reader = Files.newBufferedReader(file)) { - return JsonUtils.GSON.fromJson(reader, type); - } - } - - String resource = app.getPath() + "/" + GROUP_ID + ".json"; - try (InputStream is = ClassLoaderUtils.getResourceAsStream(resource)) { - String json = Utils.toString(is); - try (Writer writer = Files.newBufferedWriter(tmp)) { - writer.write(json); - } - Utils.moveQuietly(tmp, file); - return JsonUtils.GSON.fromJson(json, type); - } - } finally { - Utils.deleteQuietly(tmp); - } - } catch (IOException e) { - logger.warn("Failed load gguf index file", e); - } - - return Collections.emptyMap(); - } - - private static final class ModelDetail { - - private Map gguf; - - public Map getGguf() { - return gguf; - } - - public void setGguf(Map gguf) { - this.gguf = gguf; - } - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java deleted file mode 100644 index ba2b04722c1..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.zoo; - -import ai.djl.repository.zoo.ModelZoo; -import ai.djl.repository.zoo.ZooProvider; - -/** - * An Huggingface llama.cpp model zoo provider implements the {@link - * ai.djl.repository.zoo.ZooProvider} interface. - */ -public class LlamaZooProvider implements ZooProvider { - - /** {@inheritDoc} */ - @Override - public ModelZoo getModelZoo() { - return new LlamaModelZoo(); - } -} diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java b/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java deleted file mode 100644 index a9c1df64cd0..00000000000 --- a/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java +++ /dev/null @@ -1,14 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -/** Contains the built-in {@link ai.djl.llama.zoo.LlamaModelZoo}. */ -package ai.djl.llama.zoo; diff --git a/engines/llama/src/main/javadoc/overview.html b/engines/llama/src/main/javadoc/overview.html deleted file mode 100644 index 05dec7d0bd4..00000000000 --- a/engines/llama/src/main/javadoc/overview.html +++ /dev/null @@ -1,14 +0,0 @@ - - - - - -

This document is the API specification for the Deep Java Library (DJL) Llama Engine.

- -

- The Llama Engine module contains the Llama.cpp implementation of the DJL EngineProvider. - See here for more details. -

- - - diff --git a/engines/llama/src/main/native/ai_djl_llama.cpp b/engines/llama/src/main/native/ai_djl_llama.cpp deleted file mode 100644 index 1d6072751f2..00000000000 --- a/engines/llama/src/main/native/ai_djl_llama.cpp +++ /dev/null @@ -1,1025 +0,0 @@ -#include -#include -#include -#include - -#include "ai_djl_llama_jni_LlamaLibrary.h" -#include "common.h" -#include "grammar-parser.h" -#include "llama.h" -#include "sampling.h" - -// classes -static jclass c_lib_utils = 0; -static jclass c_model_params = 0; -static jclass c_input_params = 0; -static jclass c_token = 0; -static jclass c_standard_charsets = 0; -static jclass c_string = 0; -static jclass c_hash_map = 0; -static jclass c_map = 0; -static jclass c_set = 0; -static jclass c_entry = 0; -static jclass c_integer = 0; -static jclass c_float = 0; -static jclass c_logger = 0; -static jclass c_engine_exception = 0; - -// constructors -static jmethodID cc_token = 0; -static jmethodID cc_hash_map = 0; -static jmethodID cc_integer = 0; -static jmethodID cc_float = 0; - -// methods -static jmethodID m_get_bytes = 0; -static jmethodID m_entry_set = 0; -static jmethodID m_set_iterator = 0; -static jmethodID m_iterator_has_next = 0; -static jmethodID m_iterator_next = 0; -static jmethodID m_entry_key = 0; -static jmethodID m_entry_value = 0; -static jmethodID m_map_put = 0; -static jmethodID m_int_value = 0; -static jmethodID m_float_value = 0; -static jmethodID m_log_debug = 0; -static jmethodID m_log_info = 0; -static jmethodID m_log_warn = 0; -static jmethodID m_log_error = 0; - -// fields -static jfieldID f_logger = 0; -// inference parameters -static jfieldID f_n_predict = 0; -static jfieldID f_n_keep = 0; -static jfieldID f_n_probs = 0; -static jfieldID f_logit_bias = 0; -static jfieldID f_top_k = 0; -static jfieldID f_top_p = 0; -static jfieldID f_tfs_z = 0; -static jfieldID f_typical_p = 0; -static jfieldID f_temperature = 0; -static jfieldID f_repeat_penalty = 0; -static jfieldID f_repeat_last_n = 0; -static jfieldID f_frequency_penalty = 0; -static jfieldID f_presence_penalty = 0; -static jfieldID f_penalize_nl = 0; -static jfieldID f_ignore_eos = 0; -static jfieldID f_mirostat = 0; -static jfieldID f_mirostat_tau = 0; -static jfieldID f_mirostat_eta = 0; -static jfieldID f_n_beams = 0; -static jfieldID f_grammar = 0; -static jfieldID f_antiprompt = 0; -static jfieldID f_infer_seed = 0; -// model parameters -static jfieldID f_n_threads = 0; -static jfieldID f_n_ctx = 0; -static jfieldID f_n_batch = 0; -static jfieldID f_n_gpu_layers = 0; -static jfieldID f_main_gpu = 0; -static jfieldID f_tensor_split = 0; -static jfieldID f_rope_freq_base = 0; -static jfieldID f_rope_freq_scale = 0; -static jfieldID f_mul_mat_q = 0; -static jfieldID f_f16_kv = 0; -static jfieldID f_logits_all = 0; -static jfieldID f_vocab_only = 0; -static jfieldID f_use_mmap = 0; -static jfieldID f_use_mlock = 0; -static jfieldID f_embedding = 0; -static jfieldID f_lora_adapter = 0; -static jfieldID f_lora_base = 0; -static jfieldID f_memory_f16 = 0; -static jfieldID f_mem_test = 0; -static jfieldID f_numa = 0; -static jfieldID f_verbose_prompt = 0; -// log level -static jfieldID f_utf_8 = 0; -// objects -static jobject o_utf_8 = 0; -static jobject o_logger = 0; - -static JavaVM *g_vm = nullptr; - -static void null_log_callback(enum ggml_log_level level, const char *text, void *user_data) {} - -JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { - JNIEnv *env = 0; - - if (JNI_OK != vm->GetEnv((void **) &env, JNI_VERSION_1_1)) { - return JNI_ERR; - } - - log_disable(); - llama_log_set(null_log_callback, nullptr); - - // find classes - c_input_params = env->FindClass("ai/djl/llama/jni/InputParameters"); - c_model_params = env->FindClass("ai/djl/llama/jni/ModelParameters"); - c_lib_utils = env->FindClass("ai/djl/llama/jni/LibUtils"); - c_token = env->FindClass("ai/djl/llama/jni/Token"); - c_engine_exception = env->FindClass("ai/djl/engine/EngineException"); - c_logger = env->FindClass("org/slf4j/Logger"); - c_standard_charsets = env->FindClass("java/nio/charset/StandardCharsets"); - c_string = env->FindClass("java/lang/String"); - c_hash_map = env->FindClass("java/util/HashMap"); - c_map = env->FindClass("java/util/Map"); - c_set = env->FindClass("java/util/Set"); - c_entry = env->FindClass("java/util/Map$Entry"); - c_integer = env->FindClass("java/lang/Integer"); - c_float = env->FindClass("java/lang/Float"); - - // create references - c_input_params = (jclass) env->NewGlobalRef(c_input_params); - c_model_params = (jclass) env->NewGlobalRef(c_model_params); - c_lib_utils = (jclass) env->NewGlobalRef(c_lib_utils); - c_token = (jclass) env->NewGlobalRef(c_token); - c_engine_exception = (jclass) env->NewGlobalRef(c_engine_exception); - c_logger = (jclass) env->NewGlobalRef(c_logger); - c_string = (jclass) env->NewGlobalRef(c_string); - c_hash_map = (jclass) env->NewGlobalRef(c_hash_map); - c_map = (jclass) env->NewGlobalRef(c_map); - c_set = (jclass) env->NewGlobalRef(c_set); - c_entry = (jclass) env->NewGlobalRef(c_entry); - c_integer = (jclass) env->NewGlobalRef(c_integer); - c_float = (jclass) env->NewGlobalRef(c_float); - - // find constructors - cc_token = env->GetMethodID(c_token, "", "(I[BLjava/util/Map;JJZ)V"); - cc_hash_map = env->GetMethodID(c_hash_map, "", "()V"); - cc_integer = env->GetMethodID(c_integer, "", "(I)V"); - cc_float = env->GetMethodID(c_float, "", "(F)V"); - - // find methods - m_get_bytes = env->GetMethodID(c_string, "getBytes", "(Ljava/lang/String;)[B"); - m_entry_set = env->GetMethodID(c_map, "entrySet", "()Ljava/util/Set;"); - m_entry_key = env->GetMethodID(c_entry, "getKey", "()Ljava/lang/Object;"); - m_entry_value = env->GetMethodID(c_entry, "getValue", "()Ljava/lang/Object;"); - m_map_put = env->GetMethodID(c_map, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"); - m_int_value = env->GetMethodID(c_integer, "intValue", "()I"); - m_float_value = env->GetMethodID(c_float, "floatValue", "()F"); - m_log_debug = env->GetMethodID(c_logger, "debug", "(Ljava/lang/String;)V"); - m_log_info = env->GetMethodID(c_logger, "info", "(Ljava/lang/String;)V"); - m_log_warn = env->GetMethodID(c_logger, "warn", "(Ljava/lang/String;)V"); - m_log_error = env->GetMethodID(c_logger, "error", "(Ljava/lang/String;)V"); - - // find fields - f_logger = env->GetStaticFieldID(c_lib_utils, "logger", "Lorg/slf4j/Logger;"); - - f_n_predict = env->GetFieldID(c_input_params, "nPredict", "I"); - f_n_keep = env->GetFieldID(c_input_params, "nKeep", "I"); - f_n_probs = env->GetFieldID(c_input_params, "nProbs", "I"); - f_logit_bias = env->GetFieldID(c_input_params, "logitBias", "Ljava/util/Map;"); - f_top_k = env->GetFieldID(c_input_params, "topK", "I"); - f_top_p = env->GetFieldID(c_input_params, "topP", "F"); - f_tfs_z = env->GetFieldID(c_input_params, "tfsZ", "F"); - f_typical_p = env->GetFieldID(c_input_params, "typicalP", "F"); - f_temperature = env->GetFieldID(c_input_params, "temperature", "F"); - f_repeat_penalty = env->GetFieldID(c_input_params, "repeatPenalty", "F"); - f_repeat_last_n = env->GetFieldID(c_input_params, "repeatLastN", "I"); - f_frequency_penalty = env->GetFieldID(c_input_params, "frequencyPenalty", "F"); - f_presence_penalty = env->GetFieldID(c_input_params, "presencePenalty", "F"); - f_penalize_nl = env->GetFieldID(c_input_params, "penalizeNl", "Z"); - f_ignore_eos = env->GetFieldID(c_input_params, "ignoreEos", "Z"); - f_mirostat = env->GetFieldID(c_input_params, "mirostat", "I"); - f_mirostat_tau = env->GetFieldID(c_input_params, "mirostatTau", "F"); - f_mirostat_eta = env->GetFieldID(c_input_params, "mirostatEta", "F"); - f_n_beams = env->GetFieldID(c_input_params, "nBeams", "I"); - f_grammar = env->GetFieldID(c_input_params, "grammar", "Ljava/lang/String;"); - f_antiprompt = env->GetFieldID(c_input_params, "antiPrompt", "[Ljava/lang/String;"); - f_infer_seed = env->GetFieldID(c_input_params, "seed", "I"); - - f_n_threads = env->GetFieldID(c_model_params, "nThreads", "I"); - f_n_ctx = env->GetFieldID(c_model_params, "nCtx", "I"); - f_n_batch = env->GetFieldID(c_model_params, "nBatch", "I"); - f_n_gpu_layers = env->GetFieldID(c_model_params, "nGpuLayers", "I"); - f_main_gpu = env->GetFieldID(c_model_params, "mainGpu", "I"); - f_tensor_split = env->GetFieldID(c_model_params, "tensorSplit", "[F"); - f_rope_freq_base = env->GetFieldID(c_model_params, "ropeFreqBase", "F"); - f_rope_freq_scale = env->GetFieldID(c_model_params, "ropeFreqScale", "F"); - f_mul_mat_q = env->GetFieldID(c_model_params, "mulMatQ", "Z"); - f_f16_kv = env->GetFieldID(c_model_params, "f16Kv", "Z"); - f_logits_all = env->GetFieldID(c_model_params, "logitsAll", "Z"); - f_vocab_only = env->GetFieldID(c_model_params, "vocabOnly", "Z"); - f_use_mmap = env->GetFieldID(c_model_params, "useMmap", "Z"); - f_use_mlock = env->GetFieldID(c_model_params, "useMlock", "Z"); - f_embedding = env->GetFieldID(c_model_params, "embedding", "Z"); - f_lora_adapter = env->GetFieldID(c_model_params, "loraAdapter", "Ljava/lang/String;"); - f_lora_base = env->GetFieldID(c_model_params, "loraBase", "Ljava/lang/String;"); - f_memory_f16 = env->GetFieldID(c_model_params, "memoryF16", "Z"); - f_mem_test = env->GetFieldID(c_model_params, "memTest", "Z"); - f_numa = env->GetFieldID(c_model_params, "numa", "Z"); - f_verbose_prompt = env->GetFieldID(c_model_params, "verbosePrompt", "Z"); - - f_utf_8 = env->GetStaticFieldID(c_standard_charsets, "UTF_8", "Ljava/nio/charset/Charset;"); - o_utf_8 = env->NewStringUTF("UTF-8"); - o_utf_8 = (jobject) env->NewGlobalRef(o_utf_8); - o_logger = env->GetStaticObjectField(c_lib_utils, f_logger); - o_logger = (jobject) env->NewGlobalRef(o_logger); - - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - return JNI_ERR; - } - - return JNI_VERSION_1_1; -} - -JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { - JNIEnv *env = 0; - - if (JNI_OK != vm->GetEnv((void **) &env, JNI_VERSION_1_1)) { - return; - } - - env->DeleteGlobalRef(c_input_params); - env->DeleteGlobalRef(c_model_params); - env->DeleteGlobalRef(c_token); - env->DeleteGlobalRef(c_string); - env->DeleteGlobalRef(c_hash_map); - env->DeleteGlobalRef(c_map); - env->DeleteGlobalRef(c_set); - env->DeleteGlobalRef(c_entry); - env->DeleteGlobalRef(c_integer); - env->DeleteGlobalRef(c_float); - env->DeleteGlobalRef(c_logger); - env->DeleteGlobalRef(c_engine_exception); - - env->DeleteGlobalRef(o_utf_8); -} - -static void log(JNIEnv *env, enum ggml_log_level level, const char *text) { - jstring java_text = env->NewStringUTF(text); - - switch (level) { - case GGML_LOG_LEVEL_ERROR: - env->CallVoidMethod(o_logger, m_log_error, java_text); - break; - case GGML_LOG_LEVEL_WARN: - env->CallVoidMethod(o_logger, m_log_warn, java_text); - break; - case GGML_LOG_LEVEL_INFO: - env->CallVoidMethod(o_logger, m_log_info, java_text); - break; - default: - env->CallVoidMethod(o_logger, m_log_debug, java_text); - break; - } - env->DeleteLocalRef(java_text); -} - -static void log(JNIEnv *env, enum ggml_log_level level, std::string text) { log(env, level, text.c_str()); } - -static std::string parse_jstring(JNIEnv *env, jstring java_string) { - const jbyteArray string_bytes = (jbyteArray) env->CallObjectMethod(java_string, m_get_bytes, o_utf_8); - - size_t length = (size_t) env->GetArrayLength(string_bytes); - jbyte *byte_elements = env->GetByteArrayElements(string_bytes, nullptr); - - std::string string = std::string((char *) byte_elements, length); - - env->ReleaseByteArrayElements(string_bytes, byte_elements, JNI_ABORT); - env->DeleteLocalRef(string_bytes); - - return string; -} - -static int parse_jinteger(JNIEnv *env, jobject java_integer) { - if (!java_integer) return 0; - return env->CallIntMethod(java_integer, m_int_value); -} - -static float parse_jfloat(JNIEnv *env, jobject java_float) { - if (!java_float) return 0; - return env->CallFloatMethod(java_float, m_float_value); -} - -static jbyteArray parse_jbytes(JNIEnv *env, std::string string) { - jsize len = string.size(); - jbyteArray bytes = env->NewByteArray(len); - env->SetByteArrayRegion(bytes, 0, len, reinterpret_cast(string.c_str())); - return bytes; -} - -// completion token output with probabilities -struct completion_token_output { - struct token_prob { - llama_token tok; - float prob; - }; - - std::vector probs; - llama_token tok; -}; - -static size_t common_part(const std::vector &a, const std::vector &b) { - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) { - } - return i; -} - -enum stop_type { - STOP_FULL, - STOP_PARTIAL, -}; - -static bool ends_with(const std::string &str, const std::string &suffix) { - return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); -} - -static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { - if (!text.empty() && !stop.empty()) { - const char text_last_char = text.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { - if (stop[char_index] == text_last_char) { - const std::string current_partial = stop.substr(0, char_index + 1); - if (ends_with(text, current_partial)) { - return text.size() - char_index - 1; - } - } - } - } - return std::string::npos; -} - -template -static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) { - std::string ret; - for (; begin != end; ++begin) { - ret += llama_token_to_piece(ctx, *begin); - } - return ret; -} - -// format incomplete utf-8 multibyte character for output -static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) { - std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); - // if the size is 1 and first bit is 1, meaning it's a partial character - // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) { - std::stringstream ss; - ss << std::hex << (out[0] & 0xff); - std::string res(ss.str()); - out = "byte: \\x" + res; - } - return out; -} - -struct jllama_context { - bool has_next_token = false; - std::string generated_text; - std::vector generated_token_probs; - - size_t num_prompt_tokens = 0; - size_t num_tokens_predicted = 0; - size_t n_past = 0; - size_t n_remain = 0; - - std::string prompt; - std::vector embd; - std::vector last_n_tokens; - - llama_model *model = nullptr; - llama_context *ctx = nullptr; - gpt_params params; - llama_sampling_context ctx_sampling; - int n_ctx; - - grammar_parser::parse_state parsed_grammar; - llama_grammar *grammar = nullptr; - - bool truncated = false; - bool stopped_eos = false; - bool stopped_word = false; - bool stopped_limit = false; - std::string stopping_word; - int32_t multibyte_pending = 0; - - std::mutex mutex; - - std::unique_lock lock() { return std::unique_lock(mutex); } - - ~jllama_context() { - if (ctx) { - llama_free(ctx); - ctx = nullptr; - } - if (model) { - llama_free_model(model); - model = nullptr; - } - if (grammar) { - llama_grammar_free(grammar); - grammar = nullptr; - } - } - - void rewind() { - params.antiprompt.clear(); - params.sparams.grammar.clear(); - num_prompt_tokens = 0; - num_tokens_predicted = 0; - generated_text = ""; - generated_text.reserve(n_ctx); - generated_token_probs.clear(); - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - multibyte_pending = 0; - n_remain = 0; - n_past = 0; - - if (grammar != nullptr) { - llama_grammar_free(grammar); - grammar = nullptr; - ctx_sampling = *llama_sampling_init(params.sparams); - } - } - - bool loadModel(const gpt_params ¶ms_) { - params = params_; - std::tie(model, ctx) = llama_init_from_gpt_params(params); - if (model == nullptr) { - return false; - } - n_ctx = llama_n_ctx(ctx); - last_n_tokens.resize(n_ctx); - std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); - return true; - } - - std::vector tokenize(std::string prompt, bool add_bos) const { - return ::llama_tokenize(ctx, prompt, add_bos); - } - - bool loadGrammar(JNIEnv *env) { - if (!params.sparams.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.sparams.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - log(env, GGML_LOG_LEVEL_ERROR, "grammar parse error"); - return false; - } - grammar_parser::print_grammar(stderr, parsed_grammar); - - { - auto it = params.sparams.logit_bias.find(llama_token_eos(model)); - if (it != params.sparams.logit_bias.end() && it->second == -INFINITY) { - log(env, GGML_LOG_LEVEL_WARN, "EOS token is disabled, which will cause most grammars to fail"); - } - } - - std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - } - ctx_sampling = *llama_sampling_init(params.sparams); - return true; - } - - void loadInfill(JNIEnv *env) { - bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { - params.input_suffix.erase(0, 1); - suff_rm_leading_spc = false; - } - - auto prefix_tokens = tokenize(params.input_prefix, false); - auto suffix_tokens = tokenize(params.input_suffix, false); - const int space_token = 29871; - if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { - suffix_tokens.erase(suffix_tokens.begin()); - } - prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); - prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS - prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); - prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); - prefix_tokens.push_back(llama_token_middle(model)); - auto prompt_tokens = prefix_tokens; - - num_prompt_tokens = prompt_tokens.size(); - - if (params.n_keep < 0) { - params.n_keep = (int) num_prompt_tokens; - } - params.n_keep = std::min(params.n_ctx - 4, params.n_keep); - - // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t) params.n_ctx) { - // todo we probably want to cut from both sides - const int n_left = (params.n_ctx - params.n_keep) / 2; - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert( - new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); - - log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left)); - - truncated = true; - prompt_tokens = new_tokens; - } else { - const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); - } - - // compare the evaluated prompt with the new prompt - n_past = common_part(embd, prompt_tokens); - embd = prompt_tokens; - - if (n_past == num_prompt_tokens) { - // we have to evaluate at least 1 token to generate logits. - n_past--; - } - - // since #3228 we now have to manually manage the KV cache - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); - - has_next_token = true; - } - - void loadPrompt(JNIEnv *env) { - auto prompt_tokens = tokenize(prompt, true); // always add BOS - - num_prompt_tokens = prompt_tokens.size(); - - if (params.n_keep < 0) { - params.n_keep = (int) num_prompt_tokens; - } - params.n_keep = std::min(n_ctx - 4, params.n_keep); - - // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t) n_ctx) { - const int n_left = (n_ctx - params.n_keep) / 2; - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert( - new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); - - log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left)); - - truncated = true; - prompt_tokens = new_tokens; - } else { - const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); - } - - // compare the evaluated prompt with the new prompt - n_past = common_part(embd, prompt_tokens); - - embd = prompt_tokens; - if (n_past == num_prompt_tokens) { - // we have to evaluate at least 1 token to generate logits. - n_past--; - } - - // since #3228 we now have to manually manage the KV cache - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); - - has_next_token = true; - } - - void beginCompletion() { - // number of tokens to keep when resetting context - n_remain = params.n_predict; - llama_set_rng_seed(ctx, params.seed); - } - - completion_token_output nextToken(JNIEnv *env) { - completion_token_output result; - result.tok = -1; - - if (embd.size() >= (size_t) n_ctx) { - // Shift context - - const int n_left = n_past - params.n_keep - 1; - const int n_discard = n_left / 2; - - llama_kv_cache_seq_rm(ctx, 0, params.n_keep + 1, params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); - - for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++) { - embd[i - n_discard] = embd[i]; - } - embd.resize(embd.size() - n_discard); - - n_past -= n_discard; - - truncated = true; - log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left)); - } - - bool tg = true; - while (n_past < embd.size()) { - int n_eval = (int) embd.size() - n_past; - tg = n_eval == 1; - if (n_eval > params.n_batch) { - n_eval = params.n_batch; - } - - if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0))) { - log(env, GGML_LOG_LEVEL_ERROR, "failed to eval n_eval=" + std::to_string(n_eval)); - has_next_token = false; - return result; - } - n_past += n_eval; - } - - if (params.n_predict == 0) { - has_next_token = false; - result.tok = llama_token_eos(model); - return result; - } - - { - // out of user input, sample next token - result.tok = llama_sampling_sample(&ctx_sampling, ctx, NULL); - - llama_token_data_array candidates_p = {ctx_sampling.cur.data(), ctx_sampling.cur.size(), false}; - - const int32_t n_probs = params.sparams.n_probs; - if (params.sparams.temp <= 0 && n_probs > 0) { - // For llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &candidates_p); - } - - for (size_t i = 0; i < std::min(candidates_p.size, (size_t) n_probs); ++i) { - result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); - } - - llama_sampling_accept(&ctx_sampling, ctx, result.tok, true); - if (tg) { - num_tokens_predicted++; - } - } - - // add it to the context - embd.push_back(result.tok); - // decrement remaining sampling budget - --n_remain; - - if (!embd.empty() && embd.back() == llama_token_eos(model)) { - // stopping_word = llama_token_to_piece(ctx, embd.back()); - has_next_token = false; - stopped_eos = true; - return result; - } - - has_next_token = params.n_predict == -1 || n_remain != 0; - return result; - } - - size_t findStoppingStrings(const std::string &text, const size_t last_token_size, const stop_type type) { - size_t stop_pos = std::string::npos; - for (const std::string &word : params.antiprompt) { - size_t pos; - if (type == STOP_FULL) { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - pos = text.find(word, from_pos); - } else { - pos = find_partial_stop_string(word, text); - } - if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (type == STOP_FULL) { - stopping_word = word; - stopped_word = true; - has_next_token = false; - } - stop_pos = pos; - } - } - return stop_pos; - } - - completion_token_output doCompletion(JNIEnv *env) { - auto token_with_probs = nextToken(env); - - const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); - generated_text += token_text; - - if (params.sparams.n_probs > 0) { - generated_token_probs.push_back(token_with_probs); - } - - if (multibyte_pending > 0) { - multibyte_pending -= token_text.size(); - } else if (token_text.size() == 1) { - const char c = token_text[0]; - // 2-byte characters: 110xxxxx 10xxxxxx - if ((c & 0xE0) == 0xC0) { - multibyte_pending = 1; - // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx - } else if ((c & 0xF0) == 0xE0) { - multibyte_pending = 2; - // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - } else if ((c & 0xF8) == 0xF0) { - multibyte_pending = 3; - } else { - multibyte_pending = 0; - } - } - - if (multibyte_pending > 0 && !has_next_token) { - has_next_token = true; - n_remain++; - } - - if (!has_next_token && n_remain == 0) { - stopped_limit = true; - } - - return token_with_probs; - } - - std::vector getEmbedding(JNIEnv *env) { - static const int n_embd = llama_n_embd(model); - if (!params.embedding) { - log(env, GGML_LOG_LEVEL_ERROR, "embedding disabled"); - return std::vector(n_embd, 0.0f); - } - const float *data = llama_get_embeddings(ctx); - std::vector embedding(data, data + n_embd); - return embedding; - } -}; - -static gpt_params parse_model_params(JNIEnv *env, jobject jparams, jstring java_file_path) { - gpt_params params; - - params.model = parse_jstring(env, java_file_path); - params.n_threads = env->GetIntField(jparams, f_n_threads); - params.n_ctx = env->GetIntField(jparams, f_n_ctx); - params.n_batch = env->GetIntField(jparams, f_n_batch); - params.n_gpu_layers = env->GetIntField(jparams, f_n_gpu_layers); - params.main_gpu = env->GetIntField(jparams, f_main_gpu); - params.rope_freq_base = env->GetFloatField(jparams, f_rope_freq_base); - params.rope_freq_scale = env->GetFloatField(jparams, f_rope_freq_scale); - params.mul_mat_q = env->GetBooleanField(jparams, f_mul_mat_q); - params.embedding = env->GetBooleanField(jparams, f_embedding); - params.escape = env->GetIntField(jparams, f_n_predict); - params.use_mmap = env->GetBooleanField(jparams, f_use_mmap); - params.use_mlock = env->GetBooleanField(jparams, f_use_mlock); - params.numa = env->GetBooleanField(jparams, f_numa); - params.verbose_prompt = env->GetBooleanField(jparams, f_verbose_prompt); - - if (params.model_alias == "unknown") { - params.model_alias = params.model; - } - - return params; -} - -static void setup_infer_params(JNIEnv *env, jllama_context *llama, jobject jparams) { - auto ¶ms = llama->params; - - params.seed = env->GetIntField(jparams, f_infer_seed); - params.n_predict = env->GetIntField(jparams, f_n_predict); - params.n_keep = env->GetIntField(jparams, f_n_keep); - - auto &sparams = params.sparams; - - sparams.top_k = env->GetIntField(jparams, f_top_k); - sparams.top_p = env->GetFloatField(jparams, f_top_p); - sparams.tfs_z = env->GetFloatField(jparams, f_tfs_z); - sparams.typical_p = env->GetFloatField(jparams, f_typical_p); - sparams.temp = env->GetFloatField(jparams, f_temperature); - sparams.penalty_repeat = env->GetFloatField(jparams, f_repeat_penalty); - sparams.n_prev = env->GetIntField(jparams, f_repeat_last_n); - sparams.penalty_freq = env->GetFloatField(jparams, f_frequency_penalty); - sparams.penalty_present = env->GetFloatField(jparams, f_presence_penalty); - sparams.penalize_nl = env->GetBooleanField(jparams, f_penalize_nl); - sparams.mirostat = env->GetIntField(jparams, f_mirostat); - sparams.mirostat_tau = env->GetFloatField(jparams, f_mirostat_tau); - sparams.mirostat_eta = env->GetFloatField(jparams, f_mirostat_eta); - sparams.n_probs = env->GetIntField(jparams, f_n_probs); - - jstring j_grammar = (jstring) env->GetObjectField(jparams, f_grammar); - if (j_grammar != nullptr) { - sparams.grammar = parse_jstring(env, j_grammar); - env->DeleteLocalRef(j_grammar); - if (!llama->loadGrammar(env)) { - env->ThrowNew(c_engine_exception, "could not load grammar"); - } - } - - sparams.logit_bias.clear(); - jboolean ignore_eos = env->GetBooleanField(jparams, f_ignore_eos); - if (ignore_eos) { - sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY; - } - - jobject logit_bias = env->GetObjectField(jparams, f_logit_bias); - if (logit_bias != nullptr) { - jobject entry_set = env->CallObjectMethod(logit_bias, m_entry_set); - jobject iterator = env->CallObjectMethod(entry_set, m_set_iterator); - while (env->CallBooleanMethod(iterator, m_iterator_has_next)) { - jobject entry = env->CallObjectMethod(iterator, m_iterator_next); - jobject key = env->CallObjectMethod(entry, m_entry_key); - jobject value = env->CallObjectMethod(entry, m_entry_value); - - int tok = parse_jinteger(env, key); - float bias = parse_jfloat(env, value); - sparams.logit_bias[tok] = bias; - - env->DeleteLocalRef(entry); - env->DeleteLocalRef(key); - env->DeleteLocalRef(value); - } - } - - params.antiprompt.clear(); - jobjectArray antiprompt = (jobjectArray) env->GetObjectField(jparams, f_antiprompt); - if (antiprompt != nullptr) { - jsize array_length = env->GetArrayLength(antiprompt); - for (jsize i = 0; i < array_length; i++) { - jstring java_string = (jstring) env->GetObjectArrayElement(antiprompt, i); - if (java_string != nullptr) { - std::string string = parse_jstring(env, java_string); - params.antiprompt.push_back(string); - env->DeleteLocalRef(java_string); - } - } - } - - llama->ctx_sampling = *llama_sampling_init(params.sparams); -} - -static void setup_answering(JNIEnv *env, jllama_context *llama, jstring prompt, jobject params) { - llama->prompt = parse_jstring(env, prompt); - llama->params.input_prefix = ""; - llama->params.input_suffix = ""; - setup_infer_params(env, llama, params); -} - -static void setup_infilling(JNIEnv *env, jllama_context *llama, jstring prefix, jstring suffix, jobject params) { - llama->prompt = ""; - llama->params.input_prefix = parse_jstring(env, prefix); - llama->params.input_suffix = parse_jstring(env, suffix); - setup_infer_params(env, llama, params); -} - -JNIEXPORT jlong JNICALL Java_ai_djl_llama_jni_LlamaLibrary_loadModel( - JNIEnv *env, jclass clazz, jstring file_path, jobject jparams) { - gpt_params params = parse_model_params(env, jparams, file_path); - - jllama_context *llama = new jllama_context; - llama_backend_init(false); - - if (!llama->loadModel(params)) { - env->ThrowNew(c_engine_exception, "could not load model from given file path"); - return 0; - } - - return reinterpret_cast(llama); -} - -JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_generate( - JNIEnv *env, jclass clazz, jlong handle, jstring prompt, jobject params) { - auto *llama = reinterpret_cast(handle); - - llama->rewind(); - llama_reset_timings(llama->ctx); - setup_answering(env, llama, prompt, params); - - llama->loadPrompt(env); - llama->beginCompletion(); -} - -JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_infill( - JNIEnv *env, jclass clazz, jlong handle, jstring prefix, jstring suffix, jobject params) { - auto *llama = reinterpret_cast(handle); - - llama->rewind(); - - llama_reset_timings(llama->ctx); - - setup_infilling(env, llama, prefix, suffix, params); - - llama->loadInfill(env); - llama->beginCompletion(); -} - -JNIEXPORT jobject JNICALL Java_ai_djl_llama_jni_LlamaLibrary_getNext( - JNIEnv *env, jclass clazz, jlong handle, jlong sent_count, jlong sent_token_probs_index) { - auto *llama = reinterpret_cast(handle); - - completion_token_output token_with_probs; - while (llama->has_next_token) { - token_with_probs = llama->doCompletion(env); - if (token_with_probs.tok >= 0 && llama->multibyte_pending <= 0) { - break; - } - } - const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok); - - size_t pos = std::min((size_t) sent_count, llama->generated_text.size()); - - const std::string str_test = llama->generated_text.substr(pos); - bool is_stop_full = false; - size_t stop_pos = llama->findStoppingStrings(str_test, token_text.size(), STOP_FULL); - if (stop_pos != std::string::npos) { - is_stop_full = true; - llama->generated_text.erase(llama->generated_text.begin() + pos + stop_pos, llama->generated_text.end()); - pos = std::min((size_t) sent_count, llama->generated_text.size()); - } else { - is_stop_full = false; - stop_pos = llama->findStoppingStrings(str_test, token_text.size(), STOP_PARTIAL); - } - - std::string to_send; - if (stop_pos == std::string::npos || - // Send rest of the text if we are at the end of the generation - (!llama->has_next_token && !is_stop_full && stop_pos > 0)) { - to_send = llama->generated_text.substr(pos, std::string::npos); - - sent_count += to_send.size(); - std::vector probs_output = {}; - - if (llama->params.sparams.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false); - size_t probs_pos = std::min((size_t) sent_token_probs_index, llama->generated_token_probs.size()); - size_t probs_stop_pos = - std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { - probs_output = std::vector( - llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos); - } - sent_token_probs_index = probs_stop_pos; - } - } else { - to_send = ""; - } - - jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); - for (const auto &tp : token_with_probs.probs) { - jobject jtoken = env->NewObject(c_integer, cc_integer, tp.tok); - jobject jprob = env->NewObject(c_float, cc_float, tp.prob); - env->CallObjectMethod(o_probabilities, m_map_put, jtoken, jprob); - } - - jbyteArray jbytes = parse_jbytes(env, to_send); - return env->NewObject(c_token, cc_token, token_with_probs.tok, jbytes, o_probabilities, sent_count, - sent_token_probs_index, llama->has_next_token); -} - -JNIEXPORT jfloatArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_embed( - JNIEnv *env, jclass clazz, jlong handle, jstring java_prompt) { - auto *llama = reinterpret_cast(handle); - - llama->rewind(); - llama_reset_timings(llama->ctx); - llama->prompt = parse_jstring(env, java_prompt); - llama->params.n_predict = 0; - llama->loadPrompt(env); - llama->beginCompletion(); - llama->doCompletion(env); - - static const int n_embd = llama_n_embd(llama->model); - const float *data = llama_get_embeddings(llama->ctx); - std::vector embedding(data, data + n_embd); - - jfloatArray java_embedding = env->NewFloatArray(embedding.size()); - env->SetFloatArrayRegion(java_embedding, 0, embedding.size(), reinterpret_cast(embedding.data())); - - return java_embedding; -} - -JNIEXPORT jintArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_encode( - JNIEnv *env, jclass clazz, jlong handle, jstring jprompt) { - auto *llama = reinterpret_cast(handle); - - std::string prompt = parse_jstring(env, jprompt); - std::vector tokens = llama->tokenize(prompt, false); - - jintArray java_tokens = env->NewIntArray(tokens.size()); - env->SetIntArrayRegion(java_tokens, 0, tokens.size(), reinterpret_cast(tokens.data())); - - return java_tokens; -} - -JNIEXPORT jbyteArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_decodeBytes( - JNIEnv *env, jclass clazz, jlong handle, jintArray java_tokens) { - auto *llama = reinterpret_cast(handle); - - jsize length = env->GetArrayLength(java_tokens); - jint *elements = env->GetIntArrayElements(java_tokens, nullptr); - std::vector tokens(elements, elements + length); - std::string text = tokens_to_str(llama->ctx, tokens.cbegin(), tokens.cend()); - - env->ReleaseIntArrayElements(java_tokens, elements, 0); - - return parse_jbytes(env, text); -} - -JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_delete(JNIEnv *env, jclass clazz, jlong handle) { - auto *llama = reinterpret_cast(handle); - delete llama; -} diff --git a/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider b/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider deleted file mode 100644 index d2f8ca8e42c..00000000000 --- a/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider +++ /dev/null @@ -1 +0,0 @@ -ai.djl.llama.engine.LlamaEngineProvider diff --git a/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider b/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider deleted file mode 100644 index 92f6245340f..00000000000 --- a/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider +++ /dev/null @@ -1 +0,0 @@ -ai.djl.llama.zoo.LlamaZooProvider diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java deleted file mode 100644 index 429cd569392..00000000000 --- a/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.engine; - -import ai.djl.llama.engine.LlamaInput.Parameters; -import ai.djl.llama.jni.InputParameters; -import ai.djl.util.JsonUtils; - -import org.testng.Assert; -import org.testng.annotations.Test; - -import java.io.IOException; -import java.io.Reader; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.Map; - -public class LlamaInputTest { - - @Test - public void testInputParameters() throws IOException { - Path file = Paths.get("src/test/resources/inputs.json"); - try (Reader reader = Files.newBufferedReader(file)) { - LlamaInput in = JsonUtils.GSON.fromJson(reader, LlamaInput.class); - checkParameters(in); - } - - Parameters param = new Parameters(); - LlamaInput in = new LlamaInput(); - in.setInputs("prompt"); - in.setPrefix("prefix"); - in.setSuffix("suffix"); - in.setParameters(param); - param.setMaxNewTokens(2); - param.setNumberKeep(2); - param.setNumberProbabilities(2); - param.setTopK(2); - param.setTopP(2f); - param.setTfsZ(2f); - param.setTypicalP(2f); - param.setTemperature(2f); - param.setRepeatPenalty(2f); - param.setRepeatLastN(2); - param.setFrequencyPenalty(2f); - param.setFrequencyPenalty(2f); - param.setPresencePenalty(2f); - param.setPenalizeNl(true); - param.setIgnoreEos(true); - param.setMirostat(2); - param.setMirostatTau(2f); - param.setMirostatEta(2f); - param.setNumberBeams(5); - param.setSeed(2); - Map logitBias = Map.of(2, 0.4f, 3, 0.5f); - param.setLogitBias(logitBias); - param.setGrammar("grammar"); - param.setAntiPrompt(new String[] {"User: "}); - checkParameters(in); - } - - private void checkParameters(LlamaInput in) { - InputParameters param = in.getParameters().toInputParameters(); - Assert.assertEquals(param.getMaxNewTokens(), 2); - Assert.assertEquals(param.getNumberKeep(), 2); - Assert.assertEquals(param.getNumberProbabilities(), 2); - Assert.assertEquals(param.getTopK(), 2); - Assert.assertEquals(param.getTopP(), 2f); - Assert.assertEquals(param.getTfsZ(), 2f); - Assert.assertEquals(param.getTypicalP(), 2f); - Assert.assertEquals(param.getTemperature(), 2f); - Assert.assertEquals(param.getRepeatPenalty(), 2f); - Assert.assertEquals(param.getRepeatLastN(), 2); - Assert.assertEquals(param.getFrequencyPenalty(), 2f); - Assert.assertEquals(param.getFrequencyPenalty(), 2f); - Assert.assertEquals(param.getPresencePenalty(), 2f); - Assert.assertTrue(param.isPenalizeNl()); - Assert.assertTrue(param.isIgnoreEos()); - Assert.assertEquals(param.getMirostat(), 2); - Assert.assertEquals(param.getMirostatTau(), 2f); - Assert.assertEquals(param.getMirostatEta(), 2f); - Assert.assertEquals(param.getNumberBeams(), 5); - Assert.assertEquals(param.getSeed(), 2); - Map logitBias = param.getLogitBias(); - Assert.assertNotNull(logitBias); - Assert.assertEquals(logitBias.size(), 2); - Assert.assertEquals(logitBias.get(2), 0.4f); - Assert.assertNotNull(param.getGrammar()); - Assert.assertNotNull(param.getAntiPrompt()[0], "User: "); - } -} diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java deleted file mode 100644 index 99e97b352b0..00000000000 --- a/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.engine; - -import ai.djl.ModelException; -import ai.djl.engine.Engine; -import ai.djl.engine.StandardCapabilities; -import ai.djl.inference.Predictor; -import ai.djl.llama.jni.Token; -import ai.djl.llama.jni.TokenIterator; -import ai.djl.modality.Input; -import ai.djl.modality.Output; -import ai.djl.ndarray.NDManager; -import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ZooModel; -import ai.djl.testing.TestRequirements; -import ai.djl.training.util.DownloadUtils; -import ai.djl.translate.TranslateException; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.testng.Assert; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; - -import java.io.IOException; -import java.net.URI; -import java.nio.file.Path; -import java.nio.file.Paths; - -public class LlamaTest { - - private static final Logger logger = LoggerFactory.getLogger(LlamaTest.class); - - @BeforeClass - public void setUp() { - System.setProperty("DJL_CACHE_DIR", "build/cache"); - } - - @AfterClass - public void tierDown() { - System.clearProperty("DJL_CACHE_DIR"); - } - - @Test - public void testLlamaVersion() { - Engine engine = Engine.getEngine("Llama"); - Assert.assertEquals(engine.getVersion(), "b1696-" + Engine.getDjlVersion()); - Assert.assertNotNull(engine.toString()); - Assert.assertEquals(engine.getRank(), 10); - Assert.assertFalse(engine.hasCapability(StandardCapabilities.CUDA)); - Assert.assertNull(engine.getAlternativeEngine()); - try (NDManager manager = engine.newBaseManager()) { - Assert.assertNotNull(manager); - } - } - - @Test - public void testLlama() throws TranslateException, ModelException, IOException { - TestRequirements.nightly(); - downloadModel(); - Path path = Paths.get("models"); - Criteria criteria = - Criteria.builder() - .setTypes(String.class, TokenIterator.class) - .optModelPath(path) - .optModelName("tinyllama-1.1b-1t-openorca.Q4_K_M") - .optEngine("Llama") - .optOption("number_gpu_layers", "43") - .optTranslatorFactory(new LlamaTranslatorFactory()) - .build(); - - String prompt = - "{\"inputs\": \"<|im_start|>system\n" - + "{system_message}<|im_end|>\n" - + "<|im_start|>user\n" - + "{prompt}<|im_end|>\n" - + "<|im_start|>assistant\", \"parameters\": {\"max_new_tokens\": 10}}"; - try (ZooModel model = criteria.loadModel(); - Predictor predictor = model.newPredictor()) { - TokenIterator it = predictor.predict(prompt); - StringBuilder sb = new StringBuilder(); - while (it.hasNext()) { - Token token = it.next(); - Assert.assertNotNull(token.getText()); - Assert.assertTrue(token.getToken() >= 0); - Assert.assertNotNull(token.getProbabilities()); - sb.append(token.getText()); - logger.info("{}", token); - } - Assert.assertTrue(sb.length() > 1); - } - } - - @Test - public void testLlamaInfill() throws TranslateException, ModelException, IOException { - TestRequirements.nightly(); - downloadModel(); - Path path = Paths.get("models/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf"); - Criteria criteria = - Criteria.builder() - .setTypes(Input.class, Output.class) - .optModelPath(path) - .optOption("number_gpu_layers", "43") - .optEngine("Llama") - .optTranslatorFactory(new LlamaTranslatorFactory()) - .build(); - - String prompt = - "{\n" - + " \"prefix\":\"def remove_non_ascii(s: str) -> str:\n\",\n" - + " \"suffix\":\"\n return result\n\",\n" - + " \"parameters\":{\n" - + " \"max_new_tokens\": 10" - + " }\n" - + "}"; - try (ZooModel model = criteria.loadModel(); - Predictor predictor = model.newPredictor()) { - Input in = new Input(); - in.add("data", prompt); - Output out = predictor.predict(in); - Assert.assertNotNull(out.getData().getAsString()); - } - } - - private void downloadModel() throws IOException { - String url = - "https://resources.djl.ai/test-models/gguf/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf"; - Path dir = Paths.get("models/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf"); - DownloadUtils.download(URI.create(url).toURL(), dir, null); - } -} diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java b/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java deleted file mode 100644 index b2ee786419f..00000000000 --- a/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java +++ /dev/null @@ -1,14 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -/** Contains test classes for llama engine. */ -package ai.djl.llama.engine; diff --git a/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java b/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java deleted file mode 100644 index fab7bacb9e3..00000000000 --- a/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.llama.zoo; - -import ai.djl.repository.zoo.ModelLoader; -import ai.djl.repository.zoo.ModelZoo; -import ai.djl.util.Utils; - -import org.testng.Assert; -import org.testng.annotations.Test; - -import java.nio.file.Paths; -import java.util.Collection; - -public class LlamaModelZooTest { - - @Test - public void testLlamaModelZoo() { - System.setProperty("DJL_CACHE_DIR", "build/cache"); - Utils.deleteQuietly(Paths.get("build/cache/cache")); - try { - ModelZoo zoo = ModelZoo.getModelZoo("ai.djl.huggingface.gguf"); - Collection models = zoo.getModelLoaders(); - Assert.assertFalse(models.isEmpty()); - Assert.assertEquals(zoo.getSupportedEngines().size(), 1); - ModelLoader loader = zoo.getModelLoader("TinyLlama/TinyLlama-1.1B-Chat-v0.6"); - Assert.assertNotNull(loader); - - ModelZoo llamaModelZoo = new LlamaModelZoo(); - Assert.assertFalse(llamaModelZoo.getModelLoaders().isEmpty()); - } finally { - System.clearProperty("DJL_CACHE_DIR"); - } - } - - @Test - public void testOffLine() { - System.setProperty("DJL_CACHE_DIR", "build/cache"); - System.setProperty("ai.djl.offline", "true"); - Utils.deleteQuietly(Paths.get("build/cache/cache")); - try { - // static variables cannot not be initialized properly if directly use LlamaModelZoo() - ModelZoo.getModelZoo("ai.djl.huggingface.gguf"); - - ModelZoo zoo = new LlamaModelZoo(); - Assert.assertFalse(zoo.getModelLoaders().isEmpty()); - } finally { - System.clearProperty("DJL_CACHE_DIR"); - System.clearProperty("ai.djl.offline"); - } - } -} diff --git a/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java b/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java deleted file mode 100644 index 145b2ddcca9..00000000000 --- a/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java +++ /dev/null @@ -1,14 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -/** Contains test classes for llama model zoo. */ -package ai.djl.llama.zoo; diff --git a/engines/llama/src/test/resources/inputs.json b/engines/llama/src/test/resources/inputs.json deleted file mode 100644 index ab77386e1b6..00000000000 --- a/engines/llama/src/test/resources/inputs.json +++ /dev/null @@ -1,33 +0,0 @@ -{ - "prefix": "def remove_non_ascii(s: str) -> str:", - "suffix": " return result", - "parameters": { - "max_new_tokens": 2, - "number_keep": 2, - "number_probabilities": 2, - "top_k": 2, - "top_p": 2, - "tfs_z": 2, - "typical_p": 2, - "temperature": 2, - "repeat_penalty": 2, - "repeat_last_n": 2, - "frequency_penalty": 2, - "presence_penalty": 2, - "penalize_nl": true, - "ignore_eos": true, - "mirostat": 2, - "mirostat_tau": 2, - "mirostat_eta": 2, - "number_beams": 5, - "seed": 2, - "logit_bias": { - "2": 0.4, - "5": 0.6 - }, - "grammar": "root ::= (expr \"=\" term \"\\n\")+\nexpr ::= term ([-+*/] term)*\nterm ::= [0-9]", - "anti_prompt": [ - "User: " - ] - } -} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 878b145793f..2b6062bb6f7 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -11,7 +11,6 @@ onnxruntime = "1.18.0" xgboost = "2.0.3" lightgbm = "3.2.110" tensorrt = "8.4.1" -llamacpp = "b1696" fasttext = "0.9.2" sentencepiece = "0.2.0" tokenizers = "0.19.1" diff --git a/settings.gradle.kts b/settings.gradle.kts index fd1c6bdd50b..0e907a2ada9 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -7,7 +7,6 @@ plugins { include(":api") include(":basicdataset") include(":djl-zero") -include(":engines:llama") include(":engines:ml:lightgbm") include(":engines:ml:xgboost") include(":engines:mxnet:jnarator")