diff --git a/.github/workflows/build_triton_and_ft.yml b/.github/workflows/build_triton_and_ft.yml index 2893e1567ff..7d0b9aaf7ff 100644 --- a/.github/workflows/build_triton_and_ft.yml +++ b/.github/workflows/build_triton_and_ft.yml @@ -1,4 +1,4 @@ -name: Build Triton Server and FasterTransformers +name: Build Triton Server on: workflow_dispatch: @@ -6,16 +6,7 @@ on: triton: description: 'triton branch version' required: true - default: 'r23.04' - fastertransformer: - description: 'fastertransformer branch/tag version' - required: true - default: 'main' - is_llama_build: - description: 'whether to build custom llama source' - required: false - type: boolean - default: false + default: 'r23.10' jobs: build-triton: @@ -33,7 +24,8 @@ jobs: - name: Build Triton Binary shell: 'script -q -e -c "bash --noprofile --norc -eo pipefail {0}"' run: | - python3 build.py --enable-logging --enable-metrics --enable-stats --enable-cpu-metrics --endpoint http + pip3 install requests + python3 build.py --enable-logging --enable-metrics --enable-stats --enable-cpu-metrics --enable-gpu --endpoint http - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v2 with: @@ -45,60 +37,3 @@ jobs: aws s3 cp build/install/lib/libtritonserver.so s3://djl-ai/publish/tritonserver/${{ github.event.inputs.triton }}/ aws s3 cp build/install/bin/tritonserver s3://djl-ai/publish/tritonserver/${{ github.event.inputs.triton }}/ aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/tritonserver/${{ github.event.inputs.triton }}/*" - - create-runner: - if: github.repository == 'deepjavalibrary/djl' - runs-on: [ self-hosted, scheduler ] - steps: - - name: Create new CPU instance - id: create_cpu - run: | - cd /home/ubuntu/djl_benchmark_script/scripts - token=$( curl -X POST -H "Authorization: token ${{ secrets.ACTION_RUNNER_PERSONAL_TOKEN }}" \ - https://api.github.com/repos/deepjavalibrary/djl/actions/runners/registration-token \ - --fail \ - | jq '.token' | tr -d '"' ) - ./start_instance.sh action_cpu $token djl - outputs: - cpu_instance_id: ${{ steps.create_cpu.outputs.action_cpu_instance_id }} - - - build-fastertransformer: - if: github.repository == 'deepjavalibrary/djl' - runs-on: [ self-hosted, cpu ] - container: deepjavalibrary/djl-serving:fastertransformer-nightly - timeout-minutes: 60 - needs: create-runner - steps: - - uses: actions/checkout@v3 - - name: Build FasterTransformers - run: | - tools/scripts/build_ft_deps.sh ${{ github.event.inputs.fastertransformer }} ${{ github.event.inputs.triton }} ${{ github.event.inputs.is_llama_build }} - - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v2 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: us-east-2 - - name: Copy files to S3 with the AWS CLI - if: github.event.inputs.is_llama_build == 'false' - run: | - aws s3 sync /tmp/binaries/ s3://djl-ai/publish/fastertransformer/${{ github.event.inputs.fastertransformer }}/ - aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/fastertransformer/${{ github.event.inputs.fastertransformer }}/*" - - name: Copy files for llama build to S3 with AWS CLI - if: github.event.inputs.is_llama_build == 'true' - run: | - echo "pushing binaries to ft/llama" - aws s3 sync /tmp/binaries/ s3://djl-ai/publish/fastertransformer/llama/ - aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/fastertransformer-llama/${{ github.event.inputs.fastertransformer }}/*" - - stop-runner: - if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} - runs-on: [ self-hosted, scheduler ] - needs: [ create-runner, build-fastertransformer] - steps: - - name: Stop all instances - run: | - cd /home/ubuntu/djl_benchmark_script/scripts - instance_id=${{ needs.create-runner.outputs.cpu_instance_id }} - ./stop_instance.sh $instance_id diff --git a/.github/workflows/codeql-analysis-java.yml b/.github/workflows/codeql-analysis-java.yml index 107efb92286..02119677e3d 100644 --- a/.github/workflows/codeql-analysis-java.yml +++ b/.github/workflows/codeql-analysis-java.yml @@ -36,11 +36,11 @@ jobs: - name: Checkout repository uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Init gradle run: ./gradlew --no-daemon clean diff --git a/.github/workflows/continuous.yml b/.github/workflows/continuous.yml index 5cfa1503ed5..69621854f5e 100644 --- a/.github/workflows/continuous.yml +++ b/.github/workflows/continuous.yml @@ -21,11 +21,11 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: @@ -107,11 +107,11 @@ jobs: runs-on: windows-latest steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 45e2177466d..0f9dfe33870 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -3,7 +3,6 @@ name: Docs on: pull_request: paths: - - "**.ipynb" - "docs/mkdocs.yml" # Publish docs weekly schedule: @@ -15,11 +14,11 @@ jobs: if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest steps: - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Set up Python3 uses: actions/setup-python@v4 with: @@ -38,10 +37,6 @@ jobs: - name: add mybinder link run: | python3 tools/scripts/add_online_runner.py - - name: run Notebooks - run: | - cd jupyter - bash test_notebook.sh - name: clone demos run: | cd docs @@ -50,6 +45,10 @@ jobs: run: | cd docs git clone https://github.com/deepjavalibrary/djl-serving.git serving + - name: run Notebooks + run: | + cd docs/demos/jupyter + bash test_notebook.sh - name: build docs run: | cd docs diff --git a/.github/workflows/native_jni_s3_paddle.yml b/.github/workflows/native_jni_s3_paddle.yml index 3cb9a62f7c9..f35b54b41c9 100644 --- a/.github/workflows/native_jni_s3_paddle.yml +++ b/.github/workflows/native_jni_s3_paddle.yml @@ -15,11 +15,11 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -55,11 +55,11 @@ jobs: ln -sf /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches diff --git a/.github/workflows/native_jni_s3_pytorch.yml b/.github/workflows/native_jni_s3_pytorch.yml index 7bd2edc1677..d0f284093ba 100644 --- a/.github/workflows/native_jni_s3_pytorch.yml +++ b/.github/workflows/native_jni_s3_pytorch.yml @@ -16,11 +16,11 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -55,11 +55,11 @@ jobs: container: nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04 steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -116,11 +116,11 @@ jobs: ln -s /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -137,6 +137,7 @@ jobs: ./gradlew :engines:pytorch:pytorch-native:compileJNI -Pprecxx11 -Ppt_version=$PYTORCH_VERSION ./gradlew -Pjni -Ppt_version=$PYTORCH_VERSION :integration:test "-Dai.djl.default_engine=PyTorch" ./gradlew :engines:pytorch:pytorch-native:cleanJNI + rm -rf ~/.djl.ai ./gradlew :engines:pytorch:pytorch-native:compileJNI -Pcu11 -Pprecxx11 -Ppt_version=$PYTORCH_VERSION - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v1-node16 @@ -156,11 +157,11 @@ jobs: runs-on: windows-latest steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -205,13 +206,13 @@ jobs: build-pytorch-jni-arm64-macos: if: github.repository == 'deepjavalibrary/djl' - runs-on: [ self-hosted, ARM64, macOS ] + runs-on: macos-latest-xlarge steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - uses: actions/cache@v3 @@ -259,6 +260,7 @@ jobs: aarch64_instance_id: ${{ steps.create_aarch64.outputs.action_graviton_instance_id }} build-pytorch-jni-aarch64: + if: github.repository == 'deepjavalibrary/djl' runs-on: [ self-hosted, aarch64 ] container: amazonlinux:2 timeout-minutes: 30 @@ -268,14 +270,14 @@ jobs: run: | yum -y update yum -y groupinstall "Development Tools" - yum -y install patch git cmake3 python3-devel java-11-amazon-corretto + yum -y install patch git cmake3 python3-devel java-17-amazon-corretto-devel ln -sf /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade - uses: actions/checkout@v3 - name: Release JNI prep run: | - export JAVA_HOME=/usr/lib/jvm/java-11-amazon-corretto.aarch64 - export PATH=$PATH:$JAVA_HOME + export JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.aarch64 + export PATH=$PATH:$JAVA_HOME/bin PYTORCH_VERSION=${{ github.event.inputs.pt_version }} export PYTORCH_VERSION=${PYTORCH_VERSION:-$(cat gradle.properties | awk -F '=' '/pytorch_version/ {print $2}')} echo $PYTORCH_VERSION diff --git a/.github/workflows/native_jni_s3_pytorch_android.yml b/.github/workflows/native_jni_s3_pytorch_android.yml index c3834057c9e..8f09232f785 100644 --- a/.github/workflows/native_jni_s3_pytorch_android.yml +++ b/.github/workflows/native_jni_s3_pytorch_android.yml @@ -7,17 +7,18 @@ on: jobs: build-pytorch-jni-android: + if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest env: NDK_VERSION: "21.1.6352462" steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches diff --git a/.github/workflows/native_jni_s3_tensorrt.yml b/.github/workflows/native_jni_s3_tensorrt.yml index cf39c6e070a..61ef4722e83 100644 --- a/.github/workflows/native_jni_s3_tensorrt.yml +++ b/.github/workflows/native_jni_s3_tensorrt.yml @@ -11,11 +11,11 @@ jobs: - name: Install Environment run: pip3 install awscli --upgrade - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches diff --git a/.github/workflows/native_publish_mxnet.yml b/.github/workflows/native_publish_mxnet.yml index b6ff595e9cf..d8157bb8d85 100644 --- a/.github/workflows/native_publish_mxnet.yml +++ b/.github/workflows/native_publish_mxnet.yml @@ -14,11 +14,11 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: diff --git a/.github/workflows/native_publish_paddle.yml b/.github/workflows/native_publish_paddle.yml index bfcdbd42708..0a9ed6761bc 100644 --- a/.github/workflows/native_publish_paddle.yml +++ b/.github/workflows/native_publish_paddle.yml @@ -14,11 +14,11 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: diff --git a/.github/workflows/native_publish_pytorch.yml b/.github/workflows/native_publish_pytorch.yml index f3ef6496a44..5145d4017b2 100644 --- a/.github/workflows/native_publish_pytorch.yml +++ b/.github/workflows/native_publish_pytorch.yml @@ -14,11 +14,11 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: diff --git a/.github/workflows/native_publish_tensorflow.yml b/.github/workflows/native_publish_tensorflow.yml index b4d95d768ba..d013f3925b4 100644 --- a/.github/workflows/native_publish_tensorflow.yml +++ b/.github/workflows/native_publish_tensorflow.yml @@ -14,11 +14,11 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: diff --git a/.github/workflows/native_publish_tflite.yml b/.github/workflows/native_publish_tflite.yml index c88d2f00491..4e252af18d8 100644 --- a/.github/workflows/native_publish_tflite.yml +++ b/.github/workflows/native_publish_tflite.yml @@ -14,11 +14,11 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: diff --git a/.github/workflows/native_s3_fasttext.yml b/.github/workflows/native_s3_fasttext.yml index 948ebabf478..4994c534301 100644 --- a/.github/workflows/native_s3_fasttext.yml +++ b/.github/workflows/native_s3_fasttext.yml @@ -8,11 +8,11 @@ jobs: runs-on: macos-latest steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -46,11 +46,11 @@ jobs: ln -sf /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -75,13 +75,13 @@ jobs: build-fasttext-jni-arm64-osx: if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} - runs-on: [ self-hosted, ARM64, macOS ] + runs-on: macos-latest-xlarge steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - uses: actions/cache@v3 diff --git a/.github/workflows/native_s3_huggingface.yml b/.github/workflows/native_s3_huggingface.yml index d9ce2d29197..31204dd3d39 100644 --- a/.github/workflows/native_s3_huggingface.yml +++ b/.github/workflows/native_s3_huggingface.yml @@ -8,11 +8,11 @@ jobs: runs-on: macos-latest steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions-rs/toolchain@v1 with: toolchain: stable @@ -45,7 +45,7 @@ jobs: - name: Install Environment run: | yum -y update - yum -y install centos-release-scl-rh epel-release + yum -y install centos-release-scl-rh epel-release perl-core yum -y install devtoolset-7 git patch cmake3 libstdc++-static ln -s /usr/bin/cmake3 /usr/bin/cmake curl https://sh.rustup.rs -sSf | sh -s -- -y @@ -54,11 +54,11 @@ jobs: with: toolchain: stable - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -93,11 +93,11 @@ jobs: with: toolchain: stable - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -124,13 +124,13 @@ jobs: build-tokenizers-jni-arm64-osx: if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} - runs-on: [ self-hosted, ARM64, macOS ] + runs-on: macos-latest-xlarge steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - uses: actions-rs/toolchain@v1 @@ -155,8 +155,8 @@ jobs: - name: Copy files to S3 with the AWS CLI run: | TOKENIZERS_VERSION="$(cat gradle.properties | awk -F '=' '/tokenizers_version/ {print $2}')" - aws s3 sync extensions/tokenizers/jnilib s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/ - aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/tokenizers/${TOKENIZERS_VERSION}/jnilib/*" + /opt/homebrew/bin/aws s3 sync extensions/tokenizers/jnilib s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/ + /opt/homebrew/bin/aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/tokenizers/${TOKENIZERS_VERSION}/jnilib/*" create-aarch64-runner: if: github.repository == 'deepjavalibrary/djl' @@ -179,24 +179,23 @@ jobs: runs-on: [ self-hosted, aarch64 ] timeout-minutes: 30 needs: create-aarch64-runner - container: centos:centos7 + container: amazonlinux:2 steps: - name: Install Environment run: | yum -y update - yum -y install centos-release-scl-rh epel-release - yum -y install devtoolset-7 git patch cmake3 libstdc++-static + yum -y groupinstall "Development Tools" + yum -y install patch perl-IPC-Cmd cmake3 ln -s /usr/bin/cmake3 /usr/bin/cmake - curl https://sh.rustup.rs -sSf | sh -s -- -y pip3 install awscli --upgrade - uses: actions-rs/toolchain@v1 with: toolchain: stable - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - uses: actions/cache@v3 @@ -207,8 +206,6 @@ jobs: ${{ runner.os }}-gradle- - name: Release JNI prep run: | - source "$HOME/.cargo/env" - export PATH=$PATH:/opt/rh/devtoolset-7/root/usr/bin ./gradlew :extensions:tokenizers:compileJNI PYTORCH_PRECXX11=true ./gradlew -Pjni :extensions:tokenizers:test - name: Configure AWS Credentials diff --git a/.github/workflows/native_s3_llama.yml b/.github/workflows/native_s3_llama.yml new file mode 100644 index 00000000000..e7f532ea7ac --- /dev/null +++ b/.github/workflows/native_s3_llama.yml @@ -0,0 +1,204 @@ +name: Native S3 llama.cpp + +on: + workflow_dispatch: + +jobs: + build-llamacpp-jni-osx: + runs-on: macos-latest + steps: + - uses: actions/checkout@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + distribution: 'corretto' + java-version: 17 + - uses: actions/cache@v3 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} + restore-keys: | + ${{ runner.os }}-gradle- + - name: Release JNI prep + run: | + ./gradlew :engines:llama:compileJNI + ./gradlew -Pjni :engines:llama:test -Dnightly=true + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v1-node16 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Copy files to S3 with the AWS CLI + run: | + LLAMACPP_VERSION="$(cat gradle.properties | awk -F '=' '/llamacpp_version/ {print $2}')" + aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" + + build-llamacpp-jni-linux: + runs-on: ubuntu-latest + container: centos:centos7 + steps: + - name: Install Environment + run: | + yum -y update + yum -y install centos-release-scl-rh epel-release perl-core + yum -y install devtoolset-7 git patch cmake3 libstdc++-static + ln -s /usr/bin/cmake3 /usr/bin/cmake + pip3 install awscli --upgrade + - uses: actions/checkout@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v3 + with: + distribution: 'corretto' + java-version: 17 + - name: Release JNI prep + run: | + export PATH=$PATH:/opt/rh/devtoolset-7/root/usr/bin + ./gradlew :engines:llama:compileJNI + ./gradlew -Pjni :engines:llama:test -Dnightly=true + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v1-node16 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Copy files to S3 with the AWS CLI + run: | + LLAMACPP_VERSION="$(cat gradle.properties | awk -F '=' '/llamacpp_version/ {print $2}')" + aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" + + build-llamacpp-jni-windows: + runs-on: windows-latest + steps: + - name: Install Environment + run: | + choco install -y mingw + - uses: actions/checkout@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + distribution: 'corretto' + java-version: 17 + - uses: actions/cache@v3 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} + restore-keys: | + ${{ runner.os }}-gradle- + - name: Release CPU JNI + shell: cmd + run: | + call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" amd64 + gradlew :engines:llama:compileJNI + gradlew -Pjni :engines:llama:test -Dnightly=true + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v1-node16 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Copy files to S3 with the AWS CLI + shell: bash + run: | + LLAMACPP_VERSION="$(cat gradle.properties | awk -F '=' '/llamacpp_version/ {print $2}')" + aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" + + build-llamacpp-jni-arm64-osx: + if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} + runs-on: macos-latest-xlarge + steps: + - uses: actions/checkout@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: 17 + distribution: corretto + architecture: aarch64 + - uses: actions/cache@v3 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*') }} + restore-keys: | + ${{ runner.os }}-gradle- + - name: Release JNI prep + run: | + ./gradlew :engines:llama:compileJNI + ./gradlew -Pjni :engines:llama:test -Dnightly=true + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v1-node16 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Copy files to S3 with the AWS CLI + run: | + LLAMACPP_VERSION="$(cat gradle.properties | awk -F '=' '/llamacpp_version/ {print $2}')" + aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" + + create-aarch64-runner: + if: github.repository == 'deepjavalibrary/djl' + runs-on: [ self-hosted, scheduler ] + steps: + - name: Create new Graviton instance + id: create_aarch64 + run: | + cd /home/ubuntu/djl_benchmark_script/scripts + token=$( curl -X POST -H "Authorization: token ${{ secrets.ACTION_RUNNER_PERSONAL_TOKEN }}" \ + https://api.github.com/repos/deepjavalibrary/djl/actions/runners/registration-token \ + --fail \ + | jq '.token' | tr -d '"' ) + ./start_instance.sh action_graviton $token djl + outputs: + aarch64_instance_id: ${{ steps.create_aarch64.outputs.action_graviton_instance_id }} + + build-llamacpp-jni-aarch64: + if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} + runs-on: [ self-hosted, aarch64 ] + timeout-minutes: 30 + needs: create-aarch64-runner + container: amazonlinux:2 + steps: + - name: Install Environment + run: | + yum -y update + yum -y groupinstall "Development Tools" + yum -y install patch perl-IPC-Cmd cmake3 + ln -s /usr/bin/cmake3 /usr/bin/cmake + pip3 install awscli --upgrade + - uses: actions/checkout@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v3 + with: + java-version: 17 + distribution: corretto + architecture: aarch64 + - name: Release JNI prep + run: | + ./gradlew :engines:llama:compileJNI + ./gradlew -Pjni :engines:llama:test -Dnightly=true + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v1-node16 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Copy files to S3 with the AWS CLI + run: | + LLAMACPP_VERSION="$(cat gradle.properties | awk -F '=' '/llamacpp_version/ {print $2}')" + aws s3 sync engines/llama/jnilib s3://djl-ai/publish/llama/${LLAMACPP_VERSION}/jnilib/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/llama/${LLAMACPP_VERSION}/jnilib/*" + + stop-runners: + if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} + runs-on: [ self-hosted, scheduler ] + needs: [ create-aarch64-runner, build-llamacpp-jni-aarch64 ] + steps: + - name: Stop all instances + run: | + cd /home/ubuntu/djl_benchmark_script/scripts + instance_id=${{ needs.create-aarch64-runner.outputs.aarch64_instance_id }} + ./stop_instance.sh $instance_id diff --git a/.github/workflows/native_s3_pytorch.yml b/.github/workflows/native_s3_pytorch.yml index fa3800e45da..4944affad6b 100644 --- a/.github/workflows/native_s3_pytorch.yml +++ b/.github/workflows/native_s3_pytorch.yml @@ -8,11 +8,11 @@ jobs: runs-on: macos-latest steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: diff --git a/.github/workflows/native_s3_pytorch_android.yml b/.github/workflows/native_s3_pytorch_android.yml index 3e03319be02..2720c3ae31d 100644 --- a/.github/workflows/native_s3_pytorch_android.yml +++ b/.github/workflows/native_s3_pytorch_android.yml @@ -10,15 +10,15 @@ jobs: matrix: format: ["armeabi-v7a", "arm64-v8a", "x86" ,"x86_64"] env: - PYTORCH_VERSION: "2.0.1" + PYTORCH_VERSION: "2.1.1" NDK_VERSION: "21.1.6352462" steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Set up Python3 uses: actions/setup-python@v4 with: @@ -48,8 +48,3 @@ jobs: run: | aws s3 cp android_pytorch_tmp/build_android/${{ matrix.format }}_native.zip s3://djl-ai/publish/pytorch/${PYTORCH_VERSION}/android_native/ aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/pytorch/${PYTORCH_VERSION}/android_native*" -# - name: Upload pytorch src -# uses: actions/upload-artifact@v3 -# with: -# name: pytorch-src-${{ matrix.format }} -# path: android_pytorch_tmp diff --git a/.github/workflows/native_s3_sentencepiece.yml b/.github/workflows/native_s3_sentencepiece.yml index 1e217ed218d..3d786a09066 100644 --- a/.github/workflows/native_s3_sentencepiece.yml +++ b/.github/workflows/native_s3_sentencepiece.yml @@ -9,11 +9,11 @@ jobs: runs-on: macos-latest steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -49,11 +49,11 @@ jobs: ln -sf /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -82,11 +82,11 @@ jobs: runs-on: windows-latest steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - uses: actions/cache@v3 with: path: ~/.gradle/caches @@ -112,13 +112,13 @@ jobs: build-sentencepiece-jni-arm64-osx: if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} - runs-on: [ self-hosted, ARM64, macOS ] + runs-on: macos-latest-xlarge steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - uses: actions-rs/toolchain@v1 @@ -164,10 +164,10 @@ jobs: with: toolchain: stable - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - uses: actions/cache@v3 diff --git a/.github/workflows/native_s3_tensorflow.yml b/.github/workflows/native_s3_tensorflow.yml index 3a119bbc7a1..5cd85047341 100644 --- a/.github/workflows/native_s3_tensorflow.yml +++ b/.github/workflows/native_s3_tensorflow.yml @@ -8,11 +8,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: diff --git a/.github/workflows/native_s3_tflite.yml b/.github/workflows/native_s3_tflite.yml index a8544baf669..7f5f92a8c9f 100644 --- a/.github/workflows/native_s3_tflite.yml +++ b/.github/workflows/native_s3_tflite.yml @@ -17,11 +17,11 @@ jobs: repository: tensorflow/tensorflow ref: v${{ env.TFLITE_VERSION }} submodules: 'recursive' - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Set up Python3 uses: actions/setup-python@v4 with: @@ -53,7 +53,7 @@ jobs: run: | yum -y update yum -y groupinstall "Development Tools" - yum -y install patch cmake3 unzip which java-11-amazon-corretto + yum -y install patch cmake3 unzip which java-17-amazon-corretto-devel ln -sf /usr/bin/cmake3 /usr/bin/cmake pip3 install awscli --upgrade pip3 install numpy --upgrade @@ -70,7 +70,7 @@ jobs: - name: build package run: | cd tensorflow - export JAVA_HOME=/usr/lib/jvm/java-11-amazon-corretto.x86_64/ + export JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/ curl -L https://github.com/bazelbuild/bazel/releases/download/3.7.2/bazel-3.7.2-installer-linux-x86_64.sh -o bazel.sh --retry 10 bash bazel.sh bazel build -c opt //tensorflow/lite/java:tensorflowlitelib //tensorflow/lite/delegates/flex:delegate diff --git a/.github/workflows/native_s3_xgboost.yml b/.github/workflows/native_s3_xgboost.yml index 3d92e1bd3a8..98b075df43f 100644 --- a/.github/workflows/native_s3_xgboost.yml +++ b/.github/workflows/native_s3_xgboost.yml @@ -34,23 +34,23 @@ jobs: run: | yum -y update yum -y install centos-release-scl-rh epel-release - yum -y install devtoolset-7 git patch libstdc++-static curl python3-devel + yum -y install devtoolset-8 git patch libstdc++-static curl python3-devel curl -L -o cmake.tar.gz https://github.com/Kitware/CMake/releases/download/v3.27.0-rc2/cmake-3.27.0-rc2-linux-aarch64.tar.gz tar xvfz cmake.tar.gz ln -sf $PWD/cmake-3.*/bin/cmake /usr/bin/cmake cmake --version pip3 install awscli --upgrade - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Release JNI prep run: | XGBOOST_VERSION=${{ github.event.inputs.xgb_version }} XGBOOST_VERSION=${XGBOOST_VERSION:-$(cat gradle.properties | awk -F '=' '/xgboost_version/ {print $2}')} git clone https://github.com/dmlc/xgboost --recursive -b v"$XGBOOST_VERSION" - export PATH=$PATH:/opt/rh/devtoolset-7/root/usr/bin + export PATH=$PATH:/opt/rh/devtoolset-8/root/usr/bin cd xgboost/jvm-packages python3 create_jni.py cd ../.. diff --git a/.github/workflows/nightly_android.yml b/.github/workflows/nightly_android.yml index 541d2ba9275..d7aa268f195 100644 --- a/.github/workflows/nightly_android.yml +++ b/.github/workflows/nightly_android.yml @@ -9,17 +9,18 @@ on: jobs: build: + if: github.repository == 'deepjavalibrary/djl' runs-on: macos-latest strategy: matrix: api-level: [ 26 ] steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 - name: Gradle cache uses: gradle/gradle-build-action@v2 - name: run tests diff --git a/.github/workflows/nightly_publish.yml b/.github/workflows/nightly_publish.yml index 64ac23e852c..b5078ef892d 100644 --- a/.github/workflows/nightly_publish.yml +++ b/.github/workflows/nightly_publish.yml @@ -16,15 +16,15 @@ jobs: runs-on: ${{ matrix.operating-system }} strategy: matrix: - operating-system: [ macos-12, ubuntu-latest ] + operating-system: [ macos-13, ubuntu-latest ] steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: @@ -34,6 +34,9 @@ jobs: ${{ runner.os }}-gradle- - name: check disk space run: df -h + - name: install libomp on macos + if: ${{ runner.os == 'macOS' }} + run: brew install libomp - name: Build with Gradle run: ./gradlew -Dnightly=true build :jacoco:testCodeCoverageReport - name: Upload test results @@ -52,11 +55,11 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: @@ -76,11 +79,11 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: @@ -103,10 +106,10 @@ jobs: yum -y update yum install -y tar gzip - uses: actions/checkout@v3 - - name: Set up JDK 11 + - name: Set up JDK 17 uses: actions/setup-java@v3 with: - java-version: 11 + java-version: 17 distribution: corretto architecture: aarch64 - uses: actions/cache@v3 @@ -122,11 +125,11 @@ jobs: ./gradlew :integration:test "-Dai.djl.default_engine=OnnxRuntime" ./gradlew :integration:clean - test-cuda-118: + test-cuda-121: if: github.repository == 'deepjavalibrary/djl' runs-on: [ self-hosted, gpu ] container: - image: nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu18.04 + image: nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu20.04 options: --gpus all --runtime=nvidia timeout-minutes: 30 needs: create-runners @@ -137,10 +140,10 @@ jobs: apt-get install -y software-properties-common wget locales libfontconfig1 locale-gen en_US.UTF-8 - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: - java-version: 11 + java-version: 17 distribution: corretto - uses: actions/cache@v3 with: @@ -163,14 +166,14 @@ jobs: publish: if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest - needs: [ build, test-pytorch, test-tensorflow, test-aarch64, test-cuda-118 ] + needs: [ build, test-pytorch, test-tensorflow, test-aarch64, test-cuda-121 ] steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: @@ -181,10 +184,8 @@ jobs: - name: Publish to snapshot repository if: ${{ github.event.inputs.mode == '' || github.event.inputs.mode == 'snapshot' }} run: | - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.11.0 -Psnapshot - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.12.1 -Psnapshot ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.13.1 -Psnapshot - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.0.1 -Psnapshot + ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.1.1 -Psnapshot ./gradlew clean engines:ml:xgboost:publish -Pgpu -Psnapshot ./gradlew clean publish -Psnapshot cd bom @@ -197,10 +198,8 @@ jobs: - name: Publish to staging repository if: ${{ github.event.inputs.mode == 'staging' }} run: | - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.11.0 -P${{ github.event.inputs.mode }} - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.12.1 -P${{ github.event.inputs.mode }} ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.13.1 -P${{ github.event.inputs.mode }} - ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.0.1 -P${{ github.event.inputs.mode }} + ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.1.1 -P${{ github.event.inputs.mode }} ./gradlew clean engines:ml:xgboost:publish -Pgpu -P${{ github.event.inputs.mode }} ./gradlew clean publish -P${{ github.event.inputs.mode }} cd bom @@ -246,7 +245,7 @@ jobs: stop-runners: if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} runs-on: [ self-hosted, scheduler ] - needs: [ create-runners, test-aarch64, test-cuda-118 ] + needs: [ create-runners, test-aarch64, test-cuda-121 ] steps: - name: Stop all instances run: | diff --git a/.github/workflows/no_response.yml b/.github/workflows/no_response.yml index 893ac1eac93..75c1c07ad54 100644 --- a/.github/workflows/no_response.yml +++ b/.github/workflows/no_response.yml @@ -11,6 +11,7 @@ on: jobs: noResponse: + if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest steps: - uses: lee-dohm/no-response@v0.5.0 diff --git a/.github/workflows/publish_android_packages.yml b/.github/workflows/publish_android_packages.yml index fea44330f94..4d7d3bfd3ed 100644 --- a/.github/workflows/publish_android_packages.yml +++ b/.github/workflows/publish_android_packages.yml @@ -12,14 +12,15 @@ on: jobs: release-android: + if: github.repository == 'deepjavalibrary/djl' runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: diff --git a/.github/workflows/serving_publish.yml b/.github/workflows/serving_publish.yml index 256a1c2eabb..e45f35ebead 100644 --- a/.github/workflows/serving_publish.yml +++ b/.github/workflows/serving_publish.yml @@ -28,11 +28,11 @@ jobs: with: repository: deepjavalibrary/djl-serving ref: ${{ github.event.inputs.serving-branch }} - - name: Set up JDK 11 - uses: actions/setup-java@v3 + - name: Set up JDK 17 + uses: actions/setup-java@v4 with: distribution: 'corretto' - java-version: 11 + java-version: 17 # Enable gradle cache: https://github.com/actions/cache/blob/master/examples.md#java---gradle - uses: actions/cache@v3 with: @@ -74,6 +74,19 @@ jobs: aws s3 cp benchmark/build/distributions/*.deb s3://djl-ai/publish/djl-bench/${DJL_VERSION}/ aws s3 cp benchmark/build/distributions/*.zip s3://djl-ai/publish/djl-bench/${DJL_VERSION}/ aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/djl-bench/${DJL_VERSION}/*" + - name: Copy awscurl snapshot artifacts to S3 + if: ${{ github.event.inputs.mode == '' || github.event.inputs.mode == 'snapshot' }} + run: | + ./gradlew :awscurl:jar + aws s3 cp awscurl/build/awscurl s3://djl-ai/publish/awscurl/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/awscurl/*" + - name: Copy awscurl staging artifacts to S3 + if: ${{ github.event.inputs.mode == 'staging' }} + run: | + ./gradlew :awscurl:jar + DJL_VERSION=$(cat gradle.properties | awk -F '=' '/djl_version/ {print $2}') + aws s3 cp awscurl/build/awscurl s3://djl-ai/publish/${DJL_VERSION}/awscurl/ + aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/awscurl/${DJL_VERSION}/*" - name: Publish to snapshot repository if: ${{ github.event.inputs.mode == '' || github.event.inputs.mode == 'snapshot' }} run: ./gradlew publish -Psnapshot --refresh-dependencies diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 5b627cfa60b..21634a08872 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,4 +1,5 @@ -## Code of Conduct +# Code of Conduct + This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact opensource-codeofconduct@amazon.com with any additional questions or comments. diff --git a/README.md b/README.md index 30975aec73a..1d96f097a7b 100644 --- a/README.md +++ b/README.md @@ -85,30 +85,13 @@ The following pseudocode demonstrates running training: ## Release Notes +* [0.26.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.26.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.26.0)) +* [0.25.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.25.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.25.0)) +* [0.24.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.24.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.24.0)) * [0.23.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.23.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.23.0)) -* [0.22.1](https://github.com/deepjavalibrary/djl/releases/tag/v0.22.1) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.22.1)) -* [0.21.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.21.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.21.0)) -* [0.20.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.20.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.20.0)) -* [0.19.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.19.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.19.0)) -* [0.18.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.18.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.18.0)) -* [0.17.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.17.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.17.0)) -* [0.16.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.16.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.16.0)) -* [0.15.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.15.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.15.0)) -* [0.14.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.14.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.14.0)) -* [0.13.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.13.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.13.0)) -* [0.12.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.12.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.12.0)) -* [0.11.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.11.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.11.0)) -* [0.10.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.10.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.10.0)) -* [0.9.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.9.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.9.0)) -* [0.8.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.8.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.8.0)) -* [0.6.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.6.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.6.0)) -* [0.5.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.5.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.5.0)) -* [0.4.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.4.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.4.0)) -* [0.3.0](https://github.com/deepjavalibrary/djl/releases/tag/v0.3.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.3.0)) -* [0.2.1](https://github.com/deepjavalibrary/djl/releases/tag/v0.2.1) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.2.1)) -* [0.2.0 Initial release](https://github.com/deepjavalibrary/djl/releases/tag/v0.2.0) ([Code](https://github.com/deepjavalibrary/djl/tree/v0.2.0)) - -The release of DJL 0.24.0 is planned for August or September 2023. +* [+23 releases](https://github.com/deepjavalibrary/djl/releases) + +The release of DJL 0.27.0 is planned for March 2024. ## Building From Source diff --git a/android/README.md b/android/README.md index 739cd86093b..b7be85388f7 100644 --- a/android/README.md +++ b/android/README.md @@ -16,7 +16,7 @@ In gradle, you can add the 5 modules in your dependencies: ```groovy dependencies { - implementation platform("ai.djl:bom:0.23.0") + implementation platform("ai.djl:bom:0.26.0") implementation "ai.djl:api" implementation "ai.djl.android:core" diff --git a/android/gradle.properties b/android/gradle.properties index 68ad6c12151..8ad177db1bf 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -17,5 +17,5 @@ org.gradle.jvmargs=-Xmx1536m android.useAndroidX=true # Automatically convert third-party libraries to use AndroidX android.enableJetifier=true -djl_version=0.23.0 +djl_version=0.25.0 pytorch_version=1.13.1 diff --git a/android/pytorch-native/README.md b/android/pytorch-native/README.md index 6a955a9e4ce..504acc93cb4 100644 --- a/android/pytorch-native/README.md +++ b/android/pytorch-native/README.md @@ -124,7 +124,7 @@ cd .. ./gradlew compileAndroidJNI -Ppt_version=${PYTORCH_VERSION} ``` -`jnilib/0.23.0/android` folder will be created after build, and shared library will be uploaded to S3 in CI build +`jnilib/0.26.0/android` folder will be created after build, and shared library will be uploaded to S3 in CI build ## Build PyTorch android library (.aar) and publish to Sonatype snapshot repo @@ -138,7 +138,7 @@ cd ../../../android # To avoid download jni from S3, manually copy them mkdir -p pytorch-native/jnilib -cp -r ../engines/pytorch/pytorch-native/jnilib/0.23.0/android/* pytorch-native/jnilib +cp -r ../engines/pytorch/pytorch-native/jnilib/0.26.0/android/* pytorch-native/jnilib ./gradlew :pytorch-native:assemble # publish to local maven repo (~/.m2 folder) diff --git a/api/README.md b/api/README.md index 85ad22f0188..8c5fe955125 100644 --- a/api/README.md +++ b/api/README.md @@ -35,7 +35,7 @@ You can pull the DJL API from the central Maven repository by including the foll ai.djl api - 0.23.0 + 0.26.0 ``` @@ -45,7 +45,7 @@ For testing the current nightly build, use the following: ai.djl api - 0.24.0-SNAPSHOT + 0.27.0-SNAPSHOT ``` diff --git a/api/src/main/java/ai/djl/Device.java b/api/src/main/java/ai/djl/Device.java index ce9b29ae5ba..597d7d9be02 100644 --- a/api/src/main/java/ai/djl/Device.java +++ b/api/src/main/java/ai/djl/Device.java @@ -14,11 +14,17 @@ import ai.djl.engine.Engine; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; /** * The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code @@ -30,7 +36,7 @@ * @see The D2L chapter * on GPU devices */ -public final class Device { +public class Device { private static final Map CACHE = new ConcurrentHashMap<>(); @@ -39,8 +45,8 @@ public final class Device { private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)"); - private String deviceType; - private int deviceId; + protected String deviceType; + protected int deviceId; /** * Creates a {@code Device} with basic information. @@ -101,6 +107,13 @@ public static Device fromName(String deviceName, Engine engine) { return engine.defaultDevice(); } + if (deviceName.contains("+")) { + String[] split = deviceName.split("\\+"); + List subDevices = + Arrays.stream(split).map(n -> fromName(n, engine)).collect(Collectors.toList()); + return new MultiDevice(subDevices); + } + Matcher matcher = DEVICE_NAME.matcher(deviceName); if (matcher.matches()) { String deviceType = matcher.group(1); @@ -150,6 +163,15 @@ public boolean isGpu() { return Type.GPU.equals(deviceType); } + /** + * Returns the sub devices if present (such as a {@link MultiDevice}), otherwise this. + * + * @return the sub devices if present (such as a {@link MultiDevice}), otherwise this. + */ + public List getDevices() { + return Collections.singletonList(this); + } + /** {@inheritDoc} */ @Override public String toString() { @@ -214,4 +236,88 @@ public interface Type { String CPU = "cpu"; String GPU = "gpu"; } + + /** A combined {@link Device} representing the composition of multiple other devices. */ + public static class MultiDevice extends Device { + + List devices; + + /** + * Constructs a {@link MultiDevice} with a range of new devices. + * + * @param deviceType the type of the sub-devices + * @param startInclusive the start (inclusive) of the devices range + * @param endExclusive the end (exclusive) of the devices range + */ + public MultiDevice(String deviceType, int startInclusive, int endExclusive) { + this( + IntStream.range(startInclusive, endExclusive) + .mapToObj(i -> Device.of(deviceType, i)) + .collect(Collectors.toList())); + } + + /** + * Constructs a {@link MultiDevice} from sub devices. + * + * @param devices the sub devices + */ + public MultiDevice(Device... devices) { + this(Arrays.asList(devices)); + } + + /** + * Constructs a {@link MultiDevice} from sub devices. + * + * @param devices the sub devices + */ + public MultiDevice(List devices) { + super(null, -1); + devices.sort( + Comparator.comparing(Device::getDeviceType, String.CASE_INSENSITIVE_ORDER) + .thenComparingInt(Device::getDeviceId)); + this.deviceType = + String.join( + "+", + (Iterable) + () -> + devices.stream() + .map(d -> d.getDeviceType() + d.getDeviceId()) + .iterator()); + this.devices = devices; + } + + /** {@inheritDoc} */ + @Override + public List getDevices() { + return devices; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + MultiDevice that = (MultiDevice) o; + return Objects.equals(devices, that.devices); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), devices); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return deviceType + "()"; + } + } } diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 8a1fc8871ac..a799c70f600 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -59,7 +59,7 @@ public abstract class Engine { private static final Map ALL_ENGINES = new ConcurrentHashMap<>(); - private static final String DEFAULT_ENGINE = initEngine(); + private static String defaultEngine = initEngine(); private static final Pattern PATTERN = Pattern.compile("KEY|TOKEN|PASSWORD", Pattern.CASE_INSENSITIVE); @@ -69,6 +69,10 @@ public abstract class Engine { private Integer seed; private static synchronized String initEngine() { + if (Boolean.parseBoolean(Utils.getenv("DJL_ENGINE_MANUAL_INIT"))) { + return null; + } + ServiceLoader loaders = ServiceLoader.load(EngineProvider.class); for (EngineProvider provider : loaders) { registerEngine(provider); @@ -80,21 +84,21 @@ private static synchronized String initEngine() { } String def = System.getProperty("ai.djl.default_engine"); - String defaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def); - if (defaultEngine == null || defaultEngine.isEmpty()) { + String newDefaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def); + if (newDefaultEngine == null || newDefaultEngine.isEmpty()) { int rank = Integer.MAX_VALUE; for (EngineProvider provider : ALL_ENGINES.values()) { if (provider.getEngineRank() < rank) { - defaultEngine = provider.getEngineName(); + newDefaultEngine = provider.getEngineName(); rank = provider.getEngineRank(); } } - } else if (!ALL_ENGINES.containsKey(defaultEngine)) { - throw new EngineException("Unknown default engine: " + defaultEngine); + } else if (!ALL_ENGINES.containsKey(newDefaultEngine)) { + throw new EngineException("Unknown default engine: " + newDefaultEngine); } - logger.debug("Found default engine: {}", defaultEngine); - Ec2Utils.callHome(defaultEngine); - return defaultEngine; + logger.debug("Found default engine: {}", newDefaultEngine); + Ec2Utils.callHome(newDefaultEngine); + return newDefaultEngine; } /** @@ -124,7 +128,7 @@ private static synchronized String initEngine() { * @return the default Engine name */ public static String getDefaultEngineName() { - return System.getProperty("ai.djl.default_engine", DEFAULT_ENGINE); + return System.getProperty("ai.djl.default_engine", defaultEngine); } /** @@ -134,7 +138,7 @@ public static String getDefaultEngineName() { * @see EngineProvider */ public static Engine getInstance() { - if (DEFAULT_ENGINE == null) { + if (defaultEngine == null) { throw new EngineException( "No deep learning engine found." + System.lineSeparator() @@ -163,7 +167,29 @@ public static boolean hasEngine(String engineName) { */ public static void registerEngine(EngineProvider provider) { logger.debug("Registering EngineProvider: {}", provider.getEngineName()); - ALL_ENGINES.putIfAbsent(provider.getEngineName(), provider); + ALL_ENGINES.put(provider.getEngineName(), provider); + } + + /** + * Returns the default engine. + * + * @return the default engine + */ + public static String getDefaultEngine() { + return defaultEngine; + } + + /** + * Sets the default engine returned by {@link #getInstance()}. + * + * @param engineName the new default engine's name + */ + public static void setDefaultEngine(String engineName) { + // Requires an engine to be loaded (without exception) before being the default + getEngine(engineName); + + logger.debug("Setting new default engine: {}", engineName); + defaultEngine = engineName; } /** @@ -187,7 +213,12 @@ public static Engine getEngine(String engineName) { if (provider == null) { throw new IllegalArgumentException("Deep learning engine not found: " + engineName); } - return provider.getEngine(); + Engine engine = provider.getEngine(); + if (engine == null) { + throw new IllegalStateException( + "The engine " + engineName + " was not able to initialize"); + } + return engine; } /** diff --git a/api/src/main/java/ai/djl/inference/Predictor.java b/api/src/main/java/ai/djl/inference/Predictor.java index 853b30d7a5e..d9b20e3ef9e 100644 --- a/api/src/main/java/ai/djl/inference/Predictor.java +++ b/api/src/main/java/ai/djl/inference/Predictor.java @@ -60,14 +60,13 @@ * * * * @param the input type diff --git a/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java b/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java index d83c4678f33..d5fdfda878b 100644 --- a/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java +++ b/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java @@ -14,13 +14,10 @@ import ai.djl.ndarray.BytesSupplier; -import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; /** @@ -29,16 +26,14 @@ */ public class PublisherBytesSupplier implements BytesSupplier { - private final List allData; - private final AtomicBoolean completed; private Consumer subscriber; - private final AtomicInteger dataPushed; + private CountDownLatch latch; + private CompletableFuture future; /** Constructs a {@link PublisherBytesSupplier}. */ public PublisherBytesSupplier() { - allData = new ArrayList<>(); - completed = new AtomicBoolean(); - dataPushed = new AtomicInteger(); + latch = new CountDownLatch(1); + future = new CompletableFuture<>(); } /** @@ -48,13 +43,24 @@ public PublisherBytesSupplier() { * @param lastChunk true if this is the last chunk */ public void appendContent(byte[] data, boolean lastChunk) { - synchronized (allData) { - allData.add(data); + if (subscriber == null) { + try { + if (!latch.await(2, TimeUnit.MINUTES)) { + throw new IllegalStateException("Wait for subscriber timeout."); + } + if (subscriber == null) { + // workaround Spotbugs + throw new IllegalStateException("subscriber is not set."); + } + } catch (InterruptedException e) { + throw new IllegalStateException("Append content interrupted.", e); + } } + subscriber.accept(data); if (lastChunk) { - completed.set(true); + subscriber.accept(null); + future.complete(null); } - pushData(); } /** @@ -62,69 +68,21 @@ public void appendContent(byte[] data, boolean lastChunk) { * * @param subscriber a consumer function that will receive bytes when new daata is added and * null when completed + * @return a {@code CompletableFuture} object */ - public void subscribe(Consumer subscriber) { + public CompletableFuture subscribe(Consumer subscriber) { if (this.subscriber != null) { throw new IllegalStateException( "The PublisherBytesSupplier only allows a single Subscriber"); } this.subscriber = subscriber; - pushData(); - } - - private void pushData() { - if (subscriber == null) { - return; - } - - int dataAvailable; - synchronized (allData) { - dataAvailable = allData.size(); - } - - int sent = dataPushed.getAndSet(dataAvailable); - if (sent < dataAvailable) { - synchronized (this) { - for (; sent < dataAvailable; sent++) { - subscriber.accept(allData.get(sent)); - } - if (completed.get()) { - subscriber.accept(null); - } - } - } - } - - /** Waits until completed before passing thread (BLOCKS THREAD!). */ - @SuppressWarnings("PMD.EmptyControlStatement") - public void waitToRead() { - // Block until complete!!! - while (!completed.get()) { - // Do nothing - } - } - - /** {@inheritDoc} */ - @Override - public byte[] getAsBytes() { - if (!completed.get()) { - throw new IllegalStateException( - "PublisherByteSupplier must be completely filled before reading."); - } - - try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { - for (byte[] data : allData) { - bos.write(data); - } - return bos.toByteArray(); - } catch (IOException e) { - throw new AssertionError("Failed to read BytesSupplier", e); - } + latch.countDown(); + return future; } /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { - return ByteBuffer.wrap(getAsBytes()); + throw new UnsupportedOperationException("Not supported."); } } diff --git a/api/src/main/java/ai/djl/modality/Classifications.java b/api/src/main/java/ai/djl/modality/Classifications.java index 84025ce07e1..070c0372a7a 100644 --- a/api/src/main/java/ai/djl/modality/Classifications.java +++ b/api/src/main/java/ai/djl/modality/Classifications.java @@ -53,7 +53,7 @@ public class Classifications implements JsonSerializable, Ensembleable probabilities; - private int topK; + protected int topK; /** * Constructs a {@code Classifications} using a parallel list of classNames and probabilities. @@ -88,10 +88,18 @@ public Classifications(List classNames, NDArray probabilities) { */ public Classifications(List classNames, NDArray probabilities, int topK) { this.classNames = classNames; - NDArray array = probabilities.toType(DataType.FLOAT64, false); - this.probabilities = - Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList()); - array.close(); + if (probabilities.getDataType() == DataType.FLOAT32) { + // Avoid converting float32 to float64 as this is not supported on MPS device + this.probabilities = new ArrayList<>(); + for (float prob : probabilities.toFloatArray()) { + this.probabilities.add((double) prob); + } + } else { + NDArray array = probabilities.toType(DataType.FLOAT64, false); + this.probabilities = + Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList()); + array.close(); + } this.topK = topK; } diff --git a/api/src/main/java/ai/djl/modality/Input.java b/api/src/main/java/ai/djl/modality/Input.java index ecd0679661b..45c6f8161f7 100644 --- a/api/src/main/java/ai/djl/modality/Input.java +++ b/api/src/main/java/ai/djl/modality/Input.java @@ -37,6 +37,7 @@ public class Input { protected Map properties; protected PairList content; + private boolean cancelled; /** Constructs a new {@code Input} instance. */ public Input() { @@ -44,6 +45,24 @@ public Input() { content = new PairList<>(); } + /** + * Returns {@code true} if the input is cancelled. + * + * @return {@code true} if the input is cancelled. + */ + public boolean isCancelled() { + return cancelled; + } + + /** + * Sets the cancelled status. + * + * @param cancelled the cancelled status + */ + public void setCancelled(boolean cancelled) { + this.cancelled = cancelled; + } + /** * Returns the properties of the input. * diff --git a/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java b/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java index c7c676fca01..c7d5414da28 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java +++ b/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java @@ -43,7 +43,7 @@ public class CategoryMask implements JsonSerializable { .registerTypeAdapter(CategoryMask.class, new SegmentationSerializer()) .create(); - private List classes; + private transient List classes; private int[][] mask; /** diff --git a/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java b/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java index 2fd90fe39ec..9d58575af59 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java +++ b/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java @@ -48,7 +48,7 @@ public DetectedObjects( List classNames, List probabilities, List boundingBoxes) { super(classNames, probabilities); this.boundingBoxes = boundingBoxes; - setTopK(Integer.MAX_VALUE); + this.topK = Integer.MAX_VALUE; } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java index a4ebfcb9df1..c31353766d3 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java @@ -160,7 +160,7 @@ protected double overlap(double x1, double w1, double x2, double w2) { return right - left; } - private DetectedObjects processFromBoxOutput(NDList list) { + protected DetectedObjects processFromBoxOutput(NDList list) { float[] flattened = list.get(0).toFloatArray(); ArrayList intermediateResults = new ArrayList<>(); int sizeClasses = classes.size(); @@ -280,7 +280,7 @@ public YoloV5Translator build() { } } - private static final class IntermediateResult { + protected static final class IntermediateResult { /** * A sortable score for how good the recognition is relative to others. Higher should be diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java new file mode 100644 index 00000000000..d47f7a4a14a --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java @@ -0,0 +1,124 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.ArgumentsUtil; + +import java.util.ArrayList; +import java.util.Map; + +/** + * A translator for YoloV8 models. This was tested with ONNX exported Yolo models. For details check + * here: https://github.com/ultralytics/ultralytics + */ +public class YoloV8Translator extends YoloV5Translator { + + private int maxBoxes; + + /** + * Constructs an ImageTranslator with the provided builder. + * + * @param builder the data to build with + */ + protected YoloV8Translator(Builder builder) { + super(builder); + maxBoxes = builder.maxBox; + } + + /** + * Creates a builder to build a {@code YoloV8Translator} with specified arguments. + * + * @param arguments arguments to specify builder options + * @return a new builder + */ + public static YoloV8Translator.Builder builder(Map arguments) { + YoloV8Translator.Builder builder = new YoloV8Translator.Builder(); + builder.configPreProcess(arguments); + builder.configPostProcess(arguments); + + return builder; + } + + /** {@inheritDoc} */ + @Override + protected DetectedObjects processFromBoxOutput(NDList list) { + NDArray rawResult = list.get(0); + NDArray reshapedResult = rawResult.transpose(); + Shape shape = reshapedResult.getShape(); + float[] buf = reshapedResult.toFloatArray(); + int numberRows = Math.toIntExact(shape.get(0)); + int nClasses = Math.toIntExact(shape.get(1)); + + ArrayList intermediateResults = new ArrayList<>(); + // reverse order search in heap; searches through #maxBoxes for optimization when set + for (int i = numberRows - 1; i > numberRows - maxBoxes; --i) { + int index = i * nClasses; + float maxClassProb = -1f; + int maxIndex = -1; + for (int c = 4; c < nClasses; c++) { + float classProb = buf[index + c]; + if (classProb > maxClassProb) { + maxClassProb = classProb; + maxIndex = c; + } + } + + if (maxClassProb > threshold) { + float xPos = buf[index]; // center x + float yPos = buf[index + 1]; // center y + float w = buf[index + 2]; + float h = buf[index + 3]; + Rectangle rect = + new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h); + intermediateResults.add( + new IntermediateResult( + classes.get(maxIndex), maxClassProb, maxIndex, rect)); + } + } + return nms(intermediateResults); + } + + /** The builder for {@link YoloV8Translator}. */ + public static class Builder extends YoloV5Translator.Builder { + + private int maxBox = 8400; + + /** + * Builds the translator. + * + * @return the new translator + */ + @Override + public YoloV8Translator build() { + if (pipeline == null) { + addTransform( + array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255)); + } + validate(); + return new YoloV8Translator(this); + } + + /** {@inheritDoc} */ + @Override + protected void configPostProcess(Map arguments) { + super.configPostProcess(arguments); + maxBox = ArgumentsUtil.intValue(arguments, "maxBox", 8400); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java new file mode 100644 index 00000000000..b5a4db00d28 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.translate.Translator; + +import java.io.Serializable; +import java.util.Map; + +/** A translatorFactory that creates a {@link YoloV8Translator} instance. */ +public class YoloV8TranslatorFactory extends ObjectDetectionTranslatorFactory + implements Serializable { + + private static final long serialVersionUID = 1L; + + /** {@inheritDoc} */ + @Override + protected Translator buildBaseTranslator( + Model model, Map arguments) { + return YoloV8Translator.builder(arguments).build(); + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/Decoder.java b/api/src/main/java/ai/djl/modality/nlp/Decoder.java index e8081666950..c422665b147 100644 --- a/api/src/main/java/ai/djl/modality/nlp/Decoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/Decoder.java @@ -42,6 +42,7 @@ public abstract class Decoder extends AbstractBlock { * @param block the block to be used to decode * @param version the version to use for parameter and metadata serialization */ + @SuppressWarnings("this-escape") public Decoder(byte version, Block block) { super(version); this.block = addChildBlock("Block", block); diff --git a/api/src/main/java/ai/djl/modality/nlp/Encoder.java b/api/src/main/java/ai/djl/modality/nlp/Encoder.java index 4c5a4469388..221626d7559 100644 --- a/api/src/main/java/ai/djl/modality/nlp/Encoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/Encoder.java @@ -40,6 +40,7 @@ public abstract class Encoder extends AbstractBlock { * @param version the version to use for parameter and metadata serialization * @param block the encoder block */ + @SuppressWarnings("this-escape") public Encoder(byte version, Block block) { super(version); this.block = addChildBlock("Block", block); diff --git a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java index 58cc67867c7..24abcb77bb8 100644 --- a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java @@ -46,6 +46,7 @@ public class EncoderDecoder extends AbstractBlock { * @param encoder the {@link Encoder} * @param decoder the {@link Decoder} */ + @SuppressWarnings("this-escape") public EncoderDecoder(Encoder encoder, Decoder decoder) { super(VERSION); this.encoder = addChildBlock("Encoder", encoder); diff --git a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java index af153cb0b23..a65e9cebb4f 100644 --- a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java +++ b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java @@ -38,6 +38,7 @@ public class TrainableTextEmbedding extends AbstractBlock implements TextEmbeddi * * @param wordEmbedding the word embedding to embed each word */ + @SuppressWarnings("this-escape") public TrainableTextEmbedding(TrainableWordEmbedding wordEmbedding) { this.trainableWordEmbedding = addChildBlock("trainableWordEmbedding", wordEmbedding); } diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java b/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java index 78f40c5b2f2..07c63428ff6 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java @@ -35,15 +35,13 @@ * policy is setting several thresholds. */ public abstract class SeqBatchScheduler { + private static final Logger logger = LoggerFactory.getLogger(SeqBatchScheduler.class); Predictor predictor; SeqBatcher seqBatcher; - NDManager manager; - SearchConfig config; - Map results; /** @@ -101,7 +99,7 @@ public boolean incrementForward(int count) throws TranslateException { * @return the output token ids * @throws TranslateException if forward fails */ - abstract NDArray inferenceCall() throws TranslateException; + protected abstract NDArray inferenceCall() throws TranslateException; /** * Adds new batch. diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java new file mode 100644 index 00000000000..e62167a34b2 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.translator; + +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.BytesSupplier; +import ai.djl.ndarray.NDList; +import ai.djl.translate.Batchifier; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.JsonUtils; +import ai.djl.util.PairList; +import ai.djl.util.StringPair; + +import com.google.gson.JsonElement; +import com.google.gson.JsonParseException; + +/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */ +public class CrossEncoderServingTranslator implements NoBatchifyTranslator { + + private Translator translator; + private Translator batchTranslator; + + /** + * Constructs a {@code CrossEncoderServingTranslator} instance. + * + * @param translator a {@code Translator} processes question answering input + */ + public CrossEncoderServingTranslator(Translator translator) { + this.translator = translator; + this.batchTranslator = translator.toBatchTranslator(); + } + + /** {@inheritDoc} */ + @Override + public void prepare(TranslatorContext ctx) throws Exception { + translator.prepare(ctx); + batchTranslator.prepare(ctx); + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, Input input) throws Exception { + PairList content = input.getContent(); + if (content.isEmpty()) { + throw new TranslateException("Input data is empty."); + } + + String contentType = input.getProperty("Content-Type", null); + StringPair pair; + if ("application/json".equals(contentType)) { + String json = input.getData().getAsString(); + try { + JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class); + if (element.isJsonArray()) { + ctx.setAttachment("batch", Boolean.TRUE); + StringPair[] inputs = JsonUtils.GSON.fromJson(json, StringPair[].class); + return batchTranslator.processInput(ctx, inputs); + } + + pair = JsonUtils.GSON.fromJson(json, StringPair.class); + if (pair.getKey() == null || pair.getValue() == null) { + throw new TranslateException("Missing key or value in json."); + } + } catch (JsonParseException e) { + throw new TranslateException("Input is not a valid json.", e); + } + } else { + String key = input.getAsString("key"); + String value = input.getAsString("value"); + if (key == null || value == null) { + throw new TranslateException("Missing key or value in input."); + } + pair = new StringPair(key, value); + } + + NDList ret = translator.processInput(ctx, pair); + Batchifier batchifier = translator.getBatchifier(); + if (batchifier != null) { + NDList[] batch = {ret}; + return batchifier.batchify(batch); + } + return ret; + } + + /** {@inheritDoc} */ + @Override + public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { + Output output = new Output(); + output.addProperty("Content-Type", "application/json"); + if (ctx.getAttachment("batch") != null) { + output.add(BytesSupplier.wrapAsJson(batchTranslator.processOutput(ctx, list))); + } else { + Batchifier batchifier = translator.getBatchifier(); + if (batchifier != null) { + list = batchifier.unbatchify(list)[0]; + } + output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list))); + } + return output; + } +} diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 29a57739aa3..c3df1ef3301 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -53,6 +53,7 @@ public abstract class BaseNDManager implements NDManager { protected AtomicBoolean closed = new AtomicBoolean(false); protected AtomicBoolean capped = new AtomicBoolean(false); + @SuppressWarnings("this-escape") protected BaseNDManager(NDManager parent, Device device) { this.parent = parent; this.device = device == null ? defaultDevice() : device; diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 385c32e88e3..2b4a9df095a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -2344,6 +2344,24 @@ default boolean allClose(NDArray other, double rtol, double atol, boolean equalN */ NDArray atan(); + /** + * Returns the element-wise arc-tangent of this/other choosing the quadrant correctly. + * + *

Examples + * + *

+     * jshell> NDArray x = manager.create(new float[] {0f, 1f});
+     * jshell> NDArray y = manager.create(new float[] {0f, -6f});
+     * jshell> x.atan2(y);
+     * ND: (2) cpu() float64
+     * [0.    , 2.9764]
+     * 
+ * + * @param other The other {@code NDArray} + * @return the result {@code NDArray} + */ + NDArray atan2(NDArray other); + /** * Returns the hyperbolic sine of this {@code NDArray} element-wise. * @@ -3375,6 +3393,48 @@ NDArray stft( boolean normalize, boolean returnComplex); + /** + * Computes the two-dimensional Discrete Fourier Transform. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @param axes Axes over which to compute the 2D-FFT. + * @return The truncated or zero-padded input, transformed along the axes. + */ + NDArray fft2(long[] sizes, long[] axes); + + /** + * Computes the two-dimensional Discrete Fourier Transform along the last 2 axes. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @return The truncated or zero-padded input, transformed along the last two axes + */ + default NDArray fft2(long[] sizes) { + return fft2(sizes, new long[] {-2, -1}); + } + + /** + * Computes the two-dimensional inverse Discrete Fourier Transform. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @param axes Axes over which to compute the 2D-Inverse-FFT. + * @return The truncated or zero-padded input, transformed along the axes. + */ + NDArray ifft2(long[] sizes, long[] axes); + + /** + * Computes the two-dimensional inverse Discrete Fourier Transform along the last 2 axes. + * + * @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to + * this size. + * @return The truncated or zero-padded input, transformed along the axes. + */ + default NDArray ifft2(long[] sizes) { + return ifft2(sizes, new long[] {-2, -1}); + } + /** * Reshapes this {@code NDArray} to the given {@link Shape}. * @@ -4922,6 +4982,22 @@ default NDArray countNonzero(int axis) { */ NDArray erfinv(); + /** + * Returns element-wise gauss error function of the {@code NDArray}. + * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
+     * jshell> array.erf();
+     * ND: (3) cpu() float32
+     * [0., 0.5, -1]
+     * 
+ * + * @return The gauss error of the {@code NDArray}, element-wise + */ + NDArray erf(); + /** {@inheritDoc} */ @Override default List getResourceNDArrays() { diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 59047e688c8..9a4ad8db93a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -726,6 +726,12 @@ public NDArray atan() { return getAlternativeArray().atan(); } + /** {@inheritDoc} */ + @Override + public NDArray atan2(NDArray other) { + return getAlternativeArray().atan2(other); + } + /** {@inheritDoc} */ @Override public NDArray sinh() { @@ -906,6 +912,18 @@ public NDArray stft( throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray reshape(Shape shape) { @@ -1188,6 +1206,12 @@ public NDArray erfinv() { return getAlternativeArray().erfinv(); } + /** {@inheritDoc} */ + @Override + public NDArray erf() { + return getAlternativeArray().erf(); + } + /** {@inheritDoc} */ @Override public NDArray inverse() { diff --git a/api/src/main/java/ai/djl/ndarray/NDArrays.java b/api/src/main/java/ai/djl/ndarray/NDArrays.java index 304b803939c..0e1c0922a7b 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrays.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrays.java @@ -1996,4 +1996,23 @@ public static NDArray logicalXor(NDArray a, NDArray b) { public static NDArray erfinv(NDArray input) { return input.erfinv(); } + + /** + * Returns element-wise gauss error function of the {@code NDArray}. + * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
+     * jshell> array.erf();
+     * ND: (3) cpu() float32
+     * [0., 0.5, -1]
+     * 
+ * + * @param input The input {@code NDArray} + * @return The gauss error of the {@code NDArray}, element-wise + */ + public static NDArray erf(NDArray input) { + return input.erf(); + } } diff --git a/api/src/main/java/ai/djl/ndarray/NDList.java b/api/src/main/java/ai/djl/ndarray/NDList.java index e48c243a3ec..f0069d3f3f3 100644 --- a/api/src/main/java/ai/djl/ndarray/NDList.java +++ b/api/src/main/java/ai/djl/ndarray/NDList.java @@ -100,12 +100,12 @@ public static NDList decode(NDManager manager, byte[] byteArray) { try { if (byteArray[0] == 'P' && byteArray[1] == 'K') { return decodeNumpy(manager, new ByteArrayInputStream(byteArray)); - } else if (byteArray[0] == (byte) 0x39 + } else if (byteArray[0] == (byte) 0x93 && byteArray[1] == 'N' && byteArray[2] == 'U' && byteArray[3] == 'M') { return new NDList( - NDSerializer.decode(manager, new ByteArrayInputStream(byteArray))); + NDSerializer.decodeNumpy(manager, new ByteArrayInputStream(byteArray))); } else if (byteArray[8] == '{') { return decodeSafetensors(manager, new ByteArrayInputStream(byteArray)); } @@ -144,11 +144,11 @@ public static NDList decode(NDManager manager, InputStream is) { if (magic[0] == 'P' && magic[1] == 'K') { // assume this is npz file return decodeNumpy(manager, pis); - } else if (magic[0] == (byte) 0x39 + } else if (magic[0] == (byte) 0x93 && magic[1] == 'N' && magic[2] == 'U' && magic[3] == 'M') { - return new NDList(NDSerializer.decode(manager, pis)); + return new NDList(NDSerializer.decodeNumpy(manager, pis)); } else if (magic[8] == '{') { return decodeSafetensors(manager, pis); } diff --git a/api/src/main/java/ai/djl/ndarray/NDScope.java b/api/src/main/java/ai/djl/ndarray/NDScope.java index 8b0deb23132..764c829c2e6 100644 --- a/api/src/main/java/ai/djl/ndarray/NDScope.java +++ b/api/src/main/java/ai/djl/ndarray/NDScope.java @@ -30,6 +30,7 @@ public class NDScope implements AutoCloseable { private IdentityHashMap resources; /** Constructs a new {@code NDScope} instance. */ + @SuppressWarnings("this-escape") public NDScope() { resources = new IdentityHashMap<>(); SCOPE_STACK.get().addLast(this); diff --git a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java index b12ac5dd07d..07e56a5ca04 100644 --- a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java +++ b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java @@ -434,7 +434,12 @@ default NDArray toTensor() { if (dim == 3) { result = result.expandDims(0); } - result = result.div(255.0).transpose(0, 3, 1, 2); + // For Apple Silicon MPS it is important not to switch to 64-bit float here + if (result.getDataType() == DataType.FLOAT32) { + result = result.div(255.0f).transpose(0, 3, 1, 2); + } else { + result = result.div(255.0).transpose(0, 3, 1, 2); + } if (dim == 3) { result = result.squeeze(0); } diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 3d58d501293..7ace6880c56 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -105,7 +105,7 @@ * further refine these elements, use {@link Block#freezeParameters(boolean)} to unfreeze them. * * @see this + * href="http://docs.djl.ai/docs/demos/jupyter/tutorial/01_create_your_first_network.html">this * tutorial on creating your first network * @see The * D2L chapter on blocks and pred) { + for (Parameter parameter : getParameters().values()) { + if (pred.test(parameter)) { + parameter.freeze(freeze); + } + } + } + /** * Validates that actual layout matches the expected layout. * diff --git a/api/src/main/java/ai/djl/nn/Blocks.java b/api/src/main/java/ai/djl/nn/Blocks.java index 47cd5843543..8abd0a19c91 100644 --- a/api/src/main/java/ai/djl/nn/Blocks.java +++ b/api/src/main/java/ai/djl/nn/Blocks.java @@ -33,12 +33,7 @@ private Blocks() {} * @return a {@link NDList} that contains the inflated {@link ai.djl.ndarray.NDArray} */ public static NDArray batchFlatten(NDArray array) { - long batch = array.size(0); - if (batch == 0) { - // calculate the size of second dimension manually as using -1 would not work here - return array.reshape(batch, array.getShape().slice(1).size()); - } - return array.reshape(batch, -1); + return array.reshape(-1, array.getShape().slice(1).size()); } /** diff --git a/api/src/main/java/ai/djl/nn/ParallelBlock.java b/api/src/main/java/ai/djl/nn/ParallelBlock.java index 4ebe1e8119b..269e52b6b22 100644 --- a/api/src/main/java/ai/djl/nn/ParallelBlock.java +++ b/api/src/main/java/ai/djl/nn/ParallelBlock.java @@ -62,6 +62,7 @@ public ParallelBlock(Function, NDList> function) { * @param function the function to define how the parallel branches are combined * @param blocks the blocks that form each of the parallel branches */ + @SuppressWarnings("this-escape") public ParallelBlock(Function, NDList> function, List blocks) { super(VERSION); this.function = function; @@ -74,6 +75,7 @@ public ParallelBlock(Function, NDList> function, List blocks * @param blocks the array of blocks to add * @return this block */ + @SuppressWarnings("this-escape") public final ParallelBlock addAll(Block... blocks) { return addAll(Arrays.asList(blocks)); } diff --git a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java index f862ee13274..a049c20e2b7 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java @@ -89,6 +89,7 @@ public abstract class Convolution extends AbstractBlock { * * @param builder the {@code Builder} that has the necessary configurations */ + @SuppressWarnings("this-escape") public Convolution(ConvolutionBuilder builder) { super(VERSION); kernelShape = builder.kernelShape; diff --git a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java index 667de724e2a..419780a98d1 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java @@ -62,6 +62,7 @@ public abstract class Deconvolution extends AbstractBlock { * * @param builder the {@code Builder} that has the necessary configurations */ + @SuppressWarnings("this-escape") public Deconvolution(DeconvolutionBuilder builder) { kernelShape = builder.kernelShape; stride = builder.stride; diff --git a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java index d2e0acf8e46..c1c27f57935 100644 --- a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java +++ b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java @@ -38,6 +38,7 @@ public class ConstantEmbedding extends AbstractBlock implements AbstractIndexedE * * @param embedding the value to return for all embeddings */ + @SuppressWarnings("this-escape") public ConstantEmbedding(NDArray embedding) { this.embedding = embedding; freezeParameters(true); diff --git a/api/src/main/java/ai/djl/nn/core/Embedding.java b/api/src/main/java/ai/djl/nn/core/Embedding.java index d6a937fe9a0..ab6167ced2f 100644 --- a/api/src/main/java/ai/djl/nn/core/Embedding.java +++ b/api/src/main/java/ai/djl/nn/core/Embedding.java @@ -49,6 +49,7 @@ public abstract class Embedding extends AbstractBlock implements AbstractInde protected Parameter embedding; + @SuppressWarnings("this-escape") protected Embedding(BaseBuilder baseBuilder) { super(VERSION); embeddingSize = baseBuilder.embeddingSize; @@ -91,6 +92,7 @@ protected Embedding(NDArray embedding) { * @param embedding the embedding array * @param format whether to compute row sparse gradient in the backward calculation */ + @SuppressWarnings("this-escape") protected Embedding(NDArray embedding, SparseFormat format) { super(VERSION); numEmbeddings = Math.toIntExact(embedding.getShape().get(0)); diff --git a/api/src/main/java/ai/djl/nn/core/Linear.java b/api/src/main/java/ai/djl/nn/core/Linear.java index 530344a8858..d10c0a91eb8 100644 --- a/api/src/main/java/ai/djl/nn/core/Linear.java +++ b/api/src/main/java/ai/djl/nn/core/Linear.java @@ -62,6 +62,7 @@ public class Linear extends AbstractBlock { private Parameter weight; private Parameter bias; + @SuppressWarnings("this-escape") protected Linear(Builder builder) { super(VERSION); units = builder.units; diff --git a/api/src/main/java/ai/djl/nn/core/Prelu.java b/api/src/main/java/ai/djl/nn/core/Prelu.java index 8fcb9971330..e70d06a448b 100644 --- a/api/src/main/java/ai/djl/nn/core/Prelu.java +++ b/api/src/main/java/ai/djl/nn/core/Prelu.java @@ -41,6 +41,7 @@ public class Prelu extends AbstractBlock { private Parameter alpha; /** Creates a Parametric ReLU Block. */ + @SuppressWarnings("this-escape") public Prelu() { super(VERSION); alpha = diff --git a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java index 5d69284132e..42ab1036aa8 100644 --- a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java @@ -66,6 +66,7 @@ public class LayerNorm extends AbstractBlock { protected Parameter gamma; protected Parameter beta; + @SuppressWarnings("this-escape") protected LayerNorm(Builder builder) { epsilon = builder.epsilon; scale = builder.scale; diff --git a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java index 3c9bb3f89d7..981e4954e7c 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java +++ b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java @@ -58,6 +58,7 @@ public abstract class RecurrentBlock extends AbstractBlock { * * @param builder the {@code Builder} that has the necessary configurations */ + @SuppressWarnings("this-escape") public RecurrentBlock(BaseBuilder builder) { super(VERSION); stateSize = builder.stateSize; diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java index a0b49b9430d..cb02a2f4074 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java @@ -46,6 +46,7 @@ public class BertMaskedLanguageModelBlock extends AbstractBlock { * @param bertBlock the bert block to create the task for * @param hiddenActivation the activation to use for the hidden layer */ + @SuppressWarnings("this-escape") public BertMaskedLanguageModelBlock( BertBlock bertBlock, Function hiddenActivation) { super(VERSION); diff --git a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java index 549d05b629e..4c3bbdb55b8 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java @@ -29,6 +29,7 @@ public class BertNextSentenceBlock extends AbstractBlock { private Linear binaryClassifier; /** Creates a next sentence block. */ + @SuppressWarnings("this-escape") public BertNextSentenceBlock() { binaryClassifier = addChildBlock( diff --git a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java index d196ace2782..8d9cec6c01e 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java @@ -36,6 +36,7 @@ public class BertPretrainingBlock extends AbstractBlock { * * @param builder a builder with a bert configuration */ + @SuppressWarnings("this-escape") public BertPretrainingBlock(final BertBlock.Builder builder) { this.bertBlock = addChildBlock("Bert", builder.build()); this.mlBlock = diff --git a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java index 3b530808bdf..451709c3f74 100644 --- a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java @@ -31,6 +31,7 @@ public class PointwiseFeedForwardBlock extends SequentialBlock { * @param activationFunction the activation function to use for the hidden layers (not applied * to output) */ + @SuppressWarnings("this-escape") public PointwiseFeedForwardBlock( List hiddenSizes, int outputSize, diff --git a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java index bc251d42e86..f01cb1adc33 100644 --- a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java @@ -51,6 +51,7 @@ public class TransformerEncoderBlock extends AbstractBlock { * @param dropoutProbability dropout probability * @param activationFunction activation function */ + @SuppressWarnings("this-escape") public TransformerEncoderBlock( int embeddingSize, int headCount, diff --git a/api/src/main/java/ai/djl/repository/AbstractRepository.java b/api/src/main/java/ai/djl/repository/AbstractRepository.java index 3b83c359aad..c28a3b16887 100644 --- a/api/src/main/java/ai/djl/repository/AbstractRepository.java +++ b/api/src/main/java/ai/djl/repository/AbstractRepository.java @@ -14,13 +14,10 @@ import ai.djl.util.Hex; import ai.djl.util.Progress; +import ai.djl.util.TarUtils; import ai.djl.util.Utils; import ai.djl.util.ZipUtils; -import org.apache.commons.compress.archivers.tar.TarArchiveEntry; -import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; -import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; -import org.apache.commons.compress.utils.CloseShieldFilterInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -212,9 +209,9 @@ protected void save(InputStream is, Path tmp, Artifact.Item item, Progress progr if ("zip".equals(extension)) { ZipUtils.unzip(pis, dir); } else if ("tgz".equals(extension)) { - untar(pis, dir, true); + TarUtils.untar(pis, dir, true); } else if ("tar".equals(extension)) { - untar(pis, dir, false); + TarUtils.untar(pis, dir, false); } else { throw new IOException("File type is not supported: " + extension); } @@ -233,36 +230,6 @@ protected void save(InputStream is, Path tmp, Artifact.Item item, Progress progr pis.validateChecksum(item); } - private void untar(InputStream is, Path dir, boolean gzip) throws IOException { - InputStream bis; - if (gzip) { - bis = new GzipCompressorInputStream(new BufferedInputStream(is)); - } else { - bis = new BufferedInputStream(is); - } - bis = new CloseShieldFilterInputStream(bis); - try (TarArchiveInputStream tis = new TarArchiveInputStream(bis)) { - TarArchiveEntry entry; - while ((entry = tis.getNextTarEntry()) != null) { - String entryName = entry.getName(); - if (entryName.contains("..")) { - throw new IOException("Malicious zip entry: " + entryName); - } - Path file = dir.resolve(entryName).toAbsolutePath(); - if (entry.isDirectory()) { - Files.createDirectories(file); - } else { - Path parentFile = file.getParent(); - if (parentFile == null) { - throw new AssertionError("Parent path should never be null: " + file); - } - Files.createDirectories(parentFile); - Files.copy(tis, file, StandardCopyOption.REPLACE_EXISTING); - } - } - } - } - private static Map parseQueryString(URI uri) { try { Map map = new ConcurrentHashMap<>(); diff --git a/api/src/main/java/ai/djl/repository/RemoteRepository.java b/api/src/main/java/ai/djl/repository/RemoteRepository.java index 6b01ce14ef8..52f87afbf82 100644 --- a/api/src/main/java/ai/djl/repository/RemoteRepository.java +++ b/api/src/main/java/ai/djl/repository/RemoteRepository.java @@ -75,7 +75,7 @@ public Metadata locate(MRL mrl) throws IOException { Metadata metadata = JsonUtils.GSON_PRETTY.fromJson(reader, Metadata.class); metadata.init(arguments); Date lastUpdated = metadata.getLastUpdated(); - if (Boolean.getBoolean("offline") + if (Utils.isOfflineMode() || System.currentTimeMillis() - lastUpdated.getTime() < ONE_DAY) { metadata.setRepositoryUri(mrlUri); return metadata; diff --git a/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java b/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java index 71f394d6d14..a730a57cb54 100644 --- a/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java +++ b/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java @@ -160,11 +160,10 @@ public Repository newInstance(String name, URI uri) { Path path = Paths.get(parseFilePath(uri)); String fileName = path.toFile().getName(); - if (!FilenameUtils.isArchiveFile(fileName)) { - throw new IllegalArgumentException("Only archive file is supported for res URL."); + if (FilenameUtils.isArchiveFile(fileName)) { + fileName = FilenameUtils.getNamePart(fileName); } - fileName = FilenameUtils.getNamePart(fileName); return new JarRepository(name, uri, fileName, queryString); } diff --git a/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java b/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java index 527871067fa..676bab73d75 100644 --- a/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java +++ b/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java @@ -29,6 +29,7 @@ public class DefaultModelZoo extends ModelZoo { private static final Logger logger = LoggerFactory.getLogger(DefaultModelZoo.class); /** Constructs a new {@code LocalModelZoo} instance. */ + @SuppressWarnings("this-escape") public DefaultModelZoo() { String locations = System.getProperty("ai.djl.repository.zoo.location"); if (locations != null) { @@ -41,6 +42,7 @@ public DefaultModelZoo() { * * @param locations a comma separated urls where the models to be loaded from */ + @SuppressWarnings("this-escape") public DefaultModelZoo(String locations) { parseLocation(locations); } diff --git a/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java b/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java index 50b219be509..e903a1677b3 100644 --- a/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java +++ b/api/src/main/java/ai/djl/repository/zoo/ModelZoo.java @@ -32,6 +32,7 @@ public abstract class ModelZoo { private static final Map MODEL_ZOO_MAP = new ConcurrentHashMap<>(); + private static ModelZooResolver resolver; private Map modelLoaders = new ConcurrentHashMap<>(); @@ -86,6 +87,15 @@ protected final void addModel(ModelLoader loader) { modelLoaders.put(loader.getArtifactId(), loader); } + /** + * Sets the {@code ModelZooResolver}. + * + * @param resolver the {@code ModelZooResolver} + */ + public static void setModelZooResolver(ModelZooResolver resolver) { + ModelZoo.resolver = resolver; + } + /** * Refreshes model zoo. * @@ -112,7 +122,14 @@ public static Collection listModelZoo() { * @return the {@code ModelZoo} with the {@code groupId} */ public static ModelZoo getModelZoo(String groupId) { - return MODEL_ZOO_MAP.get(groupId); + ModelZoo zoo = MODEL_ZOO_MAP.get(groupId); + if (zoo == null && resolver != null) { + zoo = resolver.resolve(groupId); + if (zoo != null) { + MODEL_ZOO_MAP.putIfAbsent(groupId, zoo); + } + } + return zoo; } /** diff --git a/api/src/main/java/ai/djl/repository/zoo/ModelZooResolver.java b/api/src/main/java/ai/djl/repository/zoo/ModelZooResolver.java new file mode 100644 index 00000000000..897e122f191 --- /dev/null +++ b/api/src/main/java/ai/djl/repository/zoo/ModelZooResolver.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.repository.zoo; + +/** An interface that resolves external ModelZoo. */ +public interface ModelZooResolver { + + /** + * Returns {@link ModelZoo} based on model zoo group ID. + * + * @param groupId the model zoo group ID. + * @return the resolved {@code ModelZoo} + */ + ModelZoo resolve(String groupId); +} diff --git a/api/src/main/java/ai/djl/training/EasyTrain.java b/api/src/main/java/ai/djl/training/EasyTrain.java index 0af691b8755..6872231ad34 100644 --- a/api/src/main/java/ai/djl/training/EasyTrain.java +++ b/api/src/main/java/ai/djl/training/EasyTrain.java @@ -127,6 +127,8 @@ private static boolean trainSplit( time = System.nanoTime(); batchData.getLabels().put(labels.get(0).getDevice(), labels); batchData.getPredictions().put(preds.get(0).getDevice(), preds); + batchData.getData().put(preds.get(0).getDevice(), data); + batchData.getLoss().put(preds.get(0).getDevice(), lossValue); trainer.addMetric("training-metrics", time); return true; } diff --git a/api/src/main/java/ai/djl/training/ParameterStore.java b/api/src/main/java/ai/djl/training/ParameterStore.java index 7029282c46e..15c83bde8ca 100644 --- a/api/src/main/java/ai/djl/training/ParameterStore.java +++ b/api/src/main/java/ai/djl/training/ParameterStore.java @@ -14,6 +14,7 @@ package ai.djl.training; import ai.djl.Device; +import ai.djl.Device.MultiDevice; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.nn.Parameter; @@ -64,6 +65,10 @@ public void setParameterServer(ParameterServer parameterServer, Device[] devices this.parameterServer = parameterServer; deviceMap.clear(); for (int i = 0; i < devices.length; ++i) { + if (devices[i] instanceof MultiDevice) { + throw new IllegalArgumentException( + "The parameter store does not support MultiDevices"); + } if (deviceMap.put(devices[i], i) != null) { throw new IllegalArgumentException("Duplicated devices are not allowed."); } diff --git a/api/src/main/java/ai/djl/training/Trainer.java b/api/src/main/java/ai/djl/training/Trainer.java index eab6ba07f2a..6d79dde3eec 100644 --- a/api/src/main/java/ai/djl/training/Trainer.java +++ b/api/src/main/java/ai/djl/training/Trainer.java @@ -52,14 +52,12 @@ * * * * @see The guide on memory @@ -88,6 +86,7 @@ public class Trainer implements AutoCloseable { * @param model the model the trainer will train on * @param trainingConfig the configuration used by the trainer */ + @SuppressWarnings("this-escape") public Trainer(Model model, TrainingConfig trainingConfig) { this.model = model; manager = model.getNDManager().newSubManager(); diff --git a/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java b/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java index c9a5fdf7036..8610f9e92bb 100644 --- a/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java +++ b/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java @@ -77,9 +77,22 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { Pair update = accuracyHelper(labels, predictions); - totalInstances.compute(key, (k, v) -> v + update.getKey()); - correctInstances.compute(key, (k, v) -> v + update.getValue().sum().getLong()); + NDArray value = update.getValue(); + NDArray sum = value.sum(); + long correct = sum.getLong(); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + update.getKey()); + correctInstances.compute(key, (k, v) -> v + correct); + } + value.close(); + sum.close(); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java b/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java index 4af9e5de3d1..ab2d554142d 100644 --- a/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java +++ b/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java @@ -63,10 +63,18 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { NDArray boundingBoxError = evaluate(labels, predictions); float update = boundingBoxError.sum().getFloat(); - totalInstances.compute(key, (k, v) -> v + boundingBoxError.size()); - ssdBoxPredictionError.compute(key, (k, v) -> v + update); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + boundingBoxError.size()); + ssdBoxPredictionError.compute(key, (k, v) -> v + update); + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/training/evaluator/Evaluator.java b/api/src/main/java/ai/djl/training/evaluator/Evaluator.java index 6d2c5995601..c373471f6cf 100644 --- a/api/src/main/java/ai/djl/training/evaluator/Evaluator.java +++ b/api/src/main/java/ai/djl/training/evaluator/Evaluator.java @@ -74,6 +74,25 @@ public String getName() { */ public abstract void addAccumulator(String key); + /** + * Updates the evaluator with the given keys based on a {@link NDList} of labels and + * predictions. + * + *

This is a synchronized operation. You should only call it at the end of a batch or epoch. + * + *

This is an alternative to @{link {@link #updateAccumulator(String, NDList, NDList)}} that + * may be more efficient when updating multiple accumulators at once. + * + * @param keys the keys of all the accumulators to update + * @param labels a {@code NDList} of labels + * @param predictions a {@code NDList} of predictions + */ + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { + for (String key : keys) { + updateAccumulator(key, labels, predictions); + } + } + /** * Updates the evaluator with the given key based on a {@link NDList} of labels and predictions. * diff --git a/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java b/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java index a7fe08b610e..aa12cae628c 100644 --- a/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java +++ b/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java @@ -67,6 +67,12 @@ public void updateAccumulator(String key, NDList labels, NDList predictions) { evaluator.updateAccumulator(key, getLabels(labels), getPredictions(predictions)); } + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { + evaluator.updateAccumulators(keys, getLabels(labels), getPredictions(predictions)); + } + /** {@inheritDoc} */ @Override public void resetAccumulator(String key) { diff --git a/api/src/main/java/ai/djl/training/listener/AlgebraicListener.java b/api/src/main/java/ai/djl/training/listener/AlgebraicListener.java new file mode 100644 index 00000000000..51b5288e838 --- /dev/null +++ b/api/src/main/java/ai/djl/training/listener/AlgebraicListener.java @@ -0,0 +1,288 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.training.listener; + +import ai.djl.Device; +import ai.djl.Model; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; +import ai.djl.training.Trainer; +import ai.djl.util.NativeResource; +import ai.djl.util.Pair; +import ai.djl.util.PairList; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** {@link TrainingListener} that records algebraic operations as Python code. */ +public class AlgebraicListener extends TrainingListenerAdapter { + + private static AlgebraicListener currentListener; + + private static final Logger logger = LoggerFactory.getLogger(AlgebraicListener.class); + + private final Map nodeMap = new ConcurrentHashMap<>(); + private final Map nodeMapForParameters = new ConcurrentHashMap<>(); + + @SuppressWarnings("PMD.UseConcurrentHashMap") + private final Map losses = new LinkedHashMap<>(); + + @SuppressWarnings("PMD.UseConcurrentHashMap") + private final Map predictions = new LinkedHashMap<>(); + + private Map parameters; + private String outputFile; + private AtomicInteger parametersOpCount = new AtomicInteger(0); + + private int numEpoch; + + /** + * New listener to record algebraic operations into the given file. + * + * @param outputFile file to store output - will be overridden if exist + */ + public AlgebraicListener(String outputFile) { + this.outputFile = outputFile; + } + + /** {@inheritDoc} */ + @Override + public void onEpoch(Trainer trainer) { + numEpoch++; + } + + /** {@inheritDoc} */ + @Override + public void onTrainingBatch(Trainer trainer, BatchData batchData) { + writeParameters(trainer.getModel()); + AtomicInteger opCount = new AtomicInteger(parametersOpCount.get()); + for (Device device : batchData.getLabels().keySet()) { + NDList data = batchData.getData().get(device); + NDList preds = batchData.getPredictions().get(device); + NDList labels = batchData.getLabels().get(device); + NDArray loss = batchData.getLoss().get(device); + if (data != null) { + setLeaf(data, "x"); + } + if (preds != null) { + writePredictions(preds, opCount); + } + if (preds != null) { + setLeaf(preds, "prediction"); + } + if (labels != null) { + setLeaf(labels, "label"); + } + if (loss != null) { + writeLoss(loss, opCount); + } + } + nodeMap.clear(); + nodeMap.putAll(nodeMapForParameters); + } + + /** {@inheritDoc} */ + @Override + public void onTrainingBegin(Trainer trainer) { + setCurrentListener(this); + } + + /** {@inheritDoc} */ + @Override + public void onTrainingEnd(Trainer trainer) { + try (OutputStream out = Files.newOutputStream(Paths.get(outputFile))) { + describe(out); + } catch (IOException e) { + logger.error("Failed logging algebraic operations", e); + } + parameters.clear(); + predictions.clear(); + losses.clear(); + nodeMap.clear(); + nodeMapForParameters.clear(); + setCurrentListener(null); + } + + private void setLeaf(NDArray x, String name) { + Node node = get(x); + if (node == null) { + return; + } + node.name = name; + node.isLeaf = true; + } + + private void setLeaf(NDList data, String name) { + for (NDArray x : data) { + setLeaf(x, name); + } + } + + private void writePredictions(NDList preds, AtomicInteger opCount) { + String tuple = + preds.stream() + .map(this::getArrayName) + .collect(Collectors.joining(", ", "return tf.tuple([", "])")); + if (preds.size() == 1) { + tuple = "return result"; + } + String python = + preds.stream() + .map(pred -> get(pred).toPythonFunctionBody(opCount, getArrayName(pred))) + .collect(Collectors.joining("\n", "", "\n" + Node.indent(tuple))); + predictions.compute(python, (key, count) -> count == null ? 1 : count + 1); + } + + private String getArrayName(NDArray pred) { + return pred.getName() != null ? pred.getName() : "result"; + } + + private void writeLoss(NDArray loss, AtomicInteger opCount) { + String python = + get(loss).toPythonFunctionBody(opCount, "result") + + "\n" + + Node.indent("return result"); + losses.compute(python, (key, count) -> count == null ? 1 : count + 1); + } + + private void describe(OutputStream out) throws IOException { + PrintStream writer = new PrintStream(out, true, StandardCharsets.US_ASCII.name()); + writer.println("class MyModel(tf.keras.Model):"); + writer.println(" def __init__(self, **kwargs):"); + writer.println(" super().__init__(**kwargs)"); + for (Entry param : parameters.entrySet()) { + writer.println(Node.indent(param.getKey() + " = tf.Variable(")); + writer.println(Node.indent(Node.indent(param.getValue()))); + writer.println(Node.indent(")")); + } + writer.println(""); + for (Entry pred : predictions.entrySet()) { + writer.println("## " + pred.getValue()); + writer.println(" def call(self, x):"); + writer.println(pred.getKey()); + } + writer.println(""); + for (Entry loss : losses.entrySet()) { + writer.println("## " + loss.getValue()); + writer.println("def loss(label, prediction):"); + writer.println(loss.getKey()); + } + writer.println(""); + writer.println(String.format("# number of epochs was %s", numEpoch)); + writer.println(String.format("# number of prediction functions is %s", predictions.size())); + writer.println(String.format("# number of loss functions is %s", losses.size())); + writer.println(""); + } + + private void writeParameters(Model model) { + if (parameters != null) { + return; + } + parameters = new LinkedHashMap<>(); + for (Pair pair : model.getBlock().getParameters()) { + NDArray array = pair.getValue().getArray(); + + Node init = get(array); + String initialization; + if (pair.getKey().endsWith("Conv2d_weight")) { + int[] perm = {2, 3, 1, 0}; + PairList param = + new PairList<>(Collections.singletonMap("axes", Arrays.toString(perm))); + Node transpose = new Node("_np_transpose", param, init); + transpose.outputShape = + new Shape(IntStream.of(perm).mapToLong(init.outputShape::get).toArray()); + initialization = transpose.toPythonExpression(null, parametersOpCount); + init.outputShape = transpose.outputShape; + } else { + initialization = + init.toPythonExpression(null, parametersOpCount) + + (pair.getValue().requiresGradient() + ? "" + : "\n, trainable = False"); + } + String pythonClassVariable = "self._" + pair.getKey(); + parameters.put(pythonClassVariable, initialization); + setLeaf(array, pythonClassVariable); + nodeMapForParameters.put(key(array), init); + } + } + + /** + * Records an algebraic operation that is executed with the given parameters. + * + * @param name the name of the operation + * @param src the input to the operation + * @param dest the output of the operation + * @param param parameters for the operation + */ + public static void record( + String name, NDArray[] src, NDArray[] dest, PairList param) { + if (currentListener != null) { + currentListener.recordInternal(name, src, dest, param); + } + } + + private void recordInternal( + String name, NDArray[] src, NDArray[] dest, PairList param) { + Node n = new Node(name, param != null ? param : new PairList<>(), new Node[src.length]); + int index = 0; + for (NDArray array : src) { + Node node = get(array); + if (node == null) { + node = + new Node( + array.getName() != null + ? array.getName() + : "UNKNOWN_ARRAY" + array.getShape(), + new PairList<>()); + nodeMap.put(key(array), n); + node.outputShape = array.getShape(); + } + n.src[index++] = node; + } + for (NDArray array : dest) { + nodeMap.put(key(array), n); + n.outputShape = array.getShape(); + } + } + + private Node get(NDArray array) { + return nodeMap.get(key(array)); + } + + private Object key(NDArray array) { + return ((NativeResource) array).getHandle(); + } + + private static void setCurrentListener(AlgebraicListener algebraicListener) { + currentListener = algebraicListener; + } +} diff --git a/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java new file mode 100644 index 00000000000..6c013c37715 --- /dev/null +++ b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java @@ -0,0 +1,281 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.training.listener; + +import ai.djl.training.Trainer; +import ai.djl.training.TrainingResult; + +import java.time.Duration; + +/** + * Listener that allows the training to be stopped early if the validation loss is not improving, or + * if time has expired.
+ * + *

Usage: Add this listener to the training config, and add it as the last one. + * + *

+ *  new DefaultTrainingConfig(...)
+ *        .addTrainingListeners(EarlyStoppingListener.builder()
+ *                .setEpochPatience(1)
+ *                .setEarlyStopPctImprovement(1)
+ *                .setMaxDuration(Duration.ofMinutes(42))
+ *                .setMinEpochs(1)
+ *                .build()
+ *        );
+ * 
+ * + *

Then surround the fit with a try catch that catches the {@link + * EarlyStoppingListener.EarlyStoppedException}.
+ * Example: + * + *

+ * try {
+ *   EasyTrain.fit(trainer, 5, trainDataset, testDataset);
+ * } catch (EarlyStoppingListener.EarlyStoppedException e) {
+ *   // handle early stopping
+ *   log.info("Stopped early at epoch {} because: {}", e.getEpoch(), e.getMessage());
+ * }
+ * 
+ * + *
+ * Note: Ensure that Metrics are set on the trainer. + */ +public final class EarlyStoppingListener implements TrainingListener { + private final double objectiveSuccess; + + private final int minEpochs; + private final long maxMillis; + private final double earlyStopPctImprovement; + private final int epochPatience; + + private long startTimeMills; + private double prevLoss; + private int numberOfEpochsWithoutImprovements; + + private EarlyStoppingListener( + double objectiveSuccess, + int minEpochs, + long maxMillis, + double earlyStopPctImprovement, + int earlyStopPatience) { + this.objectiveSuccess = objectiveSuccess; + this.minEpochs = minEpochs; + this.maxMillis = maxMillis; + this.earlyStopPctImprovement = earlyStopPctImprovement; + this.epochPatience = earlyStopPatience; + } + + /** {@inheritDoc} */ + @Override + public void onEpoch(Trainer trainer) { + int currentEpoch = trainer.getTrainingResult().getEpoch(); + // stopping criteria + final double loss = getLoss(trainer.getTrainingResult()); + if (currentEpoch >= minEpochs) { + if (loss < objectiveSuccess) { + throw new EarlyStoppedException( + currentEpoch, + String.format( + "validation loss %s < objectiveSuccess %s", + loss, objectiveSuccess)); + } + long elapsedMillis = System.currentTimeMillis() - startTimeMills; + if (elapsedMillis >= maxMillis) { + throw new EarlyStoppedException( + currentEpoch, + String.format("%s ms elapsed >= %s maxMillis", elapsedMillis, maxMillis)); + } + // consider early stopping? + if (Double.isFinite(prevLoss)) { + double goalImprovement = prevLoss * (100 - earlyStopPctImprovement) / 100.0; + boolean improved = loss <= goalImprovement; // false if any NANs + if (improved) { + numberOfEpochsWithoutImprovements = 0; + } else { + numberOfEpochsWithoutImprovements++; + if (numberOfEpochsWithoutImprovements >= epochPatience) { + throw new EarlyStoppedException( + currentEpoch, + String.format( + "failed to achieve %s%% improvement %s times in a row", + earlyStopPctImprovement, epochPatience)); + } + } + } + } + if (Double.isFinite(loss)) { + prevLoss = loss; + } + } + + private static double getLoss(TrainingResult trainingResult) { + Float vLoss = trainingResult.getValidateLoss(); + if (vLoss != null) { + return vLoss; + } + Float tLoss = trainingResult.getTrainLoss(); + if (tLoss == null) { + return Double.NaN; + } + return tLoss; + } + + /** {@inheritDoc} */ + @Override + public void onTrainingBatch(Trainer trainer, BatchData batchData) { + // do nothing + } + + /** {@inheritDoc} */ + @Override + public void onValidationBatch(Trainer trainer, BatchData batchData) { + // do nothing + } + + /** {@inheritDoc} */ + @Override + public void onTrainingBegin(Trainer trainer) { + this.startTimeMills = System.currentTimeMillis(); + this.prevLoss = Double.NaN; + this.numberOfEpochsWithoutImprovements = 0; + } + + /** {@inheritDoc} */ + @Override + public void onTrainingEnd(Trainer trainer) { + // do nothing + } + + /** + * Creates a builder to build a {@link EarlyStoppingListener}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** A builder for a {@link EarlyStoppingListener}. */ + public static final class Builder { + private final double objectiveSuccess; + private int minEpochs; + private long maxMillis; + private double earlyStopPctImprovement; + private int epochPatience; + + /** Constructs a {@link Builder} with default values. */ + public Builder() { + this.objectiveSuccess = 0; + this.minEpochs = 0; + this.maxMillis = Long.MAX_VALUE; + this.earlyStopPctImprovement = 0; + this.epochPatience = 0; + } + + /** + * Set the minimum # epochs, defaults to 0. + * + * @param minEpochs the minimum # epochs + * @return this builder + */ + public Builder optMinEpochs(int minEpochs) { + this.minEpochs = minEpochs; + return this; + } + + /** + * Set the maximum duration a training run should take, defaults to Long.MAX_VALUE in ms. + * + * @param duration the maximum duration a training run should take + * @return this builder + */ + public Builder optMaxDuration(Duration duration) { + this.maxMillis = duration.toMillis(); + return this; + } + + /** + * Set the maximum # milliseconds a training run should take, defaults to Long.MAX_VALUE. + * + * @param maxMillis the maximum # milliseconds a training run should take + * @return this builder + */ + public Builder optMaxMillis(int maxMillis) { + this.maxMillis = maxMillis; + return this; + } + + /** + * Consider early stopping if not x% improvement, defaults to 0. + * + * @param earlyStopPctImprovement the percentage improvement to consider early stopping, + * must be between 0 and 100. + * @return this builder + */ + public Builder optEarlyStopPctImprovement(double earlyStopPctImprovement) { + this.earlyStopPctImprovement = earlyStopPctImprovement; + return this; + } + + /** + * Stop if insufficient improvement for x epochs in a row, defaults to 0. + * + * @param epochPatience the number of epochs without improvement to consider stopping, must + * be greater than 0. + * @return this builder + */ + public Builder optEpochPatience(int epochPatience) { + this.epochPatience = epochPatience; + return this; + } + + /** + * Builds a {@link EarlyStoppingListener} with the specified values. + * + * @return a new {@link EarlyStoppingListener} + */ + public EarlyStoppingListener build() { + return new EarlyStoppingListener( + objectiveSuccess, minEpochs, maxMillis, earlyStopPctImprovement, epochPatience); + } + } + + /** + * Thrown when training is stopped early, the message will contain the reason why it is stopped + * early. + */ + public static class EarlyStoppedException extends RuntimeException { + private static final long serialVersionUID = 1L; + private final int stopEpoch; + + /** + * Constructs an {@link EarlyStoppedException} with the specified message and epoch. + * + * @param stopEpoch the epoch at which training was stopped early + * @param message the message/reason why training was stopped early + */ + public EarlyStoppedException(int stopEpoch, String message) { + super(message); + this.stopEpoch = stopEpoch; + } + + /** + * Gets the epoch at which training was stopped early. + * + * @return the epoch at which training was stopped early. + */ + public int getStopEpoch() { + return stopEpoch; + } + } +} diff --git a/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java b/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java index 1dbfe4117cd..2556a026259 100644 --- a/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java +++ b/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java @@ -144,9 +144,7 @@ private void updateEvaluators(Trainer trainer, BatchData batchData, String[] acc for (Device device : batchData.getLabels().keySet()) { NDList labels = batchData.getLabels().get(device); NDList predictions = batchData.getPredictions().get(device); - for (String accumulator : accumulators) { - evaluator.updateAccumulator(accumulator, labels, predictions); - } + evaluator.updateAccumulators(accumulators, labels, predictions); } } } diff --git a/api/src/main/java/ai/djl/training/listener/Node.java b/api/src/main/java/ai/djl/training/listener/Node.java new file mode 100644 index 00000000000..8dfa569e26a --- /dev/null +++ b/api/src/main/java/ai/djl/training/listener/Node.java @@ -0,0 +1,463 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.training.listener; + +import ai.djl.ndarray.types.Shape; +import ai.djl.util.PairList; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +/** One node of the computational graph. */ +class Node { + + String name; + final Node[] src; + final PairList param; + boolean isLeaf; + Shape outputShape; + + public Node(String name, PairList param, Node... src) { + this.name = name; + this.param = param; + this.src = src; + } + + String toPythonExpression(Map locals, AtomicInteger opCount) { + return toPythonExpression(locals, opCount, false) + " # " + outputShape; + } + + String toPythonExpression(Map locals, AtomicInteger opCount, boolean useLocals) { + if (isLeaf) { + return name; + } + if (useLocals && locals != null && locals.containsKey(this)) { + return locals.get(this); + } + switch (name) { + case "pick": + { + Object[][] args = {{0}, {1, "indices"}, {"axis", "batch_dims"}}; + return format("tf.gather", args, locals, opCount); + } + case "where": + { + Object[][] args = {{0}, {1, "x"}, {2, "y"}}; + return format("tf.where", args, locals, opCount); + } + case "_npi_slice": + { + Object[][] args = { + {0}, {"begin", "begin"}, {"end", "end"}, {"step", "strides"} + }; + return format("tf.strided_slice", args, locals, opCount); + } + case "_npi_concatenate": + { + Object[][] args = {{-1}, {"axis", "axis"}}; + return format("tf.concat", args, locals, opCount); + } + case "_np_squeeze": + { + Object[][] args = {{0}, {"axis", "axis"}}; + return format("tf.squeeze", args, locals, opCount); + } + case "_npi_stack": + { + Object[][] args = {{-1}, {"axis", "axis"}}; + return format("tf.stack", args, locals, opCount); + } + case "_npi_split": + { + Object[][] args = { + {0}, {"axis", "axis"}, {"num_outputs", "num_or_size_splits"} + }; + return format("tf.split", args, locals, opCount); + } + case "_npi_swapaxes": + { + Object[][] args = {{0}, {"dim1", "axis1"}, {"dim2", "axis2"}}; + return format("tf.experimental.numpy.swapaxes", args, locals, opCount); + } + case "_np_repeat": + { + Object[][] args = {{0}, {"repeats", "repeats"}, {"axis", "axis"}}; + return format("tf.repeat", args, locals, opCount); + } + case "_npi_copyto": + { + return src[0].toPythonExpression(locals, opCount, true); + } + case "_npi_expand_dims": + { + Object[][] args = {{0}, {"axis", "axis"}}; + return format("tf.expand_dims", args, locals, opCount); + } + case "_npx_log_softmax": + { + Object[][] args = {{0}, {"axis", "axis"}}; + return format("tf.nn.log_softmax", args, locals, opCount); + } + case "_npi_zeros": + { + Object[][] args = {{"shape", "shape"}, {"dtype", "dtype", "tf.dtypes.%s"}}; + return format("tf.zeros", args, locals, opCount); + } + case "_npi_ones": + { + Object[][] args = {{"shape", "shape"}, {"dtype", "dtype", "tf.dtypes.%s"}}; + return format("tf.ones", args, locals, opCount); + } + case "_npi_normal": + { + Object[][] args = { + {"size", "shape"}, + {"loc", "mean"}, + {"scale", "stddev"}, + {"dtype", "dtype", "tf.dtypes.%s"} + }; + return format("tf.random.normal", args, locals, opCount); + } + case "_npi_uniform": + { + Object[][] args = { + {"low", "minval"}, + {"high", "maxval"}, + {"shape", "shape"}, + {"dtype", "dtype", "tf.dtypes.%s"} + }; + return format("tf.random.uniform", args, locals, opCount); + } + case "_np_reshape": + { + Object[][] args = {{0}, {"newshape", "shape"}}; + return format("tf.reshape", args, locals, opCount); + } + case "_np_transpose": + { + Object[][] args = {{0}, {"axes", "perm"}}; + return format("tf.transpose", args, locals, opCount); + } + case "_npx_activation": + { + Object[][] args = {{0}}; + String op = + this.param.get("act_type").toString().replace("softrelu", "softplus"); + return format("tf.nn." + op, args, locals, opCount); + } + case "_npx_convolution": + { + String padding = "(0, 0)".equals(this.param.get("pad")) ? "'VALID'" : "'SAME'"; + Object[][] args = { + {0}, + {1, "filters"}, + {"stride", "strides"}, + {"pad", "padding", padding}, + {"dilate", "dilations"}, + {null, "data_format", "'NCHW'"} + }; + return addBias( + format("tf.nn.convolution", args, locals, opCount), + true, + locals, + opCount); + } + case "_npx_pooling": + { + if ("True".equals(this.param.get("global_pool"))) { + String op = + "avg".equals(this.param.get("pool_type")) + ? "reduce_mean" + : "reduce_max"; + Object[][] args = {{0}, {null, "axis", "[2, 3]"}}; + return format("tf." + op, args, locals, opCount); + } + String padding = "(0, 0)".equals(this.param.get("pad")) ? "'VALID'" : "'SAME'"; + String poolingType = + "avg".equals(this.param.get("pool_type")) ? "'AVG'" : "'MAX'"; + Object[][] args = { + {0}, + {"kernel", "window_shape"}, + {"pool_type", "pooling_type", poolingType}, + {"stride", "strides"}, + {"pad", "padding", padding}, + {"dilate", "dilations"}, + {null, "data_format", "'NCHW'"} + }; + return format("tf.nn.pool", args, locals, opCount); + } + case "_npx_batch_norm": + { + Object[][] args = { + {0}, + {1, "scale"}, + {2, "offset"}, + {3, "mean"}, + {4, "variance"}, + {"eps", "epsilon"}, + {null, "is_training", "True"}, + {"momentum", "exponential_avg_factor"}, + {null, "data_format", "'NCHW'"} + }; + return format("tf.compat.v1.nn.fused_batch_norm", args, locals, opCount); + } + + case "_npx_embedding": + { + Object[][] args = { + {0, "ids"}, + {1, "params"} + }; + return format("tf.nn.embedding_lookup", args, locals, opCount); + } + case "_npx_fully_connected": + { + Object[][] args = {{0}, {1, "b"}, {null, "transpose_b", "True"}}; + return addBias( + format("tf.matmul", args, locals, opCount), false, locals, opCount); + } + case "_npi_matmul": + { + Object[][] args = {{0}, {1}}; + return addBias( + format("tf.matmul", args, locals, opCount), false, locals, opCount); + } + case "_npi_not_equal_scalar": + { + Object[][] args = {{0}, {"scalar", "y"}}; + return format("tf.not_equal", args, locals, opCount); + } + case "_rdiv_scalar": + { + Object[][] args = {{0}, {"scalar", "y"}}; + return format("tf.divide", args, locals, opCount); + } + case "_npi_add_scalar": + { + Object[][] args = {{0}, {"scalar", "y"}}; + return format("tf.add", args, locals, opCount); + } + case "_npi_add": + { + Object[][] args = {{0}, {1}}; + return format("tf.add", args, locals, opCount); + } + case "_npi_subtract": + { + Object[][] args = {{0}, {1}}; + return format("tf.subtract", args, locals, opCount); + } + case "_npi_mean": + { + Object[][] args = {{0}, {"axis", "axis"}, {"keepdims", "keepdims"}}; + return format("tf.reduce_mean", args, locals, opCount); + } + case "gammaln": + { + Object[][] args = {{0}}; + return format("tf.gammaln", args, locals, opCount); + } + case "_np_sum": + { + Object[][] args = {{0}, {"axis", "axis"}, {"keepdims", "keepdims"}}; + return format("tf.reduce_sum", args, locals, opCount); + } + case "_npi_maximum_scalar": + { + Object[][] args = {{0}, {"scalar", "y"}}; + return format("tf.maximum", args, locals, opCount); + } + case "_npi_multiply_scalar": + { + Object[][] args = {{0}, {"scalar", "y"}}; + return format("tf.multiply", args, locals, opCount); + } + case "_npi_multiply": + { + Object[][] args = {{0}, {1}}; + return format("tf.multiply", args, locals, opCount); + } + case "_npi_true_divide": + { + Object[][] args = {{0}, {1}}; + return format("tf.divide", args, locals, opCount); + } + case "_npi_greater": + { + Object[][] args = {{0}, {1}}; + return format("tf.greater", args, locals, opCount); + } + case "_npi_negative": + { + Object[][] args = {{0}}; + return format("tf.negative", args, locals, opCount); + } + case "_npi_absolute": + { + Object[][] args = {{0}}; + return format("tf.abs", args, locals, opCount); + } + case "_npi_log": + { + Object[][] args = {{0}}; + return format("tf.log", args, locals, opCount); + } + case "_npi_exp": + { + Object[][] args = {{0}}; + return format("tf.exp", args, locals, opCount); + } + default: + { + Stream srcStream = + IntStream.range(0, src.length).mapToObj(i -> new Object[] {i}); + Stream paramStream = + param.stream().map(p -> new Object[] {p.getKey(), p.getKey()}); + Object[][] args = + Stream.concat(srcStream, paramStream).toArray(Object[][]::new); + return format(name, args, locals, opCount); + } + } + } + + /** + * Constructs a Python expression for the given operation and formatting arguments. + * + * @param op tensorflow operation name + * @param args array of array of:
+ * [0]: index for {@link #src} or {@link #param} to retrieve argument value, or null + *
+ * [1]: tensorflow parameter name
+ * [2]: format of argument
+ * [3]: output shape of argument
+ * @param locals nodes stored in local Python variables + * @param opCount operation counter + * @return the Python expression + */ + private String format( + String op, Object[][] args, Map locals, AtomicInteger opCount) { + StringBuilder sb = new StringBuilder(op + "(\n"); + for (Object[] arg : args) { + String s = arg.length >= 3 ? String.valueOf(arg[2]) : "%s"; + Shape shape = arg.length >= 4 ? (Shape) arg[3] : null; + if (Integer.valueOf(-1).equals(arg[0])) { + s = + Stream.of(src) + .map(node -> node.toPythonExpression(locals, opCount, true)) + .map(Node::indent) + .collect(Collectors.joining(",\n", "[\n", "\n]")); + } else if (arg[0] instanceof Integer && src.length > (int) arg[0]) { + Node node = src[(int) arg[0]]; + s = String.format(s, node.toPythonExpression(locals, opCount, true)); + shape = node.outputShape; + } else if (this.param.get(String.valueOf(arg[0])) != null) { + s = String.format(s, this.param.get(String.valueOf(arg[0]))); + } else if (arg[0] != null) { + continue; // cannot resolve index, so skip + } + if (s.startsWith("(") && s.endsWith(")")) { + s = String.format("[%s]", s.substring(1, s.length() - 1)); + } + if (arg.length >= 2 && arg[1] != null) { + s = String.format("%s=%s", arg[1], s); + } + sb.append(indent(s) + "," + (shape != null ? " # " + shape : "") + "\n"); + } + sb.append( + indent( + String.format( + "name='%s_%s_',", + op.substring(op.lastIndexOf('.') + 1), opCount.incrementAndGet()))); + sb.append("\n)"); + return sb.toString(); + } + + private String addBias( + String result, + boolean setChannelFirst, + Map locals, + AtomicInteger opCount) { + if (src.length == 3) { + Object[][] args = { + {null, null, result, this.outputShape}, + {2, "bias"}, + {null, "data_format", setChannelFirst ? "'NCHW'" : "None"} + }; + return format("tf.nn.bias_add", args, locals, opCount); + } + return result; + } + + private void identifyMultipleUsages(Map usages) { + if (isLeaf) { + return; + } + if (usages.compute(this, (key, count) -> count == null ? 1 : count + 1) >= 2) { + return; + } + for (Node node : src) { + node.identifyMultipleUsages(usages); + } + // reposition behind src nodes + usages.put(this, usages.remove(this)); + } + + String toPythonFunctionBody(AtomicInteger opCount, String result) { + @SuppressWarnings("PMD.UseConcurrentHashMap") + Map usages = new LinkedHashMap<>(); + identifyMultipleUsages(usages); + Map locals = new ConcurrentHashMap<>(); + List statements = new ArrayList<>(); + int val = 1; + int batchnorm = 1; + for (Map.Entry usage : usages.entrySet()) { + Node node = usage.getKey(); + if (usage.getValue() >= 2) { + // save the result of an expression that is used multiple times in local variable + locals.put(node, "val".concat(Integer.toString(val++))); + } else if ("_npx_batch_norm".equals(node.name)) { + // local required to assign locals 'running_mean' and 'running_var' at the same time + locals.put(node, "batchnorm".concat(Integer.toString(batchnorm++))); + } + } + for (Map.Entry usage : usages.entrySet()) { + Node node = usage.getKey(); + if (usage.getValue() >= 2) { + statements.add( + String.format( + "%s = %s", + locals.get(node), node.toPythonExpression(locals, opCount))); + } else if ("_npx_batch_norm".equals(node.name)) { + statements.add( + String.format( + "(%s, running_mean, running_var) = %s", + locals.get(node), node.toPythonExpression(locals, opCount))); + statements.add(String.format("%s.assign(running_mean)", node.src[3].name)); + statements.add(String.format("%s.assign(running_var)", node.src[4].name)); + } + } + statements.add(String.format("%s = %s", result, toPythonExpression(locals, opCount))); + return statements.stream().map(Node::indent).collect(Collectors.joining(" \n")); + } + + static String indent(String val) { + return val.replaceAll("(?m)^", " "); + } +} diff --git a/api/src/main/java/ai/djl/training/listener/TrainingListener.java b/api/src/main/java/ai/djl/training/listener/TrainingListener.java index 3d81601f20f..c228bdade2b 100644 --- a/api/src/main/java/ai/djl/training/listener/TrainingListener.java +++ b/api/src/main/java/ai/djl/training/listener/TrainingListener.java @@ -13,11 +13,13 @@ package ai.djl.training.listener; import ai.djl.Device; +import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.training.Trainer; import ai.djl.training.dataset.Batch; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** * {@code TrainingListener} offers an interface that performs some actions when certain events have @@ -163,6 +165,20 @@ static TrainingListener[] logging(String outputDir) { new TimeMeasureTrainingListener(outputDir) }; } + + /** + * Returns listener for logging algebraic operation. + * + * @param outputFile the output file to store the algebraic log. Can be null which skips + * algebraic logging. + * @return the new set of listeners + */ + static TrainingListener[] algebraicLogging(String outputFile) { + if (outputFile == null) { + return new TrainingListener[] {}; // algebraic logging disabled + } + return new TrainingListener[] {new AlgebraicListener(outputFile)}; + } } /** A class to pass data from the batch into the training listeners. */ @@ -171,6 +187,8 @@ class BatchData { private Batch batch; private Map labels; private Map predictions; + private Map data; + private Map loss; /** * Constructs a new {@link BatchData}. @@ -183,6 +201,8 @@ public BatchData(Batch batch, Map labels, Map pr this.batch = batch; this.labels = labels; this.predictions = predictions; + this.data = new ConcurrentHashMap<>(); + this.loss = new ConcurrentHashMap<>(); } /** @@ -211,5 +231,23 @@ public Map getLabels() { public Map getPredictions() { return predictions; } + + /** + * Returns the data for each device. + * + * @return the data for each device + */ + public Map getData() { + return data; + } + + /** + * Returns the loss for each device. + * + * @return the loss for each device + */ + public Map getLoss() { + return loss; + } } } diff --git a/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java b/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java index 2a46416190a..2e2cdcb8c86 100644 --- a/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java +++ b/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java @@ -80,10 +80,10 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override - public void updateAccumulator(String key, NDList labels, NDList predictions) { + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { for (int i = 0; i < components.size(); i++) { Pair inputs = inputForComponent(i, labels, predictions); - components.get(i).updateAccumulator(key, inputs.getKey(), inputs.getValue()); + components.get(i).updateAccumulators(keys, inputs.getKey(), inputs.getValue()); } } diff --git a/api/src/main/java/ai/djl/training/loss/Loss.java b/api/src/main/java/ai/djl/training/loss/Loss.java index a661a3e9a0e..bcf39d23b39 100644 --- a/api/src/main/java/ai/djl/training/loss/Loss.java +++ b/api/src/main/java/ai/djl/training/loss/Loss.java @@ -385,10 +385,18 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { // this is a synchronized operation, only call it at end of batch or epoch float update = evaluate(labels, predictions).sum().getFloat(); - totalInstances.compute(key, (k, v) -> v + 1); - totalLoss.compute(key, (k, v) -> v + update); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + 1); + totalLoss.compute(key, (k, v) -> v + update); + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java b/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java index 3f3bb1b2d6e..f026bd431c9 100644 --- a/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java +++ b/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java @@ -29,10 +29,17 @@ public final class PaddingStackBatchifier implements Batchifier { private static final long serialVersionUID = 1L; + @SuppressWarnings("serial") private List arraysToPad; + + @SuppressWarnings("serial") private List dimsToPad; + private transient List paddingSuppliers; + + @SuppressWarnings("serial") private List paddingSizes; + private boolean includeValidLengths; private PaddingStackBatchifier(Builder builder) { diff --git a/api/src/main/java/ai/djl/util/Ec2Utils.java b/api/src/main/java/ai/djl/util/Ec2Utils.java index 178c3d7efe7..5408182964f 100644 --- a/api/src/main/java/ai/djl/util/Ec2Utils.java +++ b/api/src/main/java/ai/djl/util/Ec2Utils.java @@ -97,7 +97,7 @@ public static String readMetadata(String key) { * @param engine the default engine name */ public static void callHome(String engine) { - if (Boolean.getBoolean("offline") + if (Utils.isOfflineMode() || Boolean.parseBoolean(Utils.getEnvOrSystemProperty("OPT_OUT_TRACKING")) || System.currentTimeMillis() - lastCheckIn < ONE_DAY) { return; diff --git a/api/src/main/java/ai/djl/util/StringPair.java b/api/src/main/java/ai/djl/util/StringPair.java new file mode 100644 index 00000000000..a42e739614b --- /dev/null +++ b/api/src/main/java/ai/djl/util/StringPair.java @@ -0,0 +1,27 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util; + +/** A class containing the string key-value pair. */ +public class StringPair extends Pair { + + /** + * Constructs a {@code Pair} instance with key and value. + * + * @param key the key + * @param value the value + */ + public StringPair(String key, String value) { + super(key, value); + } +} diff --git a/api/src/main/java/ai/djl/util/TarUtils.java b/api/src/main/java/ai/djl/util/TarUtils.java new file mode 100644 index 00000000000..c02b278788f --- /dev/null +++ b/api/src/main/java/ai/djl/util/TarUtils.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util; + +import org.apache.commons.compress.archivers.tar.TarArchiveEntry; +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; +import org.apache.commons.compress.utils.CloseShieldFilterInputStream; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; + +/** Utilities for working with zip files. */ +public final class TarUtils { + + private TarUtils() {} + + /** + * Un-compress a tar ball from InputStream. + * + * @param is the InputStream + * @param dir the target directory + * @param gzip if the bar ball is gzip + * @throws IOException for failures to untar the input directory + */ + public static void untar(InputStream is, Path dir, boolean gzip) throws IOException { + InputStream bis; + if (gzip) { + bis = new GzipCompressorInputStream(new BufferedInputStream(is)); + } else { + bis = new BufferedInputStream(is); + } + bis = new CloseShieldFilterInputStream(bis); + try (TarArchiveInputStream tis = new TarArchiveInputStream(bis)) { + TarArchiveEntry entry; + while ((entry = tis.getNextEntry()) != null) { + String entryName = entry.getName(); + if (entryName.contains("..")) { + throw new IOException("Malicious zip entry: " + entryName); + } + Path file = dir.resolve(entryName).toAbsolutePath(); + if (entry.isDirectory()) { + Files.createDirectories(file); + } else { + Path parentFile = file.getParent(); + if (parentFile == null) { + throw new AssertionError("Parent path should never be null: " + file); + } + Files.createDirectories(parentFile); + Files.copy(tis, file, StandardCopyOption.REPLACE_EXISTING); + } + } + } + } +} diff --git a/api/src/main/java/ai/djl/util/Utils.java b/api/src/main/java/ai/djl/util/Utils.java index c8e1bd514ac..270958d5b40 100644 --- a/api/src/main/java/ai/djl/util/Utils.java +++ b/api/src/main/java/ai/djl/util/Utils.java @@ -357,6 +357,20 @@ public static Path getCacheDir() { return Paths.get(cacheDir); } + /** + * Returns if offline mode is enabled. + * + * @return true if offline mode is enabled + */ + public static boolean isOfflineMode() { + String mode = getenv("DJL_OFFLINE", System.getProperty("ai.djl.offline")); + if (mode != null) { + return Boolean.parseBoolean(mode); + } + // backward compatible + return Boolean.getBoolean("offline"); + } + /** * Returns nested model directory if the directory contains only one subdirectory. * @@ -481,7 +495,7 @@ public static InputStream openUrl(String url) throws IOException { */ public static InputStream openUrl(URL url) throws IOException { String protocol = url.getProtocol(); - if (Boolean.getBoolean("offline") + if (isOfflineMode() && ("http".equalsIgnoreCase(protocol) || "https".equalsIgnoreCase(protocol))) { throw new IOException("Offline model is enabled."); } diff --git a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java index b0b8e3e4247..b30a208f6ab 100644 --- a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java +++ b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java @@ -22,7 +22,11 @@ import org.slf4j.LoggerFactory; import java.io.File; +import java.io.IOException; +import java.io.InputStream; import java.lang.management.MemoryUsage; +import java.util.ArrayList; +import java.util.List; import java.util.Locale; import java.util.regex.Pattern; @@ -33,6 +37,8 @@ public final class CudaUtils { private static final CudaLibrary LIB = loadLibrary(); + private static String[] gpuInfo; + private CudaUtils() {} /** @@ -49,7 +55,15 @@ public static boolean hasCuda() { * * @return the number of GPUs available in the system */ + @SuppressWarnings("PMD.NonThreadSafeSingleton") public static int getGpuCount() { + if (Boolean.getBoolean("ai.djl.util.cuda.folk")) { + if (gpuInfo == null) { + gpuInfo = execute(-1); // NOPMD + } + return Integer.parseInt(gpuInfo[0]); + } + if (LIB == null) { return 0; } @@ -79,7 +93,19 @@ public static int getGpuCount() { * * @return the version of CUDA runtime */ + @SuppressWarnings("PMD.NonThreadSafeSingleton") public static int getCudaVersion() { + if (Boolean.getBoolean("ai.djl.util.cuda.folk")) { + if (gpuInfo == null) { + gpuInfo = execute(-1); + } + int version = Integer.parseInt(gpuInfo[1]); + if (version == -1) { + throw new IllegalArgumentException("No cuda device found."); + } + return version; + } + if (LIB == null) { throw new IllegalStateException("No cuda library is loaded."); } @@ -95,9 +121,6 @@ public static int getCudaVersion() { * @return the version string of CUDA runtime */ public static String getCudaVersionString() { - if (LIB == null) { - throw new IllegalStateException("No cuda library is loaded."); - } int version = getCudaVersion(); int major = version / 1000; int minor = (version / 10) % 10; @@ -111,6 +134,14 @@ public static String getCudaVersionString() { * @return the CUDA compute capability */ public static String getComputeCapability(int device) { + if (Boolean.getBoolean("ai.djl.util.cuda.folk")) { + String[] ret = execute(device); + if (ret.length != 3) { + throw new IllegalArgumentException(ret[0]); + } + return ret[0]; + } + if (LIB == null) { throw new IllegalStateException("No cuda library is loaded."); } @@ -137,6 +168,16 @@ public static MemoryUsage getGpuMemory(Device device) { throw new IllegalArgumentException("Only GPU device is allowed."); } + if (Boolean.getBoolean("ai.djl.util.cuda.folk")) { + String[] ret = execute(device.getDeviceId()); + if (ret.length != 3) { + throw new IllegalArgumentException(ret[0]); + } + long total = Long.parseLong(ret[1]); + long used = Long.parseLong(ret[2]); + return new MemoryUsage(-1, used, used, total); + } + if (LIB == null) { throw new IllegalStateException("No GPU device detected."); } @@ -155,8 +196,42 @@ public static MemoryUsage getGpuMemory(Device device) { return new MemoryUsage(-1, committed, committed, total[0]); } + /** + * The main entrypoint to get CUDA information with command line. + * + * @param args the command line arguments. + */ + @SuppressWarnings("PMD.SystemPrintln") + public static void main(String[] args) { + int gpuCount = getGpuCount(); + if (args.length == 0) { + if (gpuCount <= 0) { + System.out.println("0,-1"); + return; + } + int cudaVersion = getCudaVersion(); + System.out.println(gpuCount + "," + cudaVersion); + return; + } + try { + int deviceId = Integer.parseInt(args[0]); + if (deviceId < 0 || deviceId >= gpuCount) { + System.out.println("Invalid device: " + deviceId); + return; + } + MemoryUsage mem = getGpuMemory(Device.gpu(deviceId)); + String cc = getComputeCapability(deviceId); + System.out.println(cc + ',' + mem.getMax() + ',' + mem.getUsed()); + } catch (NumberFormatException e) { + System.out.println("Invalid device: " + args[0]); + } + } + private static CudaLibrary loadLibrary() { try { + if (Boolean.getBoolean("ai.djl.util.cuda.folk")) { + return null; + } if (System.getProperty("os.name").startsWith("Win")) { String path = Utils.getenv("PATH"); if (path == null) { @@ -187,15 +262,40 @@ private static CudaLibrary loadLibrary() { } catch (UnsatisfiedLinkError e) { logger.debug("cudart library not found."); logger.trace("", e); - return null; - } catch (IncompatibleClassChangeError e) { + } catch (LinkageError e) { logger.warn("You have a conflict version of JNA in the classpath."); logger.debug("", e); - return null; } catch (SecurityException e) { logger.warn("Access denied during loading cudart library."); logger.trace("", e); - return null; + } + return null; + } + + private static String[] execute(int deviceId) { + try { + String javaHome = System.getProperty("java.home"); + String classPath = System.getProperty("java.class.path"); + String os = System.getProperty("os.name"); + List cmd = new ArrayList<>(4); + if (os.startsWith("Win")) { + cmd.add(javaHome + "\\bin\\java.exe"); + } else { + cmd.add(javaHome + "/bin/java"); + } + cmd.add("-cp"); + cmd.add(classPath); + cmd.add("ai.djl.util.cuda.CudaUtils"); + if (deviceId >= 0) { + cmd.add(String.valueOf(deviceId)); + } + Process ps = new ProcessBuilder(cmd).redirectErrorStream(true).start(); + try (InputStream is = ps.getInputStream()) { + String line = Utils.toString(is).trim(); + return line.split(","); + } + } catch (IOException e) { + throw new IllegalArgumentException("Failed get GPU information", e); } } diff --git a/api/src/test/java/ai/djl/DeviceTest.java b/api/src/test/java/ai/djl/DeviceTest.java index 92a0474c6e7..a69a502739b 100644 --- a/api/src/test/java/ai/djl/DeviceTest.java +++ b/api/src/test/java/ai/djl/DeviceTest.java @@ -13,6 +13,7 @@ package ai.djl; +import ai.djl.Device.MultiDevice; import ai.djl.engine.Engine; import org.testng.Assert; @@ -37,6 +38,9 @@ public void testDevice() { System.setProperty("test_key", "test"); Engine.debugEnvironment(); + + Assert.assertEquals(1, Device.cpu().getDevices().size()); + Assert.assertEquals(2, new MultiDevice(Device.gpu(1), Device.gpu(2)).getDevices().size()); } @Test @@ -54,5 +58,9 @@ public void testDeviceName() { Device defaultDevice = Engine.getInstance().defaultDevice(); Assert.assertEquals(Device.fromName(""), defaultDevice); Assert.assertEquals(Device.fromName(null), defaultDevice); + + Assert.assertEquals( + Device.fromName("gpu1+gpu2"), new MultiDevice(Device.gpu(2), Device.gpu(1))); + Assert.assertEquals(Device.fromName("gpu1+gpu2"), new MultiDevice("gpu", 1, 3)); } } diff --git a/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java b/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java index 8c140688124..a8b2bdfab62 100644 --- a/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java +++ b/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java @@ -15,32 +15,38 @@ import org.testng.Assert; import org.testng.annotations.Test; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; public class PublisherBytesSupplierTest { @Test - public void test() { + public void test() throws ExecutionException, InterruptedException { AtomicInteger contentCount = new AtomicInteger(); PublisherBytesSupplier supplier = new PublisherBytesSupplier(); - // Add to supplier without subscriber - supplier.appendContent(new byte[] {1}, false); - Assert.assertEquals(contentCount.get(), 0); + new Thread( + () -> { + // Add to supplier without subscriber + supplier.appendContent(new byte[] {1}, false); + // Add to supplier with subscriber + supplier.appendContent(new byte[] {1}, true); + }) + .start(); // Subscribing with data should trigger subscriptions - supplier.subscribe( - d -> { - if (d == null) { - // Do nothing on completion - return; - } - contentCount.getAndIncrement(); - }); - Assert.assertEquals(contentCount.get(), 1); + CompletableFuture future = + supplier.subscribe( + d -> { + if (d == null) { + // Do nothing on completion + return; + } + contentCount.getAndIncrement(); + }); - // Add to supplier with subscriber - supplier.appendContent(new byte[] {1}, true); + future.get(); Assert.assertEquals(contentCount.get(), 2); } } diff --git a/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java new file mode 100644 index 00000000000..8fbbae7301b --- /dev/null +++ b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.translate.BasicTranslator; +import ai.djl.translate.Translator; + +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.InputStream; +import java.net.URL; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +public class YoloV8TranslatorFactoryTest { + + private YoloV8TranslatorFactory factory; + + @BeforeClass + public void setUp() { + factory = new YoloV8TranslatorFactory(); + } + + @Test + public void testGetSupportedTypes() { + Assert.assertEquals(factory.getSupportedTypes().size(), 5); + } + + @Test + public void testNewInstance() { + Map arguments = new HashMap<>(); + try (Model model = Model.newInstance("test")) { + Translator translator1 = + factory.newInstance(Image.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator1 instanceof YoloV8Translator); + + Translator translator2 = + factory.newInstance(Path.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator2 instanceof BasicTranslator); + + Translator translator3 = + factory.newInstance(URL.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator3 instanceof BasicTranslator); + + Translator translator4 = + factory.newInstance(InputStream.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator4 instanceof BasicTranslator); + + Translator translator5 = + factory.newInstance(Input.class, Output.class, model, arguments); + Assert.assertTrue(translator5 instanceof ImageServingTranslator); + + Assert.assertThrows( + IllegalArgumentException.class, + () -> factory.newInstance(Image.class, Output.class, model, arguments)); + } + } +} diff --git a/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java b/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java index 0e38c2d8be6..e89f2244203 100644 --- a/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java +++ b/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java @@ -107,7 +107,7 @@ private static byte[] encode(NDArray array) throws IOException { private static NDArray decode(NDManager manager, byte[] data) throws IOException { try (ByteArrayInputStream bis = new ByteArrayInputStream(data)) { - return NDSerializer.decodeNumpy(manager, bis); + return NDList.decode(manager, bis).get(0); } } diff --git a/api/src/test/java/ai/djl/repository/ZooTest.java b/api/src/test/java/ai/djl/repository/ZooTest.java index 2b44f967144..29fc10391aa 100644 --- a/api/src/test/java/ai/djl/repository/ZooTest.java +++ b/api/src/test/java/ai/djl/repository/ZooTest.java @@ -17,6 +17,7 @@ import ai.djl.modality.Output; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; import org.testng.Assert; import org.testng.annotations.Test; @@ -48,4 +49,11 @@ public void testInvalidCriteria() Criteria criteria = Criteria.builder().build(); criteria.loadModel(); } + + @Test + public void testModelZooResolver() { + ModelZoo.setModelZooResolver(groupId -> null); + ModelZoo zoo = ModelZoo.getModelZoo("unknown"); + Assert.assertNull(zoo); + } } diff --git a/api/src/test/java/ai/djl/util/SecurityManagerTest.java b/api/src/test/java/ai/djl/util/SecurityManagerTest.java index fd9b5db72bc..1e9eb17f63c 100644 --- a/api/src/test/java/ai/djl/util/SecurityManagerTest.java +++ b/api/src/test/java/ai/djl/util/SecurityManagerTest.java @@ -74,8 +74,11 @@ public void checkPermission(Permission perm) { } }; System.setSecurityManager(sm); - - Assert.assertFalse(CudaUtils.hasCuda()); - Assert.assertEquals(CudaUtils.getGpuCount(), 0); + try { + Assert.assertFalse(CudaUtils.hasCuda()); + Assert.assertEquals(CudaUtils.getGpuCount(), 0); + } finally { + System.setSecurityManager(null); + } } } diff --git a/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java b/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java index de1c5cb4a20..a598d8482e6 100644 --- a/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java +++ b/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java @@ -20,8 +20,6 @@ import org.testng.annotations.Test; import java.lang.management.MemoryUsage; -import java.util.Arrays; -import java.util.List; public class CudaUtilsTest { @@ -30,6 +28,9 @@ public class CudaUtilsTest { @Test public void testCudaUtils() { if (!CudaUtils.hasCuda()) { + Assert.assertThrows(CudaUtils::getCudaVersionString); + Assert.assertThrows(() -> CudaUtils.getComputeCapability(0)); + Assert.assertThrows(() -> CudaUtils.getGpuMemory(Device.gpu())); return; } // Possible to have CUDA and not have a GPU. @@ -37,16 +38,24 @@ public void testCudaUtils() { return; } - int cudaVersion = CudaUtils.getCudaVersion(); + String cudaVersion = CudaUtils.getCudaVersionString(); String smVersion = CudaUtils.getComputeCapability(0); MemoryUsage memoryUsage = CudaUtils.getGpuMemory(Device.gpu()); logger.info("CUDA runtime version: {}, sm: {}", cudaVersion, smVersion); logger.info("Memory usage: {}", memoryUsage); - Assert.assertTrue(cudaVersion >= 9020, "cuda 9.2+ required."); + Assert.assertNotNull(cudaVersion); + Assert.assertNotNull(smVersion); + } - List supportedSm = Arrays.asList("37", "52", "60", "61", "70", "75"); - Assert.assertTrue(supportedSm.contains(smVersion), "Unsupported cuda sm: " + smVersion); + @Test + public void testCudaUtilsWithFolk() { + System.setProperty("ai.djl.util.cuda.folk", "true"); + try { + testCudaUtils(); + } finally { + System.clearProperty("ai.djl.util.cuda.folk"); + } } } diff --git a/apt.txt b/apt.txt index 7083f85c374..c89953ff1f9 100644 --- a/apt.txt +++ b/apt.txt @@ -1 +1 @@ -openjdk-11-jdk +openjdk-17-jdk diff --git a/basicdataset/README.md b/basicdataset/README.md index 37bab679551..217f58d22b3 100644 --- a/basicdataset/README.md +++ b/basicdataset/README.md @@ -29,7 +29,7 @@ You can pull the module from the central Maven repository by including the follo ai.djl basicdataset - 0.23.0 + 0.26.0 ``` diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java index a92a9b6a3d4..deef04907be 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; import java.util.Map; /** @@ -118,8 +119,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException { byte[] buf = Utils.toByteArray(is); try (NDArray array = manager.create( - new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), DataType.UINT8)) { - array.set(buf); + ByteBuffer.wrap(buf), + new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), + DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } @@ -132,8 +134,8 @@ private NDArray readLabel(Artifact.Item item) throws IOException { } byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) { - array.set(buf); + try (NDArray array = + manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java index 164ba9876cb..5503e721caa 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; import java.util.Map; /** @@ -111,8 +112,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException { } byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(new Shape(length, 28, 28, 1), DataType.UINT8)) { - array.set(buf); + try (NDArray array = + manager.create( + ByteBuffer.wrap(buf), new Shape(length, 28, 28, 1), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } @@ -123,10 +125,9 @@ private NDArray readLabel(Artifact.Item item) throws IOException { if (is.skip(8) != 8) { throw new AssertionError("Failed skip data."); } - byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) { - array.set(buf); + try (NDArray array = + manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java index 42fc1744451..b04ae800a10 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/ListFeatures.java @@ -44,6 +44,7 @@ public ListFeatures(int initialCapacity) { * * @param source the source list */ + @SuppressWarnings("this-escape") public ListFeatures(List source) { super(source.size()); addAll(source); diff --git a/bom/README.md b/bom/README.md index 44519846712..c98b9d1fbe1 100644 --- a/bom/README.md +++ b/bom/README.md @@ -22,7 +22,7 @@ will need to mention the type as pom and the scope as import) as the following: ai.djl bom - 0.23.0 + 0.26.0 pom import @@ -38,7 +38,7 @@ will need to mention the type as pom and the scope as import) as the following: ai.djl bom - 0.23.0 + 0.26.0 pom import @@ -65,7 +65,7 @@ will need to mention the type as pom and the scope as import) as the following: - First you need add BOM into your build.gradle file as the following: ``` - implementation platform("ai.djl:bom:0.23.0") + implementation platform("ai.djl:bom:0.26.0") ``` - Then you import the desired DJL modules into to you pom.xml file (no version is needed): diff --git a/bom/build.gradle b/bom/build.gradle index 4708978b5b5..31317316138 100644 --- a/bom/build.gradle +++ b/bom/build.gradle @@ -28,6 +28,7 @@ dependencies { api "ai.djl.fasttext:fasttext-engine:${version}" api "ai.djl.hadoop:hadoop:${version}" api "ai.djl.huggingface:tokenizers:${version}" + api "ai.djl.llama:llama:${version}" api "ai.djl.ml.lightgbm:lightgbm:${version}" api "ai.djl.ml.xgboost:xgboost-gpu:${version}" api "ai.djl.ml.xgboost:xgboost:${version}" @@ -115,15 +116,12 @@ publishing { addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu", "win-x86_64", "${pytorch_version}") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu-precxx11", "linux-x86_64", "${pytorch_version}") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu-precxx11", "linux-aarch64", "${pytorch_version}") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu116", "linux-x86_64", "1.12.1") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu116", "win-x86_64", "1.12.1") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu116-precxx11", "linux-x86_64", "1.12.1") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121", "linux-x86_64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121", "win-x86_64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121-precxx11", "linux-x86_64", "${pytorch_version}") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117", "linux-x86_64", "1.13.1") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117", "win-x86_64", "1.13.1") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117-precxx11", "linux-x86_64", "1.13.1") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu118", "linux-x86_64", "${pytorch_version}") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu118", "win-x86_64", "${pytorch_version}") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu118-precxx11", "linux-x86_64", "${pytorch_version}") addDependency(dependencies, "ai.djl.tensorflow", "tensorflow-native-cpu", "osx-x86_64", "${tensorflow_version}") addDependency(dependencies, "ai.djl.tensorflow", "tensorflow-native-cpu", "linux-x86_64", "${tensorflow_version}") addDependency(dependencies, "ai.djl.tensorflow", "tensorflow-native-cpu", "win-x86_64", "${tensorflow_version}") diff --git a/build.gradle b/build.gradle index f98b86c4e51..ca6f7e68133 100644 --- a/build.gradle +++ b/build.gradle @@ -44,6 +44,7 @@ configure(javaProjects()) { targetCompatibility = JavaVersion.VERSION_11 options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static,-removal" << "-Werror" } + javadoc.options.addStringOption("Xdoclint:none", "-quiet") apply plugin: 'eclipse' @@ -88,7 +89,7 @@ configure(javaProjects()) { systemProperty "disableProgressBar", "true" systemProperty "nightly", System.getProperty("nightly", "false") if (gradle.startParameter.offline) { - systemProperty "offline", "true" + systemProperty "ai.djl.offline", "true" } // This is used to avoid overriding on default engine for modules: // mxnet-engine, mxnet-model-zoo, api (MockEngine), basicdataset, fasttext, etc diff --git a/djl-zero/README.md b/djl-zero/README.md index 2d2c473cc88..34acbac07b9 100644 --- a/djl-zero/README.md +++ b/djl-zero/README.md @@ -49,6 +49,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl djl-zero - 0.23.0 + 0.26.0 ``` diff --git a/docker/README.md b/docker/README.md index 5b5bd01be2b..0df33be9f83 100644 --- a/docker/README.md +++ b/docker/README.md @@ -1,10 +1,12 @@ # Docker Resources + DJL provides docker files that you can use to setup containers with the appropriate environment for certain platforms. We recommend setting up a docker container with the provided Dockerfile when developing for the following platforms and/or engines. ## Windows + You can use the [docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/windows/Dockerfile) provided by us. Please note that this docker will only work with Windows server 2019 by default. If you want it to work with other versions of Windows, you need to pass the version as an argument as follows: @@ -14,19 +16,20 @@ docker build --build-arg version= ``` ## TensorRT + You can use the [docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/tensorrt/Dockerfile) provided by us. This docker file is a modification of the one provided by NVIDIA in -[TensorRT](https://github.com/NVIDIA/TensorRT/blob/8.4.1/docker/ubuntu-18.04.Dockerfile) to include JDK11. -By default this sets up a container using Ubuntu 18.04 and CUDA 11.6.2. You can build the container with other versions as follows, +[TensorRT](https://github.com/NVIDIA/TensorRT/blob/8.4.1/docker/ubuntu-18.04.Dockerfile) to include JDK17. +By default this sets up a container using Ubuntu 18.04 and CUDA 11.6.2. You can build the container with other versions as follows, but keep in mind the TensorRT software requirements outlined [here](https://github.com/NVIDIA/TensorRT#prerequisites): ```bash docker build --build-arg OS_VERSION= --build-arg CUDA_VERSION= ``` -To run the container, we recommend using `nvidia-docker run ...` to ensure cuda driver and runtime are compatible. +To run the container, we recommend using `nvidia-docker run ...` to ensure cuda driver and runtime are compatible. -We recommend that you follow the setup steps in the [TensorRT guide](https://github.com/NVIDIA/TensorRT) if you -need access to the full suite of tools TensorRT provides, such as `trtexec` which can convert onnx models to -uff tensorrt models. When following that guide, make sure to use the DJL provided -[docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/tensorrt/Dockerfile) to enable JDK11 in the docker container. +We recommend that you follow the setup steps in the [TensorRT guide](https://github.com/NVIDIA/TensorRT) if you +need access to the full suite of tools TensorRT provides, such as `trtexec` which can convert onnx models to +uff tensorrt models. When following that guide, make sure to use the DJL provided +[docker file](https://github.com/deepjavalibrary/djl/blob/master/docker/tensorrt/Dockerfile) to enable JDK17 in the docker container. diff --git a/docker/spark/Dockerfile b/docker/spark/Dockerfile index b715899e2f1..b777d5a69ed 100644 --- a/docker/spark/Dockerfile +++ b/docker/spark/Dockerfile @@ -13,7 +13,7 @@ FROM 314815235551.dkr.ecr.us-east-2.amazonaws.com/sagemaker-spark-processing:3.3 LABEL maintainer="djl-dev@amazon.com" # Install dependencies -ARG DJL_VERSION=0.23.0 +ARG DJL_VERSION=0.24.0 ARG JNA_VERSION=5.13.0 ARG JAVACV_VERSION=1.5.9 ARG JAVACPP_VERSION=1.5.9 diff --git a/docker/tensorrt/Dockerfile b/docker/tensorrt/Dockerfile index 3a99bb9cb5d..94f81230e19 100644 --- a/docker/tensorrt/Dockerfile +++ b/docker/tensorrt/Dockerfile @@ -42,7 +42,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ fakeroot \ dh-make \ build-essential \ - openjdk-11-jdk && \ + openjdk-17-jdk && \ apt-get clean -y && rm -rf /var/lib/apt/lists/* # Install python3 diff --git a/docker/windows/Dockerfile b/docker/windows/Dockerfile index 31567b3168b..10989e8a4c8 100644 --- a/docker/windows/Dockerfile +++ b/docker/windows/Dockerfile @@ -11,4 +11,4 @@ RUN powershell -Command \ Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://chocolatey.org/install.ps1')); \ choco feature disable --name showDownloadProgress -RUN choco install -y openjdk11 +RUN choco install -y openjdk17 diff --git a/docs/README.md b/docs/README.md index cdd02661c78..81d547e92f2 100644 --- a/docs/README.md +++ b/docs/README.md @@ -20,14 +20,14 @@ Note: when searching in JavaDoc, if your access is denied, please try removing t - [Troubleshooting](development/troubleshooting.md) - [Inference Optimization](development/inference_performance_optimization.md) -## [Jupyter notebook tutorials](../jupyter/README.md) - -- **[Beginner Jupyter Tutorial](../jupyter/tutorial/README.md)** -- [Run object detection with model zoo](../jupyter/object_detection_with_model_zoo.ipynb) -- [Load pre-trained PyTorch model](../jupyter/load_pytorch_model.ipynb) -- [Load pre-trained Apache MXNet model](../jupyter/load_mxnet_model.ipynb) -- [Transfer learning example](../jupyter/transfer_learning_on_cifar10.ipynb) -- [Question answering example](../jupyter/BERTQA.ipynb) +## [Jupyter notebook tutorials](http://docs.djl.ai/docs/demos/jupyter/index.html) + +- **[Beginner Jupyter Tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/index.html)** +- [Run object detection with model zoo](http://docs.djl.ai/docs/demos/jupyter/object_detection_with_model_zoo.html) +- [Load pre-trained PyTorch model](http://docs.djl.ai/docs/demos/jupyter/load_pytorch_model.html) +- [Load pre-trained Apache MXNet model](http://docs.djl.ai/docs/demos/jupyter/load_mxnet_model.html) +- [Transfer learning example](http://docs.djl.ai/docs/demos/jupyter/transfer_learning_on_cifar10.html) +- [Question answering example](http://docs.djl.ai/docs/demos/jupyter/BERTQA.html) ## [API Examples](../examples/README.md) diff --git a/docs/development/example_dataset.md b/docs/development/example_dataset.md index 35e071f728b..63583c2fdeb 100644 --- a/docs/development/example_dataset.md +++ b/docs/development/example_dataset.md @@ -1,4 +1,4 @@ -## Example CSV Dataset +# Custom CSV Dataset Example If the provided Datasets don't meet your requirements, you can also easily extend our dataset to create your own customized dataset. @@ -24,8 +24,8 @@ api group: 'org.apache.commons', name: 'commons-csv', version: '1.7' In order to extend the dataset, the following dependencies are required: ``` -api "ai.djl:api:0.23.0" -api "ai.djl:basicdataset:0.23.0" +api "ai.djl:api:0.26.0" +api "ai.djl:basicdataset:0.26.0" ``` There are four parts we need to implement for CSVDataset. diff --git a/docs/development/external_libraries.md b/docs/development/external_libraries.md index 7f57fec3165..701fb9d0a03 100644 --- a/docs/development/external_libraries.md +++ b/docs/development/external_libraries.md @@ -1,5 +1,4 @@ - -## DJL external dependencies +# DJL external dependencies This document contains external libraries that DJL depends on and their versions. diff --git a/docs/development/profiler.md b/docs/development/profiler.md index 6db5739483c..4a2a9f626e4 100644 --- a/docs/development/profiler.md +++ b/docs/development/profiler.md @@ -1,4 +1,4 @@ -## Profiler (Experimental) +# Engine Profiler Support Currently, DJL supports experimental profilers for developers that investigate the performance of operator execution as well as memory consumption. diff --git a/docs/development/setup.md b/docs/development/setup.md index e4eb73b2501..fb290eb0e3a 100644 --- a/docs/development/setup.md +++ b/docs/development/setup.md @@ -10,13 +10,13 @@ you can use the $JAVA_HOME environment variable to control which version of Java For ubuntu: ```bash -sudo apt-get install openjdk-11-jdk +sudo apt-get install openjdk-17-jdk ``` For centos ```bash -sudo yum install java-11-openjdk +sudo yum install java-17-openjdk ``` For Mac: @@ -24,7 +24,7 @@ For Mac: ```bash brew tap homebrew/cask-versions brew update -brew install --cask temurin11 +brew install --cask zulu17 ``` You can also download and install [Oracle JDK](https://www.oracle.com/technetwork/java/javase/overview/index.html) diff --git a/docs/development/troubleshooting.md b/docs/development/troubleshooting.md index ff03d32648e..1a04592dc12 100644 --- a/docs/development/troubleshooting.md +++ b/docs/development/troubleshooting.md @@ -105,6 +105,11 @@ For more information, please refer to [DJL Cache Management](cache_management.md It happened when you had a wrong version with DJL and Deep Engines. You can check the combination [here](dependency_management.md) and use DJL BOM to solve the issue. +### 1.6 Manual initialization + +If you are using manual engine initialization, you must both register an engine and set it as the default. +This can be done with `Engine.registerEngine(..)` and `Engine.setDefaultEngine(..)`. + ## 2. IntelliJ throws the `No Log4j 2 configuration file found.` exception. The following exception may appear after running the `./gradlew clean` command: diff --git a/docs/hybrid_engine.md b/docs/hybrid_engine.md index 58bdbe69cb4..ddde08337ee 100644 --- a/docs/hybrid_engine.md +++ b/docs/hybrid_engine.md @@ -21,17 +21,17 @@ to run in a hybrid mode: To use it along with Apache MXNet for additional API support, add the following two dependencies: ``` -runtimeOnly "ai.djl.mxnet:mxnet-engine:0.23.0" +runtimeOnly "ai.djl.mxnet:mxnet-engine:0.26.0" ``` You can also use PyTorch or TensorFlow Engine as the supplemental engine by adding their corresponding dependencies. ``` -runtimeOnly "ai.djl.pytorch:pytorch-engine:0.23.0" +runtimeOnly "ai.djl.pytorch:pytorch-engine:0.26.0" ``` ``` -runtimeOnly "ai.djl.tensorflow:tensorflow-engine:0.23.0" +runtimeOnly "ai.djl.tensorflow:tensorflow-engine:0.26.0" ``` ## How Hybrid works diff --git a/docs/interactive_tool.md b/docs/interactive_tool.md index ed102fedc8d..d7d267db710 100644 --- a/docs/interactive_tool.md +++ b/docs/interactive_tool.md @@ -63,7 +63,7 @@ After that, click `run` and you should see the following result: Finally, you can get the running project setup by clicking `Get Template`. This will bring you a gradle project that can be used in your local machine. -## [Java Jupyter Notebook](../jupyter/README.md) +## [Java Jupyter Notebook](http://docs.djl.ai/docs/demos/jupyter/index.html) Wait a second, are we talking about hosting Jupyter Notebook in python? No, it’s Java 11, only. @@ -71,9 +71,9 @@ No, it’s Java 11, only. ![jupyter](https://djl-ai.s3.amazonaws.com/web-data/images/jupyter.gif) Inspired by Spencer Park’s [IJava project](https://github.com/SpencerPark/IJava), we integrated DJL with Jupyter Notebooks. -For more information on the simple setup, follow the instructions in [DJL Jupyter notebooks](../jupyter/README.md#setup). +For more information on the simple setup, follow the instructions in [DJL Jupyter notebooks](http://docs.djl.ai/docs/demos/jupyter/index.html#setup). After that, use the Jupyter Notebook freely in your hosted server. You can do all kinds of work, like block building and plotting a graph. -There are [tutorials and instructions](../jupyter/README.md#djl---jupyter-notebooks) to guide you how you can run training and/or inference with Java. +There are [tutorials and instructions](http://docs.djl.ai/docs/demos/jupyter/index.html#djl---jupyter-notebooks) to guide you how you can run training and/or inference with Java. ## About Future Lab diff --git a/docs/load_model.md b/docs/load_model.md index 621d7514605..653ba3e91d7 100644 --- a/docs/load_model.md +++ b/docs/load_model.md @@ -181,7 +181,7 @@ Here is a few tips you can use to help you debug model loading issue: See [here](development/configure_logging.md#configure-logging-level) for how to enable debug log #### List models programmatically in your code -You can use [ModelZoo.listModels()](https://javadoc.io/static/ai.djl/api/0.23.0/ai/djl/repository/zoo/ModelZoo.html#listModels--) API to query available models. +You can use [ModelZoo.listModels()](https://javadoc.io/static/ai.djl/api/0.26.0/ai/djl/repository/zoo/ModelZoo.html#listModels--) API to query available models. #### List available models using DJL command line diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index c911bf43b2d..6511e9a865e 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -61,15 +61,15 @@ nav: - 'docs/faq.md' - Tutorials: - Beginner Tutorial: - - 'jupyter/tutorial/01_create_your_first_network.ipynb' - - 'jupyter/tutorial/02_train_your_first_model.ipynb' - - 'jupyter/tutorial/03_image_classification_with_your_model.ipynb' + - 'docs/demos/jupyter/tutorial/01_create_your_first_network.ipynb' + - 'docs/demos/jupyter/tutorial/02_train_your_first_model.ipynb' + - 'docs/demos/jupyter/tutorial/03_image_classification_with_your_model.ipynb' - 'docs/d2l.md' - - 'jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb' - - 'jupyter/transfer_learning_on_cifar10.ipynb' + - 'docs/demos/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb' + - 'docs/demos/jupyter/transfer_learning_on_cifar10.ipynb' - Load your own BERT: - - BERT with MXNet: 'jupyter/mxnet/load_your_own_mxnet_bert.ipynb' - - BERT with PyTorch: 'jupyter/pytorch/load_your_own_pytorch_bert.ipynb' + - BERT with MXNet: 'docs/demos/jupyter/mxnet/load_your_own_mxnet_bert.ipynb' + - BERT with PyTorch: 'docs/demos/jupyter/pytorch/load_your_own_pytorch_bert.ipynb' - Guides: - Models: - 'docs/load_model.md' @@ -97,25 +97,25 @@ nav: - PyTorch NDArray Operators: 'docs/pytorch/pytorch-djl-ndarray-cheatsheet.md' - PyTorch Model Zoo: 'engines/pytorch/pytorch-model-zoo/README.md' - Import PyTorch Model: 'docs/pytorch/how_to_convert_your_model_to_torchscript.md' - - Load a PyTorch Model: 'jupyter/load_pytorch_model.ipynb' + - Load a PyTorch Model: 'docs/demos/jupyter/load_pytorch_model.ipynb' - TensorFlow: - Overview: 'engines/tensorflow/README.md' - TensorFlow Engine: 'engines/tensorflow/tensorflow-engine/README.md' - TensorFlow Model Zoo: 'engines/tensorflow/tensorflow-model-zoo/README.md' - Import TensorFlow Model: 'docs/tensorflow/how_to_import_tensorflow_models_in_DJL.md' - - Load a TensorFlow Model: 'jupyter/tensorflow/pneumonia_detection.ipynb' + - Load a TensorFlow Model: 'docs/demos/jupyter/tensorflow/pneumonia_detection.ipynb' - Apache MXNet: - Overview: 'engines/mxnet/README.md' - MXNet Engine: 'engines/mxnet/mxnet-engine/README.md' - MXNet Model Zoo: 'engines/mxnet/mxnet-model-zoo/README.md' - Import Gluon Model: 'docs/mxnet/how_to_convert_your_model_to_symbol.md' - - Load a MXNet Model: 'jupyter/load_mxnet_model.ipynb' + - Load a MXNet Model: 'docs/demos/jupyter/load_mxnet_model.ipynb' - Backend Optimizer for MXNet: 'docs/mxnet/mxnet_backend_optimizer.md' - Hybrid engines: - Hybrid engine overview: 'docs/hybrid_engine.md' - ONNX Runtime: - Overview: 'engines/onnxruntime/onnxruntime-engine/README.md' - - Load a ONNX Model: 'jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb' + - Load a ONNX Model: 'docs/demos/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb' - PaddlePaddle: - Overview: 'engines/paddlepaddle/README.md' - PaddlePaddle Engine: 'engines/paddlepaddle/paddlepaddle-engine/README.md' @@ -124,11 +124,11 @@ nav: - English: 'docs/paddlepaddle/how_to_create_paddlepaddle_model.md' - 中文: 'docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md' - Facemask detection using PaddlePaddle: - - English: 'jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb' - - 中文: 'jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb' + - English: 'docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb' + - 中文: 'docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb' - PaddleOCR example: - - English: 'jupyter/paddlepaddle/paddle_ocr_java.ipynb' - - 中文: 'jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb' + - English: 'docs/demos/jupyter/paddlepaddle/paddle_ocr_java.ipynb' + - 中文: 'docs/demos/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb' - XGBoost: 'engines/ml/xgboost/README.md' - LightGBM: 'engines/ml/lightgbm/README.md' - TensorRT: 'engines/tensorrt/README.md' @@ -153,15 +153,34 @@ nav: - 'docs/serving/serving/docs/inference.md' - 'docs/serving/serving/docs/modes.md' - 'docs/serving/serving/docs/console.md' - - 'docs/serving/serving/docs/configuration.md' - - 'docs/serving/serving/docs/configurations.md' - - 'docs/serving/serving/docs/workflows.md' + - Configuration: + - 'docs/serving/serving/docs/configuration.md' + - 'docs/serving/serving/docs/configurations_global.md' + - 'docs/serving/serving/docs/configurations.md' + - 'docs/serving/serving/docs/workflows.md' + - 'docs/serving/serving/docs/configurations_model.md' - 'docs/serving/serving/docs/architecture.md' - HTTP API: - 'docs/serving/serving/docs/inference_api.md' - 'docs/serving/serving/docs/management_api.md' - 'docs/serving/serving/docs/plugin_management.md' - 'docs/serving/wlm/README.md' + - Large Model Inference: + - 'docs/serving/serving/docs/large_model_inference.md' + - 'docs/serving/serving/docs/lmi/configurations_large_model_inference_containers.md' + - 'docs/serving/serving/docs/lmi/lmi_environment_variable_instruction.md' + - 'docs/serving/serving/docs/lmi/lmi_input_output_schema.md' + - Tutorials: + - 'docs/serving/serving/docs/lmi/tutorials/seq_scheduler_tutorial.md' + - 'docs/serving/serving/docs/lmi/tutorials/trtllm_aot_tutorial.md' + - 'docs/serving/serving/docs/lmi/tutorials/trtllm_manual_convert_tutorial.md' + - Tuning guides: + - 'docs/serving/serving/docs/lmi/tuning_guides/deepspeed_tuning_guide.md' + - 'docs/serving/serving/docs/lmi/tuning_guides/lmi_dist_tuning_guide.md' + - 'docs/serving/serving/docs/lmi/tuning_guides/tnx_tuning_guide.md' + - 'docs/serving/serving/docs/lmi/tuning_guides/trtllm_tuning_guide.md' + - SageMaker LMI containers resources: + - 'docs/demos/aws/sagemaker/large-model-inference/README.md' - Demos: - Demos: 'docs/demos/README.md' - AWS: diff --git a/docs/mxnet/how_to_convert_your_model_to_symbol.md b/docs/mxnet/how_to_convert_your_model_to_symbol.md index be178afe437..57a5b8a9b05 100644 --- a/docs/mxnet/how_to_convert_your_model_to_symbol.md +++ b/docs/mxnet/how_to_convert_your_model_to_symbol.md @@ -1,4 +1,4 @@ -## How to convert your Gluon model to an MXNet Symbol +# How to convert your Gluon model to an MXNet Symbol DJL currently supports symbolic model loading from MXNet. A gluon [HybridBlock](https://mxnet.apache.org/api/python/docs/api/gluon/hybrid_block.html) can be converted into a symbol for loading by doing as follows: diff --git a/docs/paddlepaddle/how_to_create_paddlepaddle_model.md b/docs/paddlepaddle/how_to_create_paddlepaddle_model.md index 042acbd2d61..b78d4406946 100644 --- a/docs/paddlepaddle/how_to_create_paddlepaddle_model.md +++ b/docs/paddlepaddle/how_to_create_paddlepaddle_model.md @@ -157,5 +157,5 @@ predictor.predict(list); As mentioned, you need to find out what is the input for the model, like images usually interpret as NCHW (batch_size, channel, height, width). -However, usage like this is really basic, you can write a `Translator` in DJL for it. You can find some code examples [here](../../jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb). +However, usage like this is really basic, you can write a `Translator` in DJL for it. You can find some code examples [here](http://docs.djl.ai/docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.html). diff --git a/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md b/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md index 74e5dec634f..5f79d713783 100644 --- a/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md +++ b/docs/paddlepaddle/how_to_create_paddlepaddle_model_zh.md @@ -156,4 +156,4 @@ predictor.predict(list); 在这里,你需要知道模型的输入输出格式, 比如图片经常表达成 NCHW (批大小, RGB通道, 高度, 宽度)的多维矩阵。 -虽然这样可以让模型跑起来, 但是最好还是结合 DJL 的 `Translator` class 使用。你可以在 [这里](../../jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb) 找到一些示例代码。 +虽然这样可以让模型跑起来, 但是最好还是结合 DJL 的 `Translator` class 使用。你可以在 [这里](http://docs.djl.ai/docs/demos/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.html) 找到一些示例代码。 diff --git a/docs/pytorch/how_to_convert_your_model_to_torchscript.md b/docs/pytorch/how_to_convert_your_model_to_torchscript.md index 4dd4b3102d7..f90ee468764 100644 --- a/docs/pytorch/how_to_convert_your_model_to_torchscript.md +++ b/docs/pytorch/how_to_convert_your_model_to_torchscript.md @@ -1,4 +1,4 @@ -## How to convert your PyTorch model to TorchScript +# How to convert your PyTorch model to TorchScript There are two ways to convert your model to TorchScript: tracing and scripting. We will only demonstrate the first one, tracing, but you can find information about scripting from the PyTorch documentation. diff --git a/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md b/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md index 7416ec50bab..37d24276d82 100644 --- a/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md +++ b/docs/pytorch/pytorch-djl-ndarray-cheatsheet.md @@ -1,4 +1,4 @@ -## PyTorch NDArray operators +# PyTorch NDArray operators In the following examples, we assume diff --git a/docs/quick_start.md b/docs/quick_start.md index f352a39156a..85a94494b2d 100644 --- a/docs/quick_start.md +++ b/docs/quick_start.md @@ -1,7 +1,7 @@ # Quick start Deep Java Library (DJL) is designed to be easy to get started with and simple to use. -The easiest way to learn DJL is to read the [beginner tutorial](../jupyter/tutorial/README.md) or +The easiest way to learn DJL is to read the [beginner tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/README.md) or our [examples](../examples/README.md). You can also view our 1.5 hour long (in 8 x ~10 minute segments) DJL 101 tutorial video series: @@ -22,7 +22,7 @@ See [DJL Future Labs](interactive_tool.md) ## Beginner tutorial -To get started, we recommend that you follow our short [beginner tutorial](../jupyter/tutorial/README.md). It takes you through some of the basics of deep learning to create a model, train your model, and run inference using your trained model. +To get started, we recommend that you follow our short [beginner tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/index.html). It takes you through some of the basics of deep learning to create a model, train your model, and run inference using your trained model. ## Run examples @@ -33,7 +33,7 @@ All of our examples are executed by a simple command. For detailed command line - [Train your first model](../examples/docs/train_mnist_mlp.md) - [Single-shot Object Detection inference example](../examples/docs/object_detection.md) - [More examples](https://github.com/deepjavalibrary/djl/tree/master/examples) -- [Jupyter examples](../jupyter/README.md) +- [Jupyter examples](http://docs.djl.ai/docs/demos/jupyter/index.html) ## Other resources diff --git a/docs/telemetry.md b/docs/telemetry.md index d6ff9b20bc1..256adf00a49 100644 --- a/docs/telemetry.md +++ b/docs/telemetry.md @@ -20,5 +20,5 @@ System.setProperty("OPT_OUT_TRACKING", "true") Usage tracking is also disable in `offline` mode: ```java -System.setProperty("offline", "true") +System.setProperty("ai.djl.offline", "true") ``` diff --git a/engines/llama/.gitignore b/engines/llama/.gitignore new file mode 100644 index 00000000000..3428b3b2f53 --- /dev/null +++ b/engines/llama/.gitignore @@ -0,0 +1,3 @@ +jnilib/ +llama.cpp/ +models/ diff --git a/engines/llama/CMakeLists.txt b/engines/llama/CMakeLists.txt new file mode 100644 index 00000000000..d1fc8131db8 --- /dev/null +++ b/engines/llama/CMakeLists.txt @@ -0,0 +1,23 @@ +cmake_minimum_required(VERSION 3.12 FATAL_ERROR) + +project(djl_llama CXX) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(BUILD_SHARED_LIBS ON) + +set(JAVA_AWT_LIBRARY NotNeeded) +set(JAVA_AWT_INCLUDE_PATH NotNeeded) +find_package(JNI REQUIRED) + +add_subdirectory(llama.cpp) +include(build-args.cmake) +add_library(djl_llama SHARED src/main/native/ai_djl_llama.cpp) + +target_include_directories(djl_llama PRIVATE + ${JNI_INCLUDE_DIRS} + src/main/native + llama.cpp + llama.cpp/common + build/include) +target_link_libraries(djl_llama PRIVATE common llama ${LLAMA_EXTRA_LIBS}) +target_compile_features(djl_llama PRIVATE cxx_std_11) diff --git a/engines/llama/build-args.cmake b/engines/llama/build-args.cmake new file mode 100644 index 00000000000..dee0db659cd --- /dev/null +++ b/engines/llama/build-args.cmake @@ -0,0 +1,639 @@ +if (APPLE) + set(LLAMA_METAL_DEFAULT ON) +else() + set(LLAMA_METAL_DEFAULT OFF) +endif() + +# general +option(LLAMA_NATIVE "llama: enable -march=native flag" ON) + +# instruction set specific +if (LLAMA_NATIVE) + set(INS_ENB OFF) +else() + set(INS_ENB ON) +endif() + +option(LLAMA_AVX "llama: enable AVX" ${INS_ENB}) +option(LLAMA_AVX2 "llama: enable AVX2" ${INS_ENB}) +option(LLAMA_AVX512 "llama: enable AVX512" OFF) +option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) +option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) +option(LLAMA_FMA "llama: enable FMA" ${INS_ENB}) +# in MSVC F16C is implied with AVX2/AVX512 +if (NOT MSVC) + option(LLAMA_F16C "llama: enable F16C" ${INS_ENB}) +endif() + +# 3rd party libs +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_BLAS "llama: use BLAS" OFF) +set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") +option(LLAMA_CUBLAS "llama: use CUDA" OFF) +#option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF) +option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) +option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF) +set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") +set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") +option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) +set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") +set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING + "llama: max. batch size for using peer access") +option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) +option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) +option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) +option(LLAMA_MPI "llama: use MPI" OFF) +option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) + + +# +# Compile flags +# + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED true) +set(CMAKE_C_STANDARD 11) +set(CMAKE_C_STANDARD_REQUIRED true) +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) +include(CheckCXXCompilerFlag) + +# enable libstdc++ assertions for debug builds +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + add_compile_definitions($<$:_GLIBCXX_ASSERTIONS>) +endif() + +if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) + add_compile_options(-fsanitize=thread) + link_libraries(-fsanitize=thread) + endif() + + if (LLAMA_SANITIZE_ADDRESS) + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries(-fsanitize=address) + endif() + + if (LLAMA_SANITIZE_UNDEFINED) + add_compile_options(-fsanitize=undefined) + link_libraries(-fsanitize=undefined) + endif() +endif() + +if (APPLE AND LLAMA_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate) + if (ACCELERATE_FRAMEWORK) + message(STATUS "Accelerate framework found") + + add_compile_definitions(GGML_USE_ACCELERATE) + add_compile_definitions(ACCELERATE_NEW_LAPACK) + add_compile_definitions(ACCELERATE_LAPACK_ILP64) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) + else() + message(WARNING "Accelerate framework not found") + endif() +endif() + +if (LLAMA_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + + message(STATUS "Metal framework found") + set(GGML_HEADERS_METAL ggml-metal.h) + set(GGML_SOURCES_METAL ggml-metal.m) + + add_compile_definitions(GGML_USE_METAL) + if (LLAMA_METAL_NDEBUG) + add_compile_definitions(GGML_METAL_NDEBUG) + endif() + + # get full path to the file + #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} + ${FOUNDATION_LIBRARY} + ${METAL_FRAMEWORK} + ${METALKIT_FRAMEWORK} + ) +endif() +if (LLAMA_BLAS) + if (LLAMA_STATIC) + set(BLA_STATIC ON) + endif() + if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) + set(BLA_SIZEOF_INTEGER 8) + endif() + + set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) + find_package(BLAS) + + if (BLAS_FOUND) + message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") + + if ("${BLAS_INCLUDE_DIRS}" STREQUAL "") + # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. + # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 + find_package(PkgConfig REQUIRED) + if (${LLAMA_BLAS_VENDOR} MATCHES "Generic") + pkg_check_modules(DepBLAS REQUIRED blas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS") + pkg_check_modules(DepBLAS REQUIRED openblas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME") + pkg_check_modules(DepBLAS REQUIRED blis) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS") + pkg_check_modules(DepBLAS REQUIRED blas-atlas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS") + pkg_check_modules(DepBLAS REQUIRED flexiblas_api) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel") + # all Intel* libraries share the same include path + pkg_check_modules(DepBLAS REQUIRED mkl-sdl) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC") + # this doesn't provide pkg-config + # suggest to assign BLAS_INCLUDE_DIRS on your own + if ("${NVHPC_VERSION}" STREQUAL "") + message(WARNING "Better to set NVHPC_VERSION") + else() + set(DepBLAS_FOUND ON) + set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") + endif() + endif() + if (DepBLAS_FOUND) + set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) + else() + message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" + " detected by pkgconfig, trying to find cblas.h from possible paths...") + find_path(BLAS_INCLUDE_DIRS + NAMES cblas.h + HINTS + /usr/include + /usr/local/include + /usr/include/openblas + /opt/homebrew/opt/openblas/include + /usr/local/opt/openblas/include + /usr/include/x86_64-linux-gnu/openblas/include + ) + endif() + endif() + + message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") + add_compile_options(${BLAS_LINKER_FLAGS}) + add_compile_definitions(GGML_USE_OPENBLAS) + if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel")) + add_compile_definitions(GGML_BLAS_USE_MKL) + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) + + else() + message(WARNING "BLAS not found, please refer to " + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct LLAMA_BLAS_VENDOR") + endif() +endif() + +if (LLAMA_QKK_64) + add_compile_definitions(GGML_QKK_64) +endif() + +if (LLAMA_CUBLAS) + cmake_minimum_required(VERSION 3.17) + + find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) + message(STATUS "cuBLAS found") + + enable_language(CUDA) + + set(GGML_HEADERS_CUDA ggml-cuda.h) + set(GGML_SOURCES_CUDA ggml-cuda.cu) + + add_compile_definitions(GGML_USE_CUBLAS) +# if (LLAMA_CUDA_CUBLAS) +# add_compile_definitions(GGML_CUDA_CUBLAS) +# endif() + if (LLAMA_CUDA_FORCE_DMMV) + add_compile_definitions(GGML_CUDA_FORCE_DMMV) + endif() + if (LLAMA_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_CUDA_FORCE_MMQ) + endif() + add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) + if (DEFINED LLAMA_CUDA_DMMV_Y) + add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_DMMV_Y}) # for backwards compatibility + endif() + if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) + add_compile_definitions(GGML_CUDA_F16) + endif() + add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) + add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE}) + + if (LLAMA_STATIC) + if (WIN32) + # As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt) + else () + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + endif() + else() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + endif() + + if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + # 52 == lowest CUDA 12 standard + # 60 == f16 CUDA intrinsics + # 61 == integer CUDA intrinsics + # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster + if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) + set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics + else() + set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics + #set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work + endif() + endif() + message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") + + else() + message(WARNING "cuBLAS not found") + endif() +endif() + +if (LLAMA_MPI) + cmake_minimum_required(VERSION 3.10) + find_package(MPI) + if (MPI_C_FOUND) + message(STATUS "MPI found") + set(GGML_HEADERS_MPI ggml-mpi.h) + set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) + add_compile_definitions(GGML_USE_MPI) + add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) + if (NOT MSVC) + add_compile_options(-Wno-cast-qual) + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) + # Even if you're only using the C header, C++ programs may bring in MPI + # C++ functions, so more linkage is needed + if (MPI_CXX_FOUND) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_CXX_LIBRARIES}) + endif() + else() + message(WARNING "MPI not found") + endif() +endif() + +if (LLAMA_CLBLAST) + find_package(CLBlast) + if (CLBlast_FOUND) + message(STATUS "CLBlast found") + + set(GGML_HEADERS_OPENCL ggml-opencl.h) + set(GGML_SOURCES_OPENCL ggml-opencl.cpp) + + add_compile_definitions(GGML_USE_CLBLAST) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) + else() + message(WARNING "CLBlast not found") + endif() +endif() + +if (LLAMA_HIPBLAS) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm) + + if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang") + endif() + if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") + endif() + + find_package(hip) + find_package(hipblas) + find_package(rocblas) + + if (${hipblas_FOUND} AND ${hip_FOUND}) + message(STATUS "HIP and hipBLAS found") + add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) + add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) + if (BUILD_SHARED_LIBS) + set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + if (LLAMA_CUDA_FORCE_DMMV) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV) + endif() + if (LLAMA_CUDA_FORCE_MMQ) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_MMQ) + endif() + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) + target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) + set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) + target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) + + if (LLAMA_STATIC) + message(FATAL_ERROR "Static linking not supported for HIP/ROCm") + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm) + else() + message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") + endif() +endif() + +function(get_flags CCID CCVER) + set(C_FLAGS "") + set(CXX_FLAGS "") + + if (CCID MATCHES "Clang") + set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return) + set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi) + + if ( + (CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR + (CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0) + ) + set(C_FLAGS ${C_FLAGS} -Wdouble-promotion) + endif() + elseif (CCID STREQUAL "GNU") + set(C_FLAGS -Wdouble-promotion) + set(CXX_FLAGS -Wno-array-bounds) + + if (CCVER VERSION_GREATER_EQUAL 7.1.0) + set(CXX_FLAGS ${CXX_FLAGS} -Wno-format-truncation) + endif() + if (CCVER VERSION_GREATER_EQUAL 8.1.0) + set(CXX_FLAGS ${CXX_FLAGS} -Wextra-semi) + endif() + endif() + + set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE) + set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE) +endfunction() + +if (LLAMA_ALL_WARNINGS) + if (NOT MSVC) + set(WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function) + set(C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes + -Werror=implicit-int -Werror=implicit-function-declaration) + set(CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn) + + set(C_FLAGS ${WARNING_FLAGS} ${C_FLAGS}) + set(CXX_FLAGS ${WARNING_FLAGS} ${CXX_FLAGS}) + + get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}) + + add_compile_options("$<$:${C_FLAGS};${GF_C_FLAGS}>" + "$<$:${CXX_FLAGS};${GF_CXX_FLAGS}>") + else() + # todo : msvc + set(C_FLAGS "") + set(CXX_FLAGS "") + endif() +endif() + +if (LLAMA_CUBLAS) + set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math) + if (NOT MSVC) + set(CUDA_FLAGS ${CUDA_FLAGS} -Wno-pedantic) + endif() + + if (LLAMA_ALL_WARNINGS AND NOT MSVC) + set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c) + if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "") + set(NVCC_CMD ${NVCC_CMD} -ccbin ${CMAKE_CUDA_HOST_COMPILER}) + endif() + + execute_process( + COMMAND ${NVCC_CMD} -Xcompiler --version + OUTPUT_VARIABLE CUDA_CCFULLVER + ERROR_QUIET + ) + + if (NOT CUDA_CCFULLVER MATCHES clang) + set(CUDA_CCID "GNU") + execute_process( + COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion" + OUTPUT_VARIABLE CUDA_CCVER + ERROR_QUIET + ) + else() + if (CUDA_CCFULLVER MATCHES Apple) + set(CUDA_CCID "AppleClang") + else() + set(CUDA_CCID "Clang") + endif() + string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER}) + endif() + + message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}") + + get_flags(${CUDA_CCID} ${CUDA_CCVER}) + list(JOIN GF_CXX_FLAGS " " CUDA_CXX_FLAGS) # pass host compiler flags as a single argument + if (NOT CUDA_CXX_FLAGS STREQUAL "") + set(CUDA_FLAGS ${CUDA_FLAGS} -Xcompiler ${CUDA_CXX_FLAGS}) + endif() + endif() + + add_compile_options("$<$:${CUDA_FLAGS}>") +endif() + +if (WIN32) + add_compile_definitions(_CRT_SECURE_NO_WARNINGS) + + if (BUILD_SHARED_LIBS) + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) + endif() +endif() + +if (LLAMA_LTO) + include(CheckIPOSupported) + check_ipo_supported(RESULT result OUTPUT output) + if (result) + set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) + else() + message(WARNING "IPO is not supported: ${output}") + endif() +endif() + +# this version of Apple ld64 is buggy +execute_process( + COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v + ERROR_VARIABLE output + OUTPUT_QUIET +) +if (output MATCHES "dyld-1015\.7") + add_compile_definitions(HAVE_BUGGY_APPLE_LINKER) +endif() + +# Architecture specific +# TODO: probably these flags need to be tweaked on some architectures +# feel free to update the Makefile for your architecture and send a pull request or issue +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +if (MSVC) + string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR) + message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}") +else () + set(CMAKE_GENERATOR_PLATFORM_LWR "") +endif () + +if (NOT MSVC) + if (LLAMA_STATIC) + add_link_options(-static) + if (MINGW) + add_link_options(-static-libgcc -static-libstdc++) + endif() + endif() + if (LLAMA_GPROF) + add_compile_options(-pg) + endif() +endif() + +if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") OR ("${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "arm64")) + message(STATUS "ARM detected") + if (MSVC) + add_compile_definitions(__ARM_NEON) + add_compile_definitions(__ARM_FEATURE_FMA) + add_compile_definitions(__ARM_FEATURE_DOTPROD) + # add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) # MSVC doesn't support vdupq_n_f16, vld1q_f16, vst1q_f16 + add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead + else() + check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E) + if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "") + add_compile_options(-mfp16-format=ieee) + endif() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") + # Raspberry Pi 1, Zero + add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access) + endif() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") + # Raspberry Pi 2 + add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations) + endif() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") + # Raspberry Pi 3, 4, Zero 2 (32-bit) + add_compile_options(-mno-unaligned-access) + endif() + endif() +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "^(x86_64|i686|amd64|x64)$" ) + message(STATUS "x86 detected") + if (MSVC) + # instruction set detection for MSVC only + if (LLAMA_NATIVE) + include(${llama.cpp_SOURCE_DIR}/cmake/FindSIMD.cmake) + endif () + if (LLAMA_AVX512) + add_compile_options($<$:/arch:AVX512>) + add_compile_options($<$:/arch:AVX512>) + # MSVC has no compile-time flags enabling specific + # AVX512 extensions, neither it defines the + # macros corresponding to the extensions. + # Do it manually. + if (LLAMA_AVX512_VBMI) + add_compile_definitions($<$:__AVX512VBMI__>) + add_compile_definitions($<$:__AVX512VBMI__>) + endif() + if (LLAMA_AVX512_VNNI) + add_compile_definitions($<$:__AVX512VNNI__>) + add_compile_definitions($<$:__AVX512VNNI__>) + endif() + elseif (LLAMA_AVX2) + add_compile_options($<$:/arch:AVX2>) + add_compile_options($<$:/arch:AVX2>) + elseif (LLAMA_AVX) + add_compile_options($<$:/arch:AVX>) + add_compile_options($<$:/arch:AVX>) + endif() + else() + if (LLAMA_NATIVE) + add_compile_options(-march=native) + endif() + if (LLAMA_F16C) + add_compile_options(-mf16c) + endif() + if (LLAMA_FMA) + add_compile_options(-mfma) + endif() + if (LLAMA_AVX) + add_compile_options(-mavx) + endif() + if (LLAMA_AVX2) + add_compile_options(-mavx2) + endif() + if (LLAMA_AVX512) + add_compile_options(-mavx512f) + add_compile_options(-mavx512bw) + endif() + if (LLAMA_AVX512_VBMI) + add_compile_options(-mavx512vbmi) + endif() + if (LLAMA_AVX512_VNNI) + add_compile_options(-mavx512vnni) + endif() + endif() +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") + message(STATUS "PowerPC detected") + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") + add_compile_options(-mcpu=powerpc64le) + else() + add_compile_options(-mcpu=native -mtune=native) + #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) + endif() +else() + message(STATUS "Unknown architecture") +endif() + +if (MINGW) + # Target Windows 8 for PrefetchVirtualMemory + add_compile_definitions(_WIN32_WINNT=0x602) +endif() + +# +# POSIX conformance +# + +# clock_gettime came in POSIX.1b (1993) +# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional +# posix_memalign came in POSIX.1-2001 / SUSv3 +# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985) +add_compile_definitions(_XOPEN_SOURCE=600) + +# Somehow in OpenBSD whenever POSIX conformance is specified +# some string functions rely on locale_t availability, +# which was introduced in POSIX.1-2008, forcing us to go higher +if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") + remove_definitions(-D_XOPEN_SOURCE=600) + add_compile_definitions(_XOPEN_SOURCE=700) +endif() + +# Data types, macros and functions related to controlling CPU affinity and +# some memory allocation are available on Linux through GNU extensions in libc +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + add_compile_definitions(_GNU_SOURCE) +endif() + +# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1, +# and on macOS its availability depends on enabling Darwin extensions +# similarly on DragonFly, enabling BSD extensions is necessary +if ( + CMAKE_SYSTEM_NAME MATCHES "Darwin" OR + CMAKE_SYSTEM_NAME MATCHES "iOS" OR + CMAKE_SYSTEM_NAME MATCHES "tvOS" OR + CMAKE_SYSTEM_NAME MATCHES "DragonFly" +) + add_compile_definitions(_DARWIN_C_SOURCE) +endif() + +# alloca is a non-standard interface that is not visible on BSDs when +# POSIX conformance is specified, but not all of them provide a clean way +# to enable it in such cases +if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD") + add_compile_definitions(__BSD_VISIBLE) +endif() +if (CMAKE_SYSTEM_NAME MATCHES "NetBSD") + add_compile_definitions(_NETBSD_SOURCE) +endif() +if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") + add_compile_definitions(_BSD_SOURCE) +endif() diff --git a/engines/llama/build.cmd b/engines/llama/build.cmd new file mode 100644 index 00000000000..93c422028bc --- /dev/null +++ b/engines/llama/build.cmd @@ -0,0 +1,23 @@ +@rem https://chocolatey.org/docs/installation#install-with-cmdexe +@rem to install rust java etc.. +@rem choco install jdk17 -y + +set VERSION="%1" + +if exist "llama.cpp" ( + echo Found "llama.cpp" +) else ( + git clone https://github.com/ggerganov/llama.cpp.git -b %VERSION% +) + +if exist build rd /q /s build +md build\classes +cd build +javac -sourcepath ..\src\main\java\ ..\src\main\java\ai\djl\llama\jni\LlamaLibrary.java -h include -d classes +cmake .. +cmake --build . --config Release + +@rem for nightly ci +md jnilib\win-x86_64 +copy Release\djl_llama.dll jnilib\win-x86_64\ +copy bin\Release\llama.dll jnilib\win-x86_64\ diff --git a/engines/llama/build.gradle b/engines/llama/build.gradle new file mode 100644 index 00000000000..e340758d18e --- /dev/null +++ b/engines/llama/build.gradle @@ -0,0 +1,107 @@ +import java.util.zip.GZIPInputStream + +group "ai.djl.llama" + +dependencies { + api project(":api") + + testImplementation project(":testing") + testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" +} + +compileJava.dependsOn(processResources) + +processResources { + outputs.dir file("${project.projectDir}/build/classes/java/main/native/lib") + doLast { + def url = "https://publish.djl.ai/llama/${llamacpp_version}/jnilib/${djl_version}" + def files = new String[]{ + "linux-x86_64/libdjl_llama.so", + "linux-x86_64/libllama.so", + "linux-aarch64/libdjl_llama.so", + "linux-aarch64/libllama.so", + "osx-x86_64/libdjl_llama.dylib", + "osx-x86_64/libllama.dylib", + "osx-x86_64/ggml-metal.metal", + "osx-aarch64/libdjl_llama.dylib", + "osx-aarch64/libllama.dylib", + "osx-aarch64/ggml-metal.metal", + "win-x86_64/djl_llama.dll", + "win-x86_64/llama.dll", + } + def jnilibDir = "${project.projectDir}/jnilib/${djl_version}" + files.each { entry -> + def file = new File("${jnilibDir}/${entry}") + if (file.exists()) { + project.logger.lifecycle("prebuilt or cached file found for ${entry}") + } else if (!project.hasProperty("jni")) { + project.logger.lifecycle("Downloading ${url}/${entry}") + file.getParentFile().mkdirs() + def downloadPath = new URL("${url}/${entry}") + downloadPath.withInputStream { i -> file.withOutputStream { it << i } } + } + } + copy { + from jnilibDir + into "${project.projectDir}/build/classes/java/main/native/lib" + } + + // write properties + def propFile = file("${project.projectDir}/build/classes/java/main/native/lib/llama.properties") + propFile.text = "version=${llamacpp_version}-${version}\n" + + url = "https://mlrepo.djl.ai/model/nlp/text_generation/ai/djl/huggingface/gguf/models.json.gz" + def prefix = "${project.projectDir}/build/classes/java/main/nlp/text_generation" + def file = new File("${prefix}/ai.djl.huggingface.gguf.json") + if (file.exists()) { + project.logger.lifecycle("gguf index file already exists") + } else { + project.logger.lifecycle("Downloading gguf index file") + file.getParentFile().mkdirs() + def downloadPath = new URL(url) + downloadPath.withInputStream { i -> file.withOutputStream { it << new GZIPInputStream(i) } } + } + } +} + +publishing { + publications { + maven(MavenPublication) { + pom { + name = "DJL NLP utilities for Llama.cpp" + description = "Deep Java Library (DJL) NLP utilities for llama.cpp" + url = "http://www.djl.ai/engines/${project.name}" + } + } + } +} + +apply from: file("${rootProject.projectDir}/tools/gradle/cpp-formatter.gradle") + +tasks.register('compileJNI') { + doFirst { + if (System.properties['os.name'].toLowerCase(Locale.ROOT).contains("mac") + || System.properties['os.name'].toLowerCase(Locale.ROOT).contains("linux")) { + def arch = System.properties["os.arch"] == "amd64" ? "x86_64" : System.properties["os.arch"] + exec { + commandLine "bash", "build.sh", llamacpp_version, arch + } + } else { + exec { + commandLine "${project.projectDir}/build.cmd", llamacpp_version, "x86_64" + } + } + + // for ci to upload to S3 + def ciDir = "${project.projectDir}/jnilib/${djl_version}/" + copy { + from "${project.projectDir}/build/jnilib" + into ciDir + } + delete System.getProperty("user.home") + "/.djl.ai/llama" + } +} + +clean.doFirst { + delete System.getProperty("user.home") + "/.djl.ai/llama" +} diff --git a/engines/llama/build.sh b/engines/llama/build.sh new file mode 100755 index 00000000000..1cf7151cde4 --- /dev/null +++ b/engines/llama/build.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +set -e +WORK_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +NUM_PROC=1 +if [[ -n $(command -v nproc) ]]; then + NUM_PROC=$(nproc) +elif [[ -n $(command -v sysctl) ]]; then + NUM_PROC=$(sysctl -n hw.ncpu) +fi +PLATFORM=$(uname | tr '[:upper:]' '[:lower:]') + +VERSION=$1 +ARCH=$2 + +pushd $WORK_DIR +if [ ! -d "llama.cpp" ]; then + git clone https://github.com/ggerganov/llama.cpp.git -b $VERSION +fi + +if [ ! -d "build" ]; then + mkdir build +fi +cd build + +rm -rf classes +mkdir classes +javac -sourcepath ../src/main/java/ ../src/main/java/ai/djl/llama/jni/LlamaLibrary.java -h include -d classes +cmake .. +cmake --build . --config Release -- -j "${NUM_PROC}" + +popd + +# for nightly ci +if [[ $PLATFORM == 'darwin' ]]; then + mkdir -p build/jnilib/osx-$ARCH + cp -f build/libdjl_llama.dylib build/jnilib/osx-$ARCH/ + cp -f build/llama.cpp/libllama.dylib build/jnilib/osx-$ARCH/ + cp -f llama.cpp/ggml-metal.metal build/jnilib/osx-$ARCH/ +elif [[ $PLATFORM == 'linux' ]]; then + mkdir -p build/jnilib/linux-$ARCH + cp -f build/libdjl_llama.so build/jnilib/linux-$ARCH/ + cp -f build/llama.cpp/libllama.so build/jnilib/linux-$ARCH/ +fi diff --git a/engines/llama/gradlew b/engines/llama/gradlew new file mode 120000 index 00000000000..343e0d2caa4 --- /dev/null +++ b/engines/llama/gradlew @@ -0,0 +1 @@ +../../gradlew \ No newline at end of file diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java new file mode 100644 index 00000000000..75fdf5a5d8c --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngine.java @@ -0,0 +1,110 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.llama.engine; + +import ai.djl.Device; +import ai.djl.Model; +import ai.djl.engine.Engine; +import ai.djl.engine.EngineException; +import ai.djl.llama.jni.LibUtils; +import ai.djl.ndarray.NDManager; +import ai.djl.util.Platform; +import ai.djl.util.passthrough.PassthroughNDManager; + +/** The {@code LlamaEngine} is an implementation of the {@link Engine} based on the llama.cpp. */ +public final class LlamaEngine extends Engine { + + public static final String ENGINE_NAME = "Llama"; + static final int RANK = 10; + + private Engine alternativeEngine; + private boolean initialized; + + private LlamaEngine() { + try { + LibUtils.loadLibrary(); + } catch (EngineException e) { // NOPMD + throw e; + } catch (Throwable t) { + throw new EngineException("Failed to load llama.cpp native library", t); + } + } + + static Engine newInstance() { + return new LlamaEngine(); + } + + /** {@inheritDoc} */ + @Override + public Engine getAlternativeEngine() { + if (!initialized && !Boolean.getBoolean("ai.djl.llama.disable_alternative")) { + Engine engine = Engine.getInstance(); + if (engine.getRank() < getRank()) { + // alternativeEngine should not have the same rank as Llama + alternativeEngine = engine; + } + initialized = true; + } + return alternativeEngine; + } + + /** {@inheritDoc} */ + @Override + public String getEngineName() { + return ENGINE_NAME; + } + + /** {@inheritDoc} */ + @Override + public int getRank() { + return RANK; + } + + /** {@inheritDoc} */ + @Override + public String getVersion() { + Platform platform = Platform.detectPlatform("llama"); + return platform.getVersion(); + } + + /** {@inheritDoc} */ + @Override + public boolean hasCapability(String capability) { + return false; + } + + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device) { + return new LlamaModel(name, newBaseManager(device)); + } + + /** {@inheritDoc} */ + @Override + public NDManager newBaseManager() { + return newBaseManager(null); + } + + /** {@inheritDoc} */ + @Override + public NDManager newBaseManager(Device device) { + return PassthroughNDManager.INSTANCE; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return getEngineName() + ':' + getVersion() + ", " + getEngineName() + ':' + getVersion(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java new file mode 100644 index 00000000000..ca5cc646498 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaEngineProvider.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.engine.Engine; +import ai.djl.engine.EngineProvider; + +/** {@code LlamaEngineProvider} is the Llama implementation of {@link EngineProvider}. */ +public class LlamaEngineProvider implements EngineProvider { + + /** {@inheritDoc} */ + @Override + public String getEngineName() { + return LlamaEngine.ENGINE_NAME; + } + + /** {@inheritDoc} */ + @Override + public int getEngineRank() { + return LlamaEngine.RANK; + } + + /** {@inheritDoc} */ + @Override + public Engine getEngine() { + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = LlamaEngine.newInstance(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java new file mode 100644 index 00000000000..4b4d332fc9f --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaInput.java @@ -0,0 +1,430 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.llama.jni.InputParameters; + +import com.google.gson.annotations.SerializedName; + +import java.util.Map; + +/** A class hold input data for Llama model. */ +public class LlamaInput { + + private String inputs; + private String prefix; + private String suffix; + private Parameters parameters; + + /** + * Returns the input prompt. + * + * @return the input prompt + */ + public String getInputs() { + return inputs; + } + + /** + * Sets the input prompt. + * + * @param inputs the input prompt + */ + public void setInputs(String inputs) { + this.inputs = inputs; + } + + /** + * Returns the prompt prefix. + * + * @return the prompt prefix + */ + public String getPrefix() { + return prefix; + } + + /** + * Sets the prompt prefix. + * + * @param prefix the prompt prefix + */ + public void setPrefix(String prefix) { + this.prefix = prefix; + } + + /** + * Returns the prompt suffix. + * + * @return the prompt suffix + */ + public String getSuffix() { + return suffix; + } + + /** + * Sets the prompt suffix. + * + * @param suffix the prompt suffix + */ + public void setSuffix(String suffix) { + this.suffix = suffix; + } + + /** + * Returns the input parameters. + * + * @return the input parameters + */ + public Parameters getParameters() { + if (parameters == null) { + parameters = new Parameters(); + } + return parameters; + } + + /** + * Sets the input parameters. + * + * @param parameters the input parameters + */ + public void setParameters(Parameters parameters) { + this.parameters = parameters; + } + + /** The input parameters class. */ + public static final class Parameters { + + @SerializedName("max_new_tokens") + private int nPredict; + + @SerializedName("number_keep") + private int nKeep; + + @SerializedName("number_probabilities") + private int nProbs; + + @SerializedName("top_k") + private int topK; + + @SerializedName("top_p") + private float topP; + + @SerializedName("tfs_z") + private float tfsZ; + + @SerializedName("typical_p") + private float typicalP; + + @SerializedName("temperature") + private float temperature; + + @SerializedName("repeat_penalty") + private float repeatPenalty; + + @SerializedName("repeat_last_n") + private int repeatLastN; + + @SerializedName("frequency_penalty") + private float frequencyPenalty; + + @SerializedName("presence_penalty") + private float presencePenalty; + + @SerializedName("penalize_nl") + private boolean penalizeNl; + + @SerializedName("ignore_eos") + private boolean ignoreEos; + + @SerializedName("mirostat") + private int mirostat; + + @SerializedName("mirostat_tau") + private float mirostatTau; + + @SerializedName("mirostat_eta") + private float mirostatEta; + + @SerializedName("number_beams") + private int nBeams; + + @SerializedName("seed") + private int seed; + + @SerializedName("logit_bias") + private Map logitBias; + + @SerializedName("grammar") + private String grammar; + + @SerializedName("anti_prompt") + private String[] antiPrompt; + + /** + * Sets the max new tokens. + * + * @param maxNewTokens the max new tokens + */ + public void setMaxNewTokens(int maxNewTokens) { + this.nPredict = maxNewTokens; + } + + /** + * Sets the number of keep. + * + * @param nKeep the number of keep + */ + public void setNumberKeep(int nKeep) { + this.nKeep = nKeep; + } + + /** + * Sets the number of probabilities. + * + * @param nProbs the number of probabilities + */ + public void setNumberProbabilities(int nProbs) { + this.nProbs = nProbs; + } + + /** + * Sets the top K. + * + * @param topK the top K + */ + public void setTopK(int topK) { + this.topK = topK; + } + + /** + * Sets the top P. + * + * @param topP the top P + */ + public void setTopP(float topP) { + this.topP = topP; + } + + /** + * Sets the tfs Z. + * + * @param tfsZ the tfs Z + */ + public void setTfsZ(float tfsZ) { + this.tfsZ = tfsZ; + } + + /** + * Sets the typical P. + * + * @param typicalP the typical P + */ + public void setTypicalP(float typicalP) { + this.typicalP = typicalP; + } + + /** + * Sets the temperature. + * + * @param temperature the temperature + */ + public void setTemperature(float temperature) { + this.temperature = temperature; + } + + /** + * Sets the repeat penalty. + * + * @param repeatPenalty the repeat penalty + */ + public void setRepeatPenalty(float repeatPenalty) { + this.repeatPenalty = repeatPenalty; + } + + /** + * Sets the repeat last N. + * + * @param repeatLastN the repeat last N + */ + public void setRepeatLastN(int repeatLastN) { + this.repeatLastN = repeatLastN; + } + + /** + * Sets the frequency penalty. + * + * @param frequencyPenalty the frequency penalty + */ + public void setFrequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + /** + * Sets the presence penalty. + * + * @param presencePenalty the presence penalty + */ + public void setPresencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + /** + * Sets the penalize nl. + * + * @param penalizeNl the penalize nl + */ + public void setPenalizeNl(boolean penalizeNl) { + this.penalizeNl = penalizeNl; + } + + /** + * Sets if ignore EOS. + * + * @param ignoreEos if ignore EOS + */ + public void setIgnoreEos(boolean ignoreEos) { + this.ignoreEos = ignoreEos; + } + + /** + * Sets the mirostat. + * + * @param mirostat the mirostat + */ + public void setMirostat(int mirostat) { + this.mirostat = mirostat; + } + + /** + * Sets the mirostat TAU. + * + * @param mirostatTau the mirostat TAU + */ + public void setMirostatTau(float mirostatTau) { + this.mirostatTau = mirostatTau; + } + + /** + * Sets the mirostat ETA. + * + * @param mirostatEta the mirostat ETA + */ + public void setMirostatEta(float mirostatEta) { + this.mirostatEta = mirostatEta; + } + + /** + * Sets the number of beams. + * + * @param nBeams the number of beams + */ + public void setNumberBeams(int nBeams) { + this.nBeams = nBeams; + } + + /** + * Sets the seed. + * + * @param seed the seed + */ + public void setSeed(int seed) { + this.seed = seed; + } + + /** + * Sets the logit bias. + * + * @param logitBias the logit bias + */ + public void setLogitBias(Map logitBias) { + this.logitBias = logitBias; + } + + /** + * Sets the grammar template. + * + * @param grammar the grammar template + */ + public void setGrammar(String grammar) { + this.grammar = grammar; + } + + /** + * Sets the anti prompt. + * + * @param antiPrompt the anti prompt + */ + public void setAntiPrompt(String[] antiPrompt) { + this.antiPrompt = antiPrompt; + } + + /** + * Returns the {@link InputParameters} object. + * + * @return the {@link InputParameters} object + */ + public InputParameters toInputParameters() { + setDefaultValue(); + return new InputParameters( + nPredict, + nKeep, + nProbs, + topK, + topP, + tfsZ, + typicalP, + temperature, + repeatPenalty, + repeatLastN, + frequencyPenalty, + presencePenalty, + penalizeNl, + ignoreEos, + mirostat, + mirostatTau, + mirostatEta, + nBeams, + seed, + logitBias, + grammar, + antiPrompt); + } + + private void setDefaultValue() { + if (nPredict == 0) { + nPredict = -1; + } + if (topK == 0) { + topK = 40; + } + if (topP == 0) { + topP = 0.95f; + } + if (tfsZ == 0) { + tfsZ = 1f; + } + if (typicalP == 0) { + typicalP = 1f; + } + if (temperature == 0) { + temperature = 0.8f; + } + if (repeatPenalty == 0) { + repeatPenalty = 1.10f; + } + if (repeatLastN == 0) { + repeatLastN = 64; + } + } + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java new file mode 100644 index 00000000000..0ff3c6d70c0 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaModel.java @@ -0,0 +1,112 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.BaseModel; +import ai.djl.Model; +import ai.djl.llama.jni.LlamaLibrary; +import ai.djl.llama.jni.ModelParameters; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.nn.Blocks; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + +/** {@code LlamaModel} is the llama.cpp implementation of {@link Model}. */ +public class LlamaModel extends BaseModel { + + private long handle = -1; + + /** + * Constructs a new Model on a given device. + * + * @param name the model name + * @param manager the {@link NDManager} to holds the NDArray + */ + LlamaModel(String name, NDManager manager) { + super(name); + this.manager = manager; + this.manager.setName("llamaModel"); + dataType = DataType.FLOAT32; + } + + /** {@inheritDoc} */ + @Override + public void load(Path modelPath, String prefix, Map options) throws IOException { + setModelDir(modelPath); + wasLoaded = true; + if (block != null) { + throw new UnsupportedOperationException("Llama does not support dynamic blocks"); + } + + if (prefix == null) { + prefix = modelName; + } + + // search for .onnx file with prefix, folder name or "model.onnx" + Path modelFile = findModelFile(prefix, modelDir.toFile().getName(), "model.gguf"); + if (modelFile == null) { + throw new FileNotFoundException(".gguf file not found in: " + modelPath); + } + + ModelParameters param = new ModelParameters(options); + handle = LlamaLibrary.loadModel(modelFile.toString(), param); + block = Blocks.identityBlock(); + } + + long getHandle() { + return handle; + } + + private Path findModelFile(String... prefixes) { + if (Files.isRegularFile(modelDir)) { + Path file = modelDir; + modelDir = modelDir.getParent(); + String fileName = file.toFile().getName(); + if (fileName.endsWith(".gguf")) { + modelName = fileName.substring(0, fileName.length() - 5); + } else { + modelName = fileName; + } + return file; + } + for (String prefix : prefixes) { + Path modelFile = modelDir.resolve(prefix); + if (Files.isRegularFile(modelFile)) { + return modelFile; + } + if (!prefix.endsWith(".gguf")) { + modelFile = modelDir.resolve(prefix + ".gguf"); + if (Files.isRegularFile(modelFile)) { + return modelFile; + } + } + } + return null; + } + + /** {@inheritDoc} */ + @Override + public void close() { + if (handle == -1) { + return; + } + LlamaLibrary.delete(handle); + handle = -1; + super.close(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java new file mode 100644 index 00000000000..c8d3692b160 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslator.java @@ -0,0 +1,107 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.inference.streaming.IteratorBytesSupplier; +import ai.djl.llama.jni.InputParameters; +import ai.djl.llama.jni.LlamaLibrary; +import ai.djl.llama.jni.Token; +import ai.djl.llama.jni.TokenIterator; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.BytesSupplier; +import ai.djl.ndarray.NDList; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.JsonUtils; + +import java.util.Iterator; + +/** Built-in {@code Translator} that provides preprocessing and postprocessing for llama.cpp. */ +public class LlamaTranslator implements NoBatchifyTranslator { + + private long handle; + + /** {@inheritDoc} */ + @Override + public void prepare(TranslatorContext ctx) { + LlamaModel model = (LlamaModel) ctx.getModel(); + handle = model.getHandle(); + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, I input) { + if (input instanceof String) { + ctx.setAttachment("out", generate((String) input)); + } else if (input instanceof LlamaInput) { + ctx.setAttachment("out", generate((LlamaInput) input)); + } else if (input instanceof Input) { + String prompt = ((Input) input).getData().getAsString(); + TokenIterator it = generate(prompt); + Output output = new Output(); + output.add(new IteratorBytesSupplier(new OutputIterator(it))); + ctx.setAttachment("out", output); + } + return new NDList(); + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public O processOutput(TranslatorContext ctx, NDList list) { + return (O) ctx.getAttachment("out"); + } + + private TokenIterator generate(String input) { + LlamaInput in = JsonUtils.GSON.fromJson(input, LlamaInput.class); + return generate(in); + } + + private TokenIterator generate(LlamaInput in) { + InputParameters param = in.getParameters().toInputParameters(); + String prefix = in.getPrefix(); + String suffix = in.getSuffix(); + String inputs = in.getInputs(); + if (prefix != null && suffix != null) { + LlamaLibrary.infill(handle, prefix, prefix, param); + } else if (inputs != null && !inputs.isEmpty()) { + LlamaLibrary.generate(handle, inputs, param); + } else { + throw new IllegalArgumentException("Unsupported input format"); + } + return new TokenIterator(handle); + } + + private static final class OutputIterator implements Iterator { + + private TokenIterator it; + + public OutputIterator(TokenIterator it) { + this.it = it; + } + + /** {@inheritDoc} */ + @Override + public boolean hasNext() { + return it.hasNext(); + } + + /** {@inheritDoc} */ + @Override + public BytesSupplier next() { + Token token = it.next(); + return BytesSupplier.wrap(JsonUtils.GSON.toJson(token) + "\n"); + } + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java new file mode 100644 index 00000000000..089b5055b51 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/LlamaTranslatorFactory.java @@ -0,0 +1,60 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.Model; +import ai.djl.llama.jni.TokenIterator; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; + +import java.io.Serializable; +import java.lang.reflect.Type; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** A {@link TranslatorFactory} that creates a {@link LlamaTranslator} instance. */ +public class LlamaTranslatorFactory implements TranslatorFactory, Serializable { + + private static final long serialVersionUID = 1L; + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(String.class, TokenIterator.class)); + SUPPORTED_TYPES.add(new Pair<>(LlamaInput.class, TokenIterator.class)); + SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } + + /** {@inheritDoc} */ + @Override + public boolean isSupported(Class input, Class output) { + return true; + } + + /** {@inheritDoc} */ + @Override + public Translator newInstance( + Class input, Class output, Model model, Map arguments) { + return new LlamaTranslator<>(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java b/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java new file mode 100644 index 00000000000..226e7a6ddb8 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/engine/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains classes to interface with the underlying Llama Engine. */ +package ai.djl.llama.engine; diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java b/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java new file mode 100644 index 00000000000..d13abc5ef90 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/InputParameters.java @@ -0,0 +1,314 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +import java.util.Map; + +/** A class holds input parameters. */ +@SuppressWarnings({"PMD.UnusedPrivateField", "PMD.UnusedAssignment"}) +public class InputParameters { + + private int nPredict; + private int nKeep; + private int nProbs; + private int topK; + private float topP; + private float tfsZ; + private float typicalP; + private float temperature; + private float repeatPenalty; + private int repeatLastN; + private float frequencyPenalty; + private float presencePenalty; + private boolean penalizeNl; + private boolean ignoreEos; + private int mirostat; + private float mirostatTau; + private float mirostatEta; + private int nBeams; + private int seed; + private Map logitBias; + private String grammar; + private String[] antiPrompt; + + /** + * Constructs new {@code InputParameters} instance. + * + * @param nPredict the max new tokens + * @param nKeep the number of keep + * @param nProbs the number of probabilities + * @param topK the top K + * @param topP the top P + * @param tfsZ the tfs Z + * @param typicalP the typical P + * @param temperature the temperature + * @param repeatPenalty the repeat penalty + * @param repeatLastN the repeat last N + * @param frequencyPenalty the frequency penalty + * @param presencePenalty the presence penalty + * @param penalizeNl the penalize nl + * @param ignoreEos the ignore EOS + * @param mirostat the mirostat + * @param mirostatTau the mirostat TAU + * @param mirostatEta the mirostat ETA + * @param nBeams the number of beams + * @param seed the seed + * @param logitBias the logit bias + * @param grammar the grammar + * @param antiPrompt the anti prompt + */ + public InputParameters( + int nPredict, + int nKeep, + int nProbs, + int topK, + float topP, + float tfsZ, + float typicalP, + float temperature, + float repeatPenalty, + int repeatLastN, + float frequencyPenalty, + float presencePenalty, + boolean penalizeNl, + boolean ignoreEos, + int mirostat, + float mirostatTau, + float mirostatEta, + int nBeams, + int seed, + Map logitBias, + String grammar, + String[] antiPrompt) { + this.nPredict = nPredict; + this.nKeep = nKeep; + this.nProbs = nProbs; + this.topK = topK; + this.topP = topP; + this.tfsZ = tfsZ; + this.typicalP = typicalP; + this.temperature = temperature; + this.repeatPenalty = repeatPenalty; + this.repeatLastN = repeatLastN; + this.frequencyPenalty = frequencyPenalty; + this.presencePenalty = presencePenalty; + this.penalizeNl = penalizeNl; + this.ignoreEos = ignoreEos; + this.mirostat = mirostat; + this.mirostatTau = mirostatTau; + this.mirostatEta = mirostatEta; + this.nBeams = nBeams; + this.seed = seed; + this.logitBias = logitBias; + this.grammar = grammar; + this.antiPrompt = antiPrompt; + } + + /** + * Returns the max new tokens. + * + * @return the max new tokens + */ + public int getMaxNewTokens() { + return nPredict; + } + + /** + * Returns the number of keep. + * + * @return the number of keep + */ + public int getNumberKeep() { + return nKeep; + } + + /** + * Returns the number of probabilities. + * + * @return the number of probabilities + */ + public int getNumberProbabilities() { + return nProbs; + } + + /** + * Returns the top K. + * + * @return the top K + */ + public int getTopK() { + return topK; + } + + /** + * Return the top P. + * + * @return the top P + */ + public float getTopP() { + return topP; + } + + /** + * Return the TfsZ. + * + * @return the TfsZ + */ + public float getTfsZ() { + return tfsZ; + } + + /** + * Return the typical P. + * + * @return the typical P + */ + public float getTypicalP() { + return typicalP; + } + + /** + * Return the temperature. + * + * @return the temperature + */ + public float getTemperature() { + return temperature; + } + + /** + * Return the repeat penalty. + * + * @return the repeat penalty + */ + public float getRepeatPenalty() { + return repeatPenalty; + } + + /** + * Return the repeat last N. + * + * @return the repeat last N + */ + public int getRepeatLastN() { + return repeatLastN; + } + + /** + * Return the frequency penalty. + * + * @return the frequency penalty + */ + public float getFrequencyPenalty() { + return frequencyPenalty; + } + + /** + * Return the presence penalty. + * + * @return the presence penalty + */ + public float getPresencePenalty() { + return presencePenalty; + } + + /** + * Return the penalize NL. + * + * @return the penalize NL + */ + public boolean isPenalizeNl() { + return penalizeNl; + } + + /** + * Returns {@code true} if ignore EOS. + * + * @return {@code true} if ignore EOS + */ + public boolean isIgnoreEos() { + return ignoreEos; + } + + /** + * Returns the mirostat. + * + * @return the mirostat + */ + public int getMirostat() { + return mirostat; + } + + /** + * Returns the mirostat TAU. + * + * @return the mirostat TAU + */ + public float getMirostatTau() { + return mirostatTau; + } + + /** + * Returns the mirostat ETA. + * + * @return the mirostat ETA + */ + public float getMirostatEta() { + return mirostatEta; + } + + /** + * Returns the number of beams. + * + * @return the number of beams + */ + public int getNumberBeams() { + return nBeams; + } + + /** + * Returns the seed. + * + * @return the seed + */ + public int getSeed() { + return seed; + } + + /** + * Returns the logit bias. + * + * @return the logit bias + */ + public Map getLogitBias() { + return logitBias; + } + + /** + * Returns the grammar template. + * + * @return the grammar template + */ + public String getGrammar() { + return grammar; + } + + /** + * Returns the anti-prompt. + * + * @return the anti-prompt + */ + public String[] getAntiPrompt() { + return antiPrompt; + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java b/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java new file mode 100644 index 00000000000..3792864c346 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/LibUtils.java @@ -0,0 +1,98 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +import ai.djl.util.ClassLoaderUtils; +import ai.djl.util.Platform; +import ai.djl.util.Utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.ArrayList; +import java.util.List; + +/** Utilities for finding the llama.cpp native binary on the System. */ +public final class LibUtils { + + private static final Logger logger = LoggerFactory.getLogger(LibUtils.class); + + private static final String LIB_NAME = System.mapLibraryName("djl_llama"); + private static final String LLAMA_NAME = System.mapLibraryName("llama"); + + private LibUtils() {} + + /** Loads llama.cpp native library. */ + public static void loadLibrary() { + List libs = new ArrayList<>(3); + libs.add(LLAMA_NAME); + libs.add(LIB_NAME); + if (System.getProperty("os.name").startsWith("Mac")) { + libs.add("ggml-metal.metal"); + } + Path dir = copyJniLibraryFromClasspath(libs.toArray(new String[0])); + logger.debug("Loading llama.cpp library from: {}", dir); + + for (int i = 0; i < 2; ++i) { + String lib = libs.get(i); + String path = dir.resolve(lib).toString(); + logger.debug("Loading native library: {}", path); + String nativeHelper = System.getProperty("ai.djl.llama.native_helper"); + if (nativeHelper != null && !nativeHelper.isEmpty()) { + ClassLoaderUtils.nativeLoad(nativeHelper, path); + } + System.load(path); // NOPMD + } + } + + private static Path copyJniLibraryFromClasspath(String... libs) { + Path cacheDir = Utils.getEngineCacheDir("llama"); + Platform platform = Platform.detectPlatform("llama"); + String classifier = platform.getClassifier(); + String version = platform.getVersion(); + Path dir = cacheDir.resolve(version + '-' + classifier); + Path path = dir.resolve(LIB_NAME); + logger.debug("Using cache dir: {}", dir); + if (Files.exists(path)) { + return dir.toAbsolutePath(); + } + + Path tmp = null; + try { + Files.createDirectories(cacheDir); + tmp = Files.createTempDirectory(cacheDir, "tmp"); + + for (String libName : libs) { + String libPath = "native/lib/" + classifier + "/" + libName; + logger.info("Extracting {} to cache ...", libPath); + try (InputStream is = ClassLoaderUtils.getResourceAsStream(libPath)) { + Path target = tmp.resolve(libName); + Files.copy(is, target, StandardCopyOption.REPLACE_EXISTING); + } + } + Utils.moveQuietly(tmp, dir); + return dir.toAbsolutePath(); + } catch (IOException e) { + throw new IllegalStateException("Cannot copy jni files", e); + } finally { + if (tmp != null) { + Utils.deleteQuietly(tmp); + } + } + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java b/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java new file mode 100644 index 00000000000..5d40fa29830 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/LlamaLibrary.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +/** Native library for llama.cpp. */ +@SuppressWarnings("MissingJavadocMethod") +public final class LlamaLibrary { + + private LlamaLibrary() {} + + public static native long loadModel(String filePath, ModelParameters param); + + public static native void generate(long handle, String prompt, InputParameters param); + + public static native void infill( + long handle, String prefix, String suffix, InputParameters param); + + public static native Token getNext(long handle, long count, long pos); + + public static native float[] embed(long handle, String prompt); + + public static native int[] encode(long handle, String prompt); + + public static native byte[] decodeBytes(long handle, int[] tokens); + + public static native void delete(long handle); +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java b/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java new file mode 100644 index 00000000000..e3e440474a8 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/ModelParameters.java @@ -0,0 +1,114 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +import java.util.Map; + +/** A class holds llama.cpp model loading parameters. */ +@SuppressWarnings("PMD.SingularField") +public final class ModelParameters { + + private int nThreads; + private int nCtx; + private int nBatch; + private int nGpuLayers; + private int mainGpu; + private float ropeFreqBase; + private float ropeFreqScale; + private boolean mulMatQ; + private boolean f16Kv; + private boolean logitsAll; + private boolean vocabOnly; + private boolean useMmap; + private boolean useMlock; + private boolean embedding; + private boolean memoryF16; + private boolean memTest; + private boolean numa; + private boolean verbosePrompt; + private float[] tensorSplit; + private String loraAdapter; + private String loraBase; + + /** + * Constructs a new {@code ModelParameters} instance. + * + * @param options the model loading options + */ + public ModelParameters(Map options) { + nThreads = intValue(options, "number_threads", Runtime.getRuntime().availableProcessors()); + nCtx = intValue(options, "max_context_length", 512); + nBatch = intValue(options, "max_rolling_batch", 512); + nGpuLayers = intValue(options, "number_gpu_layers", -1); + mainGpu = intValue(options, "tensor_parallel_degree", 0); + ropeFreqBase = floatValue(options, "rope_freq_base"); + ropeFreqScale = floatValue(options, "ropeFreqScale"); + f16Kv = booleanValue(options, "f16_kv"); + mulMatQ = booleanValue(options, "mulmat_q", true); + logitsAll = booleanValue(options, "logits_all"); + vocabOnly = booleanValue(options, "vocab_only"); + useMmap = booleanValue(options, "use_mmap", true); + useMlock = booleanValue(options, "use_mlock"); + embedding = booleanValue(options, "embedding"); + memoryF16 = booleanValue(options, "memory_f16", true); + memTest = booleanValue(options, "mem_test"); + numa = booleanValue(options, "numa"); + verbosePrompt = booleanValue(options, "verbose_prompt"); + String val = stringValue(options, "tensor_split"); + if (val != null && !val.isEmpty()) { + String[] tokens = val.split(","); + tensorSplit = new float[tokens.length]; + for (int i = 0; i < tokens.length; ++i) { + tensorSplit[i] = Float.parseFloat(tokens[i].trim()); + } + } + loraAdapter = stringValue(options, "lora_adapter"); + loraBase = stringValue(options, "loraBase"); + } + + private static int intValue(Map arguments, String key, int def) { + Object value = arguments.get(key); + if (value == null) { + return def; + } + return (int) Double.parseDouble(value.toString()); + } + + private static float floatValue(Map arguments, String key) { + Object value = arguments.get(key); + if (value == null) { + return 0f; + } + return (float) Double.parseDouble(value.toString()); + } + + private static boolean booleanValue(Map arguments, String key) { + return booleanValue(arguments, key, false); + } + + private static boolean booleanValue(Map arguments, String key, boolean def) { + Object value = arguments.get(key); + if (value == null) { + return def; + } + return Boolean.parseBoolean(value.toString()); + } + + private static String stringValue(Map arguments, String key) { + Object value = arguments.get(key); + if (value == null) { + return null; + } + return value.toString(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/Token.java b/engines/llama/src/main/java/ai/djl/llama/jni/Token.java new file mode 100644 index 00000000000..b8d74306b56 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/Token.java @@ -0,0 +1,87 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +import ai.djl.util.JsonUtils; + +import java.nio.charset.StandardCharsets; +import java.util.Map; + +/** The output token class. */ +public final class Token { + + private int token; + private String text; + private Map probabilities; + transient long count; + transient long pos; + transient boolean hasNext; + + /** + * Constructs a new {@code Token} instance. + * + * @param token the token id + * @param generated the token text + * @param probabilities the token probabilities + * @param count the generated token count + * @param pos the token index + * @param hasNext has more tokens + */ + public Token( + int token, + byte[] generated, + Map probabilities, + long count, + long pos, + boolean hasNext) { + this.token = token; + this.text = new String(generated, StandardCharsets.UTF_8); + this.probabilities = probabilities; + this.count = count; + this.pos = pos; + this.hasNext = hasNext; + } + + /** + * Returns the token id. + * + * @return the token id + */ + public int getToken() { + return token; + } + + /** + * Returns the token text. + * + * @return the token text + */ + public String getText() { + return text; + } + + /** + * Returns the token probabilities. + * + * @return the token probabilities + */ + public Map getProbabilities() { + return probabilities; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return JsonUtils.GSON.toJson(this) + '\n'; + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java b/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java new file mode 100644 index 00000000000..cab6575d8f7 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/TokenIterator.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.jni; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicBoolean; + +/** A iterator class holds generated tokens. */ +public class TokenIterator implements Iterator { + + private static final Logger logger = LoggerFactory.getLogger(TokenIterator.class); + + private static AtomicBoolean active = new AtomicBoolean(); + + private long handle; + private long count; + private long pos; + private boolean hasNext; + + /** + * Constructs a new {@code TokenIterator} instance. + * + * @param handle the llama.cpp handle + */ + public TokenIterator(long handle) { + this.handle = handle; + hasNext = true; + if (!active.compareAndSet(false, true)) { + active.set(true); + logger.warn("Previous inference has been reset"); + } + } + + /** {@inheritDoc} */ + @Override + public boolean hasNext() { + return hasNext; + } + + /** {@inheritDoc} */ + @Override + public Token next() { + if (!hasNext) { + throw new NoSuchElementException(); + } + Token token = LlamaLibrary.getNext(handle, count, pos); + count = token.count; + pos = token.pos; + hasNext = token.hasNext; + if (!hasNext) { + active.set(false); + } + return token; + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java b/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java new file mode 100644 index 00000000000..6f429aceda2 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/jni/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains classes to interface with the native llama.cpp code. */ +package ai.djl.llama.jni; diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java new file mode 100644 index 00000000000..69d4f200ba9 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaModelZoo.java @@ -0,0 +1,172 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.zoo; + +import ai.djl.Application; +import ai.djl.repository.Repository; +import ai.djl.repository.zoo.ModelLoader; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.util.ClassLoaderUtils; +import ai.djl.util.JsonUtils; +import ai.djl.util.Utils; + +import com.google.gson.reflect.TypeToken; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; +import java.io.Writer; +import java.lang.reflect.Type; +import java.net.URI; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.zip.GZIPInputStream; + +/** LlamaModelZoo is a repository that contains llama.cpp models. */ +public class LlamaModelZoo extends ModelZoo { + + private static final Logger logger = LoggerFactory.getLogger(LlamaModelZoo.class); + + private static final String REPO = "https://mlrepo.djl.ai/"; + private static final Repository REPOSITORY = Repository.newInstance("gguf", REPO); + private static final String GROUP_ID = "ai.djl.huggingface.gguf"; + + private static final long ONE_DAY = Duration.ofDays(1).toMillis(); + + private boolean initialized; + + LlamaModelZoo() {} + + /** {@inheritDoc} */ + @Override + public String getGroupId() { + return GROUP_ID; + } + + /** {@inheritDoc} */ + @Override + public Set getSupportedEngines() { + return Collections.singleton("Llama"); + } + + /** {@inheritDoc} */ + @Override + public Collection getModelLoaders() { + init(); + return super.getModelLoaders(); + } + + /** {@inheritDoc} */ + @Override + public ModelLoader getModelLoader(String name) { + init(); + return super.getModelLoader(name); + } + + private void init() { + if (!initialized) { + Application app = Application.NLP.TEXT_GENERATION; + Map map = listModels(app); + for (Map.Entry entry : map.entrySet()) { + String artifactId = entry.getKey(); + Map gguf = entry.getValue().getGguf(); + if (gguf != null) { + for (String key : gguf.keySet()) { + addModel(REPOSITORY.model(app, GROUP_ID, artifactId, "0.0.1", key)); + } + } + } + initialized = true; + } + } + + private Map listModels(Application app) { + try { + String path = "model/" + app.getPath() + "/ai/djl/huggingface/gguf/"; + Path dir = Utils.getCacheDir().resolve("cache/repo/" + path); + if (Files.notExists(dir)) { + Files.createDirectories(dir); + } else if (!Files.isDirectory(dir)) { + logger.warn("Failed initialize cache directory: " + dir); + return Collections.emptyMap(); + } + Type type = new TypeToken>() {}.getType(); + + Path file = dir.resolve("models.json"); + if (Files.exists(file)) { + long lastModified = Files.getLastModifiedTime(file).toMillis(); + if (Utils.isOfflineMode() || System.currentTimeMillis() - lastModified < ONE_DAY) { + try (Reader reader = Files.newBufferedReader(file)) { + return JsonUtils.GSON.fromJson(reader, type); + } + } + } + + URL url = URI.create(REPO).resolve(path + "models.json.gz").toURL(); + Path tmp = Files.createTempFile(dir, "models", ".tmp"); + try (GZIPInputStream gis = new GZIPInputStream(Utils.openUrl(url))) { + String json = Utils.toString(gis); + try (Writer writer = Files.newBufferedWriter(tmp)) { + writer.write(json); + } + Utils.moveQuietly(tmp, file); + return JsonUtils.GSON.fromJson(json, type); + } catch (IOException e) { + logger.warn("Failed to download Huggingface gguf index: {}", app); + if (Files.exists(file)) { + try (Reader reader = Files.newBufferedReader(file)) { + return JsonUtils.GSON.fromJson(reader, type); + } + } + + String resource = app.getPath() + "/" + GROUP_ID + ".json"; + try (InputStream is = ClassLoaderUtils.getResourceAsStream(resource)) { + String json = Utils.toString(is); + try (Writer writer = Files.newBufferedWriter(tmp)) { + writer.write(json); + } + Utils.moveQuietly(tmp, file); + return JsonUtils.GSON.fromJson(json, type); + } + } finally { + Utils.deleteQuietly(tmp); + } + } catch (IOException e) { + logger.warn("Failed load gguf index file", e); + } + + return Collections.emptyMap(); + } + + private static final class ModelDetail { + + private Map gguf; + + public Map getGguf() { + return gguf; + } + + public void setGguf(Map gguf) { + this.gguf = gguf; + } + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java new file mode 100644 index 00000000000..ba2b04722c1 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/zoo/LlamaZooProvider.java @@ -0,0 +1,29 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.zoo; + +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooProvider; + +/** + * An Huggingface llama.cpp model zoo provider implements the {@link + * ai.djl.repository.zoo.ZooProvider} interface. + */ +public class LlamaZooProvider implements ZooProvider { + + /** {@inheritDoc} */ + @Override + public ModelZoo getModelZoo() { + return new LlamaModelZoo(); + } +} diff --git a/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java b/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java new file mode 100644 index 00000000000..a9c1df64cd0 --- /dev/null +++ b/engines/llama/src/main/java/ai/djl/llama/zoo/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains the built-in {@link ai.djl.llama.zoo.LlamaModelZoo}. */ +package ai.djl.llama.zoo; diff --git a/engines/llama/src/main/javadoc/overview.html b/engines/llama/src/main/javadoc/overview.html new file mode 100644 index 00000000000..05dec7d0bd4 --- /dev/null +++ b/engines/llama/src/main/javadoc/overview.html @@ -0,0 +1,14 @@ + + + + + +

This document is the API specification for the Deep Java Library (DJL) Llama Engine.

+ +
+ + + diff --git a/engines/llama/src/main/native/ai_djl_llama.cpp b/engines/llama/src/main/native/ai_djl_llama.cpp new file mode 100644 index 00000000000..1d6072751f2 --- /dev/null +++ b/engines/llama/src/main/native/ai_djl_llama.cpp @@ -0,0 +1,1025 @@ +#include +#include +#include +#include + +#include "ai_djl_llama_jni_LlamaLibrary.h" +#include "common.h" +#include "grammar-parser.h" +#include "llama.h" +#include "sampling.h" + +// classes +static jclass c_lib_utils = 0; +static jclass c_model_params = 0; +static jclass c_input_params = 0; +static jclass c_token = 0; +static jclass c_standard_charsets = 0; +static jclass c_string = 0; +static jclass c_hash_map = 0; +static jclass c_map = 0; +static jclass c_set = 0; +static jclass c_entry = 0; +static jclass c_integer = 0; +static jclass c_float = 0; +static jclass c_logger = 0; +static jclass c_engine_exception = 0; + +// constructors +static jmethodID cc_token = 0; +static jmethodID cc_hash_map = 0; +static jmethodID cc_integer = 0; +static jmethodID cc_float = 0; + +// methods +static jmethodID m_get_bytes = 0; +static jmethodID m_entry_set = 0; +static jmethodID m_set_iterator = 0; +static jmethodID m_iterator_has_next = 0; +static jmethodID m_iterator_next = 0; +static jmethodID m_entry_key = 0; +static jmethodID m_entry_value = 0; +static jmethodID m_map_put = 0; +static jmethodID m_int_value = 0; +static jmethodID m_float_value = 0; +static jmethodID m_log_debug = 0; +static jmethodID m_log_info = 0; +static jmethodID m_log_warn = 0; +static jmethodID m_log_error = 0; + +// fields +static jfieldID f_logger = 0; +// inference parameters +static jfieldID f_n_predict = 0; +static jfieldID f_n_keep = 0; +static jfieldID f_n_probs = 0; +static jfieldID f_logit_bias = 0; +static jfieldID f_top_k = 0; +static jfieldID f_top_p = 0; +static jfieldID f_tfs_z = 0; +static jfieldID f_typical_p = 0; +static jfieldID f_temperature = 0; +static jfieldID f_repeat_penalty = 0; +static jfieldID f_repeat_last_n = 0; +static jfieldID f_frequency_penalty = 0; +static jfieldID f_presence_penalty = 0; +static jfieldID f_penalize_nl = 0; +static jfieldID f_ignore_eos = 0; +static jfieldID f_mirostat = 0; +static jfieldID f_mirostat_tau = 0; +static jfieldID f_mirostat_eta = 0; +static jfieldID f_n_beams = 0; +static jfieldID f_grammar = 0; +static jfieldID f_antiprompt = 0; +static jfieldID f_infer_seed = 0; +// model parameters +static jfieldID f_n_threads = 0; +static jfieldID f_n_ctx = 0; +static jfieldID f_n_batch = 0; +static jfieldID f_n_gpu_layers = 0; +static jfieldID f_main_gpu = 0; +static jfieldID f_tensor_split = 0; +static jfieldID f_rope_freq_base = 0; +static jfieldID f_rope_freq_scale = 0; +static jfieldID f_mul_mat_q = 0; +static jfieldID f_f16_kv = 0; +static jfieldID f_logits_all = 0; +static jfieldID f_vocab_only = 0; +static jfieldID f_use_mmap = 0; +static jfieldID f_use_mlock = 0; +static jfieldID f_embedding = 0; +static jfieldID f_lora_adapter = 0; +static jfieldID f_lora_base = 0; +static jfieldID f_memory_f16 = 0; +static jfieldID f_mem_test = 0; +static jfieldID f_numa = 0; +static jfieldID f_verbose_prompt = 0; +// log level +static jfieldID f_utf_8 = 0; +// objects +static jobject o_utf_8 = 0; +static jobject o_logger = 0; + +static JavaVM *g_vm = nullptr; + +static void null_log_callback(enum ggml_log_level level, const char *text, void *user_data) {} + +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { + JNIEnv *env = 0; + + if (JNI_OK != vm->GetEnv((void **) &env, JNI_VERSION_1_1)) { + return JNI_ERR; + } + + log_disable(); + llama_log_set(null_log_callback, nullptr); + + // find classes + c_input_params = env->FindClass("ai/djl/llama/jni/InputParameters"); + c_model_params = env->FindClass("ai/djl/llama/jni/ModelParameters"); + c_lib_utils = env->FindClass("ai/djl/llama/jni/LibUtils"); + c_token = env->FindClass("ai/djl/llama/jni/Token"); + c_engine_exception = env->FindClass("ai/djl/engine/EngineException"); + c_logger = env->FindClass("org/slf4j/Logger"); + c_standard_charsets = env->FindClass("java/nio/charset/StandardCharsets"); + c_string = env->FindClass("java/lang/String"); + c_hash_map = env->FindClass("java/util/HashMap"); + c_map = env->FindClass("java/util/Map"); + c_set = env->FindClass("java/util/Set"); + c_entry = env->FindClass("java/util/Map$Entry"); + c_integer = env->FindClass("java/lang/Integer"); + c_float = env->FindClass("java/lang/Float"); + + // create references + c_input_params = (jclass) env->NewGlobalRef(c_input_params); + c_model_params = (jclass) env->NewGlobalRef(c_model_params); + c_lib_utils = (jclass) env->NewGlobalRef(c_lib_utils); + c_token = (jclass) env->NewGlobalRef(c_token); + c_engine_exception = (jclass) env->NewGlobalRef(c_engine_exception); + c_logger = (jclass) env->NewGlobalRef(c_logger); + c_string = (jclass) env->NewGlobalRef(c_string); + c_hash_map = (jclass) env->NewGlobalRef(c_hash_map); + c_map = (jclass) env->NewGlobalRef(c_map); + c_set = (jclass) env->NewGlobalRef(c_set); + c_entry = (jclass) env->NewGlobalRef(c_entry); + c_integer = (jclass) env->NewGlobalRef(c_integer); + c_float = (jclass) env->NewGlobalRef(c_float); + + // find constructors + cc_token = env->GetMethodID(c_token, "", "(I[BLjava/util/Map;JJZ)V"); + cc_hash_map = env->GetMethodID(c_hash_map, "", "()V"); + cc_integer = env->GetMethodID(c_integer, "", "(I)V"); + cc_float = env->GetMethodID(c_float, "", "(F)V"); + + // find methods + m_get_bytes = env->GetMethodID(c_string, "getBytes", "(Ljava/lang/String;)[B"); + m_entry_set = env->GetMethodID(c_map, "entrySet", "()Ljava/util/Set;"); + m_entry_key = env->GetMethodID(c_entry, "getKey", "()Ljava/lang/Object;"); + m_entry_value = env->GetMethodID(c_entry, "getValue", "()Ljava/lang/Object;"); + m_map_put = env->GetMethodID(c_map, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"); + m_int_value = env->GetMethodID(c_integer, "intValue", "()I"); + m_float_value = env->GetMethodID(c_float, "floatValue", "()F"); + m_log_debug = env->GetMethodID(c_logger, "debug", "(Ljava/lang/String;)V"); + m_log_info = env->GetMethodID(c_logger, "info", "(Ljava/lang/String;)V"); + m_log_warn = env->GetMethodID(c_logger, "warn", "(Ljava/lang/String;)V"); + m_log_error = env->GetMethodID(c_logger, "error", "(Ljava/lang/String;)V"); + + // find fields + f_logger = env->GetStaticFieldID(c_lib_utils, "logger", "Lorg/slf4j/Logger;"); + + f_n_predict = env->GetFieldID(c_input_params, "nPredict", "I"); + f_n_keep = env->GetFieldID(c_input_params, "nKeep", "I"); + f_n_probs = env->GetFieldID(c_input_params, "nProbs", "I"); + f_logit_bias = env->GetFieldID(c_input_params, "logitBias", "Ljava/util/Map;"); + f_top_k = env->GetFieldID(c_input_params, "topK", "I"); + f_top_p = env->GetFieldID(c_input_params, "topP", "F"); + f_tfs_z = env->GetFieldID(c_input_params, "tfsZ", "F"); + f_typical_p = env->GetFieldID(c_input_params, "typicalP", "F"); + f_temperature = env->GetFieldID(c_input_params, "temperature", "F"); + f_repeat_penalty = env->GetFieldID(c_input_params, "repeatPenalty", "F"); + f_repeat_last_n = env->GetFieldID(c_input_params, "repeatLastN", "I"); + f_frequency_penalty = env->GetFieldID(c_input_params, "frequencyPenalty", "F"); + f_presence_penalty = env->GetFieldID(c_input_params, "presencePenalty", "F"); + f_penalize_nl = env->GetFieldID(c_input_params, "penalizeNl", "Z"); + f_ignore_eos = env->GetFieldID(c_input_params, "ignoreEos", "Z"); + f_mirostat = env->GetFieldID(c_input_params, "mirostat", "I"); + f_mirostat_tau = env->GetFieldID(c_input_params, "mirostatTau", "F"); + f_mirostat_eta = env->GetFieldID(c_input_params, "mirostatEta", "F"); + f_n_beams = env->GetFieldID(c_input_params, "nBeams", "I"); + f_grammar = env->GetFieldID(c_input_params, "grammar", "Ljava/lang/String;"); + f_antiprompt = env->GetFieldID(c_input_params, "antiPrompt", "[Ljava/lang/String;"); + f_infer_seed = env->GetFieldID(c_input_params, "seed", "I"); + + f_n_threads = env->GetFieldID(c_model_params, "nThreads", "I"); + f_n_ctx = env->GetFieldID(c_model_params, "nCtx", "I"); + f_n_batch = env->GetFieldID(c_model_params, "nBatch", "I"); + f_n_gpu_layers = env->GetFieldID(c_model_params, "nGpuLayers", "I"); + f_main_gpu = env->GetFieldID(c_model_params, "mainGpu", "I"); + f_tensor_split = env->GetFieldID(c_model_params, "tensorSplit", "[F"); + f_rope_freq_base = env->GetFieldID(c_model_params, "ropeFreqBase", "F"); + f_rope_freq_scale = env->GetFieldID(c_model_params, "ropeFreqScale", "F"); + f_mul_mat_q = env->GetFieldID(c_model_params, "mulMatQ", "Z"); + f_f16_kv = env->GetFieldID(c_model_params, "f16Kv", "Z"); + f_logits_all = env->GetFieldID(c_model_params, "logitsAll", "Z"); + f_vocab_only = env->GetFieldID(c_model_params, "vocabOnly", "Z"); + f_use_mmap = env->GetFieldID(c_model_params, "useMmap", "Z"); + f_use_mlock = env->GetFieldID(c_model_params, "useMlock", "Z"); + f_embedding = env->GetFieldID(c_model_params, "embedding", "Z"); + f_lora_adapter = env->GetFieldID(c_model_params, "loraAdapter", "Ljava/lang/String;"); + f_lora_base = env->GetFieldID(c_model_params, "loraBase", "Ljava/lang/String;"); + f_memory_f16 = env->GetFieldID(c_model_params, "memoryF16", "Z"); + f_mem_test = env->GetFieldID(c_model_params, "memTest", "Z"); + f_numa = env->GetFieldID(c_model_params, "numa", "Z"); + f_verbose_prompt = env->GetFieldID(c_model_params, "verbosePrompt", "Z"); + + f_utf_8 = env->GetStaticFieldID(c_standard_charsets, "UTF_8", "Ljava/nio/charset/Charset;"); + o_utf_8 = env->NewStringUTF("UTF-8"); + o_utf_8 = (jobject) env->NewGlobalRef(o_utf_8); + o_logger = env->GetStaticObjectField(c_lib_utils, f_logger); + o_logger = (jobject) env->NewGlobalRef(o_logger); + + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + return JNI_ERR; + } + + return JNI_VERSION_1_1; +} + +JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { + JNIEnv *env = 0; + + if (JNI_OK != vm->GetEnv((void **) &env, JNI_VERSION_1_1)) { + return; + } + + env->DeleteGlobalRef(c_input_params); + env->DeleteGlobalRef(c_model_params); + env->DeleteGlobalRef(c_token); + env->DeleteGlobalRef(c_string); + env->DeleteGlobalRef(c_hash_map); + env->DeleteGlobalRef(c_map); + env->DeleteGlobalRef(c_set); + env->DeleteGlobalRef(c_entry); + env->DeleteGlobalRef(c_integer); + env->DeleteGlobalRef(c_float); + env->DeleteGlobalRef(c_logger); + env->DeleteGlobalRef(c_engine_exception); + + env->DeleteGlobalRef(o_utf_8); +} + +static void log(JNIEnv *env, enum ggml_log_level level, const char *text) { + jstring java_text = env->NewStringUTF(text); + + switch (level) { + case GGML_LOG_LEVEL_ERROR: + env->CallVoidMethod(o_logger, m_log_error, java_text); + break; + case GGML_LOG_LEVEL_WARN: + env->CallVoidMethod(o_logger, m_log_warn, java_text); + break; + case GGML_LOG_LEVEL_INFO: + env->CallVoidMethod(o_logger, m_log_info, java_text); + break; + default: + env->CallVoidMethod(o_logger, m_log_debug, java_text); + break; + } + env->DeleteLocalRef(java_text); +} + +static void log(JNIEnv *env, enum ggml_log_level level, std::string text) { log(env, level, text.c_str()); } + +static std::string parse_jstring(JNIEnv *env, jstring java_string) { + const jbyteArray string_bytes = (jbyteArray) env->CallObjectMethod(java_string, m_get_bytes, o_utf_8); + + size_t length = (size_t) env->GetArrayLength(string_bytes); + jbyte *byte_elements = env->GetByteArrayElements(string_bytes, nullptr); + + std::string string = std::string((char *) byte_elements, length); + + env->ReleaseByteArrayElements(string_bytes, byte_elements, JNI_ABORT); + env->DeleteLocalRef(string_bytes); + + return string; +} + +static int parse_jinteger(JNIEnv *env, jobject java_integer) { + if (!java_integer) return 0; + return env->CallIntMethod(java_integer, m_int_value); +} + +static float parse_jfloat(JNIEnv *env, jobject java_float) { + if (!java_float) return 0; + return env->CallFloatMethod(java_float, m_float_value); +} + +static jbyteArray parse_jbytes(JNIEnv *env, std::string string) { + jsize len = string.size(); + jbyteArray bytes = env->NewByteArray(len); + env->SetByteArrayRegion(bytes, 0, len, reinterpret_cast(string.c_str())); + return bytes; +} + +// completion token output with probabilities +struct completion_token_output { + struct token_prob { + llama_token tok; + float prob; + }; + + std::vector probs; + llama_token tok; +}; + +static size_t common_part(const std::vector &a, const std::vector &b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) { + } + return i; +} + +enum stop_type { + STOP_FULL, + STOP_PARTIAL, +}; + +static bool ends_with(const std::string &str, const std::string &suffix) { + return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +} + +static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { + if (!text.empty() && !stop.empty()) { + const char text_last_char = text.back(); + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { + if (stop[char_index] == text_last_char) { + const std::string current_partial = stop.substr(0, char_index + 1); + if (ends_with(text, current_partial)) { + return text.size() - char_index - 1; + } + } + } + } + return std::string::npos; +} + +template +static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) { + std::string ret; + for (; begin != end; ++begin) { + ret += llama_token_to_piece(ctx, *begin); + } + return ret; +} + +// format incomplete utf-8 multibyte character for output +static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) { + std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { + std::stringstream ss; + ss << std::hex << (out[0] & 0xff); + std::string res(ss.str()); + out = "byte: \\x" + res; + } + return out; +} + +struct jllama_context { + bool has_next_token = false; + std::string generated_text; + std::vector generated_token_probs; + + size_t num_prompt_tokens = 0; + size_t num_tokens_predicted = 0; + size_t n_past = 0; + size_t n_remain = 0; + + std::string prompt; + std::vector embd; + std::vector last_n_tokens; + + llama_model *model = nullptr; + llama_context *ctx = nullptr; + gpt_params params; + llama_sampling_context ctx_sampling; + int n_ctx; + + grammar_parser::parse_state parsed_grammar; + llama_grammar *grammar = nullptr; + + bool truncated = false; + bool stopped_eos = false; + bool stopped_word = false; + bool stopped_limit = false; + std::string stopping_word; + int32_t multibyte_pending = 0; + + std::mutex mutex; + + std::unique_lock lock() { return std::unique_lock(mutex); } + + ~jllama_context() { + if (ctx) { + llama_free(ctx); + ctx = nullptr; + } + if (model) { + llama_free_model(model); + model = nullptr; + } + if (grammar) { + llama_grammar_free(grammar); + grammar = nullptr; + } + } + + void rewind() { + params.antiprompt.clear(); + params.sparams.grammar.clear(); + num_prompt_tokens = 0; + num_tokens_predicted = 0; + generated_text = ""; + generated_text.reserve(n_ctx); + generated_token_probs.clear(); + truncated = false; + stopped_eos = false; + stopped_word = false; + stopped_limit = false; + stopping_word = ""; + multibyte_pending = 0; + n_remain = 0; + n_past = 0; + + if (grammar != nullptr) { + llama_grammar_free(grammar); + grammar = nullptr; + ctx_sampling = *llama_sampling_init(params.sparams); + } + } + + bool loadModel(const gpt_params ¶ms_) { + params = params_; + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == nullptr) { + return false; + } + n_ctx = llama_n_ctx(ctx); + last_n_tokens.resize(n_ctx); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + return true; + } + + std::vector tokenize(std::string prompt, bool add_bos) const { + return ::llama_tokenize(ctx, prompt, add_bos); + } + + bool loadGrammar(JNIEnv *env) { + if (!params.sparams.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.sparams.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + log(env, GGML_LOG_LEVEL_ERROR, "grammar parse error"); + return false; + } + grammar_parser::print_grammar(stderr, parsed_grammar); + + { + auto it = params.sparams.logit_bias.find(llama_token_eos(model)); + if (it != params.sparams.logit_bias.end() && it->second == -INFINITY) { + log(env, GGML_LOG_LEVEL_WARN, "EOS token is disabled, which will cause most grammars to fail"); + } + } + + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + ctx_sampling = *llama_sampling_init(params.sparams); + return true; + } + + void loadInfill(JNIEnv *env) { + bool suff_rm_leading_spc = true; + if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } + + auto prefix_tokens = tokenize(params.input_prefix, false); + auto suffix_tokens = tokenize(params.input_suffix, false); + const int space_token = 29871; + if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { + suffix_tokens.erase(suffix_tokens.begin()); + } + prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); + prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.push_back(llama_token_middle(model)); + auto prompt_tokens = prefix_tokens; + + num_prompt_tokens = prompt_tokens.size(); + + if (params.n_keep < 0) { + params.n_keep = (int) num_prompt_tokens; + } + params.n_keep = std::min(params.n_ctx - 4, params.n_keep); + + // if input prompt is too big, truncate like normal + if (num_prompt_tokens >= (size_t) params.n_ctx) { + // todo we probably want to cut from both sides + const int n_left = (params.n_ctx - params.n_keep) / 2; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + new_tokens.insert( + new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left)); + + truncated = true; + prompt_tokens = new_tokens; + } else { + const size_t ps = num_prompt_tokens; + std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + } + + // compare the evaluated prompt with the new prompt + n_past = common_part(embd, prompt_tokens); + embd = prompt_tokens; + + if (n_past == num_prompt_tokens) { + // we have to evaluate at least 1 token to generate logits. + n_past--; + } + + // since #3228 we now have to manually manage the KV cache + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + + has_next_token = true; + } + + void loadPrompt(JNIEnv *env) { + auto prompt_tokens = tokenize(prompt, true); // always add BOS + + num_prompt_tokens = prompt_tokens.size(); + + if (params.n_keep < 0) { + params.n_keep = (int) num_prompt_tokens; + } + params.n_keep = std::min(n_ctx - 4, params.n_keep); + + // if input prompt is too big, truncate like normal + if (num_prompt_tokens >= (size_t) n_ctx) { + const int n_left = (n_ctx - params.n_keep) / 2; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + new_tokens.insert( + new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left)); + + truncated = true; + prompt_tokens = new_tokens; + } else { + const size_t ps = num_prompt_tokens; + std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + } + + // compare the evaluated prompt with the new prompt + n_past = common_part(embd, prompt_tokens); + + embd = prompt_tokens; + if (n_past == num_prompt_tokens) { + // we have to evaluate at least 1 token to generate logits. + n_past--; + } + + // since #3228 we now have to manually manage the KV cache + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + + has_next_token = true; + } + + void beginCompletion() { + // number of tokens to keep when resetting context + n_remain = params.n_predict; + llama_set_rng_seed(ctx, params.seed); + } + + completion_token_output nextToken(JNIEnv *env) { + completion_token_output result; + result.tok = -1; + + if (embd.size() >= (size_t) n_ctx) { + // Shift context + + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left / 2; + + llama_kv_cache_seq_rm(ctx, 0, params.n_keep + 1, params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + + for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++) { + embd[i - n_discard] = embd[i]; + } + embd.resize(embd.size() - n_discard); + + n_past -= n_discard; + + truncated = true; + log(env, GGML_LOG_LEVEL_INFO, "input truncated n_left=" + std::to_string(n_left)); + } + + bool tg = true; + while (n_past < embd.size()) { + int n_eval = (int) embd.size() - n_past; + tg = n_eval == 1; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + + if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0))) { + log(env, GGML_LOG_LEVEL_ERROR, "failed to eval n_eval=" + std::to_string(n_eval)); + has_next_token = false; + return result; + } + n_past += n_eval; + } + + if (params.n_predict == 0) { + has_next_token = false; + result.tok = llama_token_eos(model); + return result; + } + + { + // out of user input, sample next token + result.tok = llama_sampling_sample(&ctx_sampling, ctx, NULL); + + llama_token_data_array candidates_p = {ctx_sampling.cur.data(), ctx_sampling.cur.size(), false}; + + const int32_t n_probs = params.sparams.n_probs; + if (params.sparams.temp <= 0 && n_probs > 0) { + // For llama_sample_token_greedy we need to sort candidates + llama_sample_softmax(ctx, &candidates_p); + } + + for (size_t i = 0; i < std::min(candidates_p.size, (size_t) n_probs); ++i) { + result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); + } + + llama_sampling_accept(&ctx_sampling, ctx, result.tok, true); + if (tg) { + num_tokens_predicted++; + } + } + + // add it to the context + embd.push_back(result.tok); + // decrement remaining sampling budget + --n_remain; + + if (!embd.empty() && embd.back() == llama_token_eos(model)) { + // stopping_word = llama_token_to_piece(ctx, embd.back()); + has_next_token = false; + stopped_eos = true; + return result; + } + + has_next_token = params.n_predict == -1 || n_remain != 0; + return result; + } + + size_t findStoppingStrings(const std::string &text, const size_t last_token_size, const stop_type type) { + size_t stop_pos = std::string::npos; + for (const std::string &word : params.antiprompt) { + size_t pos; + if (type == STOP_FULL) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + pos = text.find(word, from_pos); + } else { + pos = find_partial_stop_string(word, text); + } + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (type == STOP_FULL) { + stopping_word = word; + stopped_word = true; + has_next_token = false; + } + stop_pos = pos; + } + } + return stop_pos; + } + + completion_token_output doCompletion(JNIEnv *env) { + auto token_with_probs = nextToken(env); + + const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); + generated_text += token_text; + + if (params.sparams.n_probs > 0) { + generated_token_probs.push_back(token_with_probs); + } + + if (multibyte_pending > 0) { + multibyte_pending -= token_text.size(); + } else if (token_text.size() == 1) { + const char c = token_text[0]; + // 2-byte characters: 110xxxxx 10xxxxxx + if ((c & 0xE0) == 0xC0) { + multibyte_pending = 1; + // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + } else if ((c & 0xF0) == 0xE0) { + multibyte_pending = 2; + // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + } else if ((c & 0xF8) == 0xF0) { + multibyte_pending = 3; + } else { + multibyte_pending = 0; + } + } + + if (multibyte_pending > 0 && !has_next_token) { + has_next_token = true; + n_remain++; + } + + if (!has_next_token && n_remain == 0) { + stopped_limit = true; + } + + return token_with_probs; + } + + std::vector getEmbedding(JNIEnv *env) { + static const int n_embd = llama_n_embd(model); + if (!params.embedding) { + log(env, GGML_LOG_LEVEL_ERROR, "embedding disabled"); + return std::vector(n_embd, 0.0f); + } + const float *data = llama_get_embeddings(ctx); + std::vector embedding(data, data + n_embd); + return embedding; + } +}; + +static gpt_params parse_model_params(JNIEnv *env, jobject jparams, jstring java_file_path) { + gpt_params params; + + params.model = parse_jstring(env, java_file_path); + params.n_threads = env->GetIntField(jparams, f_n_threads); + params.n_ctx = env->GetIntField(jparams, f_n_ctx); + params.n_batch = env->GetIntField(jparams, f_n_batch); + params.n_gpu_layers = env->GetIntField(jparams, f_n_gpu_layers); + params.main_gpu = env->GetIntField(jparams, f_main_gpu); + params.rope_freq_base = env->GetFloatField(jparams, f_rope_freq_base); + params.rope_freq_scale = env->GetFloatField(jparams, f_rope_freq_scale); + params.mul_mat_q = env->GetBooleanField(jparams, f_mul_mat_q); + params.embedding = env->GetBooleanField(jparams, f_embedding); + params.escape = env->GetIntField(jparams, f_n_predict); + params.use_mmap = env->GetBooleanField(jparams, f_use_mmap); + params.use_mlock = env->GetBooleanField(jparams, f_use_mlock); + params.numa = env->GetBooleanField(jparams, f_numa); + params.verbose_prompt = env->GetBooleanField(jparams, f_verbose_prompt); + + if (params.model_alias == "unknown") { + params.model_alias = params.model; + } + + return params; +} + +static void setup_infer_params(JNIEnv *env, jllama_context *llama, jobject jparams) { + auto ¶ms = llama->params; + + params.seed = env->GetIntField(jparams, f_infer_seed); + params.n_predict = env->GetIntField(jparams, f_n_predict); + params.n_keep = env->GetIntField(jparams, f_n_keep); + + auto &sparams = params.sparams; + + sparams.top_k = env->GetIntField(jparams, f_top_k); + sparams.top_p = env->GetFloatField(jparams, f_top_p); + sparams.tfs_z = env->GetFloatField(jparams, f_tfs_z); + sparams.typical_p = env->GetFloatField(jparams, f_typical_p); + sparams.temp = env->GetFloatField(jparams, f_temperature); + sparams.penalty_repeat = env->GetFloatField(jparams, f_repeat_penalty); + sparams.n_prev = env->GetIntField(jparams, f_repeat_last_n); + sparams.penalty_freq = env->GetFloatField(jparams, f_frequency_penalty); + sparams.penalty_present = env->GetFloatField(jparams, f_presence_penalty); + sparams.penalize_nl = env->GetBooleanField(jparams, f_penalize_nl); + sparams.mirostat = env->GetIntField(jparams, f_mirostat); + sparams.mirostat_tau = env->GetFloatField(jparams, f_mirostat_tau); + sparams.mirostat_eta = env->GetFloatField(jparams, f_mirostat_eta); + sparams.n_probs = env->GetIntField(jparams, f_n_probs); + + jstring j_grammar = (jstring) env->GetObjectField(jparams, f_grammar); + if (j_grammar != nullptr) { + sparams.grammar = parse_jstring(env, j_grammar); + env->DeleteLocalRef(j_grammar); + if (!llama->loadGrammar(env)) { + env->ThrowNew(c_engine_exception, "could not load grammar"); + } + } + + sparams.logit_bias.clear(); + jboolean ignore_eos = env->GetBooleanField(jparams, f_ignore_eos); + if (ignore_eos) { + sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY; + } + + jobject logit_bias = env->GetObjectField(jparams, f_logit_bias); + if (logit_bias != nullptr) { + jobject entry_set = env->CallObjectMethod(logit_bias, m_entry_set); + jobject iterator = env->CallObjectMethod(entry_set, m_set_iterator); + while (env->CallBooleanMethod(iterator, m_iterator_has_next)) { + jobject entry = env->CallObjectMethod(iterator, m_iterator_next); + jobject key = env->CallObjectMethod(entry, m_entry_key); + jobject value = env->CallObjectMethod(entry, m_entry_value); + + int tok = parse_jinteger(env, key); + float bias = parse_jfloat(env, value); + sparams.logit_bias[tok] = bias; + + env->DeleteLocalRef(entry); + env->DeleteLocalRef(key); + env->DeleteLocalRef(value); + } + } + + params.antiprompt.clear(); + jobjectArray antiprompt = (jobjectArray) env->GetObjectField(jparams, f_antiprompt); + if (antiprompt != nullptr) { + jsize array_length = env->GetArrayLength(antiprompt); + for (jsize i = 0; i < array_length; i++) { + jstring java_string = (jstring) env->GetObjectArrayElement(antiprompt, i); + if (java_string != nullptr) { + std::string string = parse_jstring(env, java_string); + params.antiprompt.push_back(string); + env->DeleteLocalRef(java_string); + } + } + } + + llama->ctx_sampling = *llama_sampling_init(params.sparams); +} + +static void setup_answering(JNIEnv *env, jllama_context *llama, jstring prompt, jobject params) { + llama->prompt = parse_jstring(env, prompt); + llama->params.input_prefix = ""; + llama->params.input_suffix = ""; + setup_infer_params(env, llama, params); +} + +static void setup_infilling(JNIEnv *env, jllama_context *llama, jstring prefix, jstring suffix, jobject params) { + llama->prompt = ""; + llama->params.input_prefix = parse_jstring(env, prefix); + llama->params.input_suffix = parse_jstring(env, suffix); + setup_infer_params(env, llama, params); +} + +JNIEXPORT jlong JNICALL Java_ai_djl_llama_jni_LlamaLibrary_loadModel( + JNIEnv *env, jclass clazz, jstring file_path, jobject jparams) { + gpt_params params = parse_model_params(env, jparams, file_path); + + jllama_context *llama = new jllama_context; + llama_backend_init(false); + + if (!llama->loadModel(params)) { + env->ThrowNew(c_engine_exception, "could not load model from given file path"); + return 0; + } + + return reinterpret_cast(llama); +} + +JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_generate( + JNIEnv *env, jclass clazz, jlong handle, jstring prompt, jobject params) { + auto *llama = reinterpret_cast(handle); + + llama->rewind(); + llama_reset_timings(llama->ctx); + setup_answering(env, llama, prompt, params); + + llama->loadPrompt(env); + llama->beginCompletion(); +} + +JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_infill( + JNIEnv *env, jclass clazz, jlong handle, jstring prefix, jstring suffix, jobject params) { + auto *llama = reinterpret_cast(handle); + + llama->rewind(); + + llama_reset_timings(llama->ctx); + + setup_infilling(env, llama, prefix, suffix, params); + + llama->loadInfill(env); + llama->beginCompletion(); +} + +JNIEXPORT jobject JNICALL Java_ai_djl_llama_jni_LlamaLibrary_getNext( + JNIEnv *env, jclass clazz, jlong handle, jlong sent_count, jlong sent_token_probs_index) { + auto *llama = reinterpret_cast(handle); + + completion_token_output token_with_probs; + while (llama->has_next_token) { + token_with_probs = llama->doCompletion(env); + if (token_with_probs.tok >= 0 && llama->multibyte_pending <= 0) { + break; + } + } + const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok); + + size_t pos = std::min((size_t) sent_count, llama->generated_text.size()); + + const std::string str_test = llama->generated_text.substr(pos); + bool is_stop_full = false; + size_t stop_pos = llama->findStoppingStrings(str_test, token_text.size(), STOP_FULL); + if (stop_pos != std::string::npos) { + is_stop_full = true; + llama->generated_text.erase(llama->generated_text.begin() + pos + stop_pos, llama->generated_text.end()); + pos = std::min((size_t) sent_count, llama->generated_text.size()); + } else { + is_stop_full = false; + stop_pos = llama->findStoppingStrings(str_test, token_text.size(), STOP_PARTIAL); + } + + std::string to_send; + if (stop_pos == std::string::npos || + // Send rest of the text if we are at the end of the generation + (!llama->has_next_token && !is_stop_full && stop_pos > 0)) { + to_send = llama->generated_text.substr(pos, std::string::npos); + + sent_count += to_send.size(); + std::vector probs_output = {}; + + if (llama->params.sparams.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false); + size_t probs_pos = std::min((size_t) sent_token_probs_index, llama->generated_token_probs.size()); + size_t probs_stop_pos = + std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector( + llama->generated_token_probs.begin() + probs_pos, llama->generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; + } + } else { + to_send = ""; + } + + jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); + for (const auto &tp : token_with_probs.probs) { + jobject jtoken = env->NewObject(c_integer, cc_integer, tp.tok); + jobject jprob = env->NewObject(c_float, cc_float, tp.prob); + env->CallObjectMethod(o_probabilities, m_map_put, jtoken, jprob); + } + + jbyteArray jbytes = parse_jbytes(env, to_send); + return env->NewObject(c_token, cc_token, token_with_probs.tok, jbytes, o_probabilities, sent_count, + sent_token_probs_index, llama->has_next_token); +} + +JNIEXPORT jfloatArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_embed( + JNIEnv *env, jclass clazz, jlong handle, jstring java_prompt) { + auto *llama = reinterpret_cast(handle); + + llama->rewind(); + llama_reset_timings(llama->ctx); + llama->prompt = parse_jstring(env, java_prompt); + llama->params.n_predict = 0; + llama->loadPrompt(env); + llama->beginCompletion(); + llama->doCompletion(env); + + static const int n_embd = llama_n_embd(llama->model); + const float *data = llama_get_embeddings(llama->ctx); + std::vector embedding(data, data + n_embd); + + jfloatArray java_embedding = env->NewFloatArray(embedding.size()); + env->SetFloatArrayRegion(java_embedding, 0, embedding.size(), reinterpret_cast(embedding.data())); + + return java_embedding; +} + +JNIEXPORT jintArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_encode( + JNIEnv *env, jclass clazz, jlong handle, jstring jprompt) { + auto *llama = reinterpret_cast(handle); + + std::string prompt = parse_jstring(env, jprompt); + std::vector tokens = llama->tokenize(prompt, false); + + jintArray java_tokens = env->NewIntArray(tokens.size()); + env->SetIntArrayRegion(java_tokens, 0, tokens.size(), reinterpret_cast(tokens.data())); + + return java_tokens; +} + +JNIEXPORT jbyteArray JNICALL Java_ai_djl_llama_jni_LlamaLibrary_decodeBytes( + JNIEnv *env, jclass clazz, jlong handle, jintArray java_tokens) { + auto *llama = reinterpret_cast(handle); + + jsize length = env->GetArrayLength(java_tokens); + jint *elements = env->GetIntArrayElements(java_tokens, nullptr); + std::vector tokens(elements, elements + length); + std::string text = tokens_to_str(llama->ctx, tokens.cbegin(), tokens.cend()); + + env->ReleaseIntArrayElements(java_tokens, elements, 0); + + return parse_jbytes(env, text); +} + +JNIEXPORT void JNICALL Java_ai_djl_llama_jni_LlamaLibrary_delete(JNIEnv *env, jclass clazz, jlong handle) { + auto *llama = reinterpret_cast(handle); + delete llama; +} diff --git a/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider b/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider new file mode 100644 index 00000000000..d2f8ca8e42c --- /dev/null +++ b/engines/llama/src/main/resources/META-INF/services/ai.djl.engine.EngineProvider @@ -0,0 +1 @@ +ai.djl.llama.engine.LlamaEngineProvider diff --git a/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider b/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider new file mode 100644 index 00000000000..92f6245340f --- /dev/null +++ b/engines/llama/src/main/resources/META-INF/services/ai.djl.repository.zoo.ZooProvider @@ -0,0 +1 @@ +ai.djl.llama.zoo.LlamaZooProvider diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java new file mode 100644 index 00000000000..429cd569392 --- /dev/null +++ b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaInputTest.java @@ -0,0 +1,101 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.llama.engine.LlamaInput.Parameters; +import ai.djl.llama.jni.InputParameters; +import ai.djl.util.JsonUtils; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.io.Reader; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Map; + +public class LlamaInputTest { + + @Test + public void testInputParameters() throws IOException { + Path file = Paths.get("src/test/resources/inputs.json"); + try (Reader reader = Files.newBufferedReader(file)) { + LlamaInput in = JsonUtils.GSON.fromJson(reader, LlamaInput.class); + checkParameters(in); + } + + Parameters param = new Parameters(); + LlamaInput in = new LlamaInput(); + in.setInputs("prompt"); + in.setPrefix("prefix"); + in.setSuffix("suffix"); + in.setParameters(param); + param.setMaxNewTokens(2); + param.setNumberKeep(2); + param.setNumberProbabilities(2); + param.setTopK(2); + param.setTopP(2f); + param.setTfsZ(2f); + param.setTypicalP(2f); + param.setTemperature(2f); + param.setRepeatPenalty(2f); + param.setRepeatLastN(2); + param.setFrequencyPenalty(2f); + param.setFrequencyPenalty(2f); + param.setPresencePenalty(2f); + param.setPenalizeNl(true); + param.setIgnoreEos(true); + param.setMirostat(2); + param.setMirostatTau(2f); + param.setMirostatEta(2f); + param.setNumberBeams(5); + param.setSeed(2); + Map logitBias = Map.of(2, 0.4f, 3, 0.5f); + param.setLogitBias(logitBias); + param.setGrammar("grammar"); + param.setAntiPrompt(new String[] {"User: "}); + checkParameters(in); + } + + private void checkParameters(LlamaInput in) { + InputParameters param = in.getParameters().toInputParameters(); + Assert.assertEquals(param.getMaxNewTokens(), 2); + Assert.assertEquals(param.getNumberKeep(), 2); + Assert.assertEquals(param.getNumberProbabilities(), 2); + Assert.assertEquals(param.getTopK(), 2); + Assert.assertEquals(param.getTopP(), 2f); + Assert.assertEquals(param.getTfsZ(), 2f); + Assert.assertEquals(param.getTypicalP(), 2f); + Assert.assertEquals(param.getTemperature(), 2f); + Assert.assertEquals(param.getRepeatPenalty(), 2f); + Assert.assertEquals(param.getRepeatLastN(), 2); + Assert.assertEquals(param.getFrequencyPenalty(), 2f); + Assert.assertEquals(param.getFrequencyPenalty(), 2f); + Assert.assertEquals(param.getPresencePenalty(), 2f); + Assert.assertTrue(param.isPenalizeNl()); + Assert.assertTrue(param.isIgnoreEos()); + Assert.assertEquals(param.getMirostat(), 2); + Assert.assertEquals(param.getMirostatTau(), 2f); + Assert.assertEquals(param.getMirostatEta(), 2f); + Assert.assertEquals(param.getNumberBeams(), 5); + Assert.assertEquals(param.getSeed(), 2); + Map logitBias = param.getLogitBias(); + Assert.assertNotNull(logitBias); + Assert.assertEquals(logitBias.size(), 2); + Assert.assertEquals(logitBias.get(2), 0.4f); + Assert.assertNotNull(param.getGrammar()); + Assert.assertNotNull(param.getAntiPrompt()[0], "User: "); + } +} diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java new file mode 100644 index 00000000000..7b372ee4258 --- /dev/null +++ b/engines/llama/src/test/java/ai/djl/llama/engine/LlamaTest.java @@ -0,0 +1,143 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.engine; + +import ai.djl.ModelException; +import ai.djl.engine.Engine; +import ai.djl.engine.StandardCapabilities; +import ai.djl.inference.Predictor; +import ai.djl.llama.jni.Token; +import ai.djl.llama.jni.TokenIterator; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDManager; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.testing.TestRequirements; +import ai.djl.training.util.DownloadUtils; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.URI; +import java.nio.file.Path; +import java.nio.file.Paths; + +public class LlamaTest { + + private static final Logger logger = LoggerFactory.getLogger(LlamaTest.class); + + @BeforeClass + public void setUp() { + System.setProperty("DJL_CACHE_DIR", "build/cache"); + } + + @AfterClass + public void tierDown() { + System.clearProperty("DJL_CACHE_DIR"); + } + + @Test + public void testLlamaVersion() { + Engine engine = Engine.getEngine("Llama"); + Assert.assertEquals(engine.getVersion(), "b1696-" + Engine.getDjlVersion()); + Assert.assertNotNull(engine.toString()); + Assert.assertEquals(engine.getRank(), 10); + Assert.assertFalse(engine.hasCapability(StandardCapabilities.CUDA)); + Assert.assertNull(engine.getAlternativeEngine()); + try (NDManager manager = engine.newBaseManager()) { + Assert.assertNotNull(manager); + } + } + + @Test + public void testLlama() throws TranslateException, ModelException, IOException { + TestRequirements.nightly(); + downloadModel(); + Path path = Paths.get("models"); + Criteria criteria = + Criteria.builder() + .setTypes(String.class, TokenIterator.class) + .optModelPath(path) + .optModelName("tinyllama-1.1b-1t-openorca.Q4_K_M") + .optEngine("Llama") + .optOption("number_gpu_layers", "43") + .optTranslatorFactory(new LlamaTranslatorFactory()) + .build(); + + String prompt = + "{\"inputs\": \"<|im_start|>system\n" + + "{system_message}<|im_end|>\n" + + "<|im_start|>user\n" + + "{prompt}<|im_end|>\n" + + "<|im_start|>assistant\", \"parameters\": {\"max_new_tokens\": 10}}"; + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + TokenIterator it = predictor.predict(prompt); + StringBuilder sb = new StringBuilder(); + while (it.hasNext()) { + Token token = it.next(); + Assert.assertNotNull(token.getText()); + Assert.assertTrue(token.getToken() >= 0); + Assert.assertNotNull(token.getProbabilities()); + sb.append(token.getText()); + logger.info("{}", token); + } + Assert.assertTrue(sb.length() > 1); + } + } + + @Test + public void testLlamaInfill() throws TranslateException, ModelException, IOException { + TestRequirements.nightly(); + downloadModel(); + Path path = Paths.get("models/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf"); + Criteria criteria = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(path) + .optOption("number_gpu_layers", "43") + .optEngine("Llama") + .optTranslatorFactory(new LlamaTranslatorFactory()) + .build(); + + String prompt = + "{\n" + + " \"prefix\":\"def remove_non_ascii(s: str) -> str:\n\",\n" + + " \"suffix\":\"\n return result\n\",\n" + + " \"parameters\":{\n" + + " \"max_new_tokens\": 10" + + " }\n" + + "}"; + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + Input in = new Input(); + in.add("data", prompt); + Output out = predictor.predict(in); + Assert.assertNotNull(out.getData().getAsString()); + } + } + + private void downloadModel() throws IOException { + String url = + "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf?download=true"; + Path dir = Paths.get("models/tinyllama-1.1b-1t-openorca.Q4_K_M.gguf"); + DownloadUtils.download(URI.create(url).toURL(), dir, null); + } +} diff --git a/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java b/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java new file mode 100644 index 00000000000..b2ee786419f --- /dev/null +++ b/engines/llama/src/test/java/ai/djl/llama/engine/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains test classes for llama engine. */ +package ai.djl.llama.engine; diff --git a/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java b/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java new file mode 100644 index 00000000000..fab7bacb9e3 --- /dev/null +++ b/engines/llama/src/test/java/ai/djl/llama/zoo/LlamaModelZooTest.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.llama.zoo; + +import ai.djl.repository.zoo.ModelLoader; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.util.Utils; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.nio.file.Paths; +import java.util.Collection; + +public class LlamaModelZooTest { + + @Test + public void testLlamaModelZoo() { + System.setProperty("DJL_CACHE_DIR", "build/cache"); + Utils.deleteQuietly(Paths.get("build/cache/cache")); + try { + ModelZoo zoo = ModelZoo.getModelZoo("ai.djl.huggingface.gguf"); + Collection models = zoo.getModelLoaders(); + Assert.assertFalse(models.isEmpty()); + Assert.assertEquals(zoo.getSupportedEngines().size(), 1); + ModelLoader loader = zoo.getModelLoader("TinyLlama/TinyLlama-1.1B-Chat-v0.6"); + Assert.assertNotNull(loader); + + ModelZoo llamaModelZoo = new LlamaModelZoo(); + Assert.assertFalse(llamaModelZoo.getModelLoaders().isEmpty()); + } finally { + System.clearProperty("DJL_CACHE_DIR"); + } + } + + @Test + public void testOffLine() { + System.setProperty("DJL_CACHE_DIR", "build/cache"); + System.setProperty("ai.djl.offline", "true"); + Utils.deleteQuietly(Paths.get("build/cache/cache")); + try { + // static variables cannot not be initialized properly if directly use LlamaModelZoo() + ModelZoo.getModelZoo("ai.djl.huggingface.gguf"); + + ModelZoo zoo = new LlamaModelZoo(); + Assert.assertFalse(zoo.getModelLoaders().isEmpty()); + } finally { + System.clearProperty("DJL_CACHE_DIR"); + System.clearProperty("ai.djl.offline"); + } + } +} diff --git a/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java b/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java new file mode 100644 index 00000000000..145b2ddcca9 --- /dev/null +++ b/engines/llama/src/test/java/ai/djl/llama/zoo/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains test classes for llama model zoo. */ +package ai.djl.llama.zoo; diff --git a/engines/llama/src/test/resources/inputs.json b/engines/llama/src/test/resources/inputs.json new file mode 100644 index 00000000000..ab77386e1b6 --- /dev/null +++ b/engines/llama/src/test/resources/inputs.json @@ -0,0 +1,33 @@ +{ + "prefix": "def remove_non_ascii(s: str) -> str:", + "suffix": " return result", + "parameters": { + "max_new_tokens": 2, + "number_keep": 2, + "number_probabilities": 2, + "top_k": 2, + "top_p": 2, + "tfs_z": 2, + "typical_p": 2, + "temperature": 2, + "repeat_penalty": 2, + "repeat_last_n": 2, + "frequency_penalty": 2, + "presence_penalty": 2, + "penalize_nl": true, + "ignore_eos": true, + "mirostat": 2, + "mirostat_tau": 2, + "mirostat_eta": 2, + "number_beams": 5, + "seed": 2, + "logit_bias": { + "2": 0.4, + "5": 0.6 + }, + "grammar": "root ::= (expr \"=\" term \"\\n\")+\nexpr ::= term ([-+*/] term)*\nterm ::= [0-9]", + "anti_prompt": [ + "User: " + ] + } +} diff --git a/engines/ml/lightgbm/README.md b/engines/ml/lightgbm/README.md index 3ea950c8935..74ab3eba411 100644 --- a/engines/ml/lightgbm/README.md +++ b/engines/ml/lightgbm/README.md @@ -36,13 +36,13 @@ LightGBM can only run on top of the Linux/Mac/Windows machine using x86_64. ## Installation You can pull the LightGBM engine from the central Maven repository by including the following dependency: -- ai.djl.ml.lightgbm:lightgbm:0.23.0 +- ai.djl.ml.lightgbm:lightgbm:0.26.0 ```xml ai.djl.ml.lightgbm lightgbm - 0.23.0 + 0.26.0 runtime ``` diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java index a253ce3d246..583cd8132b2 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java @@ -18,7 +18,8 @@ /** {@code LgbmEngineProvider} is the LightGBM implementation of {@link EngineProvider}. */ public class LgbmEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (LgbmEngineProvider.class) { - engine = LgbmEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = LgbmEngine.newInstance(); + } } } return engine; diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java index 0bb92645a89..826b1a0f900 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmSymbolBlock.java @@ -46,6 +46,7 @@ public class LgbmSymbolBlock extends AbstractSymbolBlock implements AutoCloseabl * @param iterations the number of iterations the model was trained for * @param handle the Booster handle */ + @SuppressWarnings("this-escape") public LgbmSymbolBlock(LgbmNDManager manager, int iterations, SWIGTYPE_p_p_void handle) { this.handle = new AtomicReference<>(handle); this.iterations = iterations; diff --git a/engines/ml/xgboost/README.md b/engines/ml/xgboost/README.md index d69f1830193..df0a7897e3c 100644 --- a/engines/ml/xgboost/README.md +++ b/engines/ml/xgboost/README.md @@ -37,13 +37,13 @@ XGBoost can only run on top of the Linux/Mac machine. User can build from source ## Installation You can pull the XGBoost engine from the central Maven repository by including the following dependency: -- ai.djl.ml.xgboost:xgboost:0.23.0 +- ai.djl.ml.xgboost:xgboost:0.26.0 ```xml ai.djl.ml.xgboost xgboost - 0.23.0 + 0.26.0 runtime ``` diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java index 19cba32cc71..8b534d5196c 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java @@ -18,7 +18,8 @@ /** {@code XgbEngineProvider} is the XGBoost implementation of {@link EngineProvider}. */ public class XgbEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (XgbEngineProvider.class) { - engine = XgbEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = XgbEngine.newInstance(); + } } } return engine; diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java index bf41acb9b6c..1b3c5ae277f 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbModel.java @@ -80,6 +80,8 @@ private Path findModelFile(String prefix) { String fileName = file.toFile().getName(); if (fileName.endsWith(".json")) { modelName = fileName.substring(0, fileName.length() - 5); + } else if (fileName.endsWith(".xgb")) { + modelName = fileName.substring(0, fileName.length() - 4); } else { modelName = fileName; } @@ -90,13 +92,22 @@ private Path findModelFile(String prefix) { } Path modelFile = modelDir.resolve(prefix); if (Files.notExists(modelFile) || !Files.isRegularFile(modelFile)) { - if (prefix.endsWith(".json")) { + if (prefix.endsWith(".json") || prefix.endsWith(".xgb")) { return null; } modelFile = modelDir.resolve(prefix + ".json"); - if (Files.notExists(modelFile) || !Files.isRegularFile(modelFile)) { - return null; + if (Files.isRegularFile(modelFile)) { + return modelFile; + } + modelFile = modelDir.resolve(prefix + ".xgb"); + if (Files.isRegularFile(modelFile)) { + return modelFile; + } + modelFile = modelDir.resolve("model.xgb"); + if (Files.isRegularFile(modelFile)) { + return modelFile; } + return null; } return modelFile; } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java index 3b56cbca241..81f9708e72b 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java @@ -39,6 +39,7 @@ public class XgbNDManager extends BaseNDManager { private static final XgbNDManager SYSTEM_MANAGER = new SystemManager(); private float missingValue = Float.NaN; + private int nthread = 1; private XgbNDManager(NDManager parent, Device device) { super(parent, device); @@ -57,6 +58,15 @@ public void setMissingValue(float missingValue) { this.missingValue = missingValue; } + /** + * Sets the default number of threads. + * + * @param nthread the default number of threads + */ + public void setNthread(int nthread) { + this.nthread = nthread; + } + /** {@inheritDoc} */ @Override public ByteBuffer allocateDirect(int capacity) { @@ -166,7 +176,7 @@ public NDArray createCSR(Buffer buffer, long[] indptr, long[] indices, Shape sha int[] intIndices = Arrays.stream(indices).mapToInt(Math::toIntExact).toArray(); float[] data = new float[buffer.remaining()]; ((FloatBuffer) buffer).get(data); - long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data); + long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data, missingValue, nthread); return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.CSR); } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java index 1e2bcddd999..43a9e129dea 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java @@ -45,6 +45,7 @@ public class XgbSymbolBlock extends AbstractSymbolBlock implements AutoCloseable * @param manager the manager to use for the block * @param handle the Booster handle */ + @SuppressWarnings("this-escape") public XgbSymbolBlock(XgbNDManager manager, long handle) { this.handle = new AtomicReference<>(handle); this.manager = manager; diff --git a/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java b/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java index fefbe7f0716..eb071552fd0 100644 --- a/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java +++ b/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java @@ -67,9 +67,12 @@ public static long createDMatrix(ColumnBatch columnBatch, float missing, int nth return handles[0]; } - public static long createDMatrixCSR(long[] indptr, int[] indices, float[] array) { + public static long createDMatrixCSR( + long[] indptr, int[] indices, float[] array, float missing, int nthread) { long[] handles = new long[1]; - checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(indptr, indices, array, 0, handles)); + checkCall( + XGBoostJNI.XGDMatrixCreateFromCSR( + indptr, indices, array, 0, missing, nthread, handles)); return handles[0]; } diff --git a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java index 0b09ed6807c..acbfa998867 100644 --- a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java +++ b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java @@ -53,7 +53,7 @@ public void downloadXGBoostModel() throws IOException { @Test public void testVersion() { Engine engine = Engine.getEngine("XGBoost"); - Assert.assertEquals("1.7.5", engine.getVersion()); + Assert.assertEquals("2.0.3", engine.getVersion()); } /* @@ -93,6 +93,7 @@ public void testNDArray() { try (XgbNDManager manager = (XgbNDManager) XgbNDManager.getSystemManager().newSubManager()) { manager.setMissingValue(Float.NaN); + manager.setNthread(1); NDArray zeros = manager.zeros(new Shape(1, 2)); Assert.expectThrows(UnsupportedOperationException.class, zeros::toFloatArray); diff --git a/engines/mxnet/jnarator/build.gradle b/engines/mxnet/jnarator/build.gradle index b9cc0d4cd5f..b9fd8ceab14 100644 --- a/engines/mxnet/jnarator/build.gradle +++ b/engines/mxnet/jnarator/build.gradle @@ -17,6 +17,11 @@ dependencies { checkstyleMain.source = 'src/main/java' pmdMain.source = 'src/main/java' +compileJava { + options.compilerArgs.clear() + options.compilerArgs << "--release" << "11" << "-proc:none" << "-Xlint:all,-options,-static" +} + jar { manifest { attributes ( diff --git a/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java b/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java index 3105ec9cd48..ba3e18fea3b 100644 --- a/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java +++ b/engines/mxnet/jnarator/src/main/java/ai/djl/mxnet/jnarator/JnaGenerator.java @@ -276,6 +276,7 @@ public void writeNativeSize() throws IOException { writer.append(" public NativeSizeByReference() {\n"); writer.append(" this(new NativeSize(0));\n"); writer.append(" }\n\n"); + writer.append(" @SuppressWarnings(\"this-escape\")\n"); writer.append(" public NativeSizeByReference(NativeSize value) {\n"); writer.append(" super(NativeSize.SIZE);\n"); writer.append(" setValue(value);\n"); diff --git a/engines/mxnet/mxnet-engine/README.md b/engines/mxnet/mxnet-engine/README.md index cef559f1e31..66b2c98adc1 100644 --- a/engines/mxnet/mxnet-engine/README.md +++ b/engines/mxnet/mxnet-engine/README.md @@ -7,7 +7,7 @@ This module contains the Deep Java Library (DJL) EngineProvider for Apache MXNet We don't recommend that developers use classes in this module directly. Use of these classes will couple your code with Apache MXNet and make switching between engines difficult. Even so, developers are not restricted from using engine-specific features. For more information, -see [NDManager#invoke()](https://javadoc.io/static/ai.djl/api/0.23.0/ai/djl/ndarray/NDManager.html#invoke-java.lang.String-ai.djl.ndarray.NDArray:A-ai.djl.ndarray.NDArray:A-ai.djl.util.PairList-). +see [NDManager#invoke()](https://javadoc.io/static/ai.djl/api/0.26.0/ai/djl/ndarray/NDManager.html#invoke-java.lang.String-ai.djl.ndarray.NDArray:A-ai.djl.ndarray.NDArray:A-ai.djl.util.PairList-). ## Documentation @@ -33,7 +33,7 @@ You can pull the MXNet engine from the central Maven repository by including the ai.djl.mxnet mxnet-engine - 0.23.0 + 0.26.0 runtime ``` diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java index 62398b1868e..b1ca8e49aa4 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java @@ -63,6 +63,7 @@ public class CachedOp extends NativeResource { * @param dataIndices the input data names required by the model and their corresponding * location */ + @SuppressWarnings("this-escape") public CachedOp( Pointer handle, MxNDManager manager, diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java index f30a6a89252..2a5ab970560 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java @@ -18,7 +18,8 @@ /** {@code MxEngineProvider} is the MXNet implementation of {@link EngineProvider}. */ public class MxEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (MxEngineProvider.class) { - engine = MxEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = MxEngine.newInstance(); + } } } return engine; diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 87ccba78e96..8b884b3993a 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -888,6 +888,13 @@ public NDArray atan() { return manager.invoke("_npi_arctan", this, null); } + /** {@inheritDoc} */ + @Override + public NDArray atan2(NDArray other) { + other = manager.from(other); + return manager.invoke("_npi_arctan2", new NDArray[] {this, other}, null); + } + /** {@inheritDoc} */ @Override public NDArray sinh() { @@ -1153,6 +1160,18 @@ public NDArray stft( throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray reshape(Shape shape) { @@ -1601,6 +1620,12 @@ public NDArray erfinv() { return manager.invoke("erfinv", this, null); } + /** {@inheritDoc} */ + @Override + public NDArray erf() { + return manager.invoke("erf", this, null); + } + /** {@inheritDoc} */ @Override public NDArray norm(boolean keepDims) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java index c7efd80eba3..e1ff0db645a 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java @@ -287,7 +287,7 @@ public NDArray globalMaxPool() { params.add("pool_type", "max"); params.addParam("global_pool", true); try (NDArray temp = getManager().invoke("_npx_pooling", getArray(), params)) { - return temp.reshape(temp.getShape().size(0), temp.getShape().size(1)); + return temp.reshape(-1, temp.getShape().size(1)); } } @@ -318,7 +318,7 @@ public NDArray globalAvgPool() { params.add("pool_type", "avg"); params.addParam("global_pool", true); try (NDArray temp = getManager().invoke("_npx_pooling", getArray(), params)) { - return temp.reshape(temp.getShape().size(0), temp.getShape().size(1)); + return temp.reshape(-1, temp.getShape().size(1)); } } @@ -355,7 +355,7 @@ public NDArray globalLpPool(float normType) { params.addParam("p_value", (int) normType); params.addParam("global_pool", true); try (NDArray temp = getManager().invoke("_npx_pooling", getArray(), params)) { - return temp.reshape(temp.getShape().size(0), temp.getShape().size(1)); + return temp.reshape(-1, temp.getShape().size(1)); } } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java index 5f08cf5910c..99e415cf62c 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java @@ -23,6 +23,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; +import ai.djl.training.listener.AlgebraicListener; import ai.djl.util.PairList; import com.sun.jna.Pointer; @@ -338,12 +339,15 @@ public MxNDManager newSubManager(Device dev) { public void invoke( String operation, NDArray[] src, NDArray[] dest, PairList params) { JnaUtils.op(operation).invoke(this, src, dest, params); + AlgebraicListener.record(operation, src, dest, params); } /** {@inheritDoc} */ @Override public NDList invoke(String operation, NDList src, PairList params) { - return new NDList(JnaUtils.op(operation).invoke(this, src.toArray(EMPTY), params)); + NDArray[] dest = JnaUtils.op(operation).invoke(this, src.toArray(EMPTY), params); + AlgebraicListener.record(operation, src.toArray(EMPTY), dest, params); + return new NDList(dest); } /** @@ -379,7 +383,9 @@ public void invoke(String operation, NDList src, NDList dest, PairList params) { - return JnaUtils.op(operation).invoke(this, src, params)[0]; + NDArray[] dest = JnaUtils.op(operation).invoke(this, src, params); + AlgebraicListener.record(operation, src, dest, params); + return dest[0]; } /** diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java index 36bead164e4..952ca2f0995 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java @@ -40,6 +40,7 @@ public class MxParameterServer extends NativeResource implements Parame * * @param optimizer the optimizer to use for the parameter server updates */ + @SuppressWarnings("this-escape") public MxParameterServer(Optimizer optimizer) { super(createdKVStore()); callback = new OptimizerCallback(optimizer); diff --git a/engines/mxnet/mxnet-model-zoo/README.md b/engines/mxnet/mxnet-model-zoo/README.md index c4f44fe358c..f32678944c0 100644 --- a/engines/mxnet/mxnet-model-zoo/README.md +++ b/engines/mxnet/mxnet-model-zoo/README.md @@ -27,7 +27,7 @@ You can pull the MXNet engine from the central Maven repository by including the ai.djl.mxnet mxnet-model-zoo - 0.23.0 + 0.26.0 ``` diff --git a/engines/mxnet/native/build.gradle b/engines/mxnet/native/build.gradle index 3f8ee285054..dc9d6e5e12d 100644 --- a/engines/mxnet/native/build.gradle +++ b/engines/mxnet/native/build.gradle @@ -89,6 +89,7 @@ flavorNames.each { flavor -> } from file("${BINARY_ROOT}/${flavor}/${osName}") archiveClassifier = "${osName}-x86_64" + archiveBaseName = "mxnet-native-${flavor}" manifest { attributes("Automatic-Module-Name": "ai.djl.mxnet_native_${flavor}_${osName}") diff --git a/engines/onnxruntime/onnxruntime-android/README.md b/engines/onnxruntime/onnxruntime-android/README.md index e304e78d5c3..eba92b84288 100644 --- a/engines/onnxruntime/onnxruntime-android/README.md +++ b/engines/onnxruntime/onnxruntime-android/README.md @@ -6,13 +6,13 @@ This module contains the DJL ONNX Runtime engine for Android. ## Installation You can pull the ONNX Runtime for Android from the central Maven repository by including the following dependency: -- ai.djl.android:onnxruntime:0.23.0 +- ai.djl.android:onnxruntime:0.26.0 ```xml ai.djl.android onnxruntime - 0.23.0 + 0.26.0 runtime ``` diff --git a/engines/onnxruntime/onnxruntime-engine/README.md b/engines/onnxruntime/onnxruntime-engine/README.md index c287819d23f..b89b14f4473 100644 --- a/engines/onnxruntime/onnxruntime-engine/README.md +++ b/engines/onnxruntime/onnxruntime-engine/README.md @@ -37,13 +37,13 @@ for the official ONNX Runtime project. ## Installation You can pull the ONNX Runtime engine from the central Maven repository by including the following dependency: -- ai.djl.onnxruntime:onnxruntime-engine:0.23.0 +- ai.djl.onnxruntime:onnxruntime-engine:0.26.0 ```xml ai.djl.onnxruntime onnxruntime-engine - 0.23.0 + 0.26.0 runtime ``` @@ -61,7 +61,7 @@ Maven: ai.djl.onnxruntime onnxruntime-engine - 0.23.0 + 0.26.0 runtime @@ -81,7 +81,7 @@ Maven: Gradle: ```groovy -implementation("ai.djl.onnxruntime:onnxruntime-engine:0.23.0") { +implementation("ai.djl.onnxruntime:onnxruntime-engine:0.26.0") { exclude group: "com.microsoft.onnxruntime", module: "onnxruntime" } implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.14.0" diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index 89599722435..243377785d8 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -97,7 +97,7 @@ public int getRank() { /** {@inheritDoc} */ @Override public String getVersion() { - return "1.15.1"; + return "1.16.3"; } /** {@inheritDoc} */ diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java index c673b3dcbf1..5616eb80edb 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java @@ -18,7 +18,8 @@ /** {@code OrtEngineProvider} is the ONNX Runtime implementation of {@link EngineProvider}. */ public class OrtEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (OrtEngineProvider.class) { - engine = OrtEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = OrtEngine.newInstance(); + } } } return engine; diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java index aa54b43f376..4e8df210d40 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java @@ -59,6 +59,7 @@ public class OrtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable * @param session the {@link OrtSession} contains the model information * @param manager the {@link NDManager} to holds the NDArray */ + @SuppressWarnings("this-escape") public OrtSymbolBlock(OrtSession session, OrtNDManager manager) { this.session = session; this.manager = manager; diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java index 9d8037cfa8b..d61cb81f1ee 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java @@ -31,6 +31,7 @@ public class OrtModelZoo extends ModelZoo { OrtModelZoo() { addModel(REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolo5s", "0.0.1")); + addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1")); addModel(REPOSITORY.model(Tabular.SOFTMAX_REGRESSION, GROUP_ID, "iris_flowers", "0.0.1")); } diff --git a/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json b/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json new file mode 100644 index 00000000000..1e0169a2561 --- /dev/null +++ b/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json @@ -0,0 +1,40 @@ +{ + "metadataVersion": "0.2", + "resourceType": "model", + "application": "cv/object_detection", + "groupId": "ai.djl.onnxruntime", + "artifactId": "yolov8n", + "name": "yolov8n", + "description": "YoloV8 Model", + "website": "http://www.djl.ai/engines/onnxruntime/model-zoo", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "yolov8n", + "arguments": { + "width": 640, + "height": 640, + "resize": true, + "rescale": true, + "optApplyRatio": true, + "threshold": 0.6, + "translatorFactory": "ai.djl.modality.cv.translator.YoloV8TranslatorFactory" + }, + "files": { + "model": { + "uri": "0.0.1/yolov8n.zip", + "name": "", + "sha1Hash": "9fbad7f706713843cbb8c8d6a56c81a640ec6fa2", + "size": 11053839 + } + } + } + ] +} diff --git a/engines/paddlepaddle/paddlepaddle-engine/README.md b/engines/paddlepaddle/paddlepaddle-engine/README.md index 9e65fb76601..6671cfbcd42 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/README.md +++ b/engines/paddlepaddle/paddlepaddle-engine/README.md @@ -30,7 +30,7 @@ You can pull the PaddlePaddle engine from the central Maven repository by includ ai.djl.paddlepaddle paddlepaddle-engine - 0.23.0 + 0.26.0 runtime ``` diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java index e2b5bdd35a0..e2fb86974f5 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java @@ -18,7 +18,8 @@ /** {@code PpEngineProvider} is the PaddlePaddle implementation of {@link EngineProvider}. */ public class PpEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (PpEngineProvider.class) { - engine = PpEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = PpEngine.newInstance(); + } } } return engine; diff --git a/engines/paddlepaddle/paddlepaddle-model-zoo/README.md b/engines/paddlepaddle/paddlepaddle-model-zoo/README.md index e2c9cf6036c..55d3c67fe50 100644 --- a/engines/paddlepaddle/paddlepaddle-model-zoo/README.md +++ b/engines/paddlepaddle/paddlepaddle-model-zoo/README.md @@ -26,7 +26,7 @@ from the central Maven repository by including the following dependency: ai.djl.paddlepaddle paddlepaddle-model-zoo - 0.23.0 + 0.26.0 ``` diff --git a/engines/paddlepaddle/paddlepaddle-native/build.gradle b/engines/paddlepaddle/paddlepaddle-native/build.gradle index 74a573debad..de1ea58da2b 100644 --- a/engines/paddlepaddle/paddlepaddle-native/build.gradle +++ b/engines/paddlepaddle/paddlepaddle-native/build.gradle @@ -213,6 +213,7 @@ flavorNames.each { flavor -> } from file("${BINARY_ROOT}/${flavor}/${osName}") archiveClassifier = "${osName}-x86_64" + archiveBaseName = "paddlepaddle-native-${flavor}" manifest { attributes("Automatic-Module-Name": "ai.djl.paddlepaddle_native_${flavor}_${osName}") diff --git a/engines/pytorch/pytorch-engine/README.md b/engines/pytorch/pytorch-engine/README.md index ef74cf98808..c8571c54781 100644 --- a/engines/pytorch/pytorch-engine/README.md +++ b/engines/pytorch/pytorch-engine/README.md @@ -24,13 +24,13 @@ The javadocs output is built in the `build/doc/javadoc` folder. ## Installation You can pull the PyTorch engine from the central Maven repository by including the following dependency: -- ai.djl.pytorch:pytorch-engine:0.23.0 +- ai.djl.pytorch:pytorch-engine:0.26.0 ```xml ai.djl.pytorch pytorch-engine - 0.23.0 + 0.26.0 runtime ``` @@ -46,6 +46,9 @@ The following table illustrates which pytorch version that DJL supports: | PyTorch engine version | PyTorch native library version | |------------------------|-------------------------------------------| +| pytorch-engine:0.26.0 | 1.13.1, 2.0.1, **2.1.1** | +| pytorch-engine:0.25.0 | 1.11.0, 1.12.1, **1.13.1**, 2.0.1 | +| pytorch-engine:0.24.0 | 1.11.0, 1.12.1, **1.13.1**, 2.0.1 | | pytorch-engine:0.23.0 | 1.11.0, 1.12.1, **1.13.1**, 2.0.1 | | pytorch-engine:0.22.1 | 1.11.0, 1.12.1, **1.13.1**, 2.0.0 | | pytorch-engine:0.21.0 | 1.11.0, 1.12.1, **1.13.1** | @@ -110,21 +113,21 @@ export PYTORCH_FLAVOR=cpu ### macOS For macOS, you can use the following library: -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cpu:2.0.1:osx-x86_64 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0 +- ai.djl.pytorch:pytorch-native-cpu:2.1.1:osx-x86_64 ```xml ai.djl.pytorch pytorch-native-cpu osx-x86_64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.26.0 runtime ``` @@ -134,21 +137,21 @@ For macOS, you can use the following library: ### macOS M1 For macOS M1, you can use the following library: -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cpu:2.0.1:osx-aarch64 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0 +- ai.djl.pytorch:pytorch-native-cpu:2.1.1:osx-aarch64 ```xml ai.djl.pytorch pytorch-native-cpu osx-aarch64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.26.0 runtime ``` @@ -159,29 +162,29 @@ installed on your GPU machine, you can use one of the following library: #### Linux GPU -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cu118:2.0.1:linux-x86_64 - CUDA 11.8 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0 +- ai.djl.pytorch:pytorch-native-cu121:2.1.1:linux-x86_64 - CUDA 12.1 ```xml ai.djl.pytorch - pytorch-native-cu118 + pytorch-native-cu121 linux-x86_64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.26.0 runtime ``` ### Linux CPU -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cpu:2.0.1:linux-x86_64 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0 +- ai.djl.pytorch:pytorch-native-cpu:2.1.1:linux-x86_64 ```xml @@ -189,20 +192,20 @@ installed on your GPU machine, you can use one of the following library: pytorch-native-cpu linux-x86_64 runtime - 2.0.1 + 2.1.1 ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.26.0 runtime ``` ### For aarch64 build -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.0.1:linux-aarch64 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0 +- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.1.1:linux-aarch64 ```xml @@ -210,12 +213,12 @@ installed on your GPU machine, you can use one of the following library: pytorch-native-cpu-precxx11 linux-aarch64 runtime - 2.0.1 + 2.1.1 ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.26.0 runtime ``` @@ -225,22 +228,22 @@ installed on your GPU machine, you can use one of the following library: We also provide packages for the system like CentOS 7/Ubuntu 14.04 with GLIBC >= 2.17. All the package were built with GCC 7, we provided a newer `libstdc++.so.6.24` in the package that contains `CXXABI_1.3.9` to use the package successfully. -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cu118-precxx11:2.0.1:linux-x86_64 - CUDA 11.8 -- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.0.1:linux-x86_64 - CPU +- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0 +- ai.djl.pytorch:pytorch-native-cu121-precxx11:2.1.1:linux-x86_64 - CUDA 12.1 +- ai.djl.pytorch:pytorch-native-cpu-precxx11:2.1.1:linux-x86_64 - CPU ```xml ai.djl.pytorch - pytorch-native-cu118-precxx11 + pytorch-native-cu121-precxx11 linux-x86_64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.26.0 runtime ``` @@ -250,13 +253,13 @@ All the package were built with GCC 7, we provided a newer `libstdc++.so.6.24` i ai.djl.pytorch pytorch-native-cpu-precxx11 linux-x86_64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.26.0 runtime ``` @@ -271,29 +274,29 @@ For the Windows platform, you can choose between CPU and GPU. #### Windows GPU -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cu118:2.0.1:win-x86_64 - CUDA 11.8 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0 +- ai.djl.pytorch:pytorch-native-cu121:2.1.1:win-x86_64 - CUDA 12.1 ```xml ai.djl.pytorch - pytorch-native-cu118 + pytorch-native-cu121 win-x86_64 - 2.0.1 + 2.1.1 runtime ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.26.0 runtime ``` ### Windows CPU -- ai.djl.pytorch:pytorch-jni:2.0.1-0.23.0 -- ai.djl.pytorch:pytorch-native-cpu:2.0.1:win-x86_64 +- ai.djl.pytorch:pytorch-jni:2.1.1-0.26.0 +- ai.djl.pytorch:pytorch-native-cpu:2.1.1:win-x86_64 ```xml @@ -301,12 +304,12 @@ For the Windows platform, you can choose between CPU and GPU. pytorch-native-cpu win-x86_64 runtime - 2.0.1 + 2.1.1 ai.djl.pytorch pytorch-jni - 2.0.1-0.23.0 + 2.1.1-0.26.0 runtime ``` diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java index 57ae6c09d34..24be3e91d7a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java @@ -18,7 +18,8 @@ /** {@code PtEngineProvider} is the PyTorch implementation of {@link EngineProvider}. */ public class PtEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (PtEngineProvider.class) { - engine = PtEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = PtEngine.newInstance(); + } } } return engine; diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index e72e98c9495..35e95f7de86 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -18,6 +18,7 @@ import ai.djl.Model; import ai.djl.ndarray.types.DataType; import ai.djl.nn.Parameter; +import ai.djl.nn.Parameter.Type; import ai.djl.pytorch.jni.JniUtils; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -189,7 +190,9 @@ public Trainer newTrainer(TrainingConfig trainingConfig) { } if (wasLoaded) { // Unfreeze parameters if training directly - block.freezeParameters(false); + block.freezeParameters( + false, + p -> p.getType() != Type.RUNNING_MEAN && p.getType() != Type.RUNNING_VAR); } for (Pair> pair : initializer) { if (pair.getKey() != null && pair.getValue() != null) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 9e36ec35884..499f51ebad5 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -60,6 +60,7 @@ public class PtNDArray extends NativeResource implements NDArray { * @param manager the manager to attach the new array to * @param handle the pointer to the native PyTorch memory */ + @SuppressWarnings("this-escape") public PtNDArray(PtNDManager manager, long handle) { super(handle); this.manager = manager; @@ -76,6 +77,7 @@ public PtNDArray(PtNDManager manager, long handle) { * @param handle the pointer to the native PyTorch memory * @param data the direct buffer of the data */ + @SuppressWarnings("this-escape") public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { super(handle); this.manager = manager; @@ -93,6 +95,7 @@ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { * @param strs the string array * @param shape the {@link Shape} of the {@link NDArray} */ + @SuppressWarnings("this-escape") public PtNDArray(PtNDManager manager, String[] strs, Shape shape) { super(-1L); this.manager = manager; @@ -888,6 +891,12 @@ public PtNDArray atan() { return JniUtils.atan(this); } + /** {@inheritDoc} */ + @Override + public PtNDArray atan2(NDArray other) { + return JniUtils.atan2(this, manager.from(other)); + } + /** {@inheritDoc} */ @Override public PtNDArray sinh() { @@ -1097,6 +1106,18 @@ public NDArray stft( this, nFft, hopLength, (PtNDArray) window, center, normalize, returnComplex); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + return JniUtils.fft2(this, sizes, axes); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + return JniUtils.ifft2(this, sizes, axes); + } + /** {@inheritDoc} */ @Override public PtNDArray reshape(Shape shape) { @@ -1539,6 +1560,12 @@ public PtNDArray erfinv() { return JniUtils.erfinv(this); } + /** {@inheritDoc} */ + @Override + public PtNDArray erf() { + return JniUtils.erf(this); + } + /** {@inheritDoc} */ @Override public PtNDArray inverse() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java index fa4ee81f26c..b7f92cbd1c3 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java @@ -13,6 +13,7 @@ package ai.djl.pytorch.engine; import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.NDUtils; @@ -24,6 +25,8 @@ import ai.djl.nn.recurrent.RNN; import ai.djl.pytorch.jni.JniUtils; +import java.util.Arrays; +import java.util.Comparator; import java.util.List; /** {@code PtNDArrayEx} is the PyTorch implementation of the {@link NDArrayEx}. */ @@ -760,7 +763,152 @@ public NDList multiBoxDetection( float nmsThreshold, boolean forceSuppress, int nmsTopK) { - throw new UnsupportedOperationException("Not implemented"); + assert (inputs.size() == 3); + + NDArray clsProb = inputs.get(0); + NDArray locPred = inputs.get(1); + NDArray anchors = inputs.get(2).reshape(new Shape(-1, 4)); + + NDManager ndManager = array.getManager(); + + NDArray variances = ndManager.create(new float[] {0.1f, 0.1f, 0.2f, 0.2f}); + + assert (variances.size() == 4); // << "Variance size must be 4"; + final int numClasses = (int) clsProb.size(1); + final int numAnchors = (int) clsProb.size(2); + final int numBatches = (int) clsProb.size(0); + + final float[] pAnchor = anchors.toFloatArray(); + + // [id, prob, xmin, ymin, xmax, ymax] + // TODO Move to NDArray-based implementation + NDList batchOutputs = new NDList(); + for (int nbatch = 0; nbatch < numBatches; ++nbatch) { + float[][] outputs = new float[numAnchors][6]; + final float[] pClsProb = clsProb.get(nbatch).toFloatArray(); + final float[] pLocPred = locPred.get(nbatch).toFloatArray(); + + for (int i = 0; i < numAnchors; ++i) { + // find the predicted class id and probability + float score = -1; + int id = 0; + for (int j = 1; j < numClasses; ++j) { + float temp = pClsProb[j * numAnchors + i]; + if (temp > score) { + score = temp; + id = j; + } + } + + if (id > 0 && score < threshold) { + id = 0; + } + + // [id, prob, xmin, ymin, xmax, ymax] + outputs[i][0] = id - 1; + outputs[i][1] = score; + int offset = i * 4; + float[] pAnchorRow4 = new float[4]; + pAnchorRow4[0] = pAnchor[offset]; + pAnchorRow4[1] = pAnchor[offset + 1]; + pAnchorRow4[2] = pAnchor[offset + 2]; + pAnchorRow4[3] = pAnchor[offset + 3]; + float[] pLocPredRow4 = new float[4]; + pLocPredRow4[0] = pLocPred[offset]; + pLocPredRow4[1] = pLocPred[offset + 1]; + pLocPredRow4[2] = pLocPred[offset + 2]; + pLocPredRow4[3] = pLocPred[offset + 3]; + float[] outRowLast4 = + transformLocations( + pAnchorRow4, + pLocPredRow4, + clip, + variances.toFloatArray()[0], + variances.toFloatArray()[1], + variances.toFloatArray()[2], + variances.toFloatArray()[3]); + outputs[i][2] = outRowLast4[0]; + outputs[i][3] = outRowLast4[1]; + outputs[i][4] = outRowLast4[2]; + outputs[i][5] = outRowLast4[3]; + } + + outputs = + Arrays.stream(outputs) + .filter(o -> o[0] >= 0) + .sorted(Comparator.comparing(o -> -o[1])) + .toArray(float[][]::new); + + // apply nms + for (int i = 0; i < outputs.length; ++i) { + for (int j = i + 1; j < outputs.length; ++j) { + if (outputs[i][0] == outputs[j][0]) { + float[] outputsIRow4 = new float[4]; + float[] outputsJRow4 = new float[4]; + outputsIRow4[0] = outputs[i][2]; + outputsIRow4[1] = outputs[i][3]; + outputsIRow4[2] = outputs[i][4]; + outputsIRow4[3] = outputs[i][5]; + outputsJRow4[0] = outputs[j][2]; + outputsJRow4[1] = outputs[j][3]; + outputsJRow4[2] = outputs[j][4]; + outputsJRow4[3] = outputs[j][5]; + float iou = calculateOverlap(outputsIRow4, outputsJRow4); + if (iou >= nmsThreshold) { + outputs[j][0] = -1; + } + } + } + } + batchOutputs.add(ndManager.create(outputs)); + } // end iter batch + + NDArray pOutNDArray = NDArrays.stack(batchOutputs); + NDList resultNDList = new NDList(); + resultNDList.add(pOutNDArray); + assert (resultNDList.size() == 1); + return resultNDList; + } + + private float[] transformLocations( + final float[] anchors, + final float[] locPred, + final boolean clip, + final float vx, + final float vy, + final float vw, + final float vh) { + float[] outRowLast4 = new float[4]; + // transform predictions to detection results + float al = anchors[0]; + float at = anchors[1]; + float ar = anchors[2]; + float ab = anchors[3]; + float aw = ar - al; + float ah = ab - at; + float ax = (al + ar) / 2.f; + float ay = (at + ab) / 2.f; + float px = locPred[0]; + float py = locPred[1]; + float pw = locPred[2]; + float ph = locPred[3]; + float ox = px * vx * aw + ax; + float oy = py * vy * ah + ay; + float ow = (float) (Math.exp(pw * vw) * aw / 2); + float oh = (float) (Math.exp(ph * vh) * ah / 2); + outRowLast4[0] = clip ? Math.max(0f, Math.min(1f, ox - ow)) : (ox - ow); + outRowLast4[1] = clip ? Math.max(0f, Math.min(1f, oy - oh)) : (oy - oh); + outRowLast4[2] = clip ? Math.max(0f, Math.min(1f, ox + ow)) : (ox + ow); + outRowLast4[3] = clip ? Math.max(0f, Math.min(1f, oy + oh)) : (oy + oh); + return outRowLast4; + } + + private float calculateOverlap(final float[] a, final float[] b) { + float w = Math.max(0f, Math.min(a[2], b[2]) - Math.max(a[0], b[0])); + float h = Math.max(0f, Math.min(a[3], b[3]) - Math.max(a[1], b[1])); + float i = w * h; + float u = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - i; + return u <= 0.f ? 0f : (i / u); } /** {@inheritDoc} */ diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 8bc28a2c21b..7075cb05efa 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -67,6 +67,7 @@ public class PtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable * @param manager the manager to use for the block * @param handle the module handle */ + @SuppressWarnings("this-escape") public PtSymbolBlock(PtNDManager manager, long handle) { this(manager); this.handle = new AtomicReference<>(handle); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index aad38ae8f0c..40a6a0065bc 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -1040,6 +1040,18 @@ public static PtNDArray stft( return new PtNDArray(ndArray.getManager(), handle); } + public static PtNDArray fft2(PtNDArray ndArray, long[] sizes, long[] axes) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchFft2(ndArray.getHandle(), sizes, axes)); + } + + public static PtNDArray ifft2(PtNDArray ndArray, long[] sizes, long[] axes) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchIfft2(ndArray.getHandle(), sizes, axes)); + } + public static PtNDArray real(PtNDArray ndArray) { long handle = PyTorchLibrary.LIB.torchViewAsReal(ndArray.getHandle()); if (handle == -1) { @@ -1145,6 +1157,12 @@ public static PtNDArray atan(PtNDArray ndArray) { ndArray.getManager(), PyTorchLibrary.LIB.torchAtan(ndArray.getHandle())); } + public static PtNDArray atan2(PtNDArray self, PtNDArray other) { + return new PtNDArray( + self.getManager(), + PyTorchLibrary.LIB.torchAtan2(self.getHandle(), other.getHandle())); + } + public static PtNDArray sqrt(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt(ndArray.getHandle())); @@ -1334,6 +1352,11 @@ public static PtNDArray erfinv(PtNDArray ndArray) { ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv(ndArray.getHandle())); } + public static PtNDArray erf(PtNDArray ndArray) { + return new PtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchErf(ndArray.getHandle())); + } + public static PtNDArray inverse(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchInverse(ndArray.getHandle())); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java index 9d422463910..03835b6ca68 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java @@ -106,9 +106,16 @@ public static String getLibtorchPath() { private static void loadLibTorch(LibTorch libTorch) { Path libDir = libTorch.dir.toAbsolutePath(); - if ("1.8.1".equals(getVersion()) && System.getProperty("os.name").startsWith("Mac")) { - // PyTorch 1.8.1 libtorch_cpu.dylib cannot be loaded individually - return; + if (Files.exists(libDir.resolve("libstdc++.so.6"))) { + String libstd = Utils.getEnvOrSystemProperty("LIBSTDCXX_LIBRARY_PATH"); + if (libstd != null) { + try { + logger.info("Loading libstdc++.so.6 from: {}", libstd); + System.load(libstd); + } catch (UnsatisfiedLinkError e) { + logger.warn("Failed Loading libstdc++.so.6 from: {}", libstd); + } + } } boolean isCuda = libTorch.flavor.contains("cu"); List deferred = @@ -120,6 +127,7 @@ private static void loadLibTorch(LibTorch libTorch) { System.mapLibraryName("torch_cuda_cpp"), System.mapLibraryName("torch_cuda_cu"), System.mapLibraryName("torch_cuda"), + System.mapLibraryName("nvfuser_codegen"), System.mapLibraryName("torch")); Set loadLater = new HashSet<>(deferred); @@ -133,7 +141,8 @@ private static void loadLibTorch(LibTorch libTorch) { && name.contains("cudart") && name.contains("nvTools")) { return false; - } else if (name.startsWith("libarm_compute-")) { + } else if (name.startsWith("libarm_compute-") + || name.startsWith("libopenblasp")) { rank.put(path, 2); return true; } else if (name.startsWith("libarm_compute_")) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index c0f7b553ab2..54fc5419145 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -273,6 +273,10 @@ native long torchStft( boolean normalize, boolean returnComplex); + native long torchFft2(long handle, long[] sizes, long[] axes); + + native long torchIfft2(long handle, long[] sizes, long[] axes); + native long torchViewAsReal(long handle); native long torchViewAsComplex(long handle); @@ -332,6 +336,8 @@ native long[] torchUnique( native long torchAtan(long handle); + native long torchAtan2(long self, long other); + native long torchSqrt(long handle); native long torchSinh(long handle); @@ -405,6 +411,8 @@ native long tensorUniform( native long torchErfinv(long handle); + native long torchErf(long handle); + native long torchInverse(long self); native long torchNNInterpolate(long handle, long[] size, int mode, boolean alignCorners); diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/LibUtilsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java similarity index 73% rename from engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/LibUtilsTest.java rename to engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java index 617d2cfb809..f6cfda91106 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/LibUtilsTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java @@ -18,17 +18,21 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -public class LibUtilsTest { +// Ensure this test run first +public class ALibUtilsTest { @BeforeClass public void setup() { - System.setProperty( - "ai.djl.pytorch.native_helper", "ai.djl.pytorch.integration.LibUtilsTest"); + System.setProperty("ai.djl.pytorch.native_helper", ALibUtilsTest.class.getName()); + System.setProperty("STDCXX_LIBRARY_PATH", "/usr/lib/non-exists"); + System.setProperty("PYTORCH_PRECXX11", "true"); } @AfterClass public void teardown() { System.clearProperty("ai.djl.pytorch.native_helper"); + System.clearProperty("LIBSTDCXX_LIBRARY_PATH"); + System.clearProperty("PYTORCH_PRECXX11"); } @Test diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java index 8b4e2326f26..5b6ed349e10 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java @@ -13,6 +13,7 @@ package ai.djl.pytorch.integration; import ai.djl.Device; +import ai.djl.modality.Classifications; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; @@ -21,6 +22,10 @@ import org.testng.SkipException; import org.testng.annotations.Test; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + public class MpsTest { @Test @@ -36,4 +41,39 @@ public void testMps() { Assert.assertEquals(array.getDevice().getDeviceType(), "mps"); } } + + private static boolean checkMpsCompatible() { + return "aarch64".equals(System.getProperty("os.arch")) + && System.getProperty("os.name").startsWith("Mac"); + } + + @Test + public void testToTensorMPS() { + if (!checkMpsCompatible()) { + throw new SkipException("MPS toTensor test requires Apple Silicon macOS."); + } + + // Test that toTensor does not fail on MPS (e.g. due to use of float64 for division) + try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) { + NDArray array = manager.create(127f).reshape(1, 1, 1, 1); + NDArray tensor = array.getNDArrayInternal().toTensor(); + Assert.assertEquals(tensor.toFloatArray(), new float[] {127f / 255f}); + } + } + + @Test + public void testClassificationsMPS() { + if (!checkMpsCompatible()) { + throw new SkipException("MPS classification test requires Apple Silicon macOS."); + } + + // Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to + // float64) + try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) { + List names = Arrays.asList("First", "Second", "Third", "Fourth", "Fifth"); + NDArray tensor = manager.create(new float[] {0f, 0.125f, 1f, 0.5f, 0.25f}); + Classifications classifications = new Classifications(names, tensor); + Assert.assertEquals(classifications.topK(1), Collections.singletonList("Third")); + } + } } diff --git a/engines/pytorch/pytorch-jni/build.gradle b/engines/pytorch/pytorch-jni/build.gradle index 450c832e803..c2b0ee9dc7b 100644 --- a/engines/pytorch/pytorch-jni/build.gradle +++ b/engines/pytorch/pytorch-jni/build.gradle @@ -24,7 +24,13 @@ processResources { "osx-x86_64/cpu/libdjl_torch.dylib", "win-x86_64/cpu/djl_torch.dll" ] - if (ptVersion.startsWith("2.0.")) { + if (ptVersion.startsWith("2.1.")) { + files.add("linux-aarch64/cpu-precxx11/libdjl_torch.so") + files.add("linux-x86_64/cu121/libdjl_torch.so") + files.add("linux-x86_64/cu121-precxx11/libdjl_torch.so") + files.add("win-x86_64/cu121/djl_torch.dll") + files.add("osx-aarch64/cpu/libdjl_torch.dylib") + } else if (ptVersion.startsWith("2.0.")) { files.add("linux-aarch64/cpu-precxx11/libdjl_torch.so") files.add("linux-x86_64/cu118/libdjl_torch.so") files.add("linux-x86_64/cu118-precxx11/libdjl_torch.so") diff --git a/engines/pytorch/pytorch-model-zoo/README.md b/engines/pytorch/pytorch-model-zoo/README.md index 8d3113842e1..f598dd2aecd 100644 --- a/engines/pytorch/pytorch-model-zoo/README.md +++ b/engines/pytorch/pytorch-model-zoo/README.md @@ -25,7 +25,7 @@ You can pull the PyTorch engine from the central Maven repository by including t ai.djl.pytorch pytorch-model-zoo - 0.23.0 + 0.26.0 ``` diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java index ea70871eff0..abb820cced9 100644 --- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java +++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java @@ -38,6 +38,7 @@ public class PtModelZoo extends ModelZoo { REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov5s", "0.0.1")); + addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1")); addModel(REPOSITORY.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa", "0.0.1")); addModel(REPOSITORY.model(NLP.SENTIMENT_ANALYSIS, GROUP_ID, "distilbert", "0.0.1")); addModel(REPOSITORY.model(CV.IMAGE_GENERATION, GROUP_ID, "biggan-deep", "0.0.1")); diff --git a/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/yolov8n/metadata.json b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/yolov8n/metadata.json new file mode 100644 index 00000000000..399b79b4889 --- /dev/null +++ b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/yolov8n/metadata.json @@ -0,0 +1,40 @@ +{ + "metadataVersion": "0.2", + "resourceType": "model", + "application": "cv/object_detection", + "groupId": "ai.djl.pytorch", + "artifactId": "yolov8n", + "name": "yolov8n", + "description": "YoloV8 Model", + "website": "http://www.djl.ai/engines/onnxruntime/model-zoo", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "yolov8n", + "arguments": { + "width": 640, + "height": 640, + "resize": true, + "rescale": true, + "optApplyRatio": true, + "threshold": 0.6, + "translatorFactory": "ai.djl.modality.cv.translator.YoloV8TranslatorFactory" + }, + "files": { + "model": { + "uri": "0.0.1/yolov8n.zip", + "name": "", + "sha1Hash": "a868778452ef8d6d2f9cb7109a9e14a64e851d48", + "size": 11183356 + } + } + } + ] +} diff --git a/engines/pytorch/pytorch-native/CMakeLists.txt b/engines/pytorch/pytorch-native/CMakeLists.txt index 4453186be6f..c53d71dc93e 100644 --- a/engines/pytorch/pytorch-native/CMakeLists.txt +++ b/engines/pytorch/pytorch-native/CMakeLists.txt @@ -60,11 +60,12 @@ if(USE_CUDA) endif() add_library(djl_torch SHARED ${SOURCE_FILES}) +set_property(TARGET djl_torch PROPERTY CXX_STANDARD 17) + # build host if(NOT BUILD_ANDROID) target_link_libraries(djl_torch "${TORCH_LIBRARIES}") target_include_directories(djl_torch PUBLIC build/include ${JNI_INCLUDE_DIRS} ${UTILS_INCLUDE_DIR}) - set_property(TARGET djl_torch PROPERTY CXX_STANDARD 14) # We have to kill the default rpath and use current dir set(CMAKE_SKIP_RPATH TRUE) if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") diff --git a/engines/pytorch/pytorch-native/build.gradle b/engines/pytorch/pytorch-native/build.gradle index b4a195e109f..99a658bf3ed 100644 --- a/engines/pytorch/pytorch-native/build.gradle +++ b/engines/pytorch/pytorch-native/build.gradle @@ -24,6 +24,8 @@ if (project.hasProperty("cu11")) { FLAVOR = "cu117" } else if (VERSION.startsWith("2.0.")) { FLAVOR = "cu118" + } else if (VERSION.startsWith("2.1.")) { + FLAVOR = "cu121" } else { throw new GradleException("Unsupported PyTorch version: ${VERSION}") } @@ -88,15 +90,17 @@ def prepareNativeLib(String binaryRoot, String ver) { def officialPytorchUrl = "https://download.pytorch.org/libtorch" def aarch64PytorchUrl = "https://djl-ai.s3.amazonaws.com/publish/pytorch" - String cu11 + String cuda if (ver.startsWith("1.11.")) { - cu11 = "cu113" + cuda = "cu113" } else if (ver.startsWith("1.12.")) { - cu11 = "cu116" + cuda = "cu116" } else if (ver.startsWith("1.13.")) { - cu11 = "cu117" + cuda = "cu117" } else if (ver.startsWith("2.0.")) { - cu11 = "cu118" + cuda = "cu118" + } else if (ver.startsWith("2.1.")) { + cuda = "cu121" } else { throw new GradleException("Unsupported PyTorch version: ${ver}") } @@ -105,10 +109,10 @@ def prepareNativeLib(String binaryRoot, String ver) { "cpu/libtorch-cxx11-abi-shared-with-deps-${ver}%2Bcpu.zip" : "cpu/linux-x86_64", "cpu/libtorch-macos-${ver}.zip" : "cpu/osx-x86_64", "cpu/libtorch-win-shared-with-deps-${ver}%2Bcpu.zip" : "cpu/win-x86_64", - "${cu11}/libtorch-cxx11-abi-shared-with-deps-${ver}%2B${cu11}.zip": "${cu11}/linux-x86_64", - "${cu11}/libtorch-win-shared-with-deps-${ver}%2B${cu11}.zip" : "${cu11}/win-x86_64", + "${cuda}/libtorch-cxx11-abi-shared-with-deps-${ver}%2B${cuda}.zip": "${cuda}/linux-x86_64", + "${cuda}/libtorch-win-shared-with-deps-${ver}%2B${cuda}.zip" : "${cuda}/win-x86_64", "cpu/libtorch-shared-with-deps-${ver}%2Bcpu.zip" : "cpu-precxx11/linux-x86_64", - "${cu11}/libtorch-shared-with-deps-${ver}%2B${cu11}.zip" : "${cu11}-precxx11/linux-x86_64" + "${cuda}/libtorch-shared-with-deps-${ver}%2B${cuda}.zip" : "${cuda}-precxx11/linux-x86_64" ] def aarch64Files = [ @@ -138,17 +142,12 @@ def copyNativeLibToOutputDir(Map fileStoreMap, String binaryRoot from zipTree(file) into outputDir } - // CPU dependencies - copy { - from("${outputDir}/libtorch/lib/") { - include "libc10.*", "c10.dll", "libiomp5*.*", "libarm_compute*.*", "libgomp*.*", "libnvfuser_codegen.so", "libtorch.*", "libtorch_cpu.*", "torch.dll", "torch_cpu.dll", "fbgemm.dll", "asmjit.dll", "uv.dll", "nvfuser_codegen.dll" - } - into("${outputDir}/native/lib") - } - // GPU dependencies + delete "${outputDir}/libtorch/lib/*.lib" + delete "${outputDir}/libtorch/lib/*.a" + copy { from("${outputDir}/libtorch/lib/") { - include "libtorch_cuda*.so", "torch_cuda*.dll", "libc10_cuda.so", "c10_cuda.dll", "libcaffe2_nvrtc.so", "libnvrtc*.so.*", "libcudart*.*", "*nvToolsExt*.*", "cudnn*.dll", "caffe2_nvrtc.dll", "nvrtc64*.dll", "uv.dll", "libcublas*", "zlibwapi.dll" + include "libarm_compute*", "libc10_cuda.so", "libc10.*", "libcaffe2_nvrtc.so", "libcu*", "libgfortran-*", "libgomp*", "libiomp*", "libnv*", "libopenblasp-*", "libtorch_cpu.*", "libtorch_cuda*.so", "libtorch.*", "asmjit.dll", "c10_cuda.dll", "c10.dll", "caffe2_nvrtc.dll", "cu*.dll", "fbgemm.dll", "nv*.dll", "torch_cpu.dll", "torch_cuda*.dll", "torch.dll", "uv.dll", "zlibwapi.dll" } into("${outputDir}/native/lib") } @@ -287,9 +286,9 @@ tasks.register('uploadS3') { "${BINARY_ROOT}/cpu/win-x86_64/native/lib/", "${BINARY_ROOT}/cpu-precxx11/linux-aarch64/native/lib/", "${BINARY_ROOT}/cpu-precxx11/linux-x86_64/native/lib/", - "${BINARY_ROOT}/cu118/linux-x86_64/native/lib/", - "${BINARY_ROOT}/cu118/win-x86_64/native/lib/", - "${BINARY_ROOT}/cu118-precxx11/linux-x86_64/native/lib/" + "${BINARY_ROOT}/cu121/linux-x86_64/native/lib/", + "${BINARY_ROOT}/cu121/win-x86_64/native/lib/", + "${BINARY_ROOT}/cu121-precxx11/linux-x86_64/native/lib/" ] uploadDirs.each { item -> fileTree(item).files.name.each { diff --git a/engines/pytorch/pytorch-native/build.sh b/engines/pytorch/pytorch-native/build.sh index 78c59d6bf2a..ae0456bec62 100755 --- a/engines/pytorch/pytorch-native/build.sh +++ b/engines/pytorch/pytorch-native/build.sh @@ -23,22 +23,22 @@ ARCH=$4 if [[ ! -d "libtorch" ]]; then if [[ $PLATFORM == 'linux' ]]; then - if [[ ! "$FLAVOR" =~ ^(cpu|cu102|cu113|cu116|cu117|cu118)$ ]]; then + if [[ ! "$FLAVOR" =~ ^(cpu|cu102|cu113|cu116|cu117|cu118|cu121)$ ]]; then echo "$FLAVOR is not supported." exit 1 fi if [[ $ARCH == 'aarch64' ]]; then - curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch${AARCH64_CXX11ABI}-shared-with-deps-${VERSION}-aarch64.zip | jar xv + curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch${AARCH64_CXX11ABI}-shared-with-deps-${VERSION}-aarch64.zip | jar xv > /dev/null else - curl -s https://download.pytorch.org/libtorch/${FLAVOR}/libtorch${CXX11ABI}-shared-with-deps-${VERSION}%2B${FLAVOR}.zip | jar xv + curl -s https://download.pytorch.org/libtorch/${FLAVOR}/libtorch${CXX11ABI}-shared-with-deps-${VERSION}%2B${FLAVOR}.zip | jar xv > /dev/null fi elif [[ $PLATFORM == 'darwin' ]]; then if [[ $ARCH == 'aarch64' ]]; then - curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch-macos-${VERSION}-aarch64.zip | jar xv + curl -s https://djl-ai.s3.amazonaws.com/publish/pytorch/${VERSION}/libtorch-macos-${VERSION}-aarch64.zip | jar xv > /dev/null else - curl -s https://download.pytorch.org/libtorch/cpu/libtorch-macos-${VERSION}.zip | jar xv + curl -s https://download.pytorch.org/libtorch/cpu/libtorch-macos-${VERSION}.zip | jar xv > /dev/null fi else echo "$PLATFORM is not supported." @@ -62,6 +62,12 @@ mkdir classes javac -sourcepath ../../pytorch-engine/src/main/java/ ../../pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java -h include -d classes cmake -DCMAKE_PREFIX_PATH=libtorch -DPT_VERSION=${PT_VERSION} -DUSE_CUDA=$USE_CUDA .. cmake --build . --config Release -- -j "${NUM_PROC}" +if [[ "$FLAVOR" = cu* ]]; then + # avoid link with libcudart.so.11.0 + sed -i -r "s/\/usr\/local\/cuda(.{5})?\/lib64\/lib(cudart|nvrtc).so//g" CMakeFiles/djl_torch.dir/link.txt + rm libdjl_torch.so + . CMakeFiles/djl_torch.dir/link.txt +fi if [[ $PLATFORM == 'darwin' ]]; then install_name_tool -add_rpath @loader_path libdjl_torch.dylib diff --git a/engines/pytorch/pytorch-native/build_android.sh b/engines/pytorch/pytorch-native/build_android.sh index b37dd96a86d..72050b20a85 100755 --- a/engines/pytorch/pytorch-native/build_android.sh +++ b/engines/pytorch/pytorch-native/build_android.sh @@ -20,7 +20,7 @@ if [[ ! -d libtorch_android/"$FLAVOR" ]]; then mkdir -p libtorch_android/"$FLAVOR" cd libtorch_android/"$FLAVOR" echo "Downloading https://publish.djl.ai/pytorch/$VERSION/android_native/${FLAVOR}_native.zip" - curl -s "https://publish.djl.ai/pytorch/$VERSION/android_native/${FLAVOR}_native.zip" | jar xv + curl -s "https://publish.djl.ai/pytorch/$VERSION/android_native/${FLAVOR}_native.zip" | jar xv > /dev/null mv install/include include cd - fi diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc index 5a65e1eca69..08932098da9 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_other.cc @@ -34,6 +34,28 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft( API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchFft2( + JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const std::vector sizes = djl::utils::jni::GetVecFromJLongArray(env, js); + const std::vector axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes); + const auto* result_ptr = new torch::Tensor(torch::fft_fft2(*tensor_ptr, sizes, axes)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIfft2( + JNIEnv* env, jobject jthis, jlong jhandle, jlongArray js, jlongArray jaxes) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const std::vector sizes = djl::utils::jni::GetVecFromJLongArray(env, js); + const std::vector axes = djl::utils::jni::GetVecFromJLongArray(env, jaxes); + const auto* result_ptr = new torch::Tensor(torch::fft_ifft2(*tensor_ptr, sizes, axes)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchStft(JNIEnv* env, jobject jthis, jlong jhandle, jlong jn_fft, jlong jhop_length, jlong jwindow, jboolean jcenter, jboolean jnormalize, jboolean jreturn_complex) { #ifdef V1_11_X diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc index 28e40e916be..ccf2616dc65 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc @@ -355,6 +355,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan(JNIEnv* API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan2( +JNIEnv* env, jobject jthis, jlong jself, jlong jother) { + API_BEGIN() + const auto* self_ptr = reinterpret_cast(jself); + const auto* other_ptr = reinterpret_cast(jother); + const auto* result_ptr = new torch::Tensor(self_ptr->atan2(*other_ptr)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSqrt(JNIEnv* env, jobject jthis, jlong jhandle) { API_BEGIN() const auto* tensor_ptr = reinterpret_cast(jhandle); @@ -496,6 +506,14 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErfinv(JNIEn API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErf(JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const auto* result_ptr = new torch::Tensor(tensor_ptr->erf()); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchInverse(JNIEnv* env, jobject jthis, jlong jself) { API_BEGIN() const auto* self_ptr = reinterpret_cast(jself); diff --git a/engines/tensorflow/tensorflow-api/README.md b/engines/tensorflow/tensorflow-api/README.md index fd2741dc9e4..9e151a274a0 100644 --- a/engines/tensorflow/tensorflow-api/README.md +++ b/engines/tensorflow/tensorflow-api/README.md @@ -16,6 +16,6 @@ You can pull the TensorFlow core java API from the central Maven repository by i ai.djl.tensorflow tensorflow-api - 0.23.0 + 0.26.0 ``` diff --git a/engines/tensorflow/tensorflow-engine/README.md b/engines/tensorflow/tensorflow-engine/README.md index 57bcdda98d7..5a6ac3e6da1 100644 --- a/engines/tensorflow/tensorflow-engine/README.md +++ b/engines/tensorflow/tensorflow-engine/README.md @@ -28,13 +28,13 @@ The javadocs output is built in the `build/doc/javadoc` folder. You can pull the TensorFlow engine from the central Maven repository by including the following dependency: -- ai.djl.tensorflow:tensorflow-engine:0.23.0 +- ai.djl.tensorflow:tensorflow-engine:0.26.0 ```xml ai.djl.tensorflow tensorflow-engine - 0.23.0 + 0.26.0 runtime ``` diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java index d964ea5c295..fa7813a49fb 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java @@ -18,7 +18,8 @@ /** {@code TfEngineProvider} is the TensorFlow implementation of {@link EngineProvider}. */ public class TfEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (TfEngineProvider.class) { - engine = TfEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = TfEngine.newInstance(); + } } } return engine; diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 07c31bacd99..419be4c09f6 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -457,6 +457,12 @@ public NDArray erfinv() { return manager.opExecutor("Erfinv").addInput(this).buildSingletonOrThrow(); } + /** {@inheritDoc} */ + @Override + public NDArray erf() { + return manager.opExecutor("Erf").addInput(this).buildSingletonOrThrow(); + } + /** {@inheritDoc} */ @Override public NDArray norm(boolean keepDims) { @@ -911,6 +917,12 @@ public NDArray atan() { return manager.opExecutor("Atan").addInput(this).buildSingletonOrThrow(); } + /** {@inheritDoc} */ + @Override + public NDArray atan2(NDArray other) { + return manager.opExecutor("Atan2").addInput(this).addInput(other).buildSingletonOrThrow(); + } + /** {@inheritDoc} */ @Override public NDArray sinh() { @@ -1172,6 +1184,18 @@ public NDArray stft( throw new UnsupportedOperationException("Not implemented yet."); } + /** {@inheritDoc} */ + @Override + public NDArray fft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public NDArray ifft2(long[] sizes, long[] axes) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray reshape(Shape shape) { diff --git a/engines/tensorflow/tensorflow-model-zoo/README.md b/engines/tensorflow/tensorflow-model-zoo/README.md index b34154fa126..975caa6df82 100644 --- a/engines/tensorflow/tensorflow-model-zoo/README.md +++ b/engines/tensorflow/tensorflow-model-zoo/README.md @@ -26,7 +26,7 @@ from the central Maven repository by including the following dependency: ai.djl.tensorflow tensorflow-model-zoo - 0.23.0 + 0.26.0 ``` diff --git a/engines/tensorflow/tensorflow-native/build.gradle b/engines/tensorflow/tensorflow-native/build.gradle index 8138d93334d..56cd6eed9e2 100644 --- a/engines/tensorflow/tensorflow-native/build.gradle +++ b/engines/tensorflow/tensorflow-native/build.gradle @@ -153,6 +153,7 @@ flavorNames.each { flavor -> } from file("${BINARY_ROOT}/${flavor}/${osName}") archiveClassifier = "${osName}-x86_64" + archiveBaseName = "tensorflow-native-${flavor}" manifest { attributes("Automatic-Module-Name": "ai.djl.tensorflow_native_${flavor}_${osName}") diff --git a/engines/tensorrt/README.md b/engines/tensorrt/README.md index 6373386479e..8100b615e24 100644 --- a/engines/tensorrt/README.md +++ b/engines/tensorrt/README.md @@ -28,13 +28,13 @@ The javadocs output is generated in the `build/doc/javadoc` folder. ## Installation You can pull the TensorRT engine from the central Maven repository by including the following dependency: -- ai.djl.tensorrt:tensorrt:0.23.0 +- ai.djl.tensorrt:tensorrt:0.26.0 ```xml ai.djl.tensorrt tensorrt - 0.23.0 + 0.26.0 runtime ``` diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java index 05a7eceeb41..8c90859c6c6 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java @@ -18,7 +18,8 @@ /** {@code TrtEngineProvider} is the TensorRT implementation of {@link EngineProvider}. */ public class TrtEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (TrtEngineProvider.class) { - engine = TrtEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = TrtEngine.newInstance(); + } } } return engine; diff --git a/engines/tflite/tflite-engine/README.md b/engines/tflite/tflite-engine/README.md index b1dd8fc9778..861a66f9aaa 100644 --- a/engines/tflite/tflite-engine/README.md +++ b/engines/tflite/tflite-engine/README.md @@ -24,13 +24,13 @@ The javadocs output is built in the `build/doc/javadoc` folder. ## Installation You can pull the TensorFlow Lite engine from the central Maven repository by including the following dependency: -- ai.djl.tflite:tflite-engine:0.23.0 +- ai.djl.tflite:tflite-engine:0.26.0 ```xml ai.djl.tflite tflite-engine - 0.23.0 + 0.26.0 runtime ``` diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java index aa0fdb73d21..b46cad53b99 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java @@ -18,7 +18,8 @@ /** {@code TfLiteEngineProvider} is the TFLite implementation of {@link EngineProvider}. */ public class TfLiteEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (TfLiteEngineProvider.class) { - engine = TfLiteEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = TfLiteEngine.newInstance(); + } } } return engine; diff --git a/engines/tflite/tflite-native/build.gradle b/engines/tflite/tflite-native/build.gradle index eb045331c12..3e2a6008f38 100644 --- a/engines/tflite/tflite-native/build.gradle +++ b/engines/tflite/tflite-native/build.gradle @@ -155,6 +155,7 @@ flavorNames.each { flavor -> from file("src/main/resources") from file("${project.buildDir}/classes/java/main") archiveClassifier = "${osName}" + archiveBaseName = "tflite-native-${flavor}" manifest { attributes("Automatic-Module-Name": "ai.djl.tflite_native_${flavor}_${osName}") diff --git a/examples/docs/image_classification.md b/examples/docs/image_classification.md index 1f515f9680f..c8f331320a8 100644 --- a/examples/docs/image_classification.md +++ b/examples/docs/image_classification.md @@ -6,7 +6,7 @@ In this example, you learn how to implement inference code with Deep Java Librar The image classification example code can be found at [ImageClassification.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java). -You can also use the [Jupyter notebook tutorial](../../jupyter/tutorial/03_image_classification_with_your_model.ipynb). +You can also use the [Jupyter notebook tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/03_image_classification_with_your_model.html). The Jupyter notebook explains the key concepts in detail. ## Setup Guide diff --git a/examples/docs/object_detection.md b/examples/docs/object_detection.md index 7d0898128b9..84286fb6e00 100644 --- a/examples/docs/object_detection.md +++ b/examples/docs/object_detection.md @@ -7,7 +7,7 @@ In this example, you learn how to implement inference code with a [ModelZoo mode The source code can be found at [ObjectDetection.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ObjectDetection.java). -You can also use the [Jupyter notebook tutorial](../../jupyter/object_detection_with_model_zoo.ipynb). +You can also use the [Jupyter notebook tutorial](http://docs.djl.ai/docs/demos/jupyter/object_detection_with_model_zoo.html). The Jupyter notebook explains the key concepts in detail. ## Setup guide diff --git a/examples/docs/stable_diffusion.md b/examples/docs/stable_diffusion.md index 7eb544646ee..be3cbb48d6e 100644 --- a/examples/docs/stable_diffusion.md +++ b/examples/docs/stable_diffusion.md @@ -1,4 +1,4 @@ -## Stable Diffusion in DJL +# Stable Diffusion in DJL [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) is an open-source model developed by Stability.ai. It aimed to produce images (artwork, pictures, etc.) based on diff --git a/examples/docs/train_cifar10_resnet.md b/examples/docs/train_cifar10_resnet.md index cfaf03f8a61..1cdfcb495c2 100644 --- a/examples/docs/train_cifar10_resnet.md +++ b/examples/docs/train_cifar10_resnet.md @@ -5,7 +5,7 @@ In this example, you learn how to train the [CIFAR-10](https://www.cs.toronto.ed You can find the example source code in: [TrainResnetWithCifar10.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java). -You can also find the Jupyter notebook tutorial [here](../../jupyter/transfer_learning_on_cifar10.ipynb). +You can also find the Jupyter notebook tutorial [here](http://docs.djl.ai/docs/demos/jupyter/transfer_learning_on_cifar10.html). The Jupyter notebook explains the key concepts in detail. ## Setup guide diff --git a/examples/docs/train_mnist_mlp.md b/examples/docs/train_mnist_mlp.md index 72b591d062a..40a32ca365f 100644 --- a/examples/docs/train_mnist_mlp.md +++ b/examples/docs/train_mnist_mlp.md @@ -6,7 +6,7 @@ In this example, you learn how to train the MNIST dataset with Deep Java Library The source code for this example can be found at [TrainMnist.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/TrainMnist.java). -You can also use the [Jupyter notebook tutorial](../../jupyter/tutorial/02_train_your_first_model.ipynb). +You can also use the [Jupyter notebook tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/02_train_your_first_model.html). The Jupyter notebook explains the key concepts in detail. ## Setup guide diff --git a/examples/pom.xml b/examples/pom.xml index 9eb2ee32fa0..cc18358e947 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -5,12 +5,12 @@ ai.djl examples - 0.24.0-SNAPSHOT + 0.27.0-SNAPSHOT 11 11 - 0.24.0-SNAPSHOT + 0.27.0-SNAPSHOT ai.djl.examples.inference.ObjectDetection @@ -41,7 +41,7 @@ org.apache.logging.log4j log4j-slf4j-impl - 2.18.0 + 2.21.0 ai.djl diff --git a/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java b/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java index b667cd29f90..093e159bebb 100644 --- a/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java +++ b/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java @@ -34,9 +34,8 @@ *

See: * *

    - *
  • the jupyter - * demo with more information about BERT. + *
  • the jupyter demo with more + * information about BERT. *
  • the docs * for information about running this example. diff --git a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java new file mode 100644 index 00000000000..3d2cfb26409 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java @@ -0,0 +1,86 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.translator.YoloV8TranslatorFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** An example of inference using an yolov8 model. */ +public final class Yolov8Detection { + + private static final Logger logger = LoggerFactory.getLogger(Yolov8Detection.class); + + private Yolov8Detection() {} + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + DetectedObjects detection = Yolov8Detection.predict(); + logger.info("{}", detection); + } + + public static DetectedObjects predict() throws IOException, ModelException, TranslateException { + Path imgPath = Paths.get("src/test/resources/yolov8_test.jpg"); + Image img = ImageFactory.getInstance().fromFile(imgPath); + + Criteria criteria = + Criteria.builder() + .setTypes(Image.class, DetectedObjects.class) + .optModelUrls("djl://ai.djl.onnxruntime/yolov8n") + .optEngine("OnnxRuntime") + .optArgument("width", 640) + .optArgument("height", 640) + .optArgument("resize", true) + .optArgument("toTensor", true) + .optArgument("applyRatio", true) + .optArgument("threshold", 0.6f) + // for performance optimization maxBox parameter can reduce number of + // considered boxes from 8400 + .optArgument("maxBox", 1000) + .optTranslatorFactory(new YoloV8TranslatorFactory()) + .optProgress(new ProgressBar()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + Path outputPath = Paths.get("build/output"); + Files.createDirectories(outputPath); + + DetectedObjects detection = predictor.predict(img); + if (detection.getNumberOfObjects() > 0) { + img.drawBoundingBoxes(detection); + Path output = outputPath.resolve("yolov8_detected.png"); + try (OutputStream os = Files.newOutputStream(output)) { + img.save(os, "png"); + } + logger.info("Detected object saved in: {}", output); + } + return detection; + } + } +} diff --git a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java index aa2b12af420..193f6643d56 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java @@ -28,6 +28,7 @@ import ai.djl.training.TrainingConfig; import ai.djl.training.TrainingResult; import ai.djl.training.initializer.TruncatedNormalInitializer; +import ai.djl.training.listener.TrainingListener; import ai.djl.training.listener.TrainingListener.Defaults; import ai.djl.training.optimizer.Adam; import ai.djl.training.optimizer.Optimizer; @@ -109,6 +110,8 @@ private static TrainingConfig createTrainingConfig(BertArguments arguments) { return new DefaultTrainingConfig(new BertPretrainingLoss()) .optOptimizer(optimizer) .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .addTrainingListeners( + TrainingListener.Defaults.algebraicLogging(arguments.getAlgebraicLogFile())) .addTrainingListeners(Defaults.logging()); } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainMnist.java b/examples/src/main/java/ai/djl/examples/training/TrainMnist.java index 786a71bfbed..85d09145081 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainMnist.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainMnist.java @@ -107,6 +107,8 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { .addEvaluator(new Accuracy()) .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) + .addTrainingListeners( + TrainingListener.Defaults.algebraicLogging(arguments.getAlgebraicLogFile())) .addTrainingListeners(listener); } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java index e0143ed524b..33db2efd2ff 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java @@ -214,6 +214,8 @@ private static DefaultTrainingConfig setupTrainingConfig( .addEvaluator(new Rmsse(distributionOutput)) .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) .optInitializer(new XavierInitializer(), Parameter.Type.WEIGHT) + .addTrainingListeners( + TrainingListener.Defaults.algebraicLogging(arguments.getAlgebraicLogFile())) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); } diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java index 7acb2f3531f..aa6dbc389dc 100644 --- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java +++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java @@ -215,6 +215,8 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .addTrainingListeners( + TrainingListener.Defaults.algebraicLogging(arguments.getAlgebraicLogFile())) .addTrainingListeners(TrainingListener.Defaults.logging(arguments.getOutputDir())); } diff --git a/examples/src/main/java/ai/djl/examples/training/util/Arguments.java b/examples/src/main/java/ai/djl/examples/training/util/Arguments.java index bbfa48f6381..e72f0b94fa6 100644 --- a/examples/src/main/java/ai/djl/examples/training/util/Arguments.java +++ b/examples/src/main/java/ai/djl/examples/training/util/Arguments.java @@ -38,6 +38,7 @@ public class Arguments { protected long limit; protected String modelDir; protected Map criteria; + protected String algebraicLogFile; protected void initialize() { epoch = 2; @@ -45,6 +46,7 @@ protected void initialize() { outputDir = "build/model"; limit = Long.MAX_VALUE; modelDir = null; + algebraicLogFile = null; } protected void setCmd(CommandLine cmd) { @@ -75,6 +77,9 @@ protected void setCmd(CommandLine cmd) { Type type = new TypeToken>() {}.getType(); criteria = JsonUtils.GSON.fromJson(cmd.getOptionValue("criteria"), type); } + if (cmd.hasOption("algebraic-log")) { + algebraicLogFile = cmd.getOptionValue("algebraic-log"); + } } public Arguments parseArgs(String[] args) { @@ -162,6 +167,15 @@ public Options getOptions() { .argName("CRITERIA") .desc("The criteria used for the model.") .build()); + options.addOption( + Option.builder("a") + .longOpt("algebraic-log") + .hasArg() + .argName("ALGEBRAIC-LOG") + .desc( + "File to log algebraic operations executed during training as" + + " Python program.") + .build()); return options; } @@ -193,6 +207,10 @@ public String getOutputDir() { return outputDir; } + public String getAlgebraicLogFile() { + return algebraicLogFile; + } + public long getLimit() { return limit; } diff --git a/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java b/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java new file mode 100644 index 00000000000..35e3fc434aa --- /dev/null +++ b/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.ModelException; +import ai.djl.modality.Classifications; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.testing.TestRequirements; +import ai.djl.translate.TranslateException; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; + +public class Yolov8DetectionTest { + + @Test + public void testYolov8Detection() throws ModelException, TranslateException, IOException { + TestRequirements.engine("MXNet", "PyTorch"); + + DetectedObjects result = Yolov8Detection.predict(); + + Assert.assertTrue(result.getNumberOfObjects() >= 1); + Classifications.Classification obj = result.best(); + String className = obj.getClassName(); + Assert.assertEquals(className, "dog"); + Assert.assertTrue(obj.getProbability() > 0.6); + } +} diff --git a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java index 2a61e25862e..1a5699836c8 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java @@ -27,7 +27,6 @@ public class TrainPikachuTest { @Test public void testDetection() throws IOException, MalformedModelException, TranslateException { - TestRequirements.engine("MXNet"); TestRequirements.nightly(); String[] args; diff --git a/examples/src/test/java/ai/djl/examples/training/TrainWithAlgebraicLogging.java b/examples/src/test/java/ai/djl/examples/training/TrainWithAlgebraicLogging.java new file mode 100644 index 00000000000..a57373b5891 --- /dev/null +++ b/examples/src/test/java/ai/djl/examples/training/TrainWithAlgebraicLogging.java @@ -0,0 +1,90 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.training; + +import ai.djl.ModelException; +import ai.djl.engine.Engine; +import ai.djl.examples.training.transferlearning.TrainResnetWithCifar10; +import ai.djl.testing.TestRequirements; +import ai.djl.training.TrainingResult; +import ai.djl.translate.TranslateException; +import ai.djl.util.Utils; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; + +public class TrainWithAlgebraicLogging { + + private static final int SEED = 1234; + + @Test + public void testTrainMnist() throws ModelException, TranslateException, IOException { + TestRequirements.engine("MXNet"); + + Path logDir = Paths.get("build/tmp/algebraiclog"); + Path algebraicLogFile = logDir.resolve("TrainMnist.py"); + if (!algebraicLogFile.toFile().delete()) { + Files.createDirectories(logDir); + } + + String[] args = + new String[] {"-g", "1", "-m", "2", "-a", algebraicLogFile.toFile().toString()}; + + TrainMnist.runExample(args); + Path path = Paths.get("src/test/resources/algebraiclog/TrainMnist.py"); + + try (InputStream is = Files.newInputStream(path); + InputStream isActual = Files.newInputStream(algebraicLogFile)) { + List expected = Utils.readLines(is); + List actual = Utils.readLines(isActual); + Assert.assertEquals(expected, actual); + } + } + + @Test + public void testTrainResNetImperative() throws ModelException, IOException, TranslateException { + TestRequirements.engine("MXNet"); + + Path logDir = Paths.get("build/tmp/algebraiclog"); + Path algebraicLogFile = logDir.resolve("TrainResnetWithCifar10.py"); + if (!algebraicLogFile.toFile().delete()) { + Files.createDirectories(logDir); + } + + // Limit max 4 gpu for cifar10 training to make it converge faster. + // and only train 10 batch for unit test. + String[] args = { + "-e", "2", "-g", "4", "-m", "1", "-b", "111", "-a", algebraicLogFile.toFile().toString() + }; + + Engine.getInstance().setRandomSeed(SEED); + TrainingResult result = TrainResnetWithCifar10.runExample(args); + Assert.assertNotNull(result); + + Path path = Paths.get("src/test/resources/algebraiclog/TrainResnetWithCifar10.py"); + + try (InputStream is = Files.newInputStream(path); + InputStream isActual = Files.newInputStream(algebraicLogFile)) { + List expected = Utils.readLines(is); + List actual = Utils.readLines(isActual); + Assert.assertEquals(expected, actual); + } + } +} diff --git a/examples/src/test/java/ai/djl/testing/TestRequirements.java b/examples/src/test/java/ai/djl/testing/TestRequirements.java index e8c9bd4bdda..01eef756201 100644 --- a/examples/src/test/java/ai/djl/testing/TestRequirements.java +++ b/examples/src/test/java/ai/djl/testing/TestRequirements.java @@ -14,6 +14,7 @@ import ai.djl.engine.Engine; import ai.djl.engine.EngineException; +import ai.djl.util.Utils; import org.testng.SkipException; @@ -45,7 +46,7 @@ public static void weekly() { /** Requires a test not be run in offline mode. */ public static void notOffline() { - if (Boolean.getBoolean("offline")) { + if (Utils.isOfflineMode()) { throw new SkipException("This test can not run while offline"); } } diff --git a/examples/src/test/resources/algebraiclog/TrainMnist.py b/examples/src/test/resources/algebraiclog/TrainMnist.py new file mode 100644 index 00000000000..e5fa7b2763e --- /dev/null +++ b/examples/src/test/resources/algebraiclog/TrainMnist.py @@ -0,0 +1,121 @@ +class MyModel(tf.keras.Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._02Linear_weight = tf.Variable( + tf.random.normal( + shape=[128, 784], + mean=0.0, + stddev=0.050507627, + dtype=tf.dtypes.float32, + name='normal_1_', + ) # (128, 784) + ) + self._02Linear_bias = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_2_', + ) # (128) + ) + self._04Linear_weight = tf.Variable( + tf.random.normal( + shape=[64, 128], + mean=0.0, + stddev=0.125, + dtype=tf.dtypes.float32, + name='normal_3_', + ) # (64, 128) + ) + self._04Linear_bias = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_4_', + ) # (64) + ) + self._06Linear_weight = tf.Variable( + tf.random.normal( + shape=[10, 64], + mean=0.0, + stddev=0.17677669, + dtype=tf.dtypes.float32, + name='normal_5_', + ) # (10, 64) + ) + self._06Linear_bias = tf.Variable( + tf.zeros( + shape=[10], + dtype=tf.dtypes.float32, + name='zeros_6_', + ) # (10) + ) + +## 4 + def call(self, x): + result = tf.nn.bias_add( + tf.matmul( + tf.nn.relu( + tf.nn.bias_add( + tf.matmul( + tf.nn.relu( + tf.nn.bias_add( + tf.matmul( + tf.reshape( + x, # (32, 1, 28, 28) + shape=[-1, 784], + name='reshape_7_', + ), # (32, 784) + b=self._02Linear_weight, # (128, 784) + transpose_b=True, + name='matmul_8_', + ), # (32, 128) + bias=self._02Linear_bias, # (128) + data_format=None, + name='bias_add_9_', + ), # (32, 128) + name='relu_10_', + ), # (32, 128) + b=self._04Linear_weight, # (64, 128) + transpose_b=True, + name='matmul_11_', + ), # (32, 64) + bias=self._04Linear_bias, # (64) + data_format=None, + name='bias_add_12_', + ), # (32, 64) + name='relu_13_', + ), # (32, 64) + b=self._06Linear_weight, # (10, 64) + transpose_b=True, + name='matmul_14_', + ), # (32, 10) + bias=self._06Linear_bias, # (10) + data_format=None, + name='bias_add_15_', + ) # (32, 10) + return result + +## 4 +def loss(label, prediction): + result = tf.reduce_mean( + tf.negative( + tf.gather( + tf.nn.log_softmax( + prediction, # (32, 10) + axis=-1, + name='log_softmax_16_', + ), # (32, 10) + indices=label, # (32) + batch_dims=1, + name='gather_17_', + ), # (32, 1) + name='negative_18_', + ), # (32, 1) + name='reduce_mean_19_', + ) # () + return result + +# number of epochs was 2 +# number of prediction functions is 1 +# number of loss functions is 1 + diff --git a/examples/src/test/resources/algebraiclog/TrainResnetWithCifar10.py b/examples/src/test/resources/algebraiclog/TrainResnetWithCifar10.py new file mode 100644 index 00000000000..cb3b619f810 --- /dev/null +++ b/examples/src/test/resources/algebraiclog/TrainResnetWithCifar10.py @@ -0,0 +1,4084 @@ +class MyModel(tf.keras.Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[64, 3, 3, 3], + mean=0.0, + stddev=0.27216554, + dtype=tf.dtypes.float32, + name='normal_1_', + ), # (64, 3, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_2_', + ) # (3, 3, 3, 64) + ) + self._02ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[64, 64, 1, 1], + mean=0.0, + stddev=0.17677669, + dtype=tf.dtypes.float32, + name='normal_3_', + ), # (64, 64, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_4_', + ) # (1, 1, 64, 64) + ) + self._02ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_5_', + ) # (64) + ) + self._02ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_6_', + ) # (64) + ) + self._02ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_7_', + ) # (64) + ) + self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_8_', + ) # (64) + , trainable = False + ) + self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_9_', + ) # (64) + , trainable = False + ) + self._02ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[64, 64, 3, 3], + mean=0.0, + stddev=0.058925565, + dtype=tf.dtypes.float32, + name='normal_10_', + ), # (64, 64, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_11_', + ) # (3, 3, 64, 64) + ) + self._02ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_12_', + ) # (64) + ) + self._02ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_13_', + ) # (64) + ) + self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_14_', + ) # (64) + , trainable = False + ) + self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_15_', + ) # (64) + , trainable = False + ) + self._02ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 64, 1, 1], + mean=0.0, + stddev=0.17677669, + dtype=tf.dtypes.float32, + name='normal_16_', + ), # (256, 64, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_17_', + ) # (1, 1, 64, 256) + ) + self._02ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_18_', + ) # (256) + ) + self._02ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_19_', + ) # (256) + ) + self._02ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_20_', + ) # (256) + ) + self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_21_', + ) # (256) + , trainable = False + ) + self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_22_', + ) # (256) + , trainable = False + ) + self._02ParallelBlock_02SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 64, 1, 1], + mean=0.0, + stddev=0.17677669, + dtype=tf.dtypes.float32, + name='normal_23_', + ), # (256, 64, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_24_', + ) # (1, 1, 64, 256) + ) + self._02ParallelBlock_02SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_25_', + ) # (256) + ) + self._02ParallelBlock_02SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_26_', + ) # (256) + ) + self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_27_', + ) # (256) + , trainable = False + ) + self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_28_', + ) # (256) + , trainable = False + ) + self._03ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[64, 256, 1, 1], + mean=0.0, + stddev=0.088388346, + dtype=tf.dtypes.float32, + name='normal_29_', + ), # (64, 256, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_30_', + ) # (1, 1, 256, 64) + ) + self._03ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_31_', + ) # (64) + ) + self._03ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_32_', + ) # (64) + ) + self._03ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_33_', + ) # (64) + ) + self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_34_', + ) # (64) + , trainable = False + ) + self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_35_', + ) # (64) + , trainable = False + ) + self._03ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[64, 64, 3, 3], + mean=0.0, + stddev=0.058925565, + dtype=tf.dtypes.float32, + name='normal_36_', + ), # (64, 64, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_37_', + ) # (3, 3, 64, 64) + ) + self._03ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_38_', + ) # (64) + ) + self._03ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_39_', + ) # (64) + ) + self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_40_', + ) # (64) + , trainable = False + ) + self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_41_', + ) # (64) + , trainable = False + ) + self._03ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 64, 1, 1], + mean=0.0, + stddev=0.17677669, + dtype=tf.dtypes.float32, + name='normal_42_', + ), # (256, 64, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_43_', + ) # (1, 1, 64, 256) + ) + self._03ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_44_', + ) # (256) + ) + self._03ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_45_', + ) # (256) + ) + self._03ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_46_', + ) # (256) + ) + self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_47_', + ) # (256) + , trainable = False + ) + self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_48_', + ) # (256) + , trainable = False + ) + self._04ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[64, 256, 1, 1], + mean=0.0, + stddev=0.088388346, + dtype=tf.dtypes.float32, + name='normal_49_', + ), # (64, 256, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_50_', + ) # (1, 1, 256, 64) + ) + self._04ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_51_', + ) # (64) + ) + self._04ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_52_', + ) # (64) + ) + self._04ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_53_', + ) # (64) + ) + self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_54_', + ) # (64) + , trainable = False + ) + self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_55_', + ) # (64) + , trainable = False + ) + self._04ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[64, 64, 3, 3], + mean=0.0, + stddev=0.058925565, + dtype=tf.dtypes.float32, + name='normal_56_', + ), # (64, 64, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_57_', + ) # (3, 3, 64, 64) + ) + self._04ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_58_', + ) # (64) + ) + self._04ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_59_', + ) # (64) + ) + self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[64], + dtype=tf.dtypes.float32, + name='zeros_60_', + ) # (64) + , trainable = False + ) + self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[64], + dtype=tf.dtypes.float32, + name='ones_61_', + ) # (64) + , trainable = False + ) + self._04ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 64, 1, 1], + mean=0.0, + stddev=0.17677669, + dtype=tf.dtypes.float32, + name='normal_62_', + ), # (256, 64, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_63_', + ) # (1, 1, 64, 256) + ) + self._04ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_64_', + ) # (256) + ) + self._04ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_65_', + ) # (256) + ) + self._04ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_66_', + ) # (256) + ) + self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_67_', + ) # (256) + , trainable = False + ) + self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_68_', + ) # (256) + , trainable = False + ) + self._05ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[128, 256, 1, 1], + mean=0.0, + stddev=0.088388346, + dtype=tf.dtypes.float32, + name='normal_69_', + ), # (128, 256, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_70_', + ) # (1, 1, 256, 128) + ) + self._05ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_71_', + ) # (128) + ) + self._05ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_72_', + ) # (128) + ) + self._05ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_73_', + ) # (128) + ) + self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_74_', + ) # (128) + , trainable = False + ) + self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_75_', + ) # (128) + , trainable = False + ) + self._05ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[128, 128, 3, 3], + mean=0.0, + stddev=0.041666668, + dtype=tf.dtypes.float32, + name='normal_76_', + ), # (128, 128, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_77_', + ) # (3, 3, 128, 128) + ) + self._05ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_78_', + ) # (128) + ) + self._05ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_79_', + ) # (128) + ) + self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_80_', + ) # (128) + , trainable = False + ) + self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_81_', + ) # (128) + , trainable = False + ) + self._05ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[512, 128, 1, 1], + mean=0.0, + stddev=0.125, + dtype=tf.dtypes.float32, + name='normal_82_', + ), # (512, 128, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_83_', + ) # (1, 1, 128, 512) + ) + self._05ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_84_', + ) # (512) + ) + self._05ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_85_', + ) # (512) + ) + self._05ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_86_', + ) # (512) + ) + self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_87_', + ) # (512) + , trainable = False + ) + self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_88_', + ) # (512) + , trainable = False + ) + self._05ParallelBlock_02SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[512, 256, 1, 1], + mean=0.0, + stddev=0.088388346, + dtype=tf.dtypes.float32, + name='normal_89_', + ), # (512, 256, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_90_', + ) # (1, 1, 256, 512) + ) + self._05ParallelBlock_02SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_91_', + ) # (512) + ) + self._05ParallelBlock_02SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_92_', + ) # (512) + ) + self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_93_', + ) # (512) + , trainable = False + ) + self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_94_', + ) # (512) + , trainable = False + ) + self._06ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[128, 512, 1, 1], + mean=0.0, + stddev=0.0625, + dtype=tf.dtypes.float32, + name='normal_95_', + ), # (128, 512, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_96_', + ) # (1, 1, 512, 128) + ) + self._06ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_97_', + ) # (128) + ) + self._06ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_98_', + ) # (128) + ) + self._06ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_99_', + ) # (128) + ) + self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_100_', + ) # (128) + , trainable = False + ) + self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_101_', + ) # (128) + , trainable = False + ) + self._06ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[128, 128, 3, 3], + mean=0.0, + stddev=0.041666668, + dtype=tf.dtypes.float32, + name='normal_102_', + ), # (128, 128, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_103_', + ) # (3, 3, 128, 128) + ) + self._06ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_104_', + ) # (128) + ) + self._06ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_105_', + ) # (128) + ) + self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_106_', + ) # (128) + , trainable = False + ) + self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_107_', + ) # (128) + , trainable = False + ) + self._06ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[512, 128, 1, 1], + mean=0.0, + stddev=0.125, + dtype=tf.dtypes.float32, + name='normal_108_', + ), # (512, 128, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_109_', + ) # (1, 1, 128, 512) + ) + self._06ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_110_', + ) # (512) + ) + self._06ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_111_', + ) # (512) + ) + self._06ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_112_', + ) # (512) + ) + self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_113_', + ) # (512) + , trainable = False + ) + self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_114_', + ) # (512) + , trainable = False + ) + self._07ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[128, 512, 1, 1], + mean=0.0, + stddev=0.0625, + dtype=tf.dtypes.float32, + name='normal_115_', + ), # (128, 512, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_116_', + ) # (1, 1, 512, 128) + ) + self._07ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_117_', + ) # (128) + ) + self._07ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_118_', + ) # (128) + ) + self._07ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_119_', + ) # (128) + ) + self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_120_', + ) # (128) + , trainable = False + ) + self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_121_', + ) # (128) + , trainable = False + ) + self._07ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[128, 128, 3, 3], + mean=0.0, + stddev=0.041666668, + dtype=tf.dtypes.float32, + name='normal_122_', + ), # (128, 128, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_123_', + ) # (3, 3, 128, 128) + ) + self._07ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_124_', + ) # (128) + ) + self._07ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_125_', + ) # (128) + ) + self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_126_', + ) # (128) + , trainable = False + ) + self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_127_', + ) # (128) + , trainable = False + ) + self._07ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[512, 128, 1, 1], + mean=0.0, + stddev=0.125, + dtype=tf.dtypes.float32, + name='normal_128_', + ), # (512, 128, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_129_', + ) # (1, 1, 128, 512) + ) + self._07ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_130_', + ) # (512) + ) + self._07ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_131_', + ) # (512) + ) + self._07ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_132_', + ) # (512) + ) + self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_133_', + ) # (512) + , trainable = False + ) + self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_134_', + ) # (512) + , trainable = False + ) + self._08ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[128, 512, 1, 1], + mean=0.0, + stddev=0.0625, + dtype=tf.dtypes.float32, + name='normal_135_', + ), # (128, 512, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_136_', + ) # (1, 1, 512, 128) + ) + self._08ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_137_', + ) # (128) + ) + self._08ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_138_', + ) # (128) + ) + self._08ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_139_', + ) # (128) + ) + self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_140_', + ) # (128) + , trainable = False + ) + self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_141_', + ) # (128) + , trainable = False + ) + self._08ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[128, 128, 3, 3], + mean=0.0, + stddev=0.041666668, + dtype=tf.dtypes.float32, + name='normal_142_', + ), # (128, 128, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_143_', + ) # (3, 3, 128, 128) + ) + self._08ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_144_', + ) # (128) + ) + self._08ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_145_', + ) # (128) + ) + self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[128], + dtype=tf.dtypes.float32, + name='zeros_146_', + ) # (128) + , trainable = False + ) + self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[128], + dtype=tf.dtypes.float32, + name='ones_147_', + ) # (128) + , trainable = False + ) + self._08ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[512, 128, 1, 1], + mean=0.0, + stddev=0.125, + dtype=tf.dtypes.float32, + name='normal_148_', + ), # (512, 128, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_149_', + ) # (1, 1, 128, 512) + ) + self._08ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_150_', + ) # (512) + ) + self._08ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_151_', + ) # (512) + ) + self._08ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_152_', + ) # (512) + ) + self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_153_', + ) # (512) + , trainable = False + ) + self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_154_', + ) # (512) + , trainable = False + ) + self._09ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 512, 1, 1], + mean=0.0, + stddev=0.0625, + dtype=tf.dtypes.float32, + name='normal_155_', + ), # (256, 512, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_156_', + ) # (1, 1, 512, 256) + ) + self._09ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_157_', + ) # (256) + ) + self._09ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_158_', + ) # (256) + ) + self._09ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_159_', + ) # (256) + ) + self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_160_', + ) # (256) + , trainable = False + ) + self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_161_', + ) # (256) + , trainable = False + ) + self._09ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 256, 3, 3], + mean=0.0, + stddev=0.029462783, + dtype=tf.dtypes.float32, + name='normal_162_', + ), # (256, 256, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_163_', + ) # (3, 3, 256, 256) + ) + self._09ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_164_', + ) # (256) + ) + self._09ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_165_', + ) # (256) + ) + self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_166_', + ) # (256) + , trainable = False + ) + self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_167_', + ) # (256) + , trainable = False + ) + self._09ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[1024, 256, 1, 1], + mean=0.0, + stddev=0.088388346, + dtype=tf.dtypes.float32, + name='normal_168_', + ), # (1024, 256, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_169_', + ) # (1, 1, 256, 1024) + ) + self._09ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_170_', + ) # (1024) + ) + self._09ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_171_', + ) # (1024) + ) + self._09ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_172_', + ) # (1024) + ) + self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_173_', + ) # (1024) + , trainable = False + ) + self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_174_', + ) # (1024) + , trainable = False + ) + self._09ParallelBlock_02SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[1024, 512, 1, 1], + mean=0.0, + stddev=0.0625, + dtype=tf.dtypes.float32, + name='normal_175_', + ), # (1024, 512, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_176_', + ) # (1, 1, 512, 1024) + ) + self._09ParallelBlock_02SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_177_', + ) # (1024) + ) + self._09ParallelBlock_02SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_178_', + ) # (1024) + ) + self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_179_', + ) # (1024) + , trainable = False + ) + self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_180_', + ) # (1024) + , trainable = False + ) + self._10ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 1024, 1, 1], + mean=0.0, + stddev=0.044194173, + dtype=tf.dtypes.float32, + name='normal_181_', + ), # (256, 1024, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_182_', + ) # (1, 1, 1024, 256) + ) + self._10ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_183_', + ) # (256) + ) + self._10ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_184_', + ) # (256) + ) + self._10ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_185_', + ) # (256) + ) + self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_186_', + ) # (256) + , trainable = False + ) + self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_187_', + ) # (256) + , trainable = False + ) + self._10ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 256, 3, 3], + mean=0.0, + stddev=0.029462783, + dtype=tf.dtypes.float32, + name='normal_188_', + ), # (256, 256, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_189_', + ) # (3, 3, 256, 256) + ) + self._10ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_190_', + ) # (256) + ) + self._10ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_191_', + ) # (256) + ) + self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_192_', + ) # (256) + , trainable = False + ) + self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_193_', + ) # (256) + , trainable = False + ) + self._10ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[1024, 256, 1, 1], + mean=0.0, + stddev=0.088388346, + dtype=tf.dtypes.float32, + name='normal_194_', + ), # (1024, 256, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_195_', + ) # (1, 1, 256, 1024) + ) + self._10ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_196_', + ) # (1024) + ) + self._10ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_197_', + ) # (1024) + ) + self._10ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_198_', + ) # (1024) + ) + self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_199_', + ) # (1024) + , trainable = False + ) + self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_200_', + ) # (1024) + , trainable = False + ) + self._11ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 1024, 1, 1], + mean=0.0, + stddev=0.044194173, + dtype=tf.dtypes.float32, + name='normal_201_', + ), # (256, 1024, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_202_', + ) # (1, 1, 1024, 256) + ) + self._11ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_203_', + ) # (256) + ) + self._11ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_204_', + ) # (256) + ) + self._11ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_205_', + ) # (256) + ) + self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_206_', + ) # (256) + , trainable = False + ) + self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_207_', + ) # (256) + , trainable = False + ) + self._11ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 256, 3, 3], + mean=0.0, + stddev=0.029462783, + dtype=tf.dtypes.float32, + name='normal_208_', + ), # (256, 256, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_209_', + ) # (3, 3, 256, 256) + ) + self._11ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_210_', + ) # (256) + ) + self._11ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_211_', + ) # (256) + ) + self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_212_', + ) # (256) + , trainable = False + ) + self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_213_', + ) # (256) + , trainable = False + ) + self._11ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[1024, 256, 1, 1], + mean=0.0, + stddev=0.088388346, + dtype=tf.dtypes.float32, + name='normal_214_', + ), # (1024, 256, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_215_', + ) # (1, 1, 256, 1024) + ) + self._11ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_216_', + ) # (1024) + ) + self._11ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_217_', + ) # (1024) + ) + self._11ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_218_', + ) # (1024) + ) + self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_219_', + ) # (1024) + , trainable = False + ) + self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_220_', + ) # (1024) + , trainable = False + ) + self._12ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 1024, 1, 1], + mean=0.0, + stddev=0.044194173, + dtype=tf.dtypes.float32, + name='normal_221_', + ), # (256, 1024, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_222_', + ) # (1, 1, 1024, 256) + ) + self._12ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_223_', + ) # (256) + ) + self._12ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_224_', + ) # (256) + ) + self._12ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_225_', + ) # (256) + ) + self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_226_', + ) # (256) + , trainable = False + ) + self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_227_', + ) # (256) + , trainable = False + ) + self._12ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 256, 3, 3], + mean=0.0, + stddev=0.029462783, + dtype=tf.dtypes.float32, + name='normal_228_', + ), # (256, 256, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_229_', + ) # (3, 3, 256, 256) + ) + self._12ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_230_', + ) # (256) + ) + self._12ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_231_', + ) # (256) + ) + self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_232_', + ) # (256) + , trainable = False + ) + self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_233_', + ) # (256) + , trainable = False + ) + self._12ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[1024, 256, 1, 1], + mean=0.0, + stddev=0.088388346, + dtype=tf.dtypes.float32, + name='normal_234_', + ), # (1024, 256, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_235_', + ) # (1, 1, 256, 1024) + ) + self._12ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_236_', + ) # (1024) + ) + self._12ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_237_', + ) # (1024) + ) + self._12ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_238_', + ) # (1024) + ) + self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_239_', + ) # (1024) + , trainable = False + ) + self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_240_', + ) # (1024) + , trainable = False + ) + self._13ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 1024, 1, 1], + mean=0.0, + stddev=0.044194173, + dtype=tf.dtypes.float32, + name='normal_241_', + ), # (256, 1024, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_242_', + ) # (1, 1, 1024, 256) + ) + self._13ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_243_', + ) # (256) + ) + self._13ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_244_', + ) # (256) + ) + self._13ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_245_', + ) # (256) + ) + self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_246_', + ) # (256) + , trainable = False + ) + self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_247_', + ) # (256) + , trainable = False + ) + self._13ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 256, 3, 3], + mean=0.0, + stddev=0.029462783, + dtype=tf.dtypes.float32, + name='normal_248_', + ), # (256, 256, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_249_', + ) # (3, 3, 256, 256) + ) + self._13ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_250_', + ) # (256) + ) + self._13ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_251_', + ) # (256) + ) + self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_252_', + ) # (256) + , trainable = False + ) + self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_253_', + ) # (256) + , trainable = False + ) + self._13ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[1024, 256, 1, 1], + mean=0.0, + stddev=0.088388346, + dtype=tf.dtypes.float32, + name='normal_254_', + ), # (1024, 256, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_255_', + ) # (1, 1, 256, 1024) + ) + self._13ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_256_', + ) # (1024) + ) + self._13ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_257_', + ) # (1024) + ) + self._13ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_258_', + ) # (1024) + ) + self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_259_', + ) # (1024) + , trainable = False + ) + self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_260_', + ) # (1024) + , trainable = False + ) + self._14ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 1024, 1, 1], + mean=0.0, + stddev=0.044194173, + dtype=tf.dtypes.float32, + name='normal_261_', + ), # (256, 1024, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_262_', + ) # (1, 1, 1024, 256) + ) + self._14ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_263_', + ) # (256) + ) + self._14ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_264_', + ) # (256) + ) + self._14ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_265_', + ) # (256) + ) + self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_266_', + ) # (256) + , trainable = False + ) + self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_267_', + ) # (256) + , trainable = False + ) + self._14ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[256, 256, 3, 3], + mean=0.0, + stddev=0.029462783, + dtype=tf.dtypes.float32, + name='normal_268_', + ), # (256, 256, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_269_', + ) # (3, 3, 256, 256) + ) + self._14ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_270_', + ) # (256) + ) + self._14ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_271_', + ) # (256) + ) + self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[256], + dtype=tf.dtypes.float32, + name='zeros_272_', + ) # (256) + , trainable = False + ) + self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[256], + dtype=tf.dtypes.float32, + name='ones_273_', + ) # (256) + , trainable = False + ) + self._14ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[1024, 256, 1, 1], + mean=0.0, + stddev=0.088388346, + dtype=tf.dtypes.float32, + name='normal_274_', + ), # (1024, 256, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_275_', + ) # (1, 1, 256, 1024) + ) + self._14ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_276_', + ) # (1024) + ) + self._14ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_277_', + ) # (1024) + ) + self._14ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_278_', + ) # (1024) + ) + self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[1024], + dtype=tf.dtypes.float32, + name='zeros_279_', + ) # (1024) + , trainable = False + ) + self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[1024], + dtype=tf.dtypes.float32, + name='ones_280_', + ) # (1024) + , trainable = False + ) + self._15ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[512, 1024, 1, 1], + mean=0.0, + stddev=0.044194173, + dtype=tf.dtypes.float32, + name='normal_281_', + ), # (512, 1024, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_282_', + ) # (1, 1, 1024, 512) + ) + self._15ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_283_', + ) # (512) + ) + self._15ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_284_', + ) # (512) + ) + self._15ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_285_', + ) # (512) + ) + self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_286_', + ) # (512) + , trainable = False + ) + self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_287_', + ) # (512) + , trainable = False + ) + self._15ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[512, 512, 3, 3], + mean=0.0, + stddev=0.020833334, + dtype=tf.dtypes.float32, + name='normal_288_', + ), # (512, 512, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_289_', + ) # (3, 3, 512, 512) + ) + self._15ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_290_', + ) # (512) + ) + self._15ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_291_', + ) # (512) + ) + self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_292_', + ) # (512) + , trainable = False + ) + self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_293_', + ) # (512) + , trainable = False + ) + self._15ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[2048, 512, 1, 1], + mean=0.0, + stddev=0.0625, + dtype=tf.dtypes.float32, + name='normal_294_', + ), # (2048, 512, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_295_', + ) # (1, 1, 512, 2048) + ) + self._15ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[2048], + dtype=tf.dtypes.float32, + name='zeros_296_', + ) # (2048) + ) + self._15ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[2048], + dtype=tf.dtypes.float32, + name='ones_297_', + ) # (2048) + ) + self._15ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[2048], + dtype=tf.dtypes.float32, + name='zeros_298_', + ) # (2048) + ) + self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[2048], + dtype=tf.dtypes.float32, + name='zeros_299_', + ) # (2048) + , trainable = False + ) + self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[2048], + dtype=tf.dtypes.float32, + name='ones_300_', + ) # (2048) + , trainable = False + ) + self._15ParallelBlock_02SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[2048, 1024, 1, 1], + mean=0.0, + stddev=0.044194173, + dtype=tf.dtypes.float32, + name='normal_301_', + ), # (2048, 1024, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_302_', + ) # (1, 1, 1024, 2048) + ) + self._15ParallelBlock_02SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[2048], + dtype=tf.dtypes.float32, + name='ones_303_', + ) # (2048) + ) + self._15ParallelBlock_02SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[2048], + dtype=tf.dtypes.float32, + name='zeros_304_', + ) # (2048) + ) + self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[2048], + dtype=tf.dtypes.float32, + name='zeros_305_', + ) # (2048) + , trainable = False + ) + self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[2048], + dtype=tf.dtypes.float32, + name='ones_306_', + ) # (2048) + , trainable = False + ) + self._16ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[512, 2048, 1, 1], + mean=0.0, + stddev=0.03125, + dtype=tf.dtypes.float32, + name='normal_307_', + ), # (512, 2048, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_308_', + ) # (1, 1, 2048, 512) + ) + self._16ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_309_', + ) # (512) + ) + self._16ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_310_', + ) # (512) + ) + self._16ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_311_', + ) # (512) + ) + self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_312_', + ) # (512) + , trainable = False + ) + self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_313_', + ) # (512) + , trainable = False + ) + self._16ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[512, 512, 3, 3], + mean=0.0, + stddev=0.020833334, + dtype=tf.dtypes.float32, + name='normal_314_', + ), # (512, 512, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_315_', + ) # (3, 3, 512, 512) + ) + self._16ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_316_', + ) # (512) + ) + self._16ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_317_', + ) # (512) + ) + self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_318_', + ) # (512) + , trainable = False + ) + self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_319_', + ) # (512) + , trainable = False + ) + self._16ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[2048, 512, 1, 1], + mean=0.0, + stddev=0.0625, + dtype=tf.dtypes.float32, + name='normal_320_', + ), # (2048, 512, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_321_', + ) # (1, 1, 512, 2048) + ) + self._16ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[2048], + dtype=tf.dtypes.float32, + name='zeros_322_', + ) # (2048) + ) + self._16ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[2048], + dtype=tf.dtypes.float32, + name='ones_323_', + ) # (2048) + ) + self._16ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[2048], + dtype=tf.dtypes.float32, + name='zeros_324_', + ) # (2048) + ) + self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[2048], + dtype=tf.dtypes.float32, + name='zeros_325_', + ) # (2048) + , trainable = False + ) + self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[2048], + dtype=tf.dtypes.float32, + name='ones_326_', + ) # (2048) + , trainable = False + ) + self._17ParallelBlock_01SequentialBlock_01Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[512, 2048, 1, 1], + mean=0.0, + stddev=0.03125, + dtype=tf.dtypes.float32, + name='normal_327_', + ), # (512, 2048, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_328_', + ) # (1, 1, 2048, 512) + ) + self._17ParallelBlock_01SequentialBlock_01Conv2d_bias = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_329_', + ) # (512) + ) + self._17ParallelBlock_01SequentialBlock_02BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_330_', + ) # (512) + ) + self._17ParallelBlock_01SequentialBlock_02BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_331_', + ) # (512) + ) + self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_332_', + ) # (512) + , trainable = False + ) + self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_333_', + ) # (512) + , trainable = False + ) + self._17ParallelBlock_01SequentialBlock_04Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[512, 512, 3, 3], + mean=0.0, + stddev=0.020833334, + dtype=tf.dtypes.float32, + name='normal_334_', + ), # (512, 512, 3, 3) + perm=[2, 3, 1, 0], + name='transpose_335_', + ) # (3, 3, 512, 512) + ) + self._17ParallelBlock_01SequentialBlock_05BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_336_', + ) # (512) + ) + self._17ParallelBlock_01SequentialBlock_05BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_337_', + ) # (512) + ) + self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[512], + dtype=tf.dtypes.float32, + name='zeros_338_', + ) # (512) + , trainable = False + ) + self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[512], + dtype=tf.dtypes.float32, + name='ones_339_', + ) # (512) + , trainable = False + ) + self._17ParallelBlock_01SequentialBlock_07Conv2d_weight = tf.Variable( + tf.transpose( + tf.random.normal( + shape=[2048, 512, 1, 1], + mean=0.0, + stddev=0.0625, + dtype=tf.dtypes.float32, + name='normal_340_', + ), # (2048, 512, 1, 1) + perm=[2, 3, 1, 0], + name='transpose_341_', + ) # (1, 1, 512, 2048) + ) + self._17ParallelBlock_01SequentialBlock_07Conv2d_bias = tf.Variable( + tf.zeros( + shape=[2048], + dtype=tf.dtypes.float32, + name='zeros_342_', + ) # (2048) + ) + self._17ParallelBlock_01SequentialBlock_08BatchNorm_gamma = tf.Variable( + tf.ones( + shape=[2048], + dtype=tf.dtypes.float32, + name='ones_343_', + ) # (2048) + ) + self._17ParallelBlock_01SequentialBlock_08BatchNorm_beta = tf.Variable( + tf.zeros( + shape=[2048], + dtype=tf.dtypes.float32, + name='zeros_344_', + ) # (2048) + ) + self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningMean = tf.Variable( + tf.zeros( + shape=[2048], + dtype=tf.dtypes.float32, + name='zeros_345_', + ) # (2048) + , trainable = False + ) + self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningVar = tf.Variable( + tf.ones( + shape=[2048], + dtype=tf.dtypes.float32, + name='ones_346_', + ) # (2048) + , trainable = False + ) + self._20Linear_weight = tf.Variable( + tf.random.normal( + shape=[10, 2048], + mean=0.0, + stddev=0.03125, + dtype=tf.dtypes.float32, + name='normal_347_', + ) # (10, 2048) + ) + self._20Linear_bias = tf.Variable( + tf.zeros( + shape=[10], + dtype=tf.dtypes.float32, + name='zeros_348_', + ) # (10) + ) + +## 2 + def call(self, x): + val1 = tf.nn.convolution( + x, # (111, 3, 32, 32) + filters=self._01Conv2d_weight, # (3, 3, 3, 64) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_349_', + ) # (111, 64, 32, 32) + (batchnorm1, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val1, # (111, 64, 32, 32) + filters=self._02ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 64, 64) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_350_', + ), # (111, 64, 32, 32) + bias=self._02ParallelBlock_01SequentialBlock_01Conv2d_bias, # (64) + data_format='NCHW', + name='bias_add_351_', + ), # (111, 64, 32, 32) + scale=self._02ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (64) + offset=self._02ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (64) + mean=self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (64) + variance=self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (64) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_352_', + ) # (111, 64, 32, 32) + self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._02ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm2, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm1, # (111, 64, 32, 32) + name='relu_353_', + ), # (111, 64, 32, 32) + filters=self._02ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 64, 64) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_354_', + ), # (111, 64, 32, 32) + scale=self._02ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (64) + offset=self._02ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (64) + mean=self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (64) + variance=self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (64) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_355_', + ) # (111, 64, 32, 32) + self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._02ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm3, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm2, # (111, 64, 32, 32) + name='relu_356_', + ), # (111, 64, 32, 32) + filters=self._02ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 64, 256) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_357_', + ), # (111, 256, 32, 32) + bias=self._02ParallelBlock_01SequentialBlock_07Conv2d_bias, # (256) + data_format='NCHW', + name='bias_add_358_', + ), # (111, 256, 32, 32) + scale=self._02ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (256) + offset=self._02ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (256) + mean=self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (256) + variance=self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (256) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_359_', + ) # (111, 256, 32, 32) + self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._02ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + (batchnorm4, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + val1, # (111, 64, 32, 32) + filters=self._02ParallelBlock_02SequentialBlock_01Conv2d_weight, # (1, 1, 64, 256) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_360_', + ), # (111, 256, 32, 32) + scale=self._02ParallelBlock_02SequentialBlock_02BatchNorm_gamma, # (256) + offset=self._02ParallelBlock_02SequentialBlock_02BatchNorm_beta, # (256) + mean=self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningMean, # (256) + variance=self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningVar, # (256) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_361_', + ) # (111, 256, 32, 32) + self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._02ParallelBlock_02SequentialBlock_02BatchNorm_runningVar.assign(running_var) + val2 = tf.nn.relu( + tf.add( + batchnorm3, # (111, 256, 32, 32) + batchnorm4, # (111, 256, 32, 32) + name='add_362_', + ), # (111, 256, 32, 32) + name='relu_363_', + ) # (111, 256, 32, 32) + (batchnorm5, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val2, # (111, 256, 32, 32) + filters=self._03ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 256, 64) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_364_', + ), # (111, 64, 32, 32) + bias=self._03ParallelBlock_01SequentialBlock_01Conv2d_bias, # (64) + data_format='NCHW', + name='bias_add_365_', + ), # (111, 64, 32, 32) + scale=self._03ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (64) + offset=self._03ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (64) + mean=self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (64) + variance=self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (64) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_366_', + ) # (111, 64, 32, 32) + self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._03ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm6, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm5, # (111, 64, 32, 32) + name='relu_367_', + ), # (111, 64, 32, 32) + filters=self._03ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 64, 64) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_368_', + ), # (111, 64, 32, 32) + scale=self._03ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (64) + offset=self._03ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (64) + mean=self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (64) + variance=self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (64) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_369_', + ) # (111, 64, 32, 32) + self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._03ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm7, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm6, # (111, 64, 32, 32) + name='relu_370_', + ), # (111, 64, 32, 32) + filters=self._03ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 64, 256) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_371_', + ), # (111, 256, 32, 32) + bias=self._03ParallelBlock_01SequentialBlock_07Conv2d_bias, # (256) + data_format='NCHW', + name='bias_add_372_', + ), # (111, 256, 32, 32) + scale=self._03ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (256) + offset=self._03ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (256) + mean=self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (256) + variance=self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (256) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_373_', + ) # (111, 256, 32, 32) + self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._03ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + val3 = tf.nn.relu( + tf.add( + batchnorm7, # (111, 256, 32, 32) + val2, # (111, 256, 32, 32) + name='add_374_', + ), # (111, 256, 32, 32) + name='relu_375_', + ) # (111, 256, 32, 32) + (batchnorm8, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val3, # (111, 256, 32, 32) + filters=self._04ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 256, 64) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_376_', + ), # (111, 64, 32, 32) + bias=self._04ParallelBlock_01SequentialBlock_01Conv2d_bias, # (64) + data_format='NCHW', + name='bias_add_377_', + ), # (111, 64, 32, 32) + scale=self._04ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (64) + offset=self._04ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (64) + mean=self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (64) + variance=self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (64) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_378_', + ) # (111, 64, 32, 32) + self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._04ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm9, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm8, # (111, 64, 32, 32) + name='relu_379_', + ), # (111, 64, 32, 32) + filters=self._04ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 64, 64) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_380_', + ), # (111, 64, 32, 32) + scale=self._04ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (64) + offset=self._04ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (64) + mean=self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (64) + variance=self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (64) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_381_', + ) # (111, 64, 32, 32) + self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._04ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm10, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm9, # (111, 64, 32, 32) + name='relu_382_', + ), # (111, 64, 32, 32) + filters=self._04ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 64, 256) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_383_', + ), # (111, 256, 32, 32) + bias=self._04ParallelBlock_01SequentialBlock_07Conv2d_bias, # (256) + data_format='NCHW', + name='bias_add_384_', + ), # (111, 256, 32, 32) + scale=self._04ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (256) + offset=self._04ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (256) + mean=self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (256) + variance=self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (256) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_385_', + ) # (111, 256, 32, 32) + self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._04ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + val4 = tf.nn.relu( + tf.add( + batchnorm10, # (111, 256, 32, 32) + val3, # (111, 256, 32, 32) + name='add_386_', + ), # (111, 256, 32, 32) + name='relu_387_', + ) # (111, 256, 32, 32) + (batchnorm11, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val4, # (111, 256, 32, 32) + filters=self._05ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 256, 128) + strides=[2, 2], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_388_', + ), # (111, 128, 16, 16) + bias=self._05ParallelBlock_01SequentialBlock_01Conv2d_bias, # (128) + data_format='NCHW', + name='bias_add_389_', + ), # (111, 128, 16, 16) + scale=self._05ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (128) + offset=self._05ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (128) + mean=self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (128) + variance=self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (128) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_390_', + ) # (111, 128, 16, 16) + self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._05ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm12, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm11, # (111, 128, 16, 16) + name='relu_391_', + ), # (111, 128, 16, 16) + filters=self._05ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 128, 128) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_392_', + ), # (111, 128, 16, 16) + scale=self._05ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (128) + offset=self._05ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (128) + mean=self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (128) + variance=self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (128) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_393_', + ) # (111, 128, 16, 16) + self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._05ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm13, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm12, # (111, 128, 16, 16) + name='relu_394_', + ), # (111, 128, 16, 16) + filters=self._05ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 128, 512) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_395_', + ), # (111, 512, 16, 16) + bias=self._05ParallelBlock_01SequentialBlock_07Conv2d_bias, # (512) + data_format='NCHW', + name='bias_add_396_', + ), # (111, 512, 16, 16) + scale=self._05ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (512) + offset=self._05ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (512) + mean=self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (512) + variance=self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (512) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_397_', + ) # (111, 512, 16, 16) + self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._05ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + (batchnorm14, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + val4, # (111, 256, 32, 32) + filters=self._05ParallelBlock_02SequentialBlock_01Conv2d_weight, # (1, 1, 256, 512) + strides=[2, 2], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_398_', + ), # (111, 512, 16, 16) + scale=self._05ParallelBlock_02SequentialBlock_02BatchNorm_gamma, # (512) + offset=self._05ParallelBlock_02SequentialBlock_02BatchNorm_beta, # (512) + mean=self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningMean, # (512) + variance=self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningVar, # (512) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_399_', + ) # (111, 512, 16, 16) + self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._05ParallelBlock_02SequentialBlock_02BatchNorm_runningVar.assign(running_var) + val5 = tf.nn.relu( + tf.add( + batchnorm13, # (111, 512, 16, 16) + batchnorm14, # (111, 512, 16, 16) + name='add_400_', + ), # (111, 512, 16, 16) + name='relu_401_', + ) # (111, 512, 16, 16) + (batchnorm15, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val5, # (111, 512, 16, 16) + filters=self._06ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 512, 128) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_402_', + ), # (111, 128, 16, 16) + bias=self._06ParallelBlock_01SequentialBlock_01Conv2d_bias, # (128) + data_format='NCHW', + name='bias_add_403_', + ), # (111, 128, 16, 16) + scale=self._06ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (128) + offset=self._06ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (128) + mean=self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (128) + variance=self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (128) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_404_', + ) # (111, 128, 16, 16) + self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._06ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm16, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm15, # (111, 128, 16, 16) + name='relu_405_', + ), # (111, 128, 16, 16) + filters=self._06ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 128, 128) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_406_', + ), # (111, 128, 16, 16) + scale=self._06ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (128) + offset=self._06ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (128) + mean=self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (128) + variance=self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (128) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_407_', + ) # (111, 128, 16, 16) + self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._06ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm17, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm16, # (111, 128, 16, 16) + name='relu_408_', + ), # (111, 128, 16, 16) + filters=self._06ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 128, 512) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_409_', + ), # (111, 512, 16, 16) + bias=self._06ParallelBlock_01SequentialBlock_07Conv2d_bias, # (512) + data_format='NCHW', + name='bias_add_410_', + ), # (111, 512, 16, 16) + scale=self._06ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (512) + offset=self._06ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (512) + mean=self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (512) + variance=self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (512) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_411_', + ) # (111, 512, 16, 16) + self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._06ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + val6 = tf.nn.relu( + tf.add( + batchnorm17, # (111, 512, 16, 16) + val5, # (111, 512, 16, 16) + name='add_412_', + ), # (111, 512, 16, 16) + name='relu_413_', + ) # (111, 512, 16, 16) + (batchnorm18, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val6, # (111, 512, 16, 16) + filters=self._07ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 512, 128) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_414_', + ), # (111, 128, 16, 16) + bias=self._07ParallelBlock_01SequentialBlock_01Conv2d_bias, # (128) + data_format='NCHW', + name='bias_add_415_', + ), # (111, 128, 16, 16) + scale=self._07ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (128) + offset=self._07ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (128) + mean=self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (128) + variance=self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (128) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_416_', + ) # (111, 128, 16, 16) + self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._07ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm19, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm18, # (111, 128, 16, 16) + name='relu_417_', + ), # (111, 128, 16, 16) + filters=self._07ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 128, 128) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_418_', + ), # (111, 128, 16, 16) + scale=self._07ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (128) + offset=self._07ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (128) + mean=self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (128) + variance=self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (128) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_419_', + ) # (111, 128, 16, 16) + self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._07ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm20, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm19, # (111, 128, 16, 16) + name='relu_420_', + ), # (111, 128, 16, 16) + filters=self._07ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 128, 512) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_421_', + ), # (111, 512, 16, 16) + bias=self._07ParallelBlock_01SequentialBlock_07Conv2d_bias, # (512) + data_format='NCHW', + name='bias_add_422_', + ), # (111, 512, 16, 16) + scale=self._07ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (512) + offset=self._07ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (512) + mean=self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (512) + variance=self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (512) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_423_', + ) # (111, 512, 16, 16) + self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._07ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + val7 = tf.nn.relu( + tf.add( + batchnorm20, # (111, 512, 16, 16) + val6, # (111, 512, 16, 16) + name='add_424_', + ), # (111, 512, 16, 16) + name='relu_425_', + ) # (111, 512, 16, 16) + (batchnorm21, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val7, # (111, 512, 16, 16) + filters=self._08ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 512, 128) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_426_', + ), # (111, 128, 16, 16) + bias=self._08ParallelBlock_01SequentialBlock_01Conv2d_bias, # (128) + data_format='NCHW', + name='bias_add_427_', + ), # (111, 128, 16, 16) + scale=self._08ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (128) + offset=self._08ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (128) + mean=self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (128) + variance=self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (128) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_428_', + ) # (111, 128, 16, 16) + self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._08ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm22, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm21, # (111, 128, 16, 16) + name='relu_429_', + ), # (111, 128, 16, 16) + filters=self._08ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 128, 128) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_430_', + ), # (111, 128, 16, 16) + scale=self._08ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (128) + offset=self._08ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (128) + mean=self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (128) + variance=self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (128) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_431_', + ) # (111, 128, 16, 16) + self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._08ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm23, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm22, # (111, 128, 16, 16) + name='relu_432_', + ), # (111, 128, 16, 16) + filters=self._08ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 128, 512) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_433_', + ), # (111, 512, 16, 16) + bias=self._08ParallelBlock_01SequentialBlock_07Conv2d_bias, # (512) + data_format='NCHW', + name='bias_add_434_', + ), # (111, 512, 16, 16) + scale=self._08ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (512) + offset=self._08ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (512) + mean=self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (512) + variance=self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (512) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_435_', + ) # (111, 512, 16, 16) + self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._08ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + val8 = tf.nn.relu( + tf.add( + batchnorm23, # (111, 512, 16, 16) + val7, # (111, 512, 16, 16) + name='add_436_', + ), # (111, 512, 16, 16) + name='relu_437_', + ) # (111, 512, 16, 16) + (batchnorm24, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val8, # (111, 512, 16, 16) + filters=self._09ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 512, 256) + strides=[2, 2], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_438_', + ), # (111, 256, 8, 8) + bias=self._09ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256) + data_format='NCHW', + name='bias_add_439_', + ), # (111, 256, 8, 8) + scale=self._09ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256) + offset=self._09ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256) + mean=self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256) + variance=self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_440_', + ) # (111, 256, 8, 8) + self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._09ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm25, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm24, # (111, 256, 8, 8) + name='relu_441_', + ), # (111, 256, 8, 8) + filters=self._09ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_442_', + ), # (111, 256, 8, 8) + scale=self._09ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256) + offset=self._09ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256) + mean=self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256) + variance=self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_443_', + ) # (111, 256, 8, 8) + self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._09ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm26, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm25, # (111, 256, 8, 8) + name='relu_444_', + ), # (111, 256, 8, 8) + filters=self._09ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_445_', + ), # (111, 1024, 8, 8) + bias=self._09ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024) + data_format='NCHW', + name='bias_add_446_', + ), # (111, 1024, 8, 8) + scale=self._09ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024) + offset=self._09ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024) + mean=self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024) + variance=self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_447_', + ) # (111, 1024, 8, 8) + self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._09ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + (batchnorm27, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + val8, # (111, 512, 16, 16) + filters=self._09ParallelBlock_02SequentialBlock_01Conv2d_weight, # (1, 1, 512, 1024) + strides=[2, 2], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_448_', + ), # (111, 1024, 8, 8) + scale=self._09ParallelBlock_02SequentialBlock_02BatchNorm_gamma, # (1024) + offset=self._09ParallelBlock_02SequentialBlock_02BatchNorm_beta, # (1024) + mean=self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningMean, # (1024) + variance=self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningVar, # (1024) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_449_', + ) # (111, 1024, 8, 8) + self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._09ParallelBlock_02SequentialBlock_02BatchNorm_runningVar.assign(running_var) + val9 = tf.nn.relu( + tf.add( + batchnorm26, # (111, 1024, 8, 8) + batchnorm27, # (111, 1024, 8, 8) + name='add_450_', + ), # (111, 1024, 8, 8) + name='relu_451_', + ) # (111, 1024, 8, 8) + (batchnorm28, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val9, # (111, 1024, 8, 8) + filters=self._10ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 256) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_452_', + ), # (111, 256, 8, 8) + bias=self._10ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256) + data_format='NCHW', + name='bias_add_453_', + ), # (111, 256, 8, 8) + scale=self._10ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256) + offset=self._10ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256) + mean=self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256) + variance=self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_454_', + ) # (111, 256, 8, 8) + self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._10ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm29, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm28, # (111, 256, 8, 8) + name='relu_455_', + ), # (111, 256, 8, 8) + filters=self._10ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_456_', + ), # (111, 256, 8, 8) + scale=self._10ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256) + offset=self._10ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256) + mean=self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256) + variance=self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_457_', + ) # (111, 256, 8, 8) + self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._10ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm30, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm29, # (111, 256, 8, 8) + name='relu_458_', + ), # (111, 256, 8, 8) + filters=self._10ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_459_', + ), # (111, 1024, 8, 8) + bias=self._10ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024) + data_format='NCHW', + name='bias_add_460_', + ), # (111, 1024, 8, 8) + scale=self._10ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024) + offset=self._10ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024) + mean=self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024) + variance=self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_461_', + ) # (111, 1024, 8, 8) + self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._10ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + val10 = tf.nn.relu( + tf.add( + batchnorm30, # (111, 1024, 8, 8) + val9, # (111, 1024, 8, 8) + name='add_462_', + ), # (111, 1024, 8, 8) + name='relu_463_', + ) # (111, 1024, 8, 8) + (batchnorm31, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val10, # (111, 1024, 8, 8) + filters=self._11ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 256) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_464_', + ), # (111, 256, 8, 8) + bias=self._11ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256) + data_format='NCHW', + name='bias_add_465_', + ), # (111, 256, 8, 8) + scale=self._11ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256) + offset=self._11ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256) + mean=self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256) + variance=self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_466_', + ) # (111, 256, 8, 8) + self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._11ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm32, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm31, # (111, 256, 8, 8) + name='relu_467_', + ), # (111, 256, 8, 8) + filters=self._11ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_468_', + ), # (111, 256, 8, 8) + scale=self._11ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256) + offset=self._11ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256) + mean=self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256) + variance=self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_469_', + ) # (111, 256, 8, 8) + self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._11ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm33, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm32, # (111, 256, 8, 8) + name='relu_470_', + ), # (111, 256, 8, 8) + filters=self._11ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_471_', + ), # (111, 1024, 8, 8) + bias=self._11ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024) + data_format='NCHW', + name='bias_add_472_', + ), # (111, 1024, 8, 8) + scale=self._11ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024) + offset=self._11ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024) + mean=self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024) + variance=self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_473_', + ) # (111, 1024, 8, 8) + self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._11ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + val11 = tf.nn.relu( + tf.add( + batchnorm33, # (111, 1024, 8, 8) + val10, # (111, 1024, 8, 8) + name='add_474_', + ), # (111, 1024, 8, 8) + name='relu_475_', + ) # (111, 1024, 8, 8) + (batchnorm34, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val11, # (111, 1024, 8, 8) + filters=self._12ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 256) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_476_', + ), # (111, 256, 8, 8) + bias=self._12ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256) + data_format='NCHW', + name='bias_add_477_', + ), # (111, 256, 8, 8) + scale=self._12ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256) + offset=self._12ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256) + mean=self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256) + variance=self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_478_', + ) # (111, 256, 8, 8) + self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._12ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm35, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm34, # (111, 256, 8, 8) + name='relu_479_', + ), # (111, 256, 8, 8) + filters=self._12ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_480_', + ), # (111, 256, 8, 8) + scale=self._12ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256) + offset=self._12ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256) + mean=self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256) + variance=self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_481_', + ) # (111, 256, 8, 8) + self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._12ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm36, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm35, # (111, 256, 8, 8) + name='relu_482_', + ), # (111, 256, 8, 8) + filters=self._12ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_483_', + ), # (111, 1024, 8, 8) + bias=self._12ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024) + data_format='NCHW', + name='bias_add_484_', + ), # (111, 1024, 8, 8) + scale=self._12ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024) + offset=self._12ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024) + mean=self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024) + variance=self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_485_', + ) # (111, 1024, 8, 8) + self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._12ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + val12 = tf.nn.relu( + tf.add( + batchnorm36, # (111, 1024, 8, 8) + val11, # (111, 1024, 8, 8) + name='add_486_', + ), # (111, 1024, 8, 8) + name='relu_487_', + ) # (111, 1024, 8, 8) + (batchnorm37, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val12, # (111, 1024, 8, 8) + filters=self._13ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 256) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_488_', + ), # (111, 256, 8, 8) + bias=self._13ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256) + data_format='NCHW', + name='bias_add_489_', + ), # (111, 256, 8, 8) + scale=self._13ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256) + offset=self._13ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256) + mean=self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256) + variance=self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_490_', + ) # (111, 256, 8, 8) + self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._13ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm38, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm37, # (111, 256, 8, 8) + name='relu_491_', + ), # (111, 256, 8, 8) + filters=self._13ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_492_', + ), # (111, 256, 8, 8) + scale=self._13ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256) + offset=self._13ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256) + mean=self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256) + variance=self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_493_', + ) # (111, 256, 8, 8) + self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._13ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm39, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm38, # (111, 256, 8, 8) + name='relu_494_', + ), # (111, 256, 8, 8) + filters=self._13ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_495_', + ), # (111, 1024, 8, 8) + bias=self._13ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024) + data_format='NCHW', + name='bias_add_496_', + ), # (111, 1024, 8, 8) + scale=self._13ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024) + offset=self._13ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024) + mean=self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024) + variance=self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_497_', + ) # (111, 1024, 8, 8) + self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._13ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + val13 = tf.nn.relu( + tf.add( + batchnorm39, # (111, 1024, 8, 8) + val12, # (111, 1024, 8, 8) + name='add_498_', + ), # (111, 1024, 8, 8) + name='relu_499_', + ) # (111, 1024, 8, 8) + (batchnorm40, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val13, # (111, 1024, 8, 8) + filters=self._14ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 256) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_500_', + ), # (111, 256, 8, 8) + bias=self._14ParallelBlock_01SequentialBlock_01Conv2d_bias, # (256) + data_format='NCHW', + name='bias_add_501_', + ), # (111, 256, 8, 8) + scale=self._14ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (256) + offset=self._14ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (256) + mean=self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (256) + variance=self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (256) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_502_', + ) # (111, 256, 8, 8) + self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._14ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm41, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm40, # (111, 256, 8, 8) + name='relu_503_', + ), # (111, 256, 8, 8) + filters=self._14ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 256, 256) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_504_', + ), # (111, 256, 8, 8) + scale=self._14ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (256) + offset=self._14ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (256) + mean=self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (256) + variance=self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (256) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_505_', + ) # (111, 256, 8, 8) + self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._14ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm42, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm41, # (111, 256, 8, 8) + name='relu_506_', + ), # (111, 256, 8, 8) + filters=self._14ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 256, 1024) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_507_', + ), # (111, 1024, 8, 8) + bias=self._14ParallelBlock_01SequentialBlock_07Conv2d_bias, # (1024) + data_format='NCHW', + name='bias_add_508_', + ), # (111, 1024, 8, 8) + scale=self._14ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (1024) + offset=self._14ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (1024) + mean=self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (1024) + variance=self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (1024) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_509_', + ) # (111, 1024, 8, 8) + self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._14ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + val14 = tf.nn.relu( + tf.add( + batchnorm42, # (111, 1024, 8, 8) + val13, # (111, 1024, 8, 8) + name='add_510_', + ), # (111, 1024, 8, 8) + name='relu_511_', + ) # (111, 1024, 8, 8) + (batchnorm43, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val14, # (111, 1024, 8, 8) + filters=self._15ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 512) + strides=[2, 2], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_512_', + ), # (111, 512, 4, 4) + bias=self._15ParallelBlock_01SequentialBlock_01Conv2d_bias, # (512) + data_format='NCHW', + name='bias_add_513_', + ), # (111, 512, 4, 4) + scale=self._15ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (512) + offset=self._15ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (512) + mean=self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (512) + variance=self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (512) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_514_', + ) # (111, 512, 4, 4) + self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._15ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm44, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm43, # (111, 512, 4, 4) + name='relu_515_', + ), # (111, 512, 4, 4) + filters=self._15ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 512, 512) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_516_', + ), # (111, 512, 4, 4) + scale=self._15ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (512) + offset=self._15ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (512) + mean=self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (512) + variance=self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (512) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_517_', + ) # (111, 512, 4, 4) + self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._15ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm45, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm44, # (111, 512, 4, 4) + name='relu_518_', + ), # (111, 512, 4, 4) + filters=self._15ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 512, 2048) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_519_', + ), # (111, 2048, 4, 4) + bias=self._15ParallelBlock_01SequentialBlock_07Conv2d_bias, # (2048) + data_format='NCHW', + name='bias_add_520_', + ), # (111, 2048, 4, 4) + scale=self._15ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (2048) + offset=self._15ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (2048) + mean=self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (2048) + variance=self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (2048) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_521_', + ) # (111, 2048, 4, 4) + self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._15ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + (batchnorm46, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + val14, # (111, 1024, 8, 8) + filters=self._15ParallelBlock_02SequentialBlock_01Conv2d_weight, # (1, 1, 1024, 2048) + strides=[2, 2], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_522_', + ), # (111, 2048, 4, 4) + scale=self._15ParallelBlock_02SequentialBlock_02BatchNorm_gamma, # (2048) + offset=self._15ParallelBlock_02SequentialBlock_02BatchNorm_beta, # (2048) + mean=self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningMean, # (2048) + variance=self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningVar, # (2048) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_523_', + ) # (111, 2048, 4, 4) + self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._15ParallelBlock_02SequentialBlock_02BatchNorm_runningVar.assign(running_var) + val15 = tf.nn.relu( + tf.add( + batchnorm45, # (111, 2048, 4, 4) + batchnorm46, # (111, 2048, 4, 4) + name='add_524_', + ), # (111, 2048, 4, 4) + name='relu_525_', + ) # (111, 2048, 4, 4) + (batchnorm47, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val15, # (111, 2048, 4, 4) + filters=self._16ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 2048, 512) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_526_', + ), # (111, 512, 4, 4) + bias=self._16ParallelBlock_01SequentialBlock_01Conv2d_bias, # (512) + data_format='NCHW', + name='bias_add_527_', + ), # (111, 512, 4, 4) + scale=self._16ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (512) + offset=self._16ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (512) + mean=self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (512) + variance=self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (512) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_528_', + ) # (111, 512, 4, 4) + self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._16ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm48, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm47, # (111, 512, 4, 4) + name='relu_529_', + ), # (111, 512, 4, 4) + filters=self._16ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 512, 512) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_530_', + ), # (111, 512, 4, 4) + scale=self._16ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (512) + offset=self._16ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (512) + mean=self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (512) + variance=self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (512) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_531_', + ) # (111, 512, 4, 4) + self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._16ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm49, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm48, # (111, 512, 4, 4) + name='relu_532_', + ), # (111, 512, 4, 4) + filters=self._16ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 512, 2048) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_533_', + ), # (111, 2048, 4, 4) + bias=self._16ParallelBlock_01SequentialBlock_07Conv2d_bias, # (2048) + data_format='NCHW', + name='bias_add_534_', + ), # (111, 2048, 4, 4) + scale=self._16ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (2048) + offset=self._16ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (2048) + mean=self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (2048) + variance=self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (2048) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_535_', + ) # (111, 2048, 4, 4) + self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._16ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + val16 = tf.nn.relu( + tf.add( + batchnorm49, # (111, 2048, 4, 4) + val15, # (111, 2048, 4, 4) + name='add_536_', + ), # (111, 2048, 4, 4) + name='relu_537_', + ) # (111, 2048, 4, 4) + (batchnorm50, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + val16, # (111, 2048, 4, 4) + filters=self._17ParallelBlock_01SequentialBlock_01Conv2d_weight, # (1, 1, 2048, 512) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_538_', + ), # (111, 512, 4, 4) + bias=self._17ParallelBlock_01SequentialBlock_01Conv2d_bias, # (512) + data_format='NCHW', + name='bias_add_539_', + ), # (111, 512, 4, 4) + scale=self._17ParallelBlock_01SequentialBlock_02BatchNorm_gamma, # (512) + offset=self._17ParallelBlock_01SequentialBlock_02BatchNorm_beta, # (512) + mean=self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningMean, # (512) + variance=self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningVar, # (512) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_540_', + ) # (111, 512, 4, 4) + self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningMean.assign(running_mean) + self._17ParallelBlock_01SequentialBlock_02BatchNorm_runningVar.assign(running_var) + (batchnorm51, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.convolution( + tf.nn.relu( + batchnorm50, # (111, 512, 4, 4) + name='relu_541_', + ), # (111, 512, 4, 4) + filters=self._17ParallelBlock_01SequentialBlock_04Conv2d_weight, # (3, 3, 512, 512) + strides=[1, 1], + padding='SAME', + dilations=[1, 1], + data_format='NCHW', + name='convolution_542_', + ), # (111, 512, 4, 4) + scale=self._17ParallelBlock_01SequentialBlock_05BatchNorm_gamma, # (512) + offset=self._17ParallelBlock_01SequentialBlock_05BatchNorm_beta, # (512) + mean=self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningMean, # (512) + variance=self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningVar, # (512) + epsilon=2.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_543_', + ) # (111, 512, 4, 4) + self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningMean.assign(running_mean) + self._17ParallelBlock_01SequentialBlock_05BatchNorm_runningVar.assign(running_var) + (batchnorm52, running_mean, running_var) = tf.compat.v1.nn.fused_batch_norm( + tf.nn.bias_add( + tf.nn.convolution( + tf.nn.relu( + batchnorm51, # (111, 512, 4, 4) + name='relu_544_', + ), # (111, 512, 4, 4) + filters=self._17ParallelBlock_01SequentialBlock_07Conv2d_weight, # (1, 1, 512, 2048) + strides=[1, 1], + padding='VALID', + dilations=[1, 1], + data_format='NCHW', + name='convolution_545_', + ), # (111, 2048, 4, 4) + bias=self._17ParallelBlock_01SequentialBlock_07Conv2d_bias, # (2048) + data_format='NCHW', + name='bias_add_546_', + ), # (111, 2048, 4, 4) + scale=self._17ParallelBlock_01SequentialBlock_08BatchNorm_gamma, # (2048) + offset=self._17ParallelBlock_01SequentialBlock_08BatchNorm_beta, # (2048) + mean=self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningMean, # (2048) + variance=self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningVar, # (2048) + epsilon=1.0E-5, + is_training=True, + exponential_avg_factor=0.9, + data_format='NCHW', + name='fused_batch_norm_547_', + ) # (111, 2048, 4, 4) + self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningMean.assign(running_mean) + self._17ParallelBlock_01SequentialBlock_08BatchNorm_runningVar.assign(running_var) + result = tf.reshape( + tf.nn.bias_add( + tf.matmul( + tf.reshape( + tf.reshape( + tf.reduce_mean( + tf.nn.relu( + tf.add( + batchnorm52, # (111, 2048, 4, 4) + val16, # (111, 2048, 4, 4) + name='add_548_', + ), # (111, 2048, 4, 4) + name='relu_549_', + ), # (111, 2048, 4, 4) + axis=[2, 3], + name='reduce_mean_550_', + ), # (111, 2048, 1, 1) + shape=[-1, 2048], + name='reshape_551_', + ), # (111, 2048) + shape=[-1, 2048], + name='reshape_552_', + ), # (111, 2048) + b=self._20Linear_weight, # (10, 2048) + transpose_b=True, + name='matmul_553_', + ), # (111, 10) + bias=self._20Linear_bias, # (10) + data_format=None, + name='bias_add_554_', + ), # (111, 10) + shape=[-1, 10], + name='reshape_555_', + ) # (111, 10) + return result + +## 2 +def loss(label, prediction): + result = tf.reduce_mean( + tf.negative( + tf.gather( + tf.nn.log_softmax( + prediction, # (111, 10) + axis=-1, + name='log_softmax_556_', + ), # (111, 10) + indices=label, # (111) + batch_dims=1, + name='gather_557_', + ), # (111, 1) + name='negative_558_', + ), # (111, 1) + name='reduce_mean_559_', + ) # () + return result + +# number of epochs was 2 +# number of prediction functions is 1 +# number of loss functions is 1 + diff --git a/examples/src/test/resources/yolov8_synset.txt b/examples/src/test/resources/yolov8_synset.txt new file mode 100644 index 00000000000..ffba2064933 --- /dev/null +++ b/examples/src/test/resources/yolov8_synset.txt @@ -0,0 +1,84 @@ +# Classes for coco dataset on which yelov8 is trained +# source config https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco.yaml. +# COCO dataset website: https://cocodataset.org/#home +# Ultralytics Coco doc page: https://docs.ultralytics.com/datasets/detect/coco/ +person +bicycle +car +motorbike +aeroplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +sofa +pottedplant +bed +diningtable +toilet +tvmonitor +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush \ No newline at end of file diff --git a/examples/src/test/resources/yolov8_test.jpg b/examples/src/test/resources/yolov8_test.jpg new file mode 100644 index 00000000000..01e43374348 Binary files /dev/null and b/examples/src/test/resources/yolov8_test.jpg differ diff --git a/examples/src/test/resources/yolov8n.onnx b/examples/src/test/resources/yolov8n.onnx new file mode 100644 index 00000000000..430f7f2beb0 Binary files /dev/null and b/examples/src/test/resources/yolov8n.onnx differ diff --git a/extensions/audio/README.md b/extensions/audio/README.md index 7e2c89692bc..95ed8c53a84 100644 --- a/extensions/audio/README.md +++ b/extensions/audio/README.md @@ -23,6 +23,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.audio audio - 0.23.0 + 0.26.0 ``` diff --git a/extensions/aws-ai/README.md b/extensions/aws-ai/README.md index 829df0bb0ca..16d412904c5 100644 --- a/extensions/aws-ai/README.md +++ b/extensions/aws-ai/README.md @@ -58,6 +58,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.aws aws-ai - 0.23.0 + 0.26.0 ``` diff --git a/extensions/fasttext/README.md b/extensions/fasttext/README.md index 6f5a25064ea..f0c60d39bf1 100644 --- a/extensions/fasttext/README.md +++ b/extensions/fasttext/README.md @@ -34,7 +34,7 @@ You can pull the fastText engine from the central Maven repository by including ai.djl.fasttext fasttext-engine - 0.23.0 + 0.26.0 ``` diff --git a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java index 5b421ff431f..4395ddf1a6c 100644 --- a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java +++ b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java @@ -41,6 +41,7 @@ import java.io.IOException; import java.io.InputStream; +import java.net.URI; import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; @@ -129,7 +130,9 @@ public void testWord2Vec() throws IOException, MalformedModelException, ModelNot public void testBlazingText() throws IOException, ModelException { TestRequirements.nightly(); - URL url = new URL("https://resources.djl.ai/test-models/blazingtext_classification.bin"); + URL url = + URI.create("https://resources.djl.ai/test-models/blazingtext_classification.bin") + .toURL(); Path path = Paths.get("build/tmp/model"); Path modelFile = path.resolve("text_classification.bin"); if (!Files.exists(modelFile)) { diff --git a/extensions/hadoop/README.md b/extensions/hadoop/README.md index b3c4ebcc762..8a376e22d85 100644 --- a/extensions/hadoop/README.md +++ b/extensions/hadoop/README.md @@ -52,6 +52,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.hadoop hadoop - 0.23.0 + 0.26.0 ``` diff --git a/extensions/opencv/README.md b/extensions/opencv/README.md index d6c58f518dc..c23e0c58532 100644 --- a/extensions/opencv/README.md +++ b/extensions/opencv/README.md @@ -23,6 +23,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.opencv opencv - 0.23.0 + 0.26.0 ``` diff --git a/extensions/sentencepiece/README.md b/extensions/sentencepiece/README.md index 4308308111f..de28d5334df 100644 --- a/extensions/sentencepiece/README.md +++ b/extensions/sentencepiece/README.md @@ -23,6 +23,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.sentencepiece sentencepiece - 0.23.0 + 0.26.0 ``` diff --git a/extensions/spark/README.md b/extensions/spark/README.md index 02ebcc07a1d..da3171ca008 100644 --- a/extensions/spark/README.md +++ b/extensions/spark/README.md @@ -34,7 +34,7 @@ You can pull the module from the central Maven repository by including the follo ai.djl.spark spark_2.12 - 0.23.0 + 0.26.0 ``` diff --git a/extensions/tablesaw/README.md b/extensions/tablesaw/README.md index 010c6395eb9..b4287d9733d 100644 --- a/extensions/tablesaw/README.md +++ b/extensions/tablesaw/README.md @@ -25,6 +25,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.tablesaw tablesaw - 0.23.0 + 0.26.0 ``` diff --git a/extensions/timeseries/README.md b/extensions/timeseries/README.md index 9706c9334a4..3ef6887825c 100644 --- a/extensions/timeseries/README.md +++ b/extensions/timeseries/README.md @@ -245,6 +245,6 @@ You can pull the module from the central Maven repository by including the follo ai.djl.timeseries timeseries - 0.23.0 + 0.26.0 ``` diff --git a/extensions/timeseries/docs/forecast_with_M5_data.md b/extensions/timeseries/docs/forecast_with_M5_data.md index a4f1a24a1d9..7b8e1c78210 100644 --- a/extensions/timeseries/docs/forecast_with_M5_data.md +++ b/extensions/timeseries/docs/forecast_with_M5_data.md @@ -1,5 +1,7 @@ # Forecast the future in a timeseries data with Deep Java Library (DJL) + ## -- Demonstration on M5forecasting and airpassenger datasests + Junyuan Zhang, Kexin Feng Time series data are commonly seen in the world. They can contain valued information that helps forecast for the future, monitor the status of a procedure and feedforward a control. Generic applications includes the following: sales forecasting, stock market analysis, yield projections, process and quality control, and many many more. See [link1](https://www.itl.nist.gov/div898/handbook/pmc/section4/pmc41.htm) and [link2](https://www.influxdata.com/time-series-forecasting-methods/#:~:text=Time%20series%20forecasting%20means%20to,on%20what%20has%20already%20happened) for further examples of timeseries data. @@ -54,7 +56,7 @@ repositories { } dependencies { implementation "org.apache.logging.log4j:log4j-slf4j-impl:2.17.1" - implementation platform("ai.djl:bom:0.23.0") + implementation platform("ai.djl:bom:0.26.0") implementation "ai.djl:api" implementation "ai.djl.timeseries" runtimeOnly "ai.djl.mxnet:mxnet-engine" diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java index 5b642285c3e..9edb45ff5f0 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java @@ -94,15 +94,23 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { Pair update = evaluateHelper(labels, predictions); - totalInstances.compute(key, (k, v) -> v + update.getKey()); - totalLoss.compute( - key, - (k, v) -> { - try (NDArray array = update.getValue().sum()) { - return v + array.getFloat(); - } - }); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + update.getKey()); + totalLoss.compute( + key, + (k, v) -> { + try (NDArray array = update.getValue().sum()) { + return v + array.getFloat(); + } + }); + } } /** {@inheritDoc} */ diff --git a/extensions/tokenizers/README.md b/extensions/tokenizers/README.md index 1b85625572c..2cdf4f19137 100644 --- a/extensions/tokenizers/README.md +++ b/extensions/tokenizers/README.md @@ -23,7 +23,7 @@ You can pull the module from the central Maven repository by including the follo ai.djl.huggingface tokenizers - 0.23.0 + 0.26.0 ``` diff --git a/extensions/tokenizers/build.cmd b/extensions/tokenizers/build.cmd index 3a481d33bab..d83f2c1ed74 100644 --- a/extensions/tokenizers/build.cmd +++ b/extensions/tokenizers/build.cmd @@ -3,7 +3,7 @@ @rem choco install rust -y @rem choco install jdk8 -y -set VERSION=python-v"%1" +set VERSION=v"%1" if exist "tokenizers" ( echo Found "tokenizers" diff --git a/extensions/tokenizers/build.sh b/extensions/tokenizers/build.sh index 4ba45a09965..229e8124914 100755 --- a/extensions/tokenizers/build.sh +++ b/extensions/tokenizers/build.sh @@ -10,7 +10,7 @@ elif [[ -n $(command -v sysctl) ]]; then fi PLATFORM=$(uname | tr '[:upper:]' '[:lower:]') -VERSION=python-v$1 +VERSION=v$1 ARCH=$2 pushd $WORK_DIR diff --git a/extensions/tokenizers/rust/Cargo.toml b/extensions/tokenizers/rust/Cargo.toml index f6b846f636c..3418c8f5129 100644 --- a/extensions/tokenizers/rust/Cargo.toml +++ b/extensions/tokenizers/rust/Cargo.toml @@ -6,7 +6,7 @@ edition = "2018" [dependencies] jni = "0.19.0" -tokenizers = { path = "../tokenizers/tokenizers", version = "*" } +tokenizers = { path = "../tokenizers/tokenizers", version = "*", features = ["http"] } [target.'cfg(target_os = "linux")'.dependencies] openssl = { version = "0.10", features = ["vendored"] } diff --git a/extensions/tokenizers/rust/src/lib.rs b/extensions/tokenizers/rust/src/lib.rs index d1c0c455c19..590099c2ecf 100644 --- a/extensions/tokenizers/rust/src/lib.rs +++ b/extensions/tokenizers/rust/src/lib.rs @@ -490,7 +490,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } } let decoding: String = tokenizer - .decode(decode_ids, skip_special_tokens == JNI_TRUE) + .decode(&*decode_ids, skip_special_tokens == JNI_TRUE) .unwrap(); let ret = env .new_string(decoding) @@ -527,8 +527,12 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ } batch_decode_input.push(decode_ids); } + let mut references: Vec<&[u32]> = Vec::new(); + for reference in batch_decode_input.iter() { + references.push(reference); + } let decoding: Vec = tokenizer - .decode_batch(batch_decode_input, skip_special_tokens == JNI_TRUE) + .decode_batch(&references, skip_special_tokens == JNI_TRUE) .unwrap(); let ret: jobjectArray = env .new_object_array(batch_len, "java/lang/String", JObject::null()) diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java index e58d6ada5ee..887f01646dc 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java @@ -27,6 +27,7 @@ public class Encoding { private long[] specialTokenMask; private CharSpan[] charTokenSpans; private Encoding[] overflowing; + private boolean exceedMaxLength; protected Encoding( long[] ids, @@ -36,6 +37,7 @@ protected Encoding( long[] attentionMask, long[] specialTokenMask, CharSpan[] charTokenSpans, + boolean exceedMaxLength, Encoding[] overflowing) { this.ids = ids; this.typeIds = typeIds; @@ -44,6 +46,7 @@ protected Encoding( this.attentionMask = attentionMask; this.specialTokenMask = specialTokenMask; this.charTokenSpans = charTokenSpans; + this.exceedMaxLength = exceedMaxLength; this.overflowing = overflowing; } @@ -127,6 +130,15 @@ public CharSpan[] getCharTokenSpans() { return charTokenSpans; } + /** + * Returns if tokens exceed max length. + * + * @return {@code true} if tokens exceed max length + */ + public boolean exceedMaxLength() { + return exceedMaxLength; + } + /** * Returns an array of overflowing encodings. * diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index f75342b7cb8..ba4d61b79b1 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -44,6 +44,7 @@ public final class HuggingFaceTokenizer extends NativeResource implements private static final Logger logger = LoggerFactory.getLogger(HuggingFaceTokenizer.class); private boolean addSpecialTokens; + private boolean withOverflowingTokens; private TruncationStrategy truncation; private PaddingStrategy padding; private int maxLength; @@ -64,6 +65,8 @@ private HuggingFaceTokenizer(long handle, Map options) { if (options != null) { val = options.getOrDefault("addSpecialTokens", "true"); addSpecialTokens = Boolean.parseBoolean(val); + val = options.getOrDefault("withOverflowingTokens", "false"); + withOverflowingTokens = Boolean.parseBoolean(val); modelMaxLength = ArgumentsUtil.intValue(options, "modelMaxLength", 512); if (options.containsKey("truncation")) { truncation = TruncationStrategy.fromValue(options.get("truncation")); @@ -203,11 +206,12 @@ public void close() { * @param text the input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence */ - public Encoding encode(String text, boolean addSpecialTokens) { + public Encoding encode(String text, boolean addSpecialTokens, boolean withOverflowingTokens) { long encoding = TokenizersLibrary.LIB.encode(getHandle(), text, addSpecialTokens); - return toEncoding(encoding); + return toEncoding(encoding, withOverflowingTokens); } /** @@ -217,7 +221,7 @@ public Encoding encode(String text, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence */ public Encoding encode(String text) { - return encode(text, addSpecialTokens); + return encode(text, addSpecialTokens, withOverflowingTokens); } /** @@ -227,12 +231,14 @@ public Encoding encode(String text) { * @param textPair the second input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence */ - public Encoding encode(String text, String textPair, boolean addSpecialTokens) { + public Encoding encode( + String text, String textPair, boolean addSpecialTokens, boolean withOverflowingTokens) { long encoding = TokenizersLibrary.LIB.encodeDual(getHandle(), text, textPair, addSpecialTokens); - return toEncoding(encoding); + return toEncoding(encoding, withOverflowingTokens); } /** @@ -243,7 +249,7 @@ public Encoding encode(String text, String textPair, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence */ public Encoding encode(String text, String textPair) { - return encode(text, textPair, addSpecialTokens); + return encode(text, textPair, addSpecialTokens, withOverflowingTokens); } /** @@ -252,11 +258,13 @@ public Encoding encode(String text, String textPair) { * @param inputs the input sentences * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentences */ - public Encoding encode(List inputs, boolean addSpecialTokens) { + public Encoding encode( + List inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { String[] array = inputs.toArray(Utils.EMPTY_ARRAY); - return encode(array, addSpecialTokens); + return encode(array, addSpecialTokens, withOverflowingTokens); } /** @@ -266,7 +274,7 @@ public Encoding encode(List inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentences */ public Encoding encode(List inputs) { - return encode(inputs, addSpecialTokens); + return encode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -275,11 +283,13 @@ public Encoding encode(List inputs) { * @param inputs the input sentences * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentences */ - public Encoding encode(String[] inputs, boolean addSpecialTokens) { + public Encoding encode( + String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { long encoding = TokenizersLibrary.LIB.encodeList(getHandle(), inputs, addSpecialTokens); - return toEncoding(encoding); + return toEncoding(encoding, withOverflowingTokens); } /** @@ -289,7 +299,7 @@ public Encoding encode(String[] inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentences */ public Encoding encode(String[] inputs) { - return encode(inputs, addSpecialTokens); + return encode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -298,11 +308,13 @@ public Encoding encode(String[] inputs) { * @param inputs the batch of input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence in batch */ - public Encoding[] batchEncode(List inputs, boolean addSpecialTokens) { + public Encoding[] batchEncode( + List inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { String[] array = inputs.toArray(Utils.EMPTY_ARRAY); - return batchEncode(array, addSpecialTokens); + return batchEncode(array, addSpecialTokens, withOverflowingTokens); } /** @@ -312,7 +324,7 @@ public Encoding[] batchEncode(List inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence in batch */ public Encoding[] batchEncode(List inputs) { - return batchEncode(inputs, addSpecialTokens); + return batchEncode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -321,13 +333,15 @@ public Encoding[] batchEncode(List inputs) { * @param inputs the batch of input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence in batch */ - public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) { + public Encoding[] batchEncode( + String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { long[] encodings = TokenizersLibrary.LIB.batchEncode(getHandle(), inputs, addSpecialTokens); Encoding[] ret = new Encoding[encodings.length]; for (int i = 0; i < encodings.length; ++i) { - ret[i] = toEncoding(encodings[i]); + ret[i] = toEncoding(encodings[i], withOverflowingTokens); } return ret; } @@ -339,7 +353,7 @@ public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence in batch */ public Encoding[] batchEncode(String[] inputs) { - return batchEncode(inputs, addSpecialTokens); + return batchEncode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -348,9 +362,13 @@ public Encoding[] batchEncode(String[] inputs) { * @param inputs the batch of input text pair * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input text pair in batch */ - public Encoding[] batchEncode(PairList inputs, boolean addSpecialTokens) { + public Encoding[] batchEncode( + PairList inputs, + boolean addSpecialTokens, + boolean withOverflowingTokens) { String[] text = inputs.keyArray(Utils.EMPTY_ARRAY); String[] textPair = inputs.valueArray(Utils.EMPTY_ARRAY); long[] encodings = @@ -358,7 +376,7 @@ public Encoding[] batchEncode(PairList inputs, boolean addSpecia getHandle(), text, textPair, addSpecialTokens); Encoding[] ret = new Encoding[encodings.length]; for (int i = 0; i < encodings.length; ++i) { - ret[i] = toEncoding(encodings[i]); + ret[i] = toEncoding(encodings[i], withOverflowingTokens); } return ret; } @@ -370,7 +388,7 @@ public Encoding[] batchEncode(PairList inputs, boolean addSpecia * @return the {@code Encoding} of the input text pair in batch */ public Encoding[] batchEncode(PairList inputs) { - return batchEncode(inputs, addSpecialTokens); + return batchEncode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -431,6 +449,53 @@ public void enableBatch() { } } + /** + * Returns the truncation policy. + * + * @return the truncation policy + */ + public String getTruncation() { + return truncation.name(); + } + + /** + * Returns the padding policy. + * + * @return the padding policy + */ + public String getPadding() { + return padding.name(); + } + + /** + * Returns the max token length. + * + * @return the max token length + */ + public int getMaxLength() { + return maxLength; + } + + /** + * Returns the stride to use in overflow overlap when truncating sequences longer than the model + * supports. + * + * @return the stride to use in overflow overlap when truncating sequences longer than the model + * supports + */ + public int getStride() { + return stride; + } + + /** + * Returns the padToMultipleOf for padding. + * + * @return the padToMultipleOf for padding + */ + public int getPadToMultipleOf() { + return padToMultipleOf; + } + /** * Creates a builder to build a {@code HuggingFaceTokenizer}. * @@ -503,7 +568,7 @@ private void updateTruncationAndPadding() { } } - private Encoding toEncoding(long encoding) { + private Encoding toEncoding(long encoding, boolean withOverflowingTokens) { long[] ids = TokenizersLibrary.LIB.getTokenIds(encoding); long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding); String[] tokens = TokenizersLibrary.LIB.getTokens(encoding); @@ -511,11 +576,17 @@ private Encoding toEncoding(long encoding) { long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding); long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding); CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding); - long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding); - Encoding[] overflowing = new Encoding[overflowingHandles.length]; - for (int i = 0; i < overflowingHandles.length; ++i) { - overflowing[i] = toEncoding(overflowingHandles[i]); + long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding); + boolean exceedMaxLength = overflowingHandles.length > 0; + Encoding[] overflowing; + if (withOverflowingTokens) { + overflowing = new Encoding[overflowingHandles.length]; + for (int i = 0; i < overflowingHandles.length; ++i) { + overflowing[i] = toEncoding(overflowingHandles[i], true); + } + } else { + overflowing = new Encoding[0]; } TokenizersLibrary.LIB.deleteEncoding(encoding); @@ -527,6 +598,7 @@ private Encoding toEncoding(long encoding) { attentionMask, specialTokenMask, charSpans, + exceedMaxLength, overflowing); } @@ -651,6 +723,17 @@ public Builder optAddSpecialTokens(boolean addSpecialTokens) { return this; } + /** + * Sets if add special tokens. + * + * @param withOverflowingTokens true to return overflowing tokens + * @return this builder + */ + public Builder optWithOverflowingTokens(boolean withOverflowingTokens) { + options.put("withOverflowingTokens", String.valueOf(withOverflowingTokens)); + return this; + } + /** * Enables or Disables default truncation behavior for the tokenizer. * @@ -787,7 +870,7 @@ public HuggingFaceTokenizer build() throws IOException { return managed(HuggingFaceTokenizer.newInstance(vocab, merges, options)); } throw new IOException("tokenizer.json file not found."); - } else if (Files.exists(tokenizerPath)) { + } else if (!Files.exists(tokenizerPath)) { throw new IOException("Tokenizer file not exits: " + tokenizerPath); } return managed(HuggingFaceTokenizer.newInstance(tokenizerPath, options)); diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java new file mode 100644 index 00000000000..6f43c7cb480 --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.translator; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.translate.Batchifier; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.PairList; +import ai.djl.util.StringPair; + +import java.util.Arrays; + +/** The translator for Huggingface cross encoder model. */ +public class CrossEncoderBatchTranslator implements NoBatchifyTranslator { + + private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; + private Batchifier batchifier; + + CrossEncoderBatchTranslator( + HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { + this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; + this.batchifier = batchifier; + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, StringPair[] inputs) + throws TranslateException { + NDManager manager = ctx.getNDManager(); + PairList list = new PairList<>(Arrays.asList(inputs)); + Encoding[] encodings = tokenizer.batchEncode(list); + NDList[] batch = new NDList[encodings.length]; + for (int i = 0; i < encodings.length; ++i) { + batch[i] = encodings[i].toNDList(manager, includeTokenTypes); + } + return batchifier.batchify(batch); + } + + /** {@inheritDoc} */ + @Override + public float[][] processOutput(TranslatorContext ctx, NDList list) { + NDList[] batch = batchifier.unbatchify(list); + float[][] ret = new float[batch.length][]; + for (int i = 0; i < batch.length; ++i) { + NDArray logits = list.get(0); + NDArray result = logits.getNDArrayInternal().sigmoid(); + ret[i] = result.toFloatArray(); + } + return ret; + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java new file mode 100644 index 00000000000..b88347bc60e --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java @@ -0,0 +1,149 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.translator; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.translate.ArgumentsUtil; +import ai.djl.translate.Batchifier; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.StringPair; + +import java.io.IOException; +import java.util.Map; + +/** The translator for Huggingface cross encoder model. */ +public class CrossEncoderTranslator implements Translator { + + private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; + private Batchifier batchifier; + + CrossEncoderTranslator( + HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { + this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; + this.batchifier = batchifier; + } + + /** {@inheritDoc} */ + @Override + public Batchifier getBatchifier() { + return batchifier; + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, StringPair input) { + Encoding encoding = tokenizer.encode(input.getKey(), input.getValue()); + ctx.setAttachment("encoding", encoding); + return encoding.toNDList(ctx.getNDManager(), includeTokenTypes); + } + + /** {@inheritDoc} */ + @Override + public float[] processOutput(TranslatorContext ctx, NDList list) { + NDArray logits = list.get(0); + NDArray result = logits.getNDArrayInternal().sigmoid(); + return result.toFloatArray(); + } + + /** {@inheritDoc} */ + @Override + public CrossEncoderBatchTranslator toBatchTranslator(Batchifier batchifier) { + tokenizer.enableBatch(); + return new CrossEncoderBatchTranslator(tokenizer, includeTokenTypes, batchifier); + } + + /** + * Creates a builder to build a {@code CrossEncoderTranslator}. + * + * @param tokenizer the tokenizer + * @return a new builder + */ + public static Builder builder(HuggingFaceTokenizer tokenizer) { + return new Builder(tokenizer); + } + + /** + * Creates a builder to build a {@code CrossEncoderTranslator}. + * + * @param tokenizer the tokenizer + * @param arguments the models' arguments + * @return a new builder + */ + public static Builder builder(HuggingFaceTokenizer tokenizer, Map arguments) { + Builder builder = builder(tokenizer); + builder.configure(arguments); + + return builder; + } + + /** The builder for question answering translator. */ + public static final class Builder { + + private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; + private Batchifier batchifier = Batchifier.STACK; + + Builder(HuggingFaceTokenizer tokenizer) { + this.tokenizer = tokenizer; + } + + /** + * Sets if include token types for the {@link Translator}. + * + * @param includeTokenTypes true to include token types + * @return this builder + */ + public Builder optIncludeTokenTypes(boolean includeTokenTypes) { + this.includeTokenTypes = includeTokenTypes; + return this; + } + + /** + * Sets the {@link Batchifier} for the {@link Translator}. + * + * @param batchifier true to include token types + * @return this builder + */ + public Builder optBatchifier(Batchifier batchifier) { + this.batchifier = batchifier; + return this; + } + + /** + * Configures the builder with the model arguments. + * + * @param arguments the model arguments + */ + public void configure(Map arguments) { + optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); + String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); + optBatchifier(Batchifier.fromString(batchifierStr)); + } + + /** + * Builds the translator. + * + * @return the new translator + * @throws IOException if I/O error occurs + */ + public CrossEncoderTranslator build() throws IOException { + return new CrossEncoderTranslator(tokenizer, includeTokenTypes, batchifier); + } + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java new file mode 100644 index 00000000000..f4f9af02c4b --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java @@ -0,0 +1,80 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.translator; + +import ai.djl.Model; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.nlp.translator.CrossEncoderServingTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; +import ai.djl.util.StringPair; + +import java.io.IOException; +import java.io.Serializable; +import java.lang.reflect.Type; +import java.nio.file.Path; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** A {@link TranslatorFactory} that creates a {@link CrossEncoderTranslatorFactory} instance. */ +public class CrossEncoderTranslatorFactory implements TranslatorFactory, Serializable { + + private static final long serialVersionUID = 1L; + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(StringPair.class, float[].class)); + SUPPORTED_TYPES.add(new Pair<>(StringPair[].class, float[][].class)); + SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public Translator newInstance( + Class input, Class output, Model model, Map arguments) + throws TranslateException { + Path modelPath = model.getModelPath(); + try { + HuggingFaceTokenizer tokenizer = + HuggingFaceTokenizer.builder(arguments) + .optTokenizerPath(modelPath) + .optManager(model.getNDManager()) + .build(); + CrossEncoderTranslator translator = + CrossEncoderTranslator.builder(tokenizer, arguments).build(); + if (input == StringPair.class && output == float[].class) { + return (Translator) translator; + } else if (input == StringPair[].class && output == float[][].class) { + return (Translator) translator.toBatchTranslator(); + } else if (input == Input.class && output == Output.class) { + return (Translator) new CrossEncoderServingTranslator(translator); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } catch (IOException e) { + throw new TranslateException("Failed to load tokenizer.", e); + } + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java index 43b120cac43..9a4ccba42b5 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java @@ -37,7 +37,7 @@ public class FillMaskBatchTranslator implements NoBatchifyTranslator { this.maskToken = maskToken; this.topK = topK; this.batchifier = batchifier; - Encoding encoding = tokenizer.encode(maskToken, false); + Encoding encoding = tokenizer.encode(maskToken, false, false); maskTokenId = encoding.getIds()[0]; } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java index 6dc1a4ed454..326a641dee0 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java @@ -113,7 +113,7 @@ private static NDArray meanPool(NDArray embeddings, NDArray attentionMask, boole long[] shape = embeddings.getShape().getShape(); attentionMask = attentionMask.expandDims(-1).broadcast(shape); NDArray inputAttentionMaskSum = attentionMask.sum(AXIS); - NDArray clamp = inputAttentionMaskSum.clip(1e-9, 1e12); + NDArray clamp = inputAttentionMaskSum.clip(1e-9f, 1e12f); NDArray prod = embeddings.mul(attentionMask); NDArray sum = prod.sum(AXIS); if (sqrt) { diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/zoo/HfModelZoo.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/zoo/HfModelZoo.java index 9ee8fc19cf8..a0099073c97 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/zoo/HfModelZoo.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/zoo/HfModelZoo.java @@ -131,8 +131,7 @@ private Map> listModels(Application app) { Path file = dir.resolve("models.json"); if (Files.exists(file)) { long lastModified = Files.getLastModifiedTime(file).toMillis(); - if (Boolean.getBoolean("offline") - || System.currentTimeMillis() - lastModified < ONE_DAY) { + if (Utils.isOfflineMode() || System.currentTimeMillis() - lastModified < ONE_DAY) { try (Reader reader = Files.newBufferedReader(file)) { return JsonUtils.GSON.fromJson(reader, type); } diff --git a/extensions/tokenizers/src/main/python/huggingface_converter.py b/extensions/tokenizers/src/main/python/huggingface_converter.py index f3b85c241ec..1efabb61c14 100644 --- a/extensions/tokenizers/src/main/python/huggingface_converter.py +++ b/extensions/tokenizers/src/main/python/huggingface_converter.py @@ -41,9 +41,20 @@ def save_model(self, model_info, args: Namespace, temp_dir: str): if not os.path.exists(temp_dir): os.makedirs(temp_dir) - hf_pipeline = self.load_model(model_id) - # Save tokenizer.json to temp dir - self.save_tokenizer(hf_pipeline, temp_dir) + try: + hf_pipeline = self.load_model(model_id) + except Exception as e: + logging.warning(f"Failed to load model: {model_id}.") + logging.warning(e, exc_info=True) + return False, "Failed to load model", -1 + + try: + # Save tokenizer.json to temp dir + self.save_tokenizer(hf_pipeline, temp_dir) + except Exception as e: + logging.warning(f"Failed to save tokenizer: {model_id}.") + logging.warning(e, exc_info=True) + return False, "Failed to save tokenizer", -1 # Save config.json just for reference config = hf_hub_download(repo_id=model_id, filename="config.json") @@ -112,7 +123,7 @@ def jit_trace_model(self, hf_pipeline, model_id: str, temp_dir: str, logging.info(f"Saving torchscript model: {model_name}.pt ...") model_file = os.path.join(temp_dir, f"{model_name}.pt") script_module.save(model_file) - except (RuntimeError, ValueError) as e: + except Exception as e: logging.warning(f"Failed to trace model: {model_id}.") logging.warning(e, exc_info=True) return None diff --git a/extensions/tokenizers/src/main/python/huggingface_models.py b/extensions/tokenizers/src/main/python/huggingface_models.py index 3418815d5c4..5b1c6debe5d 100644 --- a/extensions/tokenizers/src/main/python/huggingface_models.py +++ b/extensions/tokenizers/src/main/python/huggingface_models.py @@ -16,7 +16,7 @@ from argparse import Namespace from typing import List -from huggingface_hub import HfApi, ModelSearchArguments +from huggingface_hub import HfApi from huggingface_hub import hf_hub_download from huggingface_hub.hf_api import ModelInfo @@ -27,7 +27,7 @@ "ForMultipleChoice": "text-classification", "ForMaskedLM": "fill-mask", } -LANGUAGES = ModelSearchArguments().language +LANGUAGES = HfApi().get_model_tags()["language"] def get_lang_tags(model_info): @@ -56,23 +56,32 @@ def __init__(self, output_dir: str): self.temp_dir = f"{self.output_dir}/tmp" def list_models(self, args: Namespace) -> List[dict]: + import_all = os.environ.get("HF_IMPORT_ALL") + api = HfApi() if args.model_name: - models = api.list_models(filter="pytorch", - search=args.model_name, - sort="downloads", - direction=-1, - limit=args.limit) - if not models: - logging.warning(f"no model found: {args.model_name}.") + all_models = api.list_models(search=args.model_name, + sort="downloads", + direction=-1, + limit=args.limit) + import_all = True else: - models = api.list_models(filter=f"{args.category},pytorch", - sort="downloads", - direction=-1, - limit=args.limit) - if not models: + all_models = api.list_models(filter=args.category, + sort="downloads", + direction=-1, + limit=args.limit) + models = [ + model for model in all_models + if 'pytorch' in model.tags or 'safetensors' in model.tags + ] + if not models: + if args.model_name: + logging.warning(f"no model found: {args.model_name}.") + else: logging.warning(f"no model matches category: {args.category}.") + return [] + ret = [] for model_info in models: model_id = model_info.modelId @@ -83,7 +92,7 @@ def list_models(self, args: Namespace) -> List[dict]: continue languages = get_lang_tags(model_info) - if "en" not in languages: + if "en" not in languages and not import_all: logging.warning(f"Skip non-English model: {model_id}.") continue @@ -94,6 +103,12 @@ def list_models(self, args: Namespace) -> List[dict]: logging.info(f"Skip converted model: {model_id}.") continue + if model_info.downloads < 50 and not import_all: + logging.info( + f"Skip model {model_info.modelId}, downloads {model_info.downloads} < 50" + ) + continue + try: config = hf_hub_download(repo_id=model_id, filename="config.json") diff --git a/extensions/tokenizers/src/main/python/model_zoo_importer.py b/extensions/tokenizers/src/main/python/model_zoo_importer.py index 9ed32ec58ef..0ed67bd1018 100644 --- a/extensions/tokenizers/src/main/python/model_zoo_importer.py +++ b/extensions/tokenizers/src/main/python/model_zoo_importer.py @@ -49,9 +49,17 @@ def main(): model_info = model["model_info"] converter = SUPPORTED_TASK[task] - result, reason, size = converter.save_model(model_info, args, temp_dir) - if not result: - logging.error(f"{model_info.modelId}: {reason}") + try: + result, reason, size = converter.save_model( + model_info, args, temp_dir) + if not result: + logging.error(f"{model_info.modelId}: {reason}") + except Exception as e: + logging.warning(f"Failed to convert model: {model_info.modelId}.") + logging.warning(e, exc_info=True) + result = False + reason = "Failed to convert model" + size = -1 huggingface_models.update_progress(model_info, converter.application, result, reason, size, args.cpu_only) diff --git a/extensions/tokenizers/src/main/python/requirements.txt b/extensions/tokenizers/src/main/python/requirements.txt index bf197b644ea..05ce0bc4833 100644 --- a/extensions/tokenizers/src/main/python/requirements.txt +++ b/extensions/tokenizers/src/main/python/requirements.txt @@ -1,4 +1,4 @@ huggingface_hub transformers -torch==1.11.0 +torch protobuf==3.20.2 diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java new file mode 100644 index 00000000000..f3ee102e325 --- /dev/null +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java @@ -0,0 +1,204 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.huggingface.tokenizers; + +import ai.djl.Model; +import ai.djl.ModelException; +import ai.djl.huggingface.translator.CrossEncoderTranslatorFactory; +import ai.djl.inference.Predictor; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.nn.Block; +import ai.djl.nn.LambdaBlock; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.TranslateException; +import ai.djl.util.JsonUtils; +import ai.djl.util.StringPair; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +public class CrossEncoderTranslatorTest { + + @Test + public void testCrossEncoderTranslator() + throws ModelException, IOException, TranslateException { + String text1 = "Sentence 1"; + String text2 = "Sentence 2"; + Block block = + new LambdaBlock( + a -> { + NDManager manager = a.getManager(); + NDArray array = manager.create(new float[] {-0.7329f}); + return new NDList(array); + }, + "model"); + Path modelDir = Paths.get("build/model"); + Files.createDirectories(modelDir); + + Criteria criteria = + Criteria.builder() + .setTypes(StringPair.class, float[].class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + StringPair input = new StringPair(text1, text2); + float[] res = predictor.predict(input); + Assert.assertEquals(res[0], 0.32456556f, 0.0001); + } + + Criteria criteria2 = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria2.loadModel(); + Predictor predictor = model.newPredictor()) { + Input input = new Input(); + input.add("key", text1); + input.add("value", text2); + Output res = predictor.predict(input); + float[] buf = (float[]) res.getData().getAsObject(); + Assert.assertEquals(buf[0], 0.32455865, 0.0001); + + Assert.assertThrows(TranslateException.class, () -> predictor.predict(new Input())); + + Assert.assertThrows( + TranslateException.class, + () -> { + Input req = new Input(); + req.add("something", "false"); + predictor.predict(req); + }); + + Assert.assertThrows( + TranslateException.class, + () -> { + Input req = new Input(); + req.addProperty("Content-Type", "application/json"); + req.add("Invalid json"); + predictor.predict(req); + }); + + Assert.assertThrows( + TranslateException.class, + () -> { + Input req = new Input(); + req.addProperty("Content-Type", "application/json"); + req.add(JsonUtils.GSON.toJson(new StringPair(text1, null))); + predictor.predict(req); + }); + } + + try (Model model = Model.newInstance("test")) { + model.setBlock(block); + Map options = new HashMap<>(); + options.put("hasParameter", "false"); + model.load(modelDir, "test", options); + + CrossEncoderTranslatorFactory factory = new CrossEncoderTranslatorFactory(); + Map arguments = new HashMap<>(); + + Assert.assertThrows( + TranslateException.class, + () -> factory.newInstance(String.class, Integer.class, model, arguments)); + + arguments.put("tokenizer", "bert-base-cased"); + + Assert.assertThrows( + IllegalArgumentException.class, + () -> factory.newInstance(String.class, Integer.class, model, arguments)); + } + } + + @Test + public void testCrossEncoderBatchTranslator() + throws ModelException, IOException, TranslateException { + StringPair pair1 = new StringPair("Sentence 1", "Sentence 2"); + StringPair pair2 = new StringPair("Sentence 3", "Sentence 4"); + + Block block = + new LambdaBlock( + a -> { + NDManager manager = a.getManager(); + NDArray array = manager.create(new float[][] {{-0.7329f}, {-0.7329f}}); + return new NDList(array); + }, + "model"); + Path modelDir = Paths.get("build/model"); + Files.createDirectories(modelDir); + + Criteria criteria = + Criteria.builder() + .setTypes(StringPair[].class, float[][].class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + StringPair[] inputs = {pair1, pair2}; + float[][] res = predictor.predict(inputs); + Assert.assertEquals(res[1][0], 0.32455865, 0.0001); + } + + Criteria criteria2 = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optOption("hasParameter", "false") + .optTranslatorFactory(new CrossEncoderTranslatorFactory()) + .build(); + + try (ZooModel model = criteria2.loadModel(); + Predictor predictor = model.newPredictor()) { + Input input = new Input(); + input.add(JsonUtils.GSON.toJson(new StringPair[] {pair1, pair2})); + input.addProperty("Content-Type", "application/json"); + Output out = predictor.predict(input); + float[][] buf = (float[][]) out.getData().getAsObject(); + Assert.assertEquals(buf[0][0], 0.32455865, 0.0001); + } + } +} diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java index dcbef24748d..0c548d51aec 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java @@ -41,6 +41,12 @@ public void testTokenizer() throws IOException { }; try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) { + Assert.assertEquals(tokenizer.getTruncation(), "DO_NOT_TRUNCATE"); + Assert.assertEquals(tokenizer.getPadding(), "DO_NOT_PAD"); + Assert.assertEquals(tokenizer.getMaxLength(), -1); + Assert.assertEquals(tokenizer.getStride(), 0); + Assert.assertEquals(tokenizer.getPadToMultipleOf(), 0); + List ret = tokenizer.tokenize(input); Assert.assertEquals(ret.toArray(Utils.EMPTY_ARRAY), expected); Encoding encoding = tokenizer.encode(input); @@ -115,6 +121,12 @@ public void testTokenizer() throws IOException { Assert.assertEquals(encodings.length, 2); Assert.assertEquals(encodings[0].getIds(), ids); } + + Assert.assertThrows( + () -> { + Path file = Paths.get("build/tokenizer/non-exists.json"); + HuggingFaceTokenizer.builder().optTokenizerPath(file).build(); + }); } @Test @@ -294,6 +306,7 @@ public void testTruncationStride() throws IOException { HuggingFaceTokenizer.builder() .optTokenizerName("bert-base-cased") .optAddSpecialTokens(false) + .optWithOverflowingTokens(true) .optTruncation(true) .optMaxLength(3) .optStride(1) @@ -316,13 +329,16 @@ public void testTruncationStride() throws IOException { HuggingFaceTokenizer.builder() .optTokenizerName("bert-base-cased") .optAddSpecialTokens(false) + .optWithOverflowingTokens(true) .optTruncation(true) .optMaxLength(8) .optStride(2) .build()) { String text = "Hello there my friend I am happy to see you"; String textPair = "How are you my friend"; - Encoding[] overflowing = tokenizer.encode(text, textPair).getOverflowing(); + Encoding encoding = tokenizer.encode(text, textPair); + Assert.assertTrue(encoding.exceedMaxLength()); + Encoding[] overflowing = encoding.getOverflowing(); int expectedNumberOfOverflowEncodings = 7; Assert.assertEquals(overflowing.length, expectedNumberOfOverflowEncodings); @@ -452,13 +468,13 @@ public void testBatchProcessing() throws IOException { Assert.assertEquals(outputs, outputsWithSpecialTokens); // encode with special tokens, decode with special tokens - encodings = tokenizer.batchEncode(inputs, true); + encodings = tokenizer.batchEncode(inputs, true, false); batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new); outputs = tokenizer.batchDecode(batchIds, false); Assert.assertEquals(outputs, outputsWithSpecialTokens); // encode without special tokens, decode without special tokens - encodings = tokenizer.batchEncode(inputs, false); + encodings = tokenizer.batchEncode(inputs, false, false); batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new); outputs = tokenizer.batchDecode(batchIds, true); Assert.assertEquals(outputs, outputsWithoutSpecialTokens); diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/zoo/ModelZooTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/zoo/ModelZooTest.java index e585a219f17..5c3ed6c3ed0 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/zoo/ModelZooTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/zoo/ModelZooTest.java @@ -103,21 +103,21 @@ public void testFutureVersion() throws IOException { @Test public void testOffLine() throws IOException { System.setProperty("DJL_CACHE_DIR", "build/cache"); - System.setProperty("offline", "true"); + System.setProperty("ai.djl.offline", "true"); try { Utils.deleteQuietly(Paths.get("build/cache")); // static variables cannot not be initialized properly if directly use new HfModelZoo() ModelZoo.getModelZoo("ai.djl.huggingface.pytorch"); ModelZoo zoo = new HfModelZoo(); - Assert.assertTrue(zoo.getModelLoaders().size() > 0); + Assert.assertFalse(zoo.getModelLoaders().isEmpty()); Set engines = zoo.getSupportedEngines(); Assert.assertEquals(engines.size(), 1); Assert.assertEquals(engines.iterator().next(), "PyTorch"); } finally { System.clearProperty("DJL_CACHE_DIR"); - System.clearProperty("offline"); + System.clearProperty("ai.djl.offline"); } } } diff --git a/gradle.properties b/gradle.properties index 23a6019761a..4ce8a7ecef3 100644 --- a/gradle.properties +++ b/gradle.properties @@ -11,31 +11,32 @@ systemProp.org.gradle.internal.http.connectionTimeout=60000 # FIXME: Workaround gradle publish issue: https://github.com/gradle/gradle/issues/11308 systemProp.org.gradle.internal.publish.checksums.insecure=true -djl_version=0.24.0 +djl_version=0.27.0 mxnet_version=1.9.1 -pytorch_version=2.0.1 +pytorch_version=2.1.1 tensorflow_version=2.10.1 tflite_version=2.6.2 trt_version=8.4.1 -onnxruntime_version=1.15.1 +onnxruntime_version=1.16.3 paddlepaddle_version=2.3.2 sentencepiece_version=0.1.97 -tokenizers_version=0.13.3 +tokenizers_version=0.15.0 +llamacpp_version=b1696 fasttext_version=0.9.2 -xgboost_version=1.7.5 +xgboost_version=2.0.3 lightgbm_version=3.2.110 rapis_version=22.12.0 -commons_cli_version=1.5.0 -commons_compress_version=1.23.0 +commons_cli_version=1.6.0 +commons_compress_version=1.25.0 commons_csv_version=1.10.0 commons_logging_version=1.2 gson_version=2.10.1 jna_version=5.13.0 slf4j_version=1.7.36 -log4j_slf4j_version=2.20.0 -awssdk_version=2.20.121 -hadoop_version=3.3.5 +log4j_slf4j_version=2.22.1 +awssdk_version=2.22.12 +hadoop_version=3.3.6 javacpp_version=1.5.9 javacv_version=1.5.9 ffmpeg_version=6.0-1.5.9 @@ -45,6 +46,6 @@ spark_version=3.3.2 openpnp_opencv_version=4.7.0-0 antlr_version=4.11.1 -testng_version=7.8.0 +testng_version=7.9.0 junit_version=4.13.2 mockito_version=5.3.1 diff --git a/index1.0.html b/index1.0.html index 1a7d9841065..98b4a3f2911 100644 --- a/index1.0.html +++ b/index1.0.html @@ -59,7 +59,7 @@
  • JavaDoc
  • Demos
  • Blogs
  • -
  • Tutorial
  • +
  • Tutorial
  • Examples
  • Slack @@ -73,7 +73,7 @@
  • JavaDoc
  • Demos
  • Blogs
  • -
  • Tutorial
  • +
  • Tutorial
  • Examples
  • Slack diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java index b5907925ee4..008d652dc82 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java @@ -31,6 +31,7 @@ import ai.djl.nn.LambdaBlock; import ai.djl.nn.SequentialBlock; import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; @@ -123,10 +124,8 @@ private TrainingConfig setupTrainingConfig() { } private ZooModel getModel() throws IOException, ModelException { - // SSD-pikachu model only available in MXNet - // TODO: Add PyTorch model to model zoo - TestUtils.requiresEngine("MXNet"); - + TestUtils.requiresEngine( + ModelZoo.getModelZoo("ai.djl.zoo").getSupportedEngines().toArray(String[]::new)); Criteria criteria = Criteria.builder() .optApplication(Application.CV.OBJECT_DETECTION) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java index 410b4009a6d..04779187267 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java @@ -22,6 +22,7 @@ import org.testng.annotations.Test; import java.util.stream.DoubleStream; +import java.util.stream.IntStream; public class NDArrayNumericOpTest { @@ -499,6 +500,42 @@ public void testAtan() { } } + @Test + public void testAtan2() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + double[] x1 = {1.0, -1.0, -1.0, 0.0, 0.0, 0.0}; + NDArray array = manager.create(x1); + double[] y1 = {1.0, 0.0, -1.0, 1.0, -1.0, 0.0}; + NDArray other = manager.create(y1); + double[] output = + IntStream.range(0, x1.length) + .mapToDouble(i -> Math.atan2(x1[i], y1[i])) + .toArray(); + NDArray expected = manager.create(output); + Assertions.assertAlmostEquals(array.atan2(other), expected); + // test multi-dim + double[] x2 = {-1.0, -0.5, 0, 0.5, 1.0}; + array = manager.create(x2, new Shape(5, 1)); + double[] y2 = {-2.0, 3.0, 6.0, 0.0, -0.3}; + other = manager.create(y2, new Shape(5, 1)); + output = + IntStream.range(0, x2.length) + .mapToDouble(i -> Math.atan2(x2[i], y2[i])) + .toArray(); + expected = manager.create(output, new Shape(5, 1)); + Assertions.assertAlmostEquals(array.atan2(other), expected); + // test scalar + array = manager.create(0f); + other = manager.create(0f); + expected = manager.create(0f); + Assertions.assertAlmostEquals(array.atan2(other), expected); + // test zero-dim + array = manager.create(new Shape(1, 0)); + other = manager.create(new Shape(1, 0)); + Assert.assertEquals(array.atan2(other), array); + } + } + @Test public void testToDegrees() { try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java index 6788a405f22..66bb136ab37 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java @@ -875,6 +875,40 @@ public void testErfinv() { } } + @Test + public void testErf() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + // test 1-D + NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY}); + NDArray expected = manager.create(new float[] {0f, 0.5f, -1f}); + Assertions.assertAlmostEquals(NDArrays.erf(array), expected); + // test 3-D + array = + manager.create( + new float[] { + Float.NEGATIVE_INFINITY, + -0.8134f, + -0.4769f, + -0.2253f, + 0f, + 0.2253f, + 0.4769f, + 0.8134f, + Float.POSITIVE_INFINITY + }) + .reshape(3, 1, 3); + expected = manager.linspace(-1.0f, 1.0f, 9).reshape(3, 1, 3); + Assertions.assertAlmostEquals(array.erf(), expected); + // test scalar + array = manager.create(Float.POSITIVE_INFINITY); + expected = manager.create(1f); + Assertions.assertAlmostEquals(array.erf(), expected); + // test zero-dim + array = manager.create(new Shape(2, 0)); + Assertions.assertAlmostEquals(array.erf(), array); + } + } + @Test public void testInverse() { try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { @@ -1053,4 +1087,58 @@ public void testStft() { Assertions.assertAlmostEquals(result.real().flatten(), expected); } } + + @Test + public void testFft2() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + NDArray array = + manager.create( + new float[][] { + {1f, 6.6f, 4.315f, 2.0f}, + {16.9f, 6.697f, 2.399f, 67.9f}, + {0f, 5f, 67.09f, 9.87f} + }); + NDArray result = array.fft2(new long[] {3, 4}, new long[] {0, 1}); + result = result.real().flatten(1, 2); // flatten complex numbers + NDArray expected = + manager.create( + new float[][] { + {189.771f, 0f, -55.904f, 61.473f, -6.363f, 0f, -55.904f, -61.473f}, + { + -74.013f, + -10.3369f, + 71.7653f, + -108.2964f, + -1.746f, + 93.1133f, + -25.8063f, + -33.0234f + }, + { + -74.013f, 10.3369f, -25.8063f, 33.0234f, -1.746f, -93.1133f, + 71.7653f, 108.2964f + } + }); + Assertions.assertAlmostEquals(result, expected); + } + } + + @Test + public void testIfft2() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + NDArray array = + manager.create( + new float[][] { + {1f, 6.6f, 4.315f, 2.0f}, + {16.9f, 6.697f, 2.399f, 67.9f}, + {0f, 5f, 67.09f, 9.87f} + }); + long[] sizes = {3, 4}; + long[] axes = {0, 1}; + NDArray fft2 = array.fft2(sizes, axes); + NDArray actual = fft2.ifft2(sizes, axes).real(); + NDArray expected = array.toType(DataType.COMPLEX64, true).real(); + Assertions.assertAlmostEquals(expected, actual); + } + } } diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java new file mode 100644 index 00000000000..9aee2661411 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java @@ -0,0 +1,190 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.integration.tests.training.listener; + +import ai.djl.Model; +import ai.djl.basicdataset.cv.classification.Mnist; +import ai.djl.basicmodelzoo.basic.Mlp; +import ai.djl.integration.util.TestUtils; +import ai.djl.metric.Metrics; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Activation; +import ai.djl.training.DefaultTrainingConfig; +import ai.djl.training.EasyTrain; +import ai.djl.training.Trainer; +import ai.djl.training.TrainingResult; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.listener.EarlyStoppingListener; +import ai.djl.training.listener.TrainingListener; +import ai.djl.training.loss.Loss; +import ai.djl.training.optimizer.Optimizer; +import ai.djl.training.tracker.Tracker; +import ai.djl.translate.TranslateException; + +import org.testng.Assert; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.time.Duration; + +public class EarlyStoppingListenerTest { + + private final Optimizer sgd = + Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build(); + + private NDManager manager; + private Mnist testMnistDataset; + private Mnist trainMnistDataset; + + @BeforeTest + public void setUp() throws IOException, TranslateException { + manager = NDManager.newBaseManager(TestUtils.getEngine()); + testMnistDataset = + Mnist.builder() + .optUsage(Dataset.Usage.TEST) + .optManager(manager) + .optLimit(8) + .setSampling(8, false) + .build(); + testMnistDataset.prepare(); + + trainMnistDataset = + Mnist.builder() + .optUsage(Dataset.Usage.TRAIN) + .optManager(manager) + .optLimit(16) + .setSampling(8, false) + .build(); + trainMnistDataset.prepare(); + } + + @AfterTest + public void closeResources() { + manager.close(); + } + + @Test + public void testEarlyStoppingStopsOnEpoch2() throws Exception { + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); + + try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { + model.setBlock(mlpModel); + + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder() + .optEpochPatience(1) + .optEarlyStopPctImprovement(99) + .optMaxDuration(Duration.ofMinutes(1)) + .optMinEpochs(1) + .build()); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(new Shape(1, 784)); + Metrics metrics = new Metrics(); + trainer.setMetrics(metrics); + + try { + // Set epoch to 5 as we expect the early stopping to stop after the second epoch + EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset); + } catch (EarlyStoppingListener.EarlyStoppedException e) { + Assert.assertEquals( + e.getMessage(), "failed to achieve 99.0% improvement 1 times in a row"); + Assert.assertEquals(e.getStopEpoch(), 2); + } + + TrainingResult trainingResult = trainer.getTrainingResult(); + Assert.assertEquals(trainingResult.getEpoch(), 2); + } + } + } + + @Test + public void testEarlyStoppingStopsOnEpoch3AsMinEpochsIs3() throws Exception { + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); + + try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { + model.setBlock(mlpModel); + + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder() + .optEpochPatience(1) + .optEarlyStopPctImprovement(50) + .optMaxMillis(60_000) + .optMinEpochs(3) + .build()); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(new Shape(1, 784)); + Metrics metrics = new Metrics(); + trainer.setMetrics(metrics); + + try { + // Set epoch to 5 as we expect the early stopping to stop after the second epoch + EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset); + } catch (EarlyStoppingListener.EarlyStoppedException e) { + Assert.assertEquals( + e.getMessage(), "failed to achieve 50.0% improvement 1 times in a row"); + Assert.assertEquals(e.getStopEpoch(), 3); + } + + TrainingResult trainingResult = trainer.getTrainingResult(); + Assert.assertEquals(trainingResult.getEpoch(), 3); + } + } + } + + @Test + public void testEarlyStoppingStopsOnEpoch1AsMaxDurationIs1ms() throws Exception { + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); + + try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { + model.setBlock(mlpModel); + + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder().optMaxMillis(1).build()); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(new Shape(1, 784)); + Metrics metrics = new Metrics(); + trainer.setMetrics(metrics); + + try { + // Set epoch to 10 as we expect the early stopping to stop after the second + // epoch + EasyTrain.fit(trainer, 10, trainMnistDataset, testMnistDataset); + } catch (EarlyStoppingListener.EarlyStoppedException e) { + Assert.assertTrue(e.getMessage().contains("ms elapsed >=")); + Assert.assertTrue(e.getStopEpoch() < 10); // Stop epoch is before 10 + } + + TrainingResult trainingResult = trainer.getTrainingResult(); + Assert.assertTrue(trainingResult.getEpoch() < 10); // Stop epoch is before 10 + } + } + } +} diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java new file mode 100644 index 00000000000..88680e5fe89 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests using the listeners {@link ai.djl.training}. */ +package ai.djl.integration.tests.training.listener; diff --git a/jacoco/build.gradle b/jacoco/build.gradle index d9196393283..fa570c50a3b 100644 --- a/jacoco/build.gradle +++ b/jacoco/build.gradle @@ -10,6 +10,7 @@ repositories { dependencies { jacocoAggregation project(":api") jacocoAggregation project(":basicdataset") + jacocoAggregation project(":engines:llama") jacocoAggregation project(":engines:ml:xgboost") jacocoAggregation project(":engines:ml:lightgbm") jacocoAggregation project(":engines:mxnet:mxnet-engine") @@ -39,7 +40,9 @@ dependencies { jacocoAggregation project(":extensions:tokenizers") jacocoAggregation project(":extensions:tablesaw") jacocoAggregation project(":extensions:timeseries") - jacocoAggregation project(":extensions:spark") + if (JavaVersion.current() < JavaVersion.VERSION_19) { + jacocoAggregation project(":extensions:spark") + } jacocoAggregation project(":integration") jacocoAggregation project(":model-zoo") } diff --git a/jupyter/BERTQA.ipynb b/jupyter/BERTQA.ipynb deleted file mode 100644 index 4ec97cbd838..00000000000 --- a/jupyter/BERTQA.ipynb +++ /dev/null @@ -1,214 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# DJL BERT Inference Demo\n", - "\n", - "## Introduction\n", - "\n", - "In this tutorial, you walk through running inference using DJL on a [BERT](https://towardsdatascience.com/bert-explained-state-of-the-art-language-model-for-nlp-f8b21a9b6270) QA model trained with MXNet and PyTorch. \n", - "You can provide a question and a paragraph containing the answer to the model. The model is then able to find the best answer from the answer paragraph.\n", - "\n", - "Example:\n", - "```text\n", - "Q: When did BBC Japan start broadcasting?\n", - "```\n", - "\n", - "Answer paragraph:\n", - "```text\n", - "BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006.\n", - "It ceased operations after its Japanese distributor folded.\n", - "```\n", - "And it picked the right answer:\n", - "```text\n", - "A: December 2004\n", - "```\n", - "\n", - "One of the most powerful features of DJL is that it's engine agnostic. Because of this, you can run different backend engines seamlessly. We showcase BERT QA first with an MXNet pre-trained model, then with a PyTorch model." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0\n", - "%maven ai.djl.pytorch:pytorch-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import java packages by running the following:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.engine.*;\n", - "import ai.djl.modality.nlp.qa.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.repository.zoo.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now that all of the prerequisites are complete, start writing code to run inference with this example.\n", - "\n", - "\n", - "## Load the model and input\n", - "\n", - "**First, load the input**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var question = \"When did BBC Japan start broadcasting?\";\n", - "var resourceDocument = \"BBC Japan was a general entertainment Channel.\\n\" +\n", - " \"Which operated between December 2004 and April 2006.\\n\" +\n", - " \"It ceased operations after its Japanese distributor folded.\";\n", - "\n", - "QAInput input = new QAInput(question, resourceDocument);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then load the model and vocabulary. Create a variable `model` by using the `ModelZoo` as shown in the following code." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria = Criteria.builder()\n", - " .optApplication(Application.NLP.QUESTION_ANSWER)\n", - " .setTypes(QAInput.class, String.class)\n", - " .optEngine(\"MXNet\") // For DJL to use MXNet engine\n", - " .optProgress(new ProgressBar()).build();\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Run inference\n", - "Once the model is loaded, you can call `Predictor` and run inference as follows" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor();\n", - "String answer = predictor.predict(input);\n", - "answer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Running inference on DJL is that easy. Now, let's try the PyTorch engine by specifying PyTorch engine in Criteria.optEngine(\"PyTorch\"). Let's rerun the inference code." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var question = \"When did BBC Japan start broadcasting?\";\n", - "var resourceDocument = \"BBC Japan was a general entertainment Channel.\\n\" +\n", - " \"Which operated between December 2004 and April 2006.\\n\" +\n", - " \"It ceased operations after its Japanese distributor folded.\";\n", - "\n", - "QAInput input = new QAInput(question, resourceDocument);\n", - "\n", - "Criteria criteria = Criteria.builder()\n", - " .optApplication(Application.NLP.QUESTION_ANSWER)\n", - " .setTypes(QAInput.class, String.class)\n", - " .optFilter(\"modelType\", \"distilbert\")\n", - " .optEngine(\"PyTorch\") // Use PyTorch engine\n", - " .optProgress(new ProgressBar()).build();\n", - "ZooModel model = criteria.loadModel();\n", - "Predictor predictor = model.newPredictor();\n", - "String answer = predictor.predict(input);\n", - "answer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "Suprisingly, there are no differences between the PyTorch code snippet and MXNet code snippet. \n", - "This is power of DJL. We define a unified API where you can switch to different backend engines on the fly.\n", - "Next chapter: Inference with your own BERT: [MXNet](mxnet/load_your_own_mxnet_bert.ipynb) [PyTorch](pytorch/load_your_own_pytorch_bert.ipynb)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/Dockerfile b/jupyter/Dockerfile deleted file mode 100644 index 9c79ec3e54a..00000000000 --- a/jupyter/Dockerfile +++ /dev/null @@ -1,24 +0,0 @@ -FROM ubuntu:18.04 - -RUN apt-get update || true -RUN apt-get install -y openjdk-11-jdk-headless -RUN apt-get install -y python3-pip git -RUN pip3 install jupyter -RUN apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y locales \ - && sed -i -e 's/# en_US.UTF-8 UTF-8/en_US.UTF-8 UTF-8/' /etc/locale.gen \ - && dpkg-reconfigure --frontend=noninteractive locales \ - && update-locale LANG=en_US.UTF-8 -RUN apt-get install -y curl - -RUN git clone https://github.com/frankfliu/IJava.git -RUN cd IJava/ && ./gradlew installKernel && cd .. && rm -rf IJava/ -RUN rm -rf ~/.gradle - -WORKDIR /home/jupyter - -ENV LANG en_US.UTF-8 -ENV LC_ALL en_US.UTF-8 - -EXPOSE 8888 -ENTRYPOINT ["jupyter", "notebook", "--ip=0.0.0.0", "--no-browser", "--allow-root", "--NotebookApp.token=''", "--NotebookApp.password=''"] diff --git a/jupyter/README.md b/jupyter/README.md index 17b0a9c9405..1b9a2584238 100644 --- a/jupyter/README.md +++ b/jupyter/README.md @@ -1,83 +1,3 @@ # DJL - Jupyter notebooks -## Overview - -This folder contains tutorials that illustrate how to accomplish basic AI tasks with Deep Java Library (DJL). - -## [Beginner Tutorial](tutorial/README.md) - -## More Tutorial Notebooks - -- [Run object detection with model zoo](object_detection_with_model_zoo.ipynb) -- [Load pre-trained PyTorch model](load_pytorch_model.ipynb) -- [Load pre-trained Apache MXNet model](load_mxnet_model.ipynb) -- [Transfer learning example](transfer_learning_on_cifar10.ipynb) -- [Question answering example](BERTQA.ipynb) - -You can run our notebook online: [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/deepjavalibrary/djl/master?filepath=jupyter) - -## Setup - -### JDK 11 (not jre) - -JDK 11 (or above are required) to run the examples provided in this folder. - -to confirm the java path is configured properly: - -```bash -java --list-modules | grep "jdk.jshell" - -> jdk.jshell@12.0.1 -``` - -### Install jupyter notebook on python3 - -```bash -pip3 install jupyter -``` - -### Install IJava kernel for jupyter - -```bash -git clone https://github.com/frankfliu/IJava.git -cd IJava/ -./gradlew installKernel -``` - -## Start jupyter notebook - -```bash -jupyter notebook -``` - -## Docker setup - -You may want to use docker for simple installation or you are using Windows. - -### Run docker image - -```sh -cd jupyter -docker run -itd -p 127.0.0.1:8888:8888 -v $PWD:/home/jupyter deepjavalibrary/jupyter -``` - -You can open the `http://localhost:8888` to see the hosted instance on docker. - -### Build docker image by yourself - -You can read [Dockerfile](https://github.com/deepjavalibrary/djl/blob/master/jupyter/Dockerfile) for detail. To build docker image: - -```sh -cd jupyter -docker build -t deepjavalibrary/jupyter . -``` - -### Run docker compose - -```sh -cd jupyter -docker-compose build -docker-compose up -d -``` - -You can open the `http://localhost:8888` to see the hosted instance on docker compose. +The jupyter notebook documentation and examples have been moved to the [DJL Demos repo](http://docs.djl.ai/docs/demos/jupyter/index.html). \ No newline at end of file diff --git a/jupyter/docker-compose.yml b/jupyter/docker-compose.yml deleted file mode 100644 index e8e4d2f83b8..00000000000 --- a/jupyter/docker-compose.yml +++ /dev/null @@ -1,12 +0,0 @@ -version: "2.4" -services: - deepjavalibrary_container: - build: - context: . - dockerfile: Dockerfile - ports: - - 8888:8888 - volumes: - - ./:/home/jupyter - restart: always - diff --git a/jupyter/load_mxnet_model.ipynb b/jupyter/load_mxnet_model.ipynb deleted file mode 100644 index f90091d1ef4..00000000000 --- a/jupyter/load_mxnet_model.ipynb +++ /dev/null @@ -1,190 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load MXNet model\n", - "\n", - "In this tutorial, you learn how to load an existing MXNet model and use it to run a prediction task.\n", - "\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:model-zoo:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.awt.image.*;\n", - "import java.nio.file.*;\n", - "import ai.djl.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.modality.cv.translator.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1: Prepare your MXNet model\n", - "\n", - "This tutorial assumes that you have a MXNet model trained using Python. A MXNet symbolic model usually contains the following files:\n", - "* Symbol file: {MODEL_NAME}-symbol.json - a json file that contains network information about the model\n", - "* Parameters file: {MODEL_NAME}-{EPOCH}.params - a binary file that stores the parameter weight and bias\n", - "* Synset file: synset.txt - an optional text file that stores classification classes labels\n", - "\n", - "This tutorial uses a pre-trained MXNet `resnet18_v1` model." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We use `DownloadUtils` for downloading files from internet." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-symbol.json\", \"build/resnet/resnet18_v1-symbol.json\", new ProgressBar());\n", - "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-0000.params.gz\", \"build/resnet/resnet18_v1-0000.params\", new ProgressBar());\n", - "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset.txt\", \"build/resnet/synset.txt\", new ProgressBar());\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2: Load your model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Path modelDir = Paths.get(\"build/resnet\");\n", - "Model model = Model.newInstance(\"resnet\");\n", - "model.load(modelDir, \"resnet18_v1\");" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3: Create a `Translator`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Pipeline pipeline = new Pipeline();\n", - "pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());\n", - "Translator translator = ImageClassificationTranslator.builder()\n", - " .setPipeline(pipeline)\n", - " .optSynsetArtifactName(\"synset.txt\")\n", - " .optApplySoftmax(true)\n", - " .build();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 4: Load image for classification" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/kitten.jpg\");\n", - "img.getWrappedImage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 5: Run inference" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor(translator);\n", - "Classifications classifications = predictor.predict(img);\n", - "\n", - "classifications" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Now, you can load any MXNet symbolic model and run inference.\n", - "\n", - "You might also want to check out [load_pytorch_model.ipynb](https://github.com/deepjavalibrary/djl/blob/master/jupyter/load_pytorch_model.ipynb) which demonstrates loading a local model using the ModelZoo API." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/load_pytorch_model.ipynb b/jupyter/load_pytorch_model.ipynb deleted file mode 100644 index bf4e3db3e3f..00000000000 --- a/jupyter/load_pytorch_model.ipynb +++ /dev/null @@ -1,232 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "# Load PyTorch model\n", - "\n", - "In this tutorial, you learn how to load an existing PyTorch model and use it to run a prediction task.\n", - "\n", - "We will run the inference in DJL way with [example](https://pytorch.org/hub/pytorch_vision_resnet/) on the pytorch official website.\n", - "\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.nio.file.*;\n", - "import java.awt.image.*;\n", - "import ai.djl.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.modality.cv.translator.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1: Prepare your model\n", - "\n", - "This tutorial assumes that you have a TorchScript model.\n", - "DJL only supports the TorchScript format for loading models from PyTorch, so other models will need to be [converted](https://github.com/deepjavalibrary/djl/blob/master/docs/pytorch/how_to_convert_your_model_to_torchscript.md).\n", - "A TorchScript model includes the model structure and all of the parameters.\n", - "\n", - "We will be using a pre-trained `resnet18` model. First, use the `DownloadUtils` to download the model files and save them in the `build/pytorch_models` folder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz\", \"build/pytorch_models/resnet18/resnet18.pt\", new ProgressBar());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In order to do image classification, you will also need the synset.txt which stores the classification class labels. We will need the synset containing the Imagenet labels with which resnet18 was originally trained." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt\", \"build/pytorch_models/resnet18/synset.txt\", new ProgressBar());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2: Create a Translator\n", - "\n", - "We will create a transformation pipeline which maps the transforms shown in the [PyTorch example](https://pytorch.org/hub/pytorch_vision_resnet/).\n", - "```python\n", - "...\n", - "preprocess = transforms.Compose([\n", - " transforms.Resize(256),\n", - " transforms.CenterCrop(224),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", - "])\n", - "...\n", - "```\n", - "\n", - "Then, we will use this pipeline to create the [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Translator translator = ImageClassificationTranslator.builder()\n", - " .addTransform(new Resize(256))\n", - " .addTransform(new CenterCrop(224, 224))\n", - " .addTransform(new ToTensor())\n", - " .addTransform(new Normalize(\n", - " new float[] {0.485f, 0.456f, 0.406f},\n", - " new float[] {0.229f, 0.224f, 0.225f}))\n", - " .optApplySoftmax(true)\n", - " .build();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3: Load your model\n", - "\n", - "Next, we add some search criteria to find the resnet18 model and load it. In this case, we need to tell `Criteria` where to locate the model by calling `.optModelPath()` API." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria = Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelPath(Paths.get(\"build/pytorch_models/resnet18\"))\n", - " .optOption(\"mapLocation\", \"true\") // this model requires mapLocation for GPU\n", - " .optTranslator(translator)\n", - " .optProgress(new ProgressBar()).build();\n", - "\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 4: Load image for classification\n", - "\n", - "We will use a sample dog image to run our prediction on." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var img = ImageFactory.getInstance().fromUrl(\"https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg\");\n", - "img.getWrappedImage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 5: Run inference\n", - "\n", - "Lastly, we will need to create a predictor using our model and translator. Once we have a predictor, we simply need to call the predict method on our test image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor();\n", - "Classifications classifications = predictor.predict(img);\n", - "\n", - "classifications" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Now, you can load any TorchScript model and run inference using it.\n", - "\n", - "You might also want to check out [load_mxnet_model.ipynb](https://github.com/deepjavalibrary/djl/blob/master/jupyter/load_mxnet_model.ipynb) which demonstrates loading a local model directly instead of through the Model Zoo API." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/mxnet/load_your_own_mxnet_bert.ipynb b/jupyter/mxnet/load_your_own_mxnet_bert.ipynb deleted file mode 100644 index 9691a4d683a..00000000000 --- a/jupyter/mxnet/load_your_own_mxnet_bert.ipynb +++ /dev/null @@ -1,485 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load your own MXNet BERT model\n", - "\n", - "In the previous [example](../BERTQA.ipynb), you run BERT inference with the model from Model Zoo. You can also load the model on your own pre-trained BERT and use custom classes as the input and output.\n", - "\n", - "In general, the MXNet BERT model requires these three inputs:\n", - "\n", - "- word indices: The index of each word in a sentence\n", - "- word types: The type index of the word.\n", - "- valid length: The actual length of the question and resource document tokens\n", - "\n", - "We will dive deep into these details later." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "There are dependencies we will use." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import java packages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;\n", - "import java.util.stream.*;\n", - "\n", - "import ai.djl.*;\n", - "import ai.djl.util.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.modality.nlp.*;\n", - "import ai.djl.modality.nlp.qa.*;\n", - "import ai.djl.mxnet.zoo.nlp.qa.*;\n", - "import ai.djl.modality.nlp.bert.*;\n", - "\n", - "import com.google.gson.annotations.SerializedName;\n", - "import java.nio.charset.StandardCharsets;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Reuse the previous input**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var question = \"When did BBC Japan start broadcasting?\";\n", - "var resourceDocument = \"BBC Japan was a general entertainment Channel.\\n\" +\n", - " \"Which operated between December 2004 and April 2006.\\n\" +\n", - " \"It ceased operations after its Japanese distributor folded.\";\n", - "\n", - "QAInput input = new QAInput(question, resourceDocument);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dive deep into Translator\n", - "\n", - "Inference in deep learning is the process of predicting the output for a given input based on a pre-defined model.\n", - "DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide\n", - "output. DJL also allows you to provide user-defined inputs. The workflow looks like the following:\n", - "\n", - "![https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n", - "\n", - "The red block (\"Images\") in the workflow is the input that DJL expects from you. The green block (\"Images\n", - "bounding box\") is the output that you expect. Because DJL does not know which input to expect and which output format that you prefer, DJL provides the [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) interface so you can define your own\n", - "input and output.\n", - "\n", - "The `Translator` interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing\n", - "component converts the user-defined input objects into an NDList, so that the [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html) in DJL can understand the\n", - "input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the\n", - "`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output\n", - "format." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Pre-processing\n", - "\n", - "Now, you need to convert the sentences into tokens. We provide a powerful tool [`BertTokenizer`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/bert/BertTokenizer.html) that you can use to convert questions and answers into tokens, and batchify your sequence together. Once you have properly formatted tokens, you can use [`Vocabulary`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/Vocabulary.html) to map your token to BERT index.\n", - "\n", - "The following code block demonstrates tokenizing the question and answer defined earlier into BERT-formatted tokens." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var tokenizer = new BertTokenizer();\n", - "List tokenQ = tokenizer.tokenize(question.toLowerCase());\n", - "List tokenA = tokenizer.tokenize(resourceDocument.toLowerCase());\n", - "\n", - "System.out.println(\"Question Token: \" + tokenQ);\n", - "System.out.println(\"Answer Token: \" + tokenA);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`BertTokenizer` can also help you batchify questions and resource documents together by calling `encode()`.\n", - "The output contains information that BERT ingests.\n", - "\n", - "- getTokens: It returns a list of strings, including the question, resource document and special word to let the model tell which part is the question and which part is the resource document. Because MXNet BERT was trained with a fixed sequence length, you see the `[PAD]` in the tokens as well.\n", - "- getTokenTypes: It returns a list of type indices of the word to indicate the location of the resource document. All Questions will be labelled with 0 and all resource documents will be labelled with 1.\n", - "\n", - " [Question tokens...DocResourceTokens...padding tokens] => [000000...11111....0000]\n", - " \n", - "\n", - "- getValidLength: It returns the actual length of the question and tokens, which are required by MXNet BERT.\n", - "- getAttentionMask: It returns the mask for the model to indicate which part should be paid attention to and which part is the padding. It is required by PyTorch BERT.\n", - "\n", - " [Question tokens...DocResourceTokens...padding tokens] => [111111...11111....0000]\n", - " \n", - "MXNet BERT was trained with fixed sequence length 384, so we need to pass that in when we encode the question and resource doc. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BertToken token = tokenizer.encode(question.toLowerCase(), resourceDocument.toLowerCase(), 384);\n", - "System.out.println(\"Encoded tokens: \" + token.getTokens());\n", - "System.out.println(\"Encoded token type: \" + token.getTokenTypes());\n", - "System.out.println(\"Valid length: \" + token.getValidLength());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Normally, words and sentences are represented as indices instead of tokens for training. \n", - "They typically work like a vector in a n-dimensional space. In this case, you need to map them into indices.\n", - "DJL provides `Vocabulary` to take care of you vocabulary mapping.\n", - "\n", - "Assume your vocab.json is of the following format\n", - "```\n", - "{'token_to_idx':{'\"slots\": 19832,...}, 'idx_to_token':[\"[UNK]\", \"[PAD]\", ...]}\n", - "```\n", - "We provide the `vocab.json` from our pre-trained BERT for demonstration." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/mxnet/bertqa/vocab.json\", \"build/mxnet/bertqa/vocab.json\", new ProgressBar());" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class VocabParser {\n", - " @SerializedName(\"idx_to_token\")\n", - " List idx2token;\n", - "\n", - " public static List parseToken(URL file) {\n", - " try (InputStream is = file.openStream();\n", - " Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) {\n", - " return JsonUtils.GSON.fromJson(reader, VocabParser.class).idx2token;\n", - " } catch (IOException e) {\n", - " throw new IllegalArgumentException(\"Invalid url: \" + file, e);\n", - " }\n", - " }\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "URL url = Paths.get(\"build/mxnet/bertqa/vocab.json\").toUri().toURL();\n", - "var vocabulary = DefaultVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromCustomizedFile(url, VocabParser::parseToken)\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can easily convert the token to the index using `vocabulary.getIndex(token)` and the other way around using `vocabulary.getToken(index)`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "long index = vocabulary.getIndex(\"car\");\n", - "String token = vocabulary.getToken(2482);\n", - "System.out.println(\"The index of the car is \" + index);\n", - "System.out.println(\"The token of the index 2482 is \" + token);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To properly convert them into `float[]` for `NDArray` creation, use the following helper function:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "/**\n", - " * Convert a List of Number to float array.\n", - " *\n", - " * @param list the list to be converted\n", - " * @return float array\n", - " */\n", - "public static float[] toFloatArray(List list) {\n", - " float[] ret = new float[list.size()];\n", - " int idx = 0;\n", - " for (Number n : list) {\n", - " ret[idx++] = n.floatValue();\n", - " }\n", - " return ret;\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now that you have everything you need, you can create an NDList and populate all of the inputs you formatted earlier. You're done with pre-processing! \n", - "\n", - "#### Construct `Translator`\n", - "\n", - "You need to do this processing within an implementation of the `Translator` interface. `Translator` is designed to do pre-processing and post-processing. You must define the input and output objects. It contains the following two override classes:\n", - "- `public NDList processInput(TranslatorContext ctx, I)`\n", - "- `public String processOutput(TranslatorContext ctx, O)`\n", - "\n", - "Every translator takes in input and returns output in the form of generic objects. In this case, the translator takes input in the form of `QAInput` (I) and returns output as a `String` (O). `QAInput` is just an object that holds questions and answer; We have prepared the Input class for you." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Armed with the needed knowledge, you can write an implementation of the `Translator` interface. `BertTranslator` uses the code snippets explained previously to implement the `processInput`method. For more information, see [`NDManager`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/ndarray/NDManager.html).\n", - "\n", - "```\n", - "manager.create(Number[] data, Shape)\n", - "manager.create(Number[] data)\n", - "```\n", - "\n", - "The `Shape` for `data0` and `data1` is sequence_length. For `data2` the `Shape` is just 1." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "public class BertTranslator implements NoBatchifyTranslator {\n", - " private List tokens;\n", - " private Vocabulary vocabulary;\n", - " private BertTokenizer tokenizer;\n", - " \n", - " @Override\n", - " public void prepare(TranslatorContext ctx) throws IOException {\n", - " URL path = Paths.get(\"build/mxnet/bertqa/vocab.json\").toUri().toURL();\n", - " vocabulary =\n", - " DefaultVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromCustomizedFile(path, VocabParser::parseToken)\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();\n", - " tokenizer = new BertTokenizer();\n", - " }\n", - " \n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, QAInput input) {\n", - " BertToken token =\n", - " tokenizer.encode(\n", - " input.getQuestion().toLowerCase(),\n", - " input.getParagraph().toLowerCase(),\n", - " 384);\n", - " // get the encoded tokens that would be used in precessOutput\n", - " tokens = token.getTokens();\n", - " // map the tokens(String) to indices(long)\n", - " List indices =\n", - " token.getTokens().stream().map(vocabulary::getIndex).collect(Collectors.toList());\n", - " float[] indexesFloat = toFloatArray(indices);\n", - " float[] types = toFloatArray(token.getTokenTypes());\n", - " int validLength = token.getValidLength();\n", - "\n", - " NDManager manager = ctx.getNDManager();\n", - " NDArray data0 = manager.create(indexesFloat);\n", - " data0.setName(\"data0\");\n", - " NDArray data1 = manager.create(types);\n", - " data1.setName(\"data1\");\n", - " NDArray data2 = manager.create(new float[] {validLength});\n", - " data2.setName(\"data2\");\n", - " return new NDList(data0, data1, data2);\n", - " }\n", - "\n", - " @Override\n", - " public String processOutput(TranslatorContext ctx, NDList list) {\n", - " NDArray array = list.singletonOrThrow();\n", - " NDList output = array.split(2, 2);\n", - " // Get the formatted logits result\n", - " NDArray startLogits = output.get(0).reshape(new Shape(1, -1));\n", - " NDArray endLogits = output.get(1).reshape(new Shape(1, -1));\n", - " int startIdx = (int) startLogits.argMax(1).getLong();\n", - " int endIdx = (int) endLogits.argMax(1).getLong();\n", - " return tokens.subList(startIdx, endIdx + 1).toString();\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Congrats! You have created your first Translator! We have pre-filled the `processOutput()` function to process the `NDList` and return it in a desired format. `processInput()` and `processOutput()` offer the flexibility to get the predictions from the model in any format you desire. \n", - "\n", - "With the Translator implemented, you need to bring up the predictor that uses your `Translator` to start making predictions. You can find the usage for `Predictor` in the [Predictor Javadoc](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html). Create a translator and use the `question` and `resourceDocument` provided previously." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/mxnet/bertqa/0.0.1/static_bert_qa-symbol.json\", \"build/mxnet/bertqa/bertqa-symbol.json\", new ProgressBar());\n", - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/mxnet/bertqa/0.0.1/static_bert_qa-0002.params.gz\", \"build/mxnet/bertqa/bertqa-0000.params\", new ProgressBar());" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BertTranslator translator = new BertTranslator();\n", - "Criteria criteria = Criteria.builder()\n", - " .setTypes(QAInput.class, String.class)\n", - " .optModelPath(Paths.get(\"build/mxnet/bertqa/\")) // Search for models in the build/mxnet/bert folder\n", - " .optTranslator(translator)\n", - " .optProgress(new ProgressBar()).build();\n", - "\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String predictResult = null;\n", - "QAInput input = new QAInput(question, resourceDocument);\n", - "\n", - "// Create a Predictor and use it to predict the output\n", - "try (Predictor predictor = model.newPredictor(translator)) {\n", - " predictResult = predictor.predict(input);\n", - "}\n", - "\n", - "System.out.println(question);\n", - "System.out.println(predictResult);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Based on the input, the following result will be shown:\n", - "```\n", - "[december, 2004]\n", - "```\n", - "That's it! \n", - "\n", - "You can try with more questions and answers. Here are the samples:\n", - "\n", - "**Answer Material**\n", - "\n", - "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.\n", - "\n", - "\n", - "**Question**\n", - "\n", - "Q: When were the Normans in Normandy?\n", - "A: 10th and 11th centuries\n", - "\n", - "Q: In what country is Normandy located?\n", - "A: france\n", - "\n", - "For the full source code,see the [DJL repo](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java) and translator implementation [MXNet](https://github.com/deepjavalibrary/djl/blob/master/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/qa/MxBertQATranslator.java) [PyTorch](https://github.com/deepjavalibrary/djl/blob/master/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/qa/PtBertQATranslator.java)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/object_detection_with_model_zoo.ipynb b/jupyter/object_detection_with_model_zoo.ipynb deleted file mode 100644 index 9435b9de7aa..00000000000 --- a/jupyter/object_detection_with_model_zoo.ipynb +++ /dev/null @@ -1,159 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Object detection with model zoo model\n", - "\n", - "In this tutorial, you learn how to use a built-in model zoo model (SSD) to achieve an [object detection](https://en.wikipedia.org/wiki/Object_detection) task.\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.output.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.mxnet.zoo.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.training.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1: Load image" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/dog_bike_car.jpg\");\n", - "img.getWrappedImage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2: Load model zoo model\n", - "\n", - "In this example, you load a SSD (Single Shot MultiBox Detector) model from the MXNet model zoo.\n", - "For more information about model zoo, see the [Model Zoo Documentation](https://github.com/deepjavalibrary/djl/blob/master/docs/model-zoo.md) " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria = Criteria.builder()\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optArtifactId(\"ssd\")\n", - " .optProgress(new ProgressBar())\n", - " .build();\n", - "var model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3: Create Predictor and detect an object in the image" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var detections = model.newPredictor().predict(img);\n", - "\n", - "detections" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Check detected result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "img.drawBoundingBoxes(detections);\n", - "img.getWrappedImage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Using the model zoo model provided, you can run inference with just the following lines of code:\n", - "\n", - "```\n", - "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/dog_bike_car.jpg\");\n", - "var criteria = Criteria.builder()\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optArtifactId(\"ssd\")\n", - " .build();\n", - "var model = criteria.loadModel();\n", - "var detections = model.newPredictor().predict(img);\n", - "```\n", - "\n", - "You can find full SsdExample source code [here](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/object_detection.md).\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb b/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb deleted file mode 100644 index d068a97e78b..00000000000 --- a/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb +++ /dev/null @@ -1,224 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Classification on Iris dataset with sklearn and DJL\n", - "\n", - "In this notebook, you will try to use a pre-trained sklearn model to run on DJL for a general classification task. The model was trained with [Iris flower dataset](https://en.wikipedia.org/wiki/Iris_flower_data_set).\n", - "\n", - "## Background \n", - "\n", - "### Iris Dataset\n", - "\n", - "The dataset contains a set of 150 records under five attributes - sepal length, sepal width, petal length, petal width and species.\n", - "\n", - "Iris setosa | Iris versicolor | Iris virginica\n", - ":-------------------------:|:-------------------------:|:-------------------------:\n", - "![](https://upload.wikimedia.org/wikipedia/commons/5/56/Kosaciec_szczecinkowaty_Iris_setosa.jpg) | ![](https://upload.wikimedia.org/wikipedia/commons/4/41/Iris_versicolor_3.jpg) | ![](https://upload.wikimedia.org/wikipedia/commons/9/9f/Iris_virginica.jpg) \n", - "\n", - "The chart above shows three different kinds of the Iris flowers. \n", - "\n", - "We will use sepal length, sepal width, petal length, petal width as the feature and species as the label to train the model.\n", - "\n", - "### Sklearn Model\n", - "\n", - "You can find more information [here](http://onnx.ai/sklearn-onnx/). You can use the sklearn built-in iris dataset to load the data. Then we defined a [RandomForestClassifer](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html) to train the model. After that, we convert the model to onnx format for DJL to run inference. The following code is a sample classification setup using sklearn:\n", - "\n", - "```python\n", - "# Train a model.\n", - "from sklearn.datasets import load_iris\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.ensemble import RandomForestClassifier\n", - "iris = load_iris()\n", - "X, y = iris.data, iris.target\n", - "X_train, X_test, y_train, y_test = train_test_split(X, y)\n", - "clr = RandomForestClassifier()\n", - "clr.fit(X_train, y_train)\n", - "```\n", - "\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md).\n", - "\n", - "These are dependencies we will use. To enhance the NDArray operation capability, we are importing ONNX Runtime and PyTorch Engine at the same time. Please find more information [here](https://github.com/deepjavalibrary/djl/blob/master/docs/hybrid_engine.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.onnxruntime:onnxruntime-engine:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "import java.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1 create a Translator\n", - "\n", - "Inference in machine learning is the process of predicting the output for a given input based on a pre-defined model.\n", - "DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide\n", - "output. DJL also allows you to provide user-defined inputs. The workflow looks like the following:\n", - "\n", - "![https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n", - "\n", - "The [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing\n", - "component converts the user-defined input objects into an NDList, so that the [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html) in DJL can understand the\n", - "input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the\n", - "`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output\n", - "format.\n", - "\n", - "In our use case, we use a class namely `IrisFlower` as our input class type. We will use [`Classifications`](https://javadoc.io/doc/ai.djl/api/0.23.0/ai/djl/modality/Classifications.html) as our output class type." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "public static class IrisFlower {\n", - "\n", - " public float sepalLength;\n", - " public float sepalWidth;\n", - " public float petalLength;\n", - " public float petalWidth;\n", - "\n", - " public IrisFlower(float sepalLength, float sepalWidth, float petalLength, float petalWidth) {\n", - " this.sepalLength = sepalLength;\n", - " this.sepalWidth = sepalWidth;\n", - " this.petalLength = petalLength;\n", - " this.petalWidth = petalWidth;\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's create a translator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "public static class MyTranslator implements NoBatchifyTranslator {\n", - "\n", - " private final List synset;\n", - "\n", - " public MyTranslator() {\n", - " // species name\n", - " synset = Arrays.asList(\"setosa\", \"versicolor\", \"virginica\");\n", - " }\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, IrisFlower input) {\n", - " float[] data = {input.sepalLength, input.sepalWidth, input.petalLength, input.petalWidth};\n", - " NDArray array = ctx.getNDManager().create(data, new Shape(1, 4));\n", - " return new NDList(array);\n", - " }\n", - "\n", - " @Override\n", - " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", - " float[] data = list.get(1).toFloatArray();\n", - " List probabilities = new ArrayList<>(data.length);\n", - " for (float f : data) {\n", - " probabilities.add((double) f);\n", - " }\n", - " return new Classifications(synset, probabilities);\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2 Prepare your model\n", - "\n", - "We will load a pretrained sklearn model into DJL. We defined a [`ModelZoo`](https://javadoc.io/doc/ai.djl/api/0.23.0/ai/djl/repository/zoo/ModelZoo.html) concept to allow user load model from varity of locations, such as remote URL, local files or DJL pretrained model zoo. We need to define [`Criteria`](https://javadoc.io/doc/ai.djl/api/0.23.0/ai/djl/repository/zoo/Criteria.html) class to help the modelzoo locate the model and attach translator. In this example, we download a compressed ONNX model from S3." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String modelUrl = \"https://mlrepo.djl.ai/model/tabular/softmax_regression/ai/djl/onnxruntime/iris_flowers/0.0.1/iris_flowers.zip\";\n", - "Criteria criteria = Criteria.builder()\n", - " .setTypes(IrisFlower.class, Classifications.class)\n", - " .optModelUrls(modelUrl)\n", - " .optTranslator(new MyTranslator())\n", - " .optEngine(\"OnnxRuntime\") // use OnnxRuntime engine by default\n", - " .build();\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3 Run inference\n", - "\n", - "User will just need to create a `Predictor` from model to run the inference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor();\n", - "IrisFlower info = new IrisFlower(1.0f, 2.0f, 3.0f, 4.0f);\n", - "predictor.predict(info);" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb b/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb deleted file mode 100644 index 1249ee12e2f..00000000000 --- a/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb +++ /dev/null @@ -1,369 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Face Mask Detection using PaddlePaddle\n", - "\n", - "In this tutorial, we will be using pretrained PaddlePaddle model from [PaddleHub](https://github.com/PaddlePaddle/PaddleHub/tree/release/v1.5/demo/mask_detection/cpp) to do mask detection on the sample image. To complete this procedure, there are two steps needs to be done:\n", - "\n", - "- Recognize face on the image (no matter wearing mask or not) using Face object detection model\n", - "- classify the face is wearing mask or not\n", - "\n", - "These two steps will involve two paddle models. We will implement the corresponding preprocess and postprocess logic to it.\n", - "\n", - "## Import dependencies and classes\n", - "\n", - "PaddlePaddle is one of the Deep Engines that requires DJL hybrid mode to run inference. Itself does not contains NDArray operations and needs a supplemental DL framework to help with that. So we import Pytorch DL engine as well in here to do the processing works." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "\n", - "// second engine to do preprocessing and postprocessing\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.output.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.modality.cv.translator.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.Shape;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "\n", - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Face Detection model\n", - "\n", - "Now we can start working on the first model. The model can do face detection and require some additional processing before we feed into it:\n", - "\n", - "- Resize: Shrink the image with a certain ratio to feed in\n", - "- Normalize the image with a scale\n", - "\n", - "Fortunatly, DJL offers a [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) interface that can help you with these processing. The rough Translator architecture looks like below:\n", - "\n", - "![](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n", - "\n", - "In the following sections, we will implement a `FaceTranslator` class to do the work.\n", - "\n", - "### Preprocessing\n", - "\n", - "In this stage, we will load an image and do some preprocessing work to it. Let's load the image first and take a look at it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String url = \"https://raw.githubusercontent.com/PaddlePaddle/PaddleHub/release/v1.5/demo/mask_detection/python/images/mask.jpg\";\n", - "Image img = ImageFactory.getInstance().fromUrl(url);\n", - "img.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then, let's try to apply some transformation to it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "NDList processImageInput(NDManager manager, Image input, float shrink) {\n", - " NDArray array = input.toNDArray(manager);\n", - " Shape shape = array.getShape();\n", - " array = NDImageUtils.resize(\n", - " array, (int) (shape.get(1) * shrink), (int) (shape.get(0) * shrink));\n", - " array = array.transpose(2, 0, 1).flip(0); // HWC -> CHW BGR -> RGB\n", - " NDArray mean = manager.create(new float[] {104f, 117f, 123f}, new Shape(3, 1, 1));\n", - " array = array.sub(mean).mul(0.007843f); // normalization\n", - " array = array.expandDims(0); // make batch dimension\n", - " return new NDList(array);\n", - "}\n", - "\n", - "processImageInput(NDManager.newBaseManager(), img, 0.5f);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As you can see above, we convert the image to a NDArray with shape following (number_of_batches, channel (RGB), height, width). This is the required input for the model to run object detection.\n", - "\n", - "### Postprocessing\n", - "\n", - "For postprocessing, The output is in shape of (number_of_boxes, (class_id, probability, xmin, ymin, xmax, ymax)). We can store them into the prebuilt DJL [`DetectedObjects`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/cv/output/DetectedObjects.html) classes for further processing. Let's assume we have an inference output of ((1, 0.99, 0.2, 0.4, 0.5, 0.8)) and try to draw this box out." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "DetectedObjects processImageOutput(NDList list, List className, float threshold) {\n", - " NDArray result = list.singletonOrThrow();\n", - " float[] probabilities = result.get(\":,1\").toFloatArray();\n", - " List names = new ArrayList<>();\n", - " List prob = new ArrayList<>();\n", - " List boxes = new ArrayList<>();\n", - " for (int i = 0; i < probabilities.length; i++) {\n", - " if (probabilities[i] >= threshold) {\n", - " float[] array = result.get(i).toFloatArray();\n", - " names.add(className.get((int) array[0]));\n", - " prob.add((double) probabilities[i]);\n", - " boxes.add(\n", - " new Rectangle(\n", - " array[2], array[3], array[4] - array[2], array[5] - array[3]));\n", - " }\n", - " }\n", - " return new DetectedObjects(names, prob, boxes);\n", - "}\n", - "\n", - "NDArray tempOutput = NDManager.newBaseManager().create(new float[]{1f, 0.99f, 0.1f, 0.1f, 0.2f, 0.2f}, new Shape(1, 6));\n", - "DetectedObjects testBox = processImageOutput(new NDList(tempOutput), Arrays.asList(\"Not Face\", \"Face\"), 0.7f);\n", - "Image newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(testBox);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Create Translator and run inference\n", - "\n", - "After this step, you might understand how process and postprocess works in DJL. Now, let's do something real and put them together in a single piece:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class FaceTranslator implements NoBatchifyTranslator {\n", - "\n", - " private float shrink;\n", - " private float threshold;\n", - " private List className;\n", - "\n", - " FaceTranslator(float shrink, float threshold) {\n", - " this.shrink = shrink;\n", - " this.threshold = threshold;\n", - " className = Arrays.asList(\"Not Face\", \"Face\");\n", - " }\n", - "\n", - " @Override\n", - " public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {\n", - " return processImageOutput(list, className, threshold);\n", - " }\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, Image input) {\n", - " return processImageInput(ctx.getNDManager(), input, shrink);\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To run inference with this model, we need to load the model from Paddle model zoo. To load a model in DJL, you need to specify a [`Criteria`](https://javadoc.io/doc/ai.djl/api/0.23.1/ai/djl/repository/zoo/Criteria.html). `Criteria` is used identify where to load the model and which `Translator` should apply to it. Then, all we need to do is to get a [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html) from the model and use it to do inference:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria = Criteria.builder()\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optModelUrls(\"djl://ai.djl.paddlepaddle/face_detection/0.0.1/mask_detection\")\n", - " .optFilter(\"flavor\", \"server\")\n", - " .optTranslator(new FaceTranslator(0.5f, 0.7f))\n", - " .build();\n", - " \n", - "var model = criteria.loadModel();\n", - "var predictor = model.newPredictor();\n", - "\n", - "DetectedObjects inferenceResult = predictor.predict(img);\n", - "newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(inferenceResult);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As you can see above, it brings you three faces detections.\n", - "\n", - "## Mask Classification model\n", - "\n", - "\n", - "So, once we have the image location ready, we can crop the image and feed it to the Mask Classification model for further processing.\n", - "\n", - "### Crop the image\n", - "\n", - "The output of the box location is a value from 0 - 1 that can be mapped to the actual box pixel location if we simply multiply by width/height. For better accuracy on the cropped image, we extend the detection box to square. Let's try to get a cropped image:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int[] extendSquare(\n", - " double xmin, double ymin, double width, double height, double percentage) {\n", - " double centerx = xmin + width / 2;\n", - " double centery = ymin + height / 2;\n", - " double maxDist = Math.max(width / 2, height / 2) * (1 + percentage);\n", - " return new int[] {\n", - " (int) (centerx - maxDist), (int) (centery - maxDist), (int) (2 * maxDist)\n", - " };\n", - "}\n", - "\n", - "Image getSubImage(Image img, BoundingBox box) {\n", - " Rectangle rect = box.getBounds();\n", - " int width = img.getWidth();\n", - " int height = img.getHeight();\n", - " int[] squareBox =\n", - " extendSquare(\n", - " rect.getX() * width,\n", - " rect.getY() * height,\n", - " rect.getWidth() * width,\n", - " rect.getHeight() * height,\n", - " 0.18);\n", - " return img.getSubImage(squareBox[0], squareBox[1], squareBox[2], squareBox[2]);\n", - "}\n", - "\n", - "List faces = inferenceResult.items();\n", - "getSubImage(img, faces.get(2).getBoundingBox()).getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Prepare Translator and load the model\n", - "\n", - "For the face classification model, we can use DJL prebuilt [`ImageClassificationTranslator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/cv/translator/ImageClassificationTranslator.html) with a few transformation. This Translator brings a basic image translation process and can be extended with additional standard processing steps. So in our case, we don't have to create another `Translator` and just leverage on this prebuilt one." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria = Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelUrls(\"djl://ai.djl.paddlepaddle/mask_classification/0.0.1/mask_classification\")\n", - " .optFilter(\"flavor\", \"server\")\n", - " .optTranslator(\n", - " ImageClassificationTranslator.builder()\n", - " .addTransform(new Resize(128, 128))\n", - " .addTransform(new ToTensor()) // HWC -> CHW div(255)\n", - " .addTransform(\n", - " new Normalize(\n", - " new float[] {0.5f, 0.5f, 0.5f},\n", - " new float[] {1.0f, 1.0f, 1.0f}))\n", - " .addTransform(nd -> nd.flip(0)) // RGB -> GBR\n", - " .build())\n", - " .build();\n", - "\n", - "var classifyModel = criteria.loadModel();\n", - "var classifier = classifyModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Run inference\n", - "\n", - "So all we need to do is to apply the previous implemented functions and apply them all together. We firstly crop the image and then use it for inference. After these steps, we create a new DetectedObjects with new Classification classes:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "List names = new ArrayList<>();\n", - "List prob = new ArrayList<>();\n", - "List rect = new ArrayList<>();\n", - "for (DetectedObjects.DetectedObject face : faces) {\n", - " Image subImg = getSubImage(img, face.getBoundingBox());\n", - " Classifications classifications = classifier.predict(subImg);\n", - " names.add(classifications.best().getClassName());\n", - " prob.add(face.getProbability());\n", - " rect.add(face.getBoundingBox());\n", - "}\n", - "\n", - "newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n", - "newImage.getWrappedImage();" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb b/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb deleted file mode 100644 index 46c86461bdb..00000000000 --- a/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb +++ /dev/null @@ -1,352 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 用飛槳+ DJL 實作人臉口罩辨識\n", - "在這個教學中我們將會展示利用 PaddleHub 下載預訓練好的 PaddlePaddle 模型並針對範例照片做人臉口罩辨識。這個範例總共會分成兩個步驟:\n", - "\n", - "- 用臉部檢測模型識別圖片中的人臉(無論是否有戴口罩) \n", - "- 確認圖片中的臉是否有戴口罩\n", - "\n", - "這兩個步驟會包含使用兩個 Paddle 模型,我們會在接下來的內容介紹兩個模型對應需要做的前後處理邏輯\n", - "\n", - "## 導入相關環境依賴及子類別\n", - "在這個例子中的前處理飛槳深度學習引擎需要搭配 DJL 混合模式進行深度學習推理,原因是引擎本身沒有包含 NDArray 操作,因此需要藉用其他引擎的 NDArray 操作能力來完成。這邊我們導入 PyTorch 來做協同的前處理工作:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "\n", - "// second engine to do preprocessing and postprocessing\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.output.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.modality.cv.translator.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.Shape;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "\n", - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 臉部偵測模型\n", - "現在我們可以開始處理第一個模型,在將圖片輸入臉部檢測模型前我們必須先做一些預處理:\n", - "•\t調整圖片尺寸: 以特定比例縮小圖片\n", - "•\t用一個數值對縮小後圖片正規化\n", - "對開發者來說好消息是,DJL 提供了 Translator 介面來幫助開發做這樣的預處理. 一個比較粗略的 Translator 架構如下:\n", - "\n", - "![](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n", - "\n", - "在接下來的段落,我們會利用一個 FaceTranslator 子類別實作來完成工作\n", - "### 預處理\n", - "在這個階段我們會讀取一張圖片並且對其做一些事先的預處理,讓我們先示範讀取一張圖片:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String url = \"https://raw.githubusercontent.com/PaddlePaddle/PaddleHub/release/v1.5/demo/mask_detection/python/images/mask.jpg\";\n", - "Image img = ImageFactory.getInstance().fromUrl(url);\n", - "img.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "接著,讓我們試著對圖片做一些預處理的轉換:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "NDList processImageInput(NDManager manager, Image input, float shrink) {\n", - " NDArray array = input.toNDArray(manager);\n", - " Shape shape = array.getShape();\n", - " array = NDImageUtils.resize(\n", - " array, (int) (shape.get(1) * shrink), (int) (shape.get(0) * shrink));\n", - " array = array.transpose(2, 0, 1).flip(0); // HWC -> CHW BGR -> RGB\n", - " NDArray mean = manager.create(new float[] {104f, 117f, 123f}, new Shape(3, 1, 1));\n", - " array = array.sub(mean).mul(0.007843f); // normalization\n", - " array = array.expandDims(0); // make batch dimension\n", - " return new NDList(array);\n", - "}\n", - "\n", - "processImageInput(NDManager.newBaseManager(), img, 0.5f);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "如上述所見,我們已經把圖片轉成如下尺寸的 NDArray: (披量, 通道(RGB), 高度, 寬度). 這是物件檢測模型輸入的格式\n", - "### 後處理\n", - "當我們做後處理時, 模型輸出的格式是 (number_of_boxes, (class_id, probability, xmin, ymin, xmax, ymax)). 我們可以將其存入預先建立好的 DJL 子類別 DetectedObjects 以便做後續操作. 我們假設有一組推論後的輸出是 ((1, 0.99, 0.2, 0.4, 0.5, 0.8)) 並且試著把人像框顯示在圖片上" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "DetectedObjects processImageOutput(NDList list, List className, float threshold) {\n", - " NDArray result = list.singletonOrThrow();\n", - " float[] probabilities = result.get(\":,1\").toFloatArray();\n", - " List names = new ArrayList<>();\n", - " List prob = new ArrayList<>();\n", - " List boxes = new ArrayList<>();\n", - " for (int i = 0; i < probabilities.length; i++) {\n", - " if (probabilities[i] >= threshold) {\n", - " float[] array = result.get(i).toFloatArray();\n", - " names.add(className.get((int) array[0]));\n", - " prob.add((double) probabilities[i]);\n", - " boxes.add(\n", - " new Rectangle(\n", - " array[2], array[3], array[4] - array[2], array[5] - array[3]));\n", - " }\n", - " }\n", - " return new DetectedObjects(names, prob, boxes);\n", - "}\n", - "\n", - "NDArray tempOutput = NDManager.newBaseManager().create(new float[]{1f, 0.99f, 0.1f, 0.1f, 0.2f, 0.2f}, new Shape(1, 6));\n", - "DetectedObjects testBox = processImageOutput(new NDList(tempOutput), Arrays.asList(\"Not Face\", \"Face\"), 0.7f);\n", - "Image newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(testBox);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 生成一個翻譯器並執行推理任務\n", - "透過這個步驟,你會理解 DJL 中的前後處理如何運作,現在讓我們把前數的幾個步驟串在一起並對真實圖片進行操作:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class FaceTranslator implements NoBatchifyTranslator {\n", - "\n", - " private float shrink;\n", - " private float threshold;\n", - " private List className;\n", - "\n", - " FaceTranslator(float shrink, float threshold) {\n", - " this.shrink = shrink;\n", - " this.threshold = threshold;\n", - " className = Arrays.asList(\"Not Face\", \"Face\");\n", - " }\n", - "\n", - " @Override\n", - " public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {\n", - " return processImageOutput(list, className, threshold);\n", - " }\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, Image input) {\n", - " return processImageInput(ctx.getNDManager(), input, shrink);\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "要執行這個人臉檢測推理,我們必須先從 DJL 的 Paddle Model Zoo 讀取模型,在讀取模型之前我們必須指定好 `Crieteria` . `Crieteria` 是用來確認要從哪邊讀取模型而後執行 `Translator` 來進行模型導入. 接著,我們只要利用 `Predictor` 就可以開始進行推論" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria = Criteria.builder()\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optModelUrls(\"djl://ai.djl.paddlepaddle/face_detection/0.0.1/mask_detection\")\n", - " .optFilter(\"flavor\", \"server\")\n", - " .optTranslator(new FaceTranslator(0.5f, 0.7f))\n", - " .build();\n", - " \n", - "var model = criteria.loadModel();\n", - "var predictor = model.newPredictor();\n", - "\n", - "DetectedObjects inferenceResult = predictor.predict(img);\n", - "newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(inferenceResult);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "如圖片所示,這個推論服務已經可以正確的辨識出圖片中的三張人臉\n", - "## 口罩分類模型\n", - "一旦有了圖片的座標,我們就可以將圖片裁剪到適當大小並且將其傳給口罩分類模型做後續的推論\n", - "### 圖片裁剪\n", - "圖中方框位置的數值範圍從0到1, 只要將這個數值乘上圖片的長寬我們就可以將方框對應到圖片中的準確位置. 為了使裁剪後的圖片有更好的精確度,我們將圖片裁剪成方形,讓我們示範一下:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int[] extendSquare(\n", - " double xmin, double ymin, double width, double height, double percentage) {\n", - " double centerx = xmin + width / 2;\n", - " double centery = ymin + height / 2;\n", - " double maxDist = Math.max(width / 2, height / 2) * (1 + percentage);\n", - " return new int[] {\n", - " (int) (centerx - maxDist), (int) (centery - maxDist), (int) (2 * maxDist)\n", - " };\n", - "}\n", - "\n", - "Image getSubImage(Image img, BoundingBox box) {\n", - " Rectangle rect = box.getBounds();\n", - " int width = img.getWidth();\n", - " int height = img.getHeight();\n", - " int[] squareBox =\n", - " extendSquare(\n", - " rect.getX() * width,\n", - " rect.getY() * height,\n", - " rect.getWidth() * width,\n", - " rect.getHeight() * height,\n", - " 0.18);\n", - " return img.getSubImage(squareBox[0], squareBox[1], squareBox[2], squareBox[2]);\n", - "}\n", - "\n", - "List faces = inferenceResult.items();\n", - "getSubImage(img, faces.get(2).getBoundingBox()).getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 事先準備 Translator 並讀取模型\n", - "在使用臉部檢測模型的時候,我們可以利用 DJL 預先建好的 `ImageClassificationTranslator` 並且加上一些轉換。這個 Translator 提供了一些基礎的圖片翻譯處理並且同時包含一些進階的標準化圖片處理。以這個例子來說, 我們不需要額外建立新的 `Translator` 而使用預先建立的就可以" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria = Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelUrls(\"djl://ai.djl.paddlepaddle/mask_classification/0.0.1/mask_classification\")\n", - " .optFilter(\"flavor\", \"server\")\n", - " .optTranslator(\n", - " ImageClassificationTranslator.builder()\n", - " .addTransform(new Resize(128, 128))\n", - " .addTransform(new ToTensor()) // HWC -> CHW div(255)\n", - " .addTransform(\n", - " new Normalize(\n", - " new float[] {0.5f, 0.5f, 0.5f},\n", - " new float[] {1.0f, 1.0f, 1.0f}))\n", - " .addTransform(nd -> nd.flip(0)) // RGB -> GBR\n", - " .build())\n", - " .build();\n", - "\n", - "var classifyModel = criteria.loadModel();\n", - "var classifier = classifyModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 執行推論任務\n", - "最後,要完成一個口罩識別的任務,我們只需要將上述的步驟合在一起即可。我們先將圖片做裁剪後並對其做上述的推論操作,結束之後再生成一個新的分類子類別 `DetectedObjects`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "List names = new ArrayList<>();\n", - "List prob = new ArrayList<>();\n", - "List rect = new ArrayList<>();\n", - "for (DetectedObjects.DetectedObject face : faces) {\n", - " Image subImg = getSubImage(img, face.getBoundingBox());\n", - " Classifications classifications = classifier.predict(subImg);\n", - " names.add(classifications.best().getClassName());\n", - " prob.add(face.getProbability());\n", - " rect.add(face.getBoundingBox());\n", - "}\n", - "\n", - "newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n", - "newImage.getWrappedImage();" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/paddlepaddle/paddle_ocr_java.ipynb b/jupyter/paddlepaddle/paddle_ocr_java.ipynb deleted file mode 100644 index da8527020ab..00000000000 --- a/jupyter/paddlepaddle/paddle_ocr_java.ipynb +++ /dev/null @@ -1,313 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# PaddleOCR DJL example\n", - "\n", - "In this tutorial, we will be using pretrained PaddlePaddle model from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) to do Optical character recognition (OCR) from the given image. There are three models involved in this tutorial:\n", - "\n", - "- Word detection model: used to detect the word block from the image\n", - "- Word direction model: used to find if the text needs to rotate\n", - "- Word recognition model: Used to recognize test from the word block\n", - "\n", - "## Import dependencies and classes\n", - "\n", - "PaddlePaddle is one of the Deep Engines that requires DJL hybrid mode to run inference. Itself does not contains NDArray operations and needs a supplemental DL framework to help with that. So we import Pytorch DL engine as well in here to do the processing works." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "\n", - "// second engine to do preprocessing and postprocessing\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.inference.Predictor;\n", - "import ai.djl.modality.Classifications;\n", - "import ai.djl.modality.cv.Image;\n", - "import ai.djl.modality.cv.ImageFactory;\n", - "import ai.djl.modality.cv.output.*;\n", - "import ai.djl.modality.cv.util.NDImageUtils;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.DataType;\n", - "import ai.djl.ndarray.types.Shape;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.paddlepaddle.zoo.cv.objectdetection.PpWordDetectionTranslator;\n", - "import ai.djl.paddlepaddle.zoo.cv.imageclassification.PpWordRotateTranslator;\n", - "import ai.djl.paddlepaddle.zoo.cv.wordrecognition.PpWordRecognitionTranslator;\n", - "import ai.djl.translate.*;\n", - "import java.util.concurrent.ConcurrentHashMap;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## the Image\n", - "Firstly, let's take a look at our sample image, a flight ticket:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String url = \"https://resources.djl.ai/images/flight_ticket.jpg\";\n", - "Image img = ImageFactory.getInstance().fromUrl(url);\n", - "img.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Word detection model\n", - "\n", - "In our word detection model, we load the model exported from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-detection-model-to-inference-model). After that, we can spawn a DJL Predictor from it called detector." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria1 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/det_db.zip\")\n", - " .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap()))\n", - " .build();\n", - "var detectionModel = criteria1.loadModel();\n", - "var detector = detectionModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then, we can detect the word block from it. The original output from the model is a bitmap that marked all word regions. The `PpWordDetectionTranslator` convert the output bitmap into a rectangle bounded box for us to crop the image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var detectedObj = detector.predict(img);\n", - "Image newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(detectedObj);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As you can see above, the word block are very narrow and does not include the whole body of all words. Let's try to extend it a bit for a better result. `extendRect` extend the box height and width to a certain scale. `getSubImage` will crop the image and extract the word block." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Image getSubImage(Image img, BoundingBox box) {\n", - " Rectangle rect = box.getBounds();\n", - " double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());\n", - " int width = img.getWidth();\n", - " int height = img.getHeight();\n", - " int[] recovered = {\n", - " (int) (extended[0] * width),\n", - " (int) (extended[1] * height),\n", - " (int) (extended[2] * width),\n", - " (int) (extended[3] * height)\n", - " };\n", - " return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);\n", - "}\n", - "\n", - "double[] extendRect(double xmin, double ymin, double width, double height) {\n", - " double centerx = xmin + width / 2;\n", - " double centery = ymin + height / 2;\n", - " if (width > height) {\n", - " width += height * 2.0;\n", - " height *= 3.0;\n", - " } else {\n", - " height += width * 2.0;\n", - " width *= 3.0;\n", - " }\n", - " double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2;\n", - " double newY = centery - height / 2 < 0 ? 0 : centery - height / 2;\n", - " double newWidth = newX + width > 1 ? 1 - newX : width;\n", - " double newHeight = newY + height > 1 ? 1 - newY : height;\n", - " return new double[] {newX, newY, newWidth, newHeight};\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's try to extract one block out:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "List boxes = detectedObj.items();\n", - "var sample = getSubImage(img, boxes.get(5).getBoundingBox());\n", - "sample.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Word Direction model\n", - "\n", - "This model is exported from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-angle-classification-model-to-inference-model) that can help to identify if the image is required to rotate. The following code will load this model and create a rotateClassifier." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria2 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/cls.zip\")\n", - " .optTranslator(new PpWordRotateTranslator())\n", - " .build();\n", - "var rotateModel = criteria2.loadModel();\n", - "var rotateClassifier = rotateModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Word Recgonition model\n", - "\n", - "The word recognition model is exported from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-recognition-model-to-inference-model) that can recognize the text on the image. Let's load this model as well.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria3 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, String.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/rec_crnn.zip\")\n", - " .optTranslator(new PpWordRecognitionTranslator())\n", - " .build();\n", - "var recognitionModel = criteria3.loadModel();\n", - "var recognizer = recognitionModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then we can try to play with these two models on the previous cropped image:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "System.out.println(rotateClassifier.predict(sample));\n", - "recognizer.predict(sample);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, let's run these models on the whole image and see the outcome. DJL offers a rich image toolkit that allows you to draw the text on image and display them." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Image rotateImg(Image image) {\n", - " try (NDManager manager = NDManager.newBaseManager()) {\n", - " NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);\n", - " return ImageFactory.getInstance().fromNDArray(rotated);\n", - " }\n", - "}\n", - "\n", - "List names = new ArrayList<>();\n", - "List prob = new ArrayList<>();\n", - "List rect = new ArrayList<>();\n", - "\n", - "for (int i = 0; i < boxes.size(); i++) {\n", - " Image subImg = getSubImage(img, boxes.get(i).getBoundingBox());\n", - " if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {\n", - " subImg = rotateImg(subImg);\n", - " }\n", - " Classifications.Classification result = rotateClassifier.predict(subImg).best();\n", - " if (\"Rotate\".equals(result.getClassName()) && result.getProbability() > 0.8) {\n", - " subImg = rotateImg(subImg);\n", - " }\n", - " String name = recognizer.predict(subImg);\n", - " names.add(name);\n", - " prob.add(-1.0);\n", - " rect.add(boxes.get(i).getBoundingBox());\n", - "}\n", - "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n", - "newImage.getWrappedImage();" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb b/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb deleted file mode 100644 index 2419baf89c7..00000000000 --- a/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb +++ /dev/null @@ -1,309 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# PaddleOCR在DJL 上的實現\n", - "在這個教程裡,我們會展示利用 PaddleOCR 下載預訓練好文字處理模型並對指定的照片進行文學文字檢測 (OCR)。這個教程總共會分成三個部分:\n", - "\n", - "- 文字區塊檢測: 從圖片檢測出文字區塊\n", - "- 文字角度檢測: 確認文字是否需要旋轉\n", - "- 文字識別: 確認區塊內的文字\n", - "\n", - "## 導入相關環境依賴及子類別\n", - "在這個例子中的前處理飛槳深度學習引擎需要搭配DJL混合模式進行深度學習推理,原因是引擎本身沒有包含ND數組操作,因此需要藉用其他引擎的數組操作能力來完成。這邊我們導入Pytorch來做協同的前處理工作:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "\n", - "// second engine to do preprocessing and postprocessing\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.inference.Predictor;\n", - "import ai.djl.modality.Classifications;\n", - "import ai.djl.modality.cv.Image;\n", - "import ai.djl.modality.cv.ImageFactory;\n", - "import ai.djl.modality.cv.output.*;\n", - "import ai.djl.modality.cv.util.NDImageUtils;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.DataType;\n", - "import ai.djl.ndarray.types.Shape;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.paddlepaddle.zoo.cv.objectdetection.PpWordDetectionTranslator;\n", - "import ai.djl.paddlepaddle.zoo.cv.imageclassification.PpWordRotateTranslator;\n", - "import ai.djl.paddlepaddle.zoo.cv.wordrecognition.PpWordRecognitionTranslator;\n", - "import ai.djl.translate.*;\n", - "import java.util.concurrent.ConcurrentHashMap;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 圖片讀取\n", - "首先讓我們載入這次教程會用到的機票範例圖片:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String url = \"https://resources.djl.ai/images/flight_ticket.jpg\";\n", - "Image img = ImageFactory.getInstance().fromUrl(url);\n", - "img.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 文字區塊檢測\n", - "我們首先從 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-detection-model-to-inference-model) 開發套件中讀取文字檢測的模型,之後我們可以生成一個DJL `Predictor` 並將其命名為 `detector`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria1 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, DetectedObjects.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/det_db.zip\")\n", - " .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap()))\n", - " .build();\n", - "var detectionModel = criteria1.loadModel();\n", - "var detector = detectionModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "接著我們檢測出圖片中的文字區塊,這個模型的原始輸出是含有標註所有文字區域的圖算法(Bitmap),我們可以利用`PpWordDetectionTranslator` 函式將圖算法的輸出轉成長方形的方框來裁剪圖片" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var detectedObj = detector.predict(img);\n", - "Image newImage = img.duplicate();\n", - "newImage.drawBoundingBoxes(detectedObj);\n", - "newImage.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "如上所示,所標註的文字區塊都非常窄,且沒有包住所有完整的文字區塊。讓我們嘗試使用`extendRect`函式來擴展文字框的長寬到需要的大小, 再利用 `getSubImage` 裁剪並擷取出文子區塊。" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Image getSubImage(Image img, BoundingBox box) {\n", - " Rectangle rect = box.getBounds();\n", - " double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());\n", - " int width = img.getWidth();\n", - " int height = img.getHeight();\n", - " int[] recovered = {\n", - " (int) (extended[0] * width),\n", - " (int) (extended[1] * height),\n", - " (int) (extended[2] * width),\n", - " (int) (extended[3] * height)\n", - " };\n", - " return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);\n", - "}\n", - "\n", - "double[] extendRect(double xmin, double ymin, double width, double height) {\n", - " double centerx = xmin + width / 2;\n", - " double centery = ymin + height / 2;\n", - " if (width > height) {\n", - " width += height * 2.0;\n", - " height *= 3.0;\n", - " } else {\n", - " height += width * 2.0;\n", - " width *= 3.0;\n", - " }\n", - " double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2;\n", - " double newY = centery - height / 2 < 0 ? 0 : centery - height / 2;\n", - " double newWidth = newX + width > 1 ? 1 - newX : width;\n", - " double newHeight = newY + height > 1 ? 1 - newY : height;\n", - " return new double[] {newX, newY, newWidth, newHeight};\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "讓我們輸出其中一個文字區塊" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "List boxes = detectedObj.items();\n", - "var sample = getSubImage(img, boxes.get(5).getBoundingBox());\n", - "sample.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 文字角度檢測\n", - "我們從 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-angle-classification-model-to-inference-model) 輸出這個模型並確認圖片及文字是否需要旋轉。以下的代碼會讀入這個模型並生成a `rotateClassifier` 子類別" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria2 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/cls.zip\")\n", - " .optTranslator(new PpWordRotateTranslator())\n", - " .build();\n", - "var rotateModel = criteria2.loadModel();\n", - "var rotateClassifier = rotateModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 文字識別\n", - "\n", - "我們從 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-recognition-model-to-inference-model) 輸出這個模型並識別圖片中的文字, 我們一樣仿造上述的步驟讀取這個模型\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var criteria3 = Criteria.builder()\n", - " .optEngine(\"PaddlePaddle\")\n", - " .setTypes(Image.class, String.class)\n", - " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/rec_crnn.zip\")\n", - " .optTranslator(new PpWordRecognitionTranslator())\n", - " .build();\n", - "var recognitionModel = criteria3.loadModel();\n", - "var recognizer = recognitionModel.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "接著我們可以試著套用這兩個模型在先前剪裁好的文字區塊上" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "System.out.println(rotateClassifier.predict(sample));\n", - "recognizer.predict(sample);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "最後我們把這些模型串連在一起並套用在整張圖片上看看結果會如何。DJL提供了豐富的影像工具包讓你可以從圖片中擷取出文字並且完美呈現" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Image rotateImg(Image image) {\n", - " try (NDManager manager = NDManager.newBaseManager()) {\n", - " NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);\n", - " return ImageFactory.getInstance().fromNDArray(rotated);\n", - " }\n", - "}\n", - "\n", - "List names = new ArrayList<>();\n", - "List prob = new ArrayList<>();\n", - "List rect = new ArrayList<>();\n", - "\n", - "for (int i = 0; i < boxes.size(); i++) {\n", - " Image subImg = getSubImage(img, boxes.get(i).getBoundingBox());\n", - " if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {\n", - " subImg = rotateImg(subImg);\n", - " }\n", - " Classifications.Classification result = rotateClassifier.predict(subImg).best();\n", - " if (\"Rotate\".equals(result.getClassName()) && result.getProbability() > 0.8) {\n", - " subImg = rotateImg(subImg);\n", - " }\n", - " String name = recognizer.predict(subImg);\n", - " names.add(name);\n", - " prob.add(-1.0);\n", - " rect.add(boxes.get(i).getBoundingBox());\n", - "}\n", - "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n", - "newImage.getWrappedImage();" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/pytorch/load_your_own_pytorch_bert.ipynb b/jupyter/pytorch/load_your_own_pytorch_bert.ipynb deleted file mode 100644 index 3c52ee599b0..00000000000 --- a/jupyter/pytorch/load_your_own_pytorch_bert.ipynb +++ /dev/null @@ -1,441 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load your own PyTorch BERT model\n", - "\n", - "In the previous [example](https://github.com/deepjavalibrary/djl/blob/master/jupyter/BERTQA.ipynb), you run BERT inference with the model from Model Zoo. You can also load the model on your own pre-trained BERT and use custom classes as the input and output.\n", - "\n", - "In general, the PyTorch BERT model from [HuggingFace](https://github.com/huggingface/transformers) requires these three inputs:\n", - "\n", - "- word indices: The index of each word in a sentence\n", - "- word types: The type index of the word.\n", - "- attention mask: The mask indicates to the model which tokens should be attended to, and which should not after batching sequence together.\n", - "\n", - "We will dive deep into these details later." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "There are dependencies we will use." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0\n", - "%maven ai.djl.pytorch:pytorch-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import java packages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;\n", - "import java.util.stream.*;\n", - "\n", - "import ai.djl.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.modality.nlp.*;\n", - "import ai.djl.modality.nlp.qa.*;\n", - "import ai.djl.modality.nlp.bert.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Reuse the previous input**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var question = \"When did BBC Japan start broadcasting?\";\n", - "var resourceDocument = \"BBC Japan was a general entertainment Channel.\\n\" +\n", - " \"Which operated between December 2004 and April 2006.\\n\" +\n", - " \"It ceased operations after its Japanese distributor folded.\";\n", - "\n", - "QAInput input = new QAInput(question, resourceDocument);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dive deep into Translator\n", - "\n", - "Inference in deep learning is the process of predicting the output for a given input based on a pre-defined model.\n", - "DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide\n", - "output. DJL also allows you to provide user-defined inputs. The workflow looks like the following:\n", - "\n", - "![https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n", - "\n", - "The red block (\"Images\") in the workflow is the input that DJL expects from you. The green block (\"Images\n", - "bounding box\") is the output that you expect. Because DJL does not know which input to expect and which output format that you prefer, DJL provides the [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) interface so you can define your own\n", - "input and output.\n", - "\n", - "The `Translator` interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing\n", - "component converts the user-defined input objects into an NDList, so that the [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html) in DJL can understand the\n", - "input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the\n", - "`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output\n", - "format." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Pre-processing\n", - "\n", - "Now, you need to convert the sentences into tokens. We provide a powerful tool [`BertTokenizer`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/bert/BertTokenizer.html) that you can use to convert questions and answers into tokens, and batchify your sequence together. Once you have properly formatted tokens, you can use [`Vocabulary`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/Vocabulary.html) to map your token to BERT index.\n", - "\n", - "The following code block demonstrates tokenizing the question and answer defined earlier into BERT-formatted tokens." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var tokenizer = new BertTokenizer();\n", - "List tokenQ = tokenizer.tokenize(question.toLowerCase());\n", - "List tokenA = tokenizer.tokenize(resourceDocument.toLowerCase());\n", - "\n", - "System.out.println(\"Question Token: \" + tokenQ);\n", - "System.out.println(\"Answer Token: \" + tokenA);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`BertTokenizer` can also help you batchify questions and resource documents together by calling `encode()`.\n", - "The output contains information that BERT ingests.\n", - "\n", - "- getTokens: It returns a list of strings including the question, resource document and special word to let the model tell which part is the question and which part is the resource document. Because PyTorch BERT was trained with varioue sequence length, you don't pad the tokens.\n", - "- getTokenTypes: It returns a list of type indices of the word to indicate the location of the resource document. All Questions will be labelled with 0 and all resource documents will be labelled with 1.\n", - "\n", - " [Question tokens...DocResourceTokens...padding tokens] => [000000...11111....0000]\n", - " \n", - "\n", - "- getValidLength: It returns the actual length of the question and resource document tokens tokens, which are required by MXNet BERT.\n", - "- getAttentionMask: It returns the mask for the model to indicate which part should be paid attention to and which part is the padding. It is required by PyTorch BERT.\n", - "\n", - " [Question tokens...DocResourceTokens...padding tokens] => [111111...11111....0000]\n", - " \n", - "PyTorch BERT was trained with varioue sequence length, so we don't need to pad the tokens." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BertToken token = tokenizer.encode(question.toLowerCase(), resourceDocument.toLowerCase());\n", - "System.out.println(\"Encoded tokens: \" + token.getTokens());\n", - "System.out.println(\"Encoded token type: \" + token.getTokenTypes());\n", - "System.out.println(\"Valid length: \" + token.getValidLength());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Normally, words and sentences are represented as indices instead of tokens for training. \n", - "They typically work like a vector in a n-dimensional space. In this case, you need to map them into indices.\n", - "DJL provides `Vocabulary` to take care of you vocabulary mapping.\n", - "\n", - "The bert vocab from Huggingface is of the following format.\n", - "```\n", - "[PAD]\n", - "[unused0]\n", - "[unused1]\n", - "[unused2]\n", - "[unused3]\n", - "[unused4]\n", - "[unused5]\n", - "[unused6]\n", - "[unused7]\n", - "[unused8]\n", - "...\n", - "```\n", - "We provide the `bert-base-uncased-vocab.txt` from our pre-trained BERT for demonstration." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/pytorch/bertqa/0.0.1/bert-base-uncased-vocab.txt.gz\", \"build/pytorch/bertqa/vocab.txt\", new ProgressBar());" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var path = Paths.get(\"build/pytorch/bertqa/vocab.txt\");\n", - "var vocabulary = DefaultVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromTextFile(path)\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can easily convert the token to the index using `vocabulary.getIndex(token)` and the other way around using `vocabulary.getToken(index)`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "long index = vocabulary.getIndex(\"car\");\n", - "String token = vocabulary.getToken(2482);\n", - "System.out.println(\"The index of the car is \" + index);\n", - "System.out.println(\"The token of the index 2482 is \" + token);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To properly convert them into `float[]` for `NDArray` creation, here is the helper function:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now that you have everything you need, you can create an NDList and populate all of the inputs you formatted earlier. You're done with pre-processing! \n", - "\n", - "#### Construct `Translator`\n", - "\n", - "You need to do this processing within an implementation of the `Translator` interface. `Translator` is designed to do pre-processing and post-processing. You must define the input and output objects. It contains the following two override classes:\n", - "- `public NDList processInput(TranslatorContext ctx, I)`\n", - "- `public String processOutput(TranslatorContext ctx, O)`\n", - "\n", - "Every translator takes in input and returns output in the form of generic objects. In this case, the translator takes input in the form of `QAInput` (I) and returns output as a `String` (O). `QAInput` is just an object that holds questions and answer; We have prepared the Input class for you." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Armed with the needed knowledge, you can write an implementation of the `Translator` interface. `BertTranslator` uses the code snippets explained previously to implement the `processInput`method. For more information, see [`NDManager`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/ndarray/NDManager.html).\n", - "\n", - "```\n", - "manager.create(Number[] data, Shape)\n", - "manager.create(Number[] data)\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "public class BertTranslator implements Translator {\n", - " private List tokens;\n", - " private Vocabulary vocabulary;\n", - " private BertTokenizer tokenizer;\n", - " \n", - " @Override\n", - " public void prepare(TranslatorContext ctx) throws IOException {\n", - " Path path = Paths.get(\"build/pytorch/bertqa/vocab.txt\");\n", - " vocabulary = DefaultVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromTextFile(path)\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();\n", - " tokenizer = new BertTokenizer();\n", - " }\n", - " \n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, QAInput input) {\n", - " BertToken token =\n", - " tokenizer.encode(\n", - " input.getQuestion().toLowerCase(),\n", - " input.getParagraph().toLowerCase());\n", - " // get the encoded tokens that would be used in precessOutput\n", - " tokens = token.getTokens();\n", - " NDManager manager = ctx.getNDManager();\n", - " // map the tokens(String) to indices(long)\n", - " long[] indices = tokens.stream().mapToLong(vocabulary::getIndex).toArray();\n", - " long[] attentionMask = token.getAttentionMask().stream().mapToLong(i -> i).toArray();\n", - " long[] tokenType = token.getTokenTypes().stream().mapToLong(i -> i).toArray();\n", - " NDArray indicesArray = manager.create(indices);\n", - " NDArray attentionMaskArray =\n", - " manager.create(attentionMask);\n", - " NDArray tokenTypeArray = manager.create(tokenType);\n", - " // The order matters\n", - " return new NDList(indicesArray, attentionMaskArray, tokenTypeArray);\n", - " }\n", - " \n", - " @Override\n", - " public String processOutput(TranslatorContext ctx, NDList list) {\n", - " NDArray startLogits = list.get(0);\n", - " NDArray endLogits = list.get(1);\n", - " int startIdx = (int) startLogits.argMax().getLong();\n", - " int endIdx = (int) endLogits.argMax().getLong();\n", - " return tokens.subList(startIdx, endIdx + 1).toString();\n", - " }\n", - " \n", - " @Override\n", - " public Batchifier getBatchifier() {\n", - " return Batchifier.STACK;\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Congrats! You have created your first Translator! We have pre-filled the `processOutput()` function to process the `NDList` and return it in a desired format. `processInput()` and `processOutput()` offer the flexibility to get the predictions from the model in any format you desire. \n", - "\n", - "With the Translator implemented, you need to bring up the predictor that uses your `Translator` to start making predictions. You can find the usage for `Predictor` in the [Predictor Javadoc](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html). Create a translator and use the `question` and `resourceDocument` provided previously." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DownloadUtils.download(\"https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/pytorch/bertqa/0.0.1/trace_bertqa.pt.gz\", \"build/pytorch/bertqa/bertqa.pt\", new ProgressBar());" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "BertTranslator translator = new BertTranslator();\n", - "\n", - "Criteria criteria = Criteria.builder()\n", - " .setTypes(QAInput.class, String.class)\n", - " .optModelPath(Paths.get(\"build/pytorch/bertqa/\")) // search in local folder\n", - " .optTranslator(translator)\n", - " .optProgress(new ProgressBar()).build();\n", - "\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String predictResult = null;\n", - "QAInput input = new QAInput(question, resourceDocument);\n", - "\n", - "// Create a Predictor and use it to predict the output\n", - "try (Predictor predictor = model.newPredictor(translator)) {\n", - " predictResult = predictor.predict(input);\n", - "}\n", - "\n", - "System.out.println(question);\n", - "System.out.println(predictResult);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Based on the input, the following result will be shown:\n", - "```\n", - "[december, 2004]\n", - "```\n", - "That's it! \n", - "\n", - "You can try with more questions and answers. Here are the samples:\n", - "\n", - "**Answer Material**\n", - "\n", - "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.\n", - "\n", - "\n", - "**Question**\n", - "\n", - "Q: When were the Normans in Normandy?\n", - "A: 10th and 11th centuries\n", - "\n", - "Q: In what country is Normandy located?\n", - "A: france\n", - "\n", - "For the full source code, see the [DJL repo](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java) and translator implementation [MXNet](https://github.com/deepjavalibrary/djl/blob/master/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/qa/MxBertQATranslator.java) [PyTorch](https://github.com/deepjavalibrary/djl/blob/master/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/qa/PtBertQATranslator.java)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb b/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb deleted file mode 100644 index 2edbc6c195f..00000000000 --- a/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb +++ /dev/null @@ -1,473 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Rank Classification using BERT on Amazon Review dataset\n", - "\n", - "## Introduction\n", - "\n", - "In this tutorial, you learn how to train a rank classification model using [Transfer Learning](https://en.wikipedia.org/wiki/Transfer_learning). We will use a pretrained DistilBert model to train on the Amazon review dataset.\n", - "\n", - "## About the dataset and model\n", - "\n", - "[Amazon Customer Review dataset](https://s3.amazonaws.com/amazon-reviews-pds/readme.html) consists of all different valid reviews from amazon.com. We will use the \"Digital_software\" category that consists of 102k valid reviews. As for the pre-trained model, use the DistilBERT[[1]](https://arxiv.org/abs/1910.01108) model. It's a light-weight BERT model already trained on [Wikipedia text corpora](https://en.wikipedia.org/wiki/List_of_text_corpora), a much larger dataset consisting of over millions text. The DistilBERT served as a base layer and we will add some more classification layers to output as rankings (1 - 5).\n", - "\n", - "\n", - "
    Amazon Review example
    \n", - "\n", - "We will use review body as our data input and ranking as label.\n", - "\n", - "\n", - "## Pre-requisites\n", - "This tutorial assumes you have the following knowledge. Follow the READMEs and tutorials if you are not familiar with:\n", - "1. How to setup and run [Java Kernel in Jupyter Notebook](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)\n", - "2. Basic components of Deep Java Library, and how to [train your first model](https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/02_train_your_first_model.ipynb).\n", - "\n", - "\n", - "## Getting started\n", - "Load the Deep Java Libarary and its dependencies from Maven. In here, you can choose between MXNet or PyTorch. MXNet is enabled by default. You can uncomment PyTorch dependencies and comment MXNet ones to switch to PyTorch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:basicdataset:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "\n", - "// PyTorch\n", - "// %maven ai.djl.pytorch:pytorch-model-zoo:0.23.0" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's import the necessary modules:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.basicdataset.tabular.*;\n", - "import ai.djl.basicdataset.tabular.utils.*;\n", - "import ai.djl.basicdataset.utils.*;\n", - "import ai.djl.engine.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.metric.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.nlp.*;\n", - "import ai.djl.modality.nlp.bert.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.nn.*;\n", - "import ai.djl.nn.core.*;\n", - "import ai.djl.nn.norm.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.training.*;\n", - "import ai.djl.training.dataset.*;\n", - "import ai.djl.training.evaluator.*;\n", - "import ai.djl.training.listener.*;\n", - "import ai.djl.training.loss.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.translate.*;\n", - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;\n", - "import org.apache.commons.csv.*;\n", - "\n", - "System.out.println(\"You are using: \" + Engine.getInstance().getEngineName() + \" Engine\");" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Prepare Dataset\n", - "\n", - "First step is to prepare the dataset for training. Since the original data was in TSV format, we can use CSVDataset to be the dataset container. We will also need to specify how do we want to preprocess the raw data. For BERT model, the input data are required to be tokenized and mapped into indices based on the inputs. In DJL, we defined an interface called Fearurizer, it is designed to allow user customize operation on each selected row/column of a dataset. In our case, we would like to clean and tokenize our sentencies. So let's try to implement it to deal with customer review sentencies." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "final class BertFeaturizer implements Featurizer {\n", - "\n", - " private final BertFullTokenizer tokenizer;\n", - " private final int maxLength; // the cut-off length\n", - "\n", - " public BertFeaturizer(BertFullTokenizer tokenizer, int maxLength) {\n", - " this.tokenizer = tokenizer;\n", - " this.maxLength = maxLength;\n", - " }\n", - "\n", - " /** {@inheritDoc} */\n", - " @Override\n", - " public void featurize(DynamicBuffer buf, String input) {\n", - " Vocabulary vocab = tokenizer.getVocabulary();\n", - " // convert sentence to tokens (toLowerCase for uncased model)\n", - " List tokens = tokenizer.tokenize(input.toLowerCase());\n", - " // trim the tokens to maxLength\n", - " tokens = tokens.size() > maxLength ? tokens.subList(0, maxLength) : tokens;\n", - " // BERT embedding convention \"[CLS] Your Sentence [SEP]\"\n", - " buf.put(vocab.getIndex(\"[CLS]\"));\n", - " tokens.forEach(token -> buf.put(vocab.getIndex(token)));\n", - " buf.put(vocab.getIndex(\"[SEP]\"));\n", - " }\n", - "\n", - " /** {@inheritDoc} */\n", - " @Override\n", - " public int dataRequired() {\n", - " throw new IllegalStateException(\"BertFeaturizer only support featurize, not deFeaturize\");\n", - " }\n", - "\n", - " /** {@inheritDoc} */\n", - " @Override\n", - " public Object deFeaturize(float[] data) {\n", - " throw new IllegalStateException(\"BertFeaturizer only support featurize, not deFeaturize\");\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Once we got this part done, we can apply the `BertFeaturizer` into our Dataset. We take `review_body` column and apply the Featurizer. We also pick `star_rating` as our label set. Since we go for batch input, we need to tell the dataset to pad our data if it is less than the `maxLength` we defined. `PaddingStackBatchifier` will do the work for you." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {\n", - " String amazonReview =\n", - " \"https://mlrepo.djl.ai/dataset/nlp/ai/djl/basicdataset/amazon_reviews/1.0/amazon_reviews_us_Digital_Software_v1_00.tsv.gz\";\n", - " float paddingToken = tokenizer.getVocabulary().getIndex(\"[PAD]\");\n", - " return CsvDataset.builder()\n", - " .optCsvUrl(amazonReview) // load from Url\n", - " .setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format\n", - " .setSampling(batchSize, true) // make sample size and random access\n", - " .optLimit(limit)\n", - " .addFeature(\n", - " new Feature(\n", - " \"review_body\", new BertFeaturizer(tokenizer, maxLength)))\n", - " .addLabel(\n", - " new Feature(\n", - " \"star_rating\", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))\n", - " .optDataBatchifier(\n", - " PaddingStackBatchifier.builder()\n", - " .optIncludeValidLengths(false)\n", - " .addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))\n", - " .build()) // define how to pad dataset to a fix length\n", - " .build();\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Construct your model\n", - "\n", - "We will load our pretrained model and prepare the classification. First construct the `criteria` to specify where to load the embedding (DistiledBERT), then call `loadModel` to download that embedding with pre-trained weights. Since this model is built without classification layer, we need to add a classification layer to the end of the model and train it. After you are done modifying the block, set it back to model using `setBlock`.\n", - "\n", - "### Load the word embedding\n", - "\n", - "We will download our word embedding and load it to memory (this may take a while)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// MXNet base model\n", - "String modelUrls = \"https://resources.djl.ai/test-models/distilbert.zip\";\n", - "if (\"PyTorch\".equals(Engine.getInstance().getEngineName())) {\n", - " modelUrls = \"https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip\";\n", - "}\n", - "\n", - "Criteria criteria = Criteria.builder()\n", - " .optApplication(Application.NLP.WORD_EMBEDDING)\n", - " .setTypes(NDList.class, NDList.class)\n", - " .optModelUrls(modelUrls)\n", - " .optProgress(new ProgressBar())\n", - " .build();\n", - "ZooModel embedding = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Create classification layers\n", - "\n", - "Then let's build a simple MLP layer to classify the ranks. We set the output of last FullyConnected (Linear) layer to 5 to get the predictions for star 1 to 5. Then all we need to do is to load the block into the model. Before applying the classification layer, we also need to add text embedding to the front. In our case, we just create a Lambda function that do the followings:\n", - "\n", - "1. batch_data (batch size, token indices) -> batch_data + max_length (size of the token indices)\n", - "2. generate embedding" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor embedder = embedding.newPredictor();\n", - "Block classifier = new SequentialBlock()\n", - " // text embedding layer\n", - " .add(\n", - " ndList -> {\n", - " NDArray data = ndList.singletonOrThrow();\n", - " NDList inputs = new NDList();\n", - " long batchSize = data.getShape().get(0);\n", - " float maxLength = data.getShape().get(1);\n", - "\n", - " if (\"PyTorch\".equals(Engine.getInstance().getEngineName())) {\n", - " inputs.add(data.toType(DataType.INT64, false));\n", - " inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));\n", - " inputs.add(data.getManager().arange(maxLength)\n", - " .toType(DataType.INT64, false)\n", - " .broadcast(data.getShape()));\n", - " } else {\n", - " inputs.add(data);\n", - " inputs.add(data.getManager().full(new Shape(batchSize), maxLength));\n", - " }\n", - " // run embedding\n", - " try {\n", - " return embedder.predict(inputs);\n", - " } catch (TranslateException e) {\n", - " throw new IllegalArgumentException(\"embedding error\", e);\n", - " }\n", - " })\n", - " // classification layer\n", - " .add(Linear.builder().setUnits(768).build()) // pre classifier\n", - " .add(Activation::relu)\n", - " .add(Dropout.builder().optRate(0.2f).build())\n", - " .add(Linear.builder().setUnits(5).build()) // 5 star rating\n", - " .addSingleton(nd -> nd.get(\":,0\")); // Take [CLS] as the head\n", - "Model model = Model.newInstance(\"AmazonReviewRatingClassification\");\n", - "model.setBlock(classifier);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Start Training\n", - "\n", - "Finally, we can start building our training pipeline to train the model.\n", - "\n", - "### Creating Training and Testing dataset\n", - "\n", - "Firstly, we need to create a voabulary that is used to map token to index such as \"hello\" to 1121 (1121 is the index of \"hello\" in dictionary). Then we simply feed the vocabulary to the tokenizer that used to tokenize the sentence. Finally, we just need to split the dataset based on the ratio.\n", - "\n", - "Note: we set the cut-off length to 64 which means only the first 64 tokens from the review will be used. You can increase this value to achieve better accuracy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Prepare the vocabulary\n", - "DefaultVocabulary vocabulary = DefaultVocabulary.builder()\n", - " .addFromTextFile(embedding.getArtifact(\"vocab.txt\"))\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();\n", - "// Prepare dataset\n", - "int maxTokenLength = 64; // cutoff tokens length\n", - "int batchSize = 8;\n", - "int limit = Integer.MAX_VALUE;\n", - "// int limit = 512; // uncomment for quick testing\n", - "\n", - "BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);\n", - "CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);\n", - "// split data with 7:3 train:valid ratio\n", - "RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);\n", - "RandomAccessDataset trainingSet = datasets[0];\n", - "RandomAccessDataset validationSet = datasets[1];" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Setup Trainer and training config\n", - "\n", - "Then, we need to setup our trainer. We set up the accuracy and loss function. The model training logs will be saved to `build/modlel`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "SaveModelTrainingListener listener = new SaveModelTrainingListener(\"build/model\");\n", - " listener.setSaveModelCallback(\n", - " trainer -> {\n", - " TrainingResult result = trainer.getTrainingResult();\n", - " Model model = trainer.getModel();\n", - " // track for accuracy and loss\n", - " float accuracy = result.getValidateEvaluation(\"Accuracy\");\n", - " model.setProperty(\"Accuracy\", String.format(\"%.5f\", accuracy));\n", - " model.setProperty(\"Loss\", String.format(\"%.5f\", result.getValidateLoss()));\n", - " });\n", - "DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type\n", - " .addEvaluator(new Accuracy())\n", - " .optDevices(Engine.getInstance().getDevices(1)) // train using single GPU\n", - " .addTrainingListeners(TrainingListener.Defaults.logging(\"build/model\"))\n", - " .addTrainingListeners(listener);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Start training\n", - "\n", - "We will start our training process. Training on GPU will takes approximately 10 mins. For CPU, it will take more than 2 hours to finish." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int epoch = 2;\n", - "\n", - "Trainer trainer = model.newTrainer(config);\n", - "trainer.setMetrics(new Metrics());\n", - "Shape encoderInputShape = new Shape(batchSize, maxTokenLength);\n", - "// initialize trainer with proper input shape\n", - "trainer.initialize(encoderInputShape);\n", - "EasyTrain.fit(trainer, epoch, trainingSet, validationSet);\n", - "System.out.println(trainer.getTrainingResult());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Save the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model.save(Paths.get(\"build/model\"), \"amazon-review.param\");" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Verify the model\n", - "\n", - "We can create a predictor from the model to run inference on our customized dataset. Firstly, we can create a `Translator` for the model to do preprocessing and post processing. Similar to what we have done before, we need to tokenize the input sentence and get the output ranking." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MyTranslator implements Translator {\n", - "\n", - " private BertFullTokenizer tokenizer;\n", - " private Vocabulary vocab;\n", - " private List ranks;\n", - "\n", - " public MyTranslator(BertFullTokenizer tokenizer) {\n", - " this.tokenizer = tokenizer;\n", - " vocab = tokenizer.getVocabulary();\n", - " ranks = Arrays.asList(\"1\", \"2\", \"3\", \"4\", \"5\");\n", - " }\n", - "\n", - " @Override\n", - " public Batchifier getBatchifier() { return Batchifier.STACK; }\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, String input) {\n", - " List tokens = tokenizer.tokenize(input);\n", - " float[] indices = new float[tokens.size() + 2];\n", - " indices[0] = vocab.getIndex(\"[CLS]\");\n", - " for (int i = 0; i < tokens.size(); i++) {\n", - " indices[i+1] = vocab.getIndex(tokens.get(i));\n", - " }\n", - " indices[indices.length - 1] = vocab.getIndex(\"[SEP]\");\n", - " return new NDList(ctx.getNDManager().create(indices));\n", - " }\n", - "\n", - " @Override\n", - " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", - " return new Classifications(ranks, list.singletonOrThrow().softmax(0));\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, we can create a `Predictor` to run the inference. Let's try with a random customer review:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String review = \"It works great, but it takes too long to update itself and slows the system\";\n", - "Predictor predictor = model.newPredictor(new MyTranslator(tokenizer));\n", - "\n", - "predictor.predict(review)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/tensorflow/pneumonia_detection.ipynb b/jupyter/tensorflow/pneumonia_detection.ipynb deleted file mode 100644 index c790ad13f55..00000000000 --- a/jupyter/tensorflow/pneumonia_detection.ipynb +++ /dev/null @@ -1,243 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Detecting Pneumonia from X-ray images using Deep Java Library" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "*Disclaimer: this blog post is intended for educational purposes only. The application was developed using experimental code. The result should not be used for any medical diagnoses of pneumonia. This content has not been reviewed or approved by any scientists or medical professionals.*\n", - "\n", - "## Introduction\n", - "In this example, we demonstrate how deep learning (DL) can be used to detect pneumonia from chest X-ray images. This work is inspired by the [Chest X-ray Images Challenge](https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia) on Kaggle and a related [paper](https://www.cell.com/cell/fulltext/S0092-8674\\(18\\)30154-5). In this notebook, we illustrates how artificial intelligence can assist clinical decision making with focus on enterprise deployment. This work leverages a model trained using Keras and TensorFlow with [this Kaggle kernel](https://www.kaggle.com/aakashnain/beating-everything-with-depthwise-convolution). In this blog post, we will focus on generating predictions with this model using [Deep Java Library](https://djl.ai/) (DJL), an open source library to build and deploy DL in Java." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [documentation](https://docs.djl.ai/jupyter/index.html).\n", - "\n", - "These are the dependencies we will use:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.tensorflow:tensorflow-api:0.23.0\n", - "%maven ai.djl.tensorflow:tensorflow-engine:0.23.0\n", - "%maven ai.djl.tensorflow:tensorflow-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%loadFromPOM\n", - "\n", - " com.google.protobuf\n", - " protobuf-java\n", - " 3.19.2\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import java packages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.util.*;\n", - "import java.net.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### set the model URL" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var modelUrl = \"https://resources.djl.ai/demo/pneumonia-detection-model/saved_model.zip\";" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Dive deep into Translator\n", - "\n", - "To successfully run inference, we need to define some preprocessing and post processing logic to achieve the best \n", - "prediction result and understandable output." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MyTranslator implements Translator {\n", - "\n", - " private static final List CLASSES = Arrays.asList(\"Normal\", \"Pneumonia\");\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, Image input) {\n", - " NDManager manager = ctx.getNDManager();\n", - " NDArray array = input.toNDArray(manager, Image.Flag.COLOR);\n", - " array = NDImageUtils.resize(array, 224).div(255.0f);\n", - " return new NDList(array);\n", - " }\n", - "\n", - " @Override\n", - " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", - " NDArray probabilities = list.singletonOrThrow();\n", - " return new Classifications(CLASSES, probabilities);\n", - " }\n", - "\n", - " @Override\n", - " public Batchifier getBatchifier() {\n", - " return Batchifier.STACK;\n", - " }\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As you can see above, the translator resizes the image to 224x224 and normalizes the image by dividing by 255 before feeding it into the model. When doing inference, you need to follow the same pre-processing procedure as was used during training. In this case, we need to match the Keras training code. After running prediction, the model outputs probabilities of each class as an [NDArray](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/ndarray/NDArray.html). We need to tell the predictor to translate it back to classes, namely “Normal” or \"Pneumonia\".\n", - "\n", - "Until this point, all preparation work is done, we can start working on the prediction logic." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Predict using DJL\n", - "\n", - "### Load the image\n", - "We are going to load an CT scanned image of an infected lung from internet " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var imagePath = \"https://resources.djl.ai/images/chest_xray.jpg\";\n", - "var image = ImageFactory.getInstance().fromUrl(imagePath);\n", - "image.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load your model\n", - "Next, we will download the model from `modelUrl`. This will download the model into the DJL cache location" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria =\n", - " Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optModelUrls(modelUrl)\n", - " .optTranslator(new MyTranslator())\n", - " .optProgress(new ProgressBar())\n", - " .build();\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Run inference\n", - "Lastly, we will need to create a predictor using our model and translator. Once we have a predictor, we simply need to call the predict method on our test image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor();\n", - "Classifications classifications = predictor.predict(image);\n", - "\n", - "classifications" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/tensorflow/rank_classification_using_BERT_on_Amazon_Review.ipynb b/jupyter/tensorflow/rank_classification_using_BERT_on_Amazon_Review.ipynb deleted file mode 100644 index 1b4647919c1..00000000000 --- a/jupyter/tensorflow/rank_classification_using_BERT_on_Amazon_Review.ipynb +++ /dev/null @@ -1,267 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Rank Classification using BERT on Amazon Review\n", - "\n", - "## Introduction\n", - "\n", - "In this tutorial, you learn how to use a pre-trained Tensorflow model to classifiy a Amazon Review rank. The model was refined on Amazon Review dataset with a pretrained DistilBert model.\n", - "\n", - "### About the dataset and model\n", - "\n", - "[Amazon Customer Review dataset](https://s3.amazonaws.com/amazon-reviews-pds/readme.html) consists of all different valid reviews from amazon.com. We will use the \"Digital_software\" category that consists of 102k valid reviews. As for the pre-trained model, use the DistilBERT[[1]](https://arxiv.org/abs/1910.01108) model. It's a light-weight BERT model already trained on [Wikipedia text corpora](https://en.wikipedia.org/wiki/List_of_text_corpora), a much larger dataset consisting of over millions text. The DistilBERT served as a base layer and we will add some more classification layers to output as rankings (1 - 5).\n", - "\n", - "\n", - "
    Amazon Review example
    \n", - "\n", - "\n", - "## Pre-requisites\n", - "This tutorial assumes you have the following knowledge. Follow the READMEs and tutorials if you are not familiar with:\n", - "1. How to setup and run [Java Kernel in Jupyter Notebook](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)\n", - "2. Basic components of Deep Java Library, and how to [train your first model](https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/02_train_your_first_model.ipynb).\n", - "\n", - "\n", - "## Getting started\n", - "Load the Deep Java Libarary and its dependencies from Maven. In here, you can choose between MXNet or PyTorch. MXNet is enabled by default. You can uncomment PyTorch dependencies and comment MXNet ones to switch to PyTorch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl.tensorflow:tensorflow-engine:0.23.0\n", - "%maven ai.djl.tensorflow:tensorflow-api:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%loadFromPOM\n", - "\n", - " com.google.protobuf\n", - " protobuf-java\n", - " 3.19.2\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's import the necessary modules:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.engine.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.nlp.*;\n", - "import ai.djl.modality.nlp.bert.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.util.*;\n", - "\n", - "import java.io.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Prepare your model files\n", - "\n", - "You can download pre-trained Tensorflow model from: https://resources.djl.ai/demo/tensorflow/amazon_review_rank_classification.zip." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String modelUrl = \"https://resources.djl.ai/demo/tensorflow/amazon_review_rank_classification.zip\";\n", - "DownloadUtils.download(modelUrl, \"build/amazon_review_rank_classification.zip\", new ProgressBar());\n", - "Path zipFile = Paths.get(\"build/amazon_review_rank_classification.zip\");\n", - "\n", - "Path modelDir = Paths.get(\"build/saved_model\");\n", - "if (Files.notExists(modelDir)) {\n", - " ZipUtils.unzip(Files.newInputStream(zipFile), modelDir); \n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create Translator\n", - "\n", - "Inference in deep learning is the process of predicting the output for a given input based on a pre-defined model.\n", - "DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide output.\n", - "\n", - "The [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) interface is used to: Pre-processing and Post-processing. The pre-processing\n", - "component converts the user-defined input objects into an NDList, so that the [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html) in DJL can understand the\n", - "input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the\n", - "`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output\n", - "format.\n", - "\n", - "### Pre-processing\n", - "\n", - "Now, you need to convert the sentences into tokens. We provide a powerful tool [`BertTokenizer`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/bert/BertTokenizer.html) that you can use to convert questions and answers into tokens, and batchify your sequence together. Once you have properly formatted tokens, you can use [`Vocabulary`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/nlp/Vocabulary.html) to map your token to BERT index.\n", - "\n", - "The following code block demonstrates tokenizing the question and answer defined earlier into BERT-formatted tokens.\n", - "\n", - "In the zip file, we also bundled the BERT `vocab.txt` file." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Prepare the vocabulary\n", - "Path vocabFile = modelDir.resolve(\"vocab.txt\");\n", - "DefaultVocabulary vocabulary = DefaultVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromTextFile(vocabFile)\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();\n", - "BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);\n", - "int maxTokenLength = 64; // cutoff tokens length\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MyTranslator implements Translator {\n", - "\n", - " private BertFullTokenizer tokenizer;\n", - " private Vocabulary vocab;\n", - " private List ranks;\n", - " private int length;\n", - "\n", - " public MyTranslator(BertFullTokenizer tokenizer, int length) {\n", - " this.tokenizer = tokenizer;\n", - " this.length = length;\n", - " vocab = tokenizer.getVocabulary();\n", - " ranks = Arrays.asList(\"1\", \"2\", \"3\", \"4\", \"5\");\n", - " }\n", - "\n", - " @Override\n", - " public Batchifier getBatchifier() {\n", - " return Batchifier.STACK;\n", - " }\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, String input) {\n", - " List tokens = tokenizer.tokenize(input);\n", - " long[] indices = new long[length];\n", - " long[] mask = new long[length];\n", - " long[] segmentIds = new long[length];\n", - " int size = Math.min(length, tokens.size());\n", - " for (int i = 0; i < size; i++) {\n", - " indices[i + 1] = vocab.getIndex(tokens.get(i));\n", - " }\n", - " Arrays.fill(mask, 0, size, 1);\n", - " NDManager m = ctx.getNDManager();\n", - " return new NDList(m.create(indices), m.create(mask), m.create(segmentIds));\n", - " }\n", - "\n", - " @Override\n", - " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", - " return new Classifications(ranks, list.singletonOrThrow().softmax(0));\n", - " }\n", - "}\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load your model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "MyTranslator translator = new MyTranslator(tokenizer, maxTokenLength);\n", - "\n", - "Criteria criteria = Criteria.builder()\n", - " .setTypes(String.class, Classifications.class)\n", - " .optModelPath(modelDir) // Load model form model directory\n", - " .optTranslator(translator) // use custom translaotr \n", - " .build();\n", - "\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Run inference\n", - "\n", - "Lastly, we will need to create a predictor using our model and translator. Once we have a predictor, we simply need to call the predict method on our test image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "String review = \"It works great, but it takes too long to update itself and slows the system\";\n", - "\n", - "Predictor predictor = model.newPredictor();\n", - "Classifications classifications = predictor.predict(review);\n", - "\n", - "classifications" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/jupyter/tensorflow_lite/inference_with_tensorflow_lite.ipynb b/jupyter/tensorflow_lite/inference_with_tensorflow_lite.ipynb deleted file mode 100644 index 3fb55f9799a..00000000000 --- a/jupyter/tensorflow_lite/inference_with_tensorflow_lite.ipynb +++ /dev/null @@ -1,156 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Inference with Tensorflow Lite\n", - "\n", - "In this tutorial, you learn how to load an existing TensorFlow Lite model and use it to run a prediction task.\n", - "\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:model-zoo:0.23.0\n", - "%maven ai.djl.tflite:tflite-engine:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32\n", - "\n", - "// Use secondary engine to help pre-processing and post-processing\n", - "%maven ai.djl.pytorch:pytorch-engine:0.23.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.awt.image.*;\n", - "import java.nio.file.*;\n", - "import ai.djl.*;\n", - "import ai.djl.inference.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.util.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.modality.cv.translator.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.translate.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.util.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1: Load your Tensorflow Lite mode from DJL model zoo" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Criteria criteria = Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optEngine(\"TFLite\")\n", - " .optFilter(\"dataset\", \"aiyDish\")\n", - " .build();\n", - "ZooModel model = criteria.loadModel();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2: Create a Predictor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Predictor predictor = model.newPredictor();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3: Load image for classification" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/sachertorte.jpg\");\n", - "\n", - "img.getWrappedImage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 4: Run inference" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Classifications classifications = predictor.predict(img);\n", - "\n", - "classifications" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Now, you can load Tensorflow Lite model and run inference.\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/test_notebook.sh b/jupyter/test_notebook.sh deleted file mode 100755 index a4cd2166e9e..00000000000 --- a/jupyter/test_notebook.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env bash - -# test_notebook.sh [filename] -# If no filename is passed, it runs all files in current directory and subdirectories - -set -e - -function run_test { - base=$(basename $1) - # Workaround on crashes - if [[ "$base" == transfer_learning_on_cifar10* || "$base" == rank_classification_using_BERT* ]]; then - jupyter nbconvert --to notebook --inplace $1 - else - jupyter nbconvert --to notebook --execute --ExecutePreprocessor.timeout=600 --inplace $1 - fi -} - -if [[ $# -eq 0 ]]; then - for f in {**,.}/*.ipynb - do - dir=$(dirname f) - run_test "$f" - done -else - run_test $1 -fi diff --git a/jupyter/transfer_learning_on_cifar10.ipynb b/jupyter/transfer_learning_on_cifar10.ipynb deleted file mode 100644 index 663a9eafc7f..00000000000 --- a/jupyter/transfer_learning_on_cifar10.ipynb +++ /dev/null @@ -1,285 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Transfer Learning on CIFAR-10 Dataset\n", - "\n", - "\n", - "## Introduction\n", - "\n", - "In this tutorial, you learn how to train an image classification model using [Transfer Learning](https://en.wikipedia.org/wiki/Transfer_learning). Transfer learning is a popular machine learning technique that uses a model trained on one problem and applies it to a second related problem. Compared to training from scratch or designing a model for your specific problem, transfer learning can leverage the features already learned on a similar problem and produce a more robust model in a much shorter time.\n", - "\n", - "Train your model with the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset which consists of 60,000 32x32 color images in 10 classes. As for the pre-trained model, use the ResNet50v1[1] model. It's a 50 layer deep model already trained on [ImageNet](http://www.image-net.org/), a much larger dataset consisting of over 1.2 million images in 1000 classes. Modify it to classify 10 classes from the CIFAR-10 dataset.\n", - "\n", - "![The CIFAR-10 Dataset](https://resources.djl.ai/images/cifar-10.png)\n", - "
    the CIFAR10 dataset
    \n", - "\n", - "\n", - "## Pre-requisites\n", - "This tutorial assumes you have the following knowledge. Follow the READMEs and tutorials if you are not familiar with:\n", - "1. How to setup and run [Java Kernel in Jupyter Notebook](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)\n", - "2. Basic components of Deep Java Library, and how to [train your first model](https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/02_train_your_first_model.ipynb).\n", - "\n", - "\n", - "## Getting started\n", - "Load the Deep Java Libarary and its dependencies from Maven:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:basicdataset:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's import the necessary modules." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.basicdataset.cv.classification.*;\n", - "import ai.djl.engine.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.transform.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.nn.*;\n", - "import ai.djl.nn.core.*;\n", - "import ai.djl.repository.zoo.*;\n", - "import ai.djl.training.*;\n", - "import ai.djl.training.dataset.*;\n", - "import ai.djl.training.initializer.*;\n", - "import ai.djl.training.listener.*;\n", - "import ai.djl.training.loss.*;\n", - "import ai.djl.training.evaluator.*;\n", - "import ai.djl.training.optimizer.*;\n", - "import ai.djl.training.tracker.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.translate.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;\n", - "import java.util.concurrent.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Construct your model\n", - "\n", - "Load the pre-trained ResNet50V1 model. You can find it in the [Model Zoo](https://github.com/deepjavalibrary/djl/blob/master/docs/model-zoo.md). First construct the `criteria` to specify which ResNet model to load, then call `loadModel` to get a ResNet50V1 model with pre-trained weights. Note this model was trained on ImageNet with 1000 classes; the last layer is a Linear layer with 1000 output channels. Because you are repurposing it on CIFAR10 with 10 classes, you need to remove the last layer and add a new Linear layer with 10 output channels. After you are done modifying the block, set it back to model using `setBlock`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// load model and change last layer\n", - "Criteria criteria = Criteria.builder()\n", - " .setTypes(Image.class, Classifications.class)\n", - " .optProgress(new ProgressBar())\n", - " .optArtifactId(\"resnet\")\n", - " .optFilter(\"layers\", \"50\")\n", - " .optFilter(\"flavor\", \"v1\").build();\n", - "Model model = criteria.loadModel();\n", - "SequentialBlock newBlock = new SequentialBlock();\n", - "SymbolBlock block = (SymbolBlock) model.getBlock();\n", - "block.removeLastBlock();\n", - "newBlock.add(block);\n", - "newBlock.add(Blocks.batchFlattenBlock());\n", - "newBlock.add(Linear.builder().setUnits(10).build());\n", - "model.setBlock(newBlock);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Prepare Dataset\n", - "\n", - "After you have the model, the next step is to prepare the dataset for training. You can construct a CIFAR10 builder with your own specifications. You have the options to get the train or test dataset, specify desired batch size, specify whether to shuffle your data during training, and most importantly, specify the pre-process pipeline. \n", - "\n", - "A pipeline consists of a series of transformations to apply on the input data before feeding it to the model. \n", - "\n", - "For example, `ToTensor` can be used to transform colored image NDArrays with shape (32, 32, 3) and values from 0 to 256 to NDArrays with shape (3, 32, 32) and values from 0 to 1. This operation is transposing image data from channels last to channels first format, which is more suitable for GPU computation. \n", - "\n", - "The `Normalize` transformation can normalize input data according to their mean and standard deviation values. This will make different features have similar range and help our model perform better." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int batchSize = 32;\n", - "int limit = Integer.MAX_VALUE; // change this to a small value for a dry run\n", - "// int limit = 160; // limit 160 records in the dataset for a dry run\n", - "Pipeline pipeline = new Pipeline(\n", - " new ToTensor(),\n", - " new Normalize(new float[] {0.4914f, 0.4822f, 0.4465f}, new float[] {0.2023f, 0.1994f, 0.2010f}));\n", - "Cifar10 trainDataset = \n", - " Cifar10.builder()\n", - " .setSampling(batchSize, true)\n", - " .optUsage(Dataset.Usage.TRAIN)\n", - " .optLimit(limit)\n", - " .optPipeline(pipeline)\n", - " .build();\n", - "trainDataset.prepare(new ProgressBar());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Set up training configuration\n", - "\n", - "You are leveraging a pre-trained model, so you can expect the model to converge quickly. You will only train only ten epochs. As the model converges, you need to reduce the learning rate to get better results. You can use a `Tracker` to reduce the learning rate by 0.1 after two, five, and eight epochs. \n", - "\n", - "Deep Java Library supports training on multiple GPUs. You can use `setDevices` and pass an array of devices you want the model to be trained on. For example, `new Device[]{Device.gpu(0), Device.gpu(1)}` for training on GPU0 and GPU1. You can also call `Engine.getInstancec().getDevices(4)` and pass the number of GPUs you want to train. It will start with GPU0, and use CPU if no GPU is available. To learn more about multi-GPU training, read our multi-GPU [documentation](https://github.com/deepjavalibrary/djl/tree/master/examples/docs).\n", - "\n", - "To complete the training configuration set up, use the `Adam` optimizer, `SoftmaxCrossEntropyLoss`, and `Accuracy` for classification problems." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())\n", - " //softmaxCrossEntropyLoss is a standard loss for classification problems\n", - " .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is\n", - " .optDevices(Engine.getInstance().getDevices(1)) // Limit your GPU, using more GPU actually will slow down coverging\n", - " .addTrainingListeners(TrainingListener.Defaults.logging());\n", - "\n", - "// Now that we have our training configuration, we should create a new trainer for our model\n", - "Trainer trainer = model.newTrainer(config);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train your model\n", - "Now you can start training. This procedure is similar to the one in [Train Your First Model](https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/02_train_your_first_model.ipynb). Training requires the following steps:\n", - "1. Initialize a new trainer using the training config you just set up\n", - "2. Initialize the weights in trainer\n", - "3. Using a `for` loop to iterate through the whole dataset 10 times (epochs), resetting the evaluators at the end of each epoch\n", - "4. During each epoch, using a `for` loop to iterate through the dataset in batches and train batch by batch while printing the training accuracy on the progress bar." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int epoch = 10;\n", - "Shape inputShape = new Shape(1, 3, 32, 32);\n", - "trainer.initialize(inputShape);" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for (int i = 0; i < epoch; ++i) {\n", - " int index = 0;\n", - " for (Batch batch : trainer.iterateDataset(trainDataset)) {\n", - " EasyTrain.trainBatch(trainer, batch);\n", - " trainer.step();\n", - " batch.close();\n", - " }\n", - "\n", - " // reset training and validation evaluators at end of epoch\n", - " trainer.notifyListeners(listener -> listener.onEpoch(trainer));\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Save your model\n", - "\n", - "Finally, you can save your model after training is done and use it for inference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Path modelDir = Paths.get(\"build/resnet\");\n", - "Files.createDirectories(modelDir);\n", - "\n", - "model.setProperty(\"Epoch\", String.valueOf(epoch));\n", - "model.save(modelDir, \"resnet\");" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## What's next\n", - "\n", - "1. Try inference using the model you just trained. You can find an airplane image in [test resources](https://github.com/deepjavalibrary/djl/blob/master/examples/src/test/resources/airplane1.png) and follow the inference tutorials in the [Jupyter module](https://github.com/deepjavalibrary/djl/tree/master/jupyter).\n", - "\n", - "2. Follow the complete example with multi-GPU support, a validation dataset, and the fit API in the [examples module](https://github.com/deepjavalibrary/djl/tree/master/examples/docs).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## References\n", - "[1] [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)\n", - "\n", - "[2] [Gluon CV model zoo](https://gluon-cv.mxnet.io/model_zoo/classification.html) for pre-trained ResNet50 models" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/tutorial/01_create_your_first_network.ipynb b/jupyter/tutorial/01_create_your_first_network.ipynb deleted file mode 100644 index 293fde5fec4..00000000000 --- a/jupyter/tutorial/01_create_your_first_network.ipynb +++ /dev/null @@ -1,206 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Create your first deep learning neural network\n", - "\n", - "## Introduction\n", - "\n", - "This is the first part of our [beginner tutorial series](https://github.com/deepjavalibrary/djl/tree/master/jupyter/tutorial) that will take you through creating, training, and running inference on a neural network. In this part, you will learn how to use the built-in `Block` to create your first neural network - a Multilayer Perceptron.\n", - "\n", - "## Step 1: Setup development environment\n", - "\n", - "### Installation\n", - "\n", - "This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the [Jupyter README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Add the snapshot repository to get the DJL snapshot artifacts\n", - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "// Add the maven dependencies\n", - "%maven ai.djl:api:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ai.djl.*;\n", - "import ai.djl.nn.*;\n", - "import ai.djl.nn.core.*;\n", - "import ai.djl.training.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Neural Network\n", - "\n", - "A neural network is a black box function. Instead of coding this function yourself, you provide many sample input/output pairs for this function. Then, we try to train the network to learn how to best approximate the observed behavior of the function given only these input/output pairs. A better model with more data can more accurately approximate the function.\n", - "\n", - "## Application\n", - "\n", - "The first thing to figure out when trying to build a neural network, like building most functions, is what your function signature is. What are your input types and output types? Because most models use relatively consistent signatures, we refer to them as [Applications](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/Application.html). Within the Applications interface, you can find a list of some of the more common model applications used in deep learning.\n", - "\n", - "In this tutorial, we will focus on the image classification application. It is one of the most common first applications and has a significant history with deep learning. In image classification, the input is a single image and it is classified based on the main subject of the image into a number of different possible classes. The classes for the image depend on the specific data you are training with." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Application application = Application.CV.IMAGE_CLASSIFICATION;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dataset\n", - "\n", - "Once you have figured out what application you want to learn, next you need to collect the data you are training with and form it into a dataset. Often, trying to collect and clean up the data is the most troublesome task in the deep learning process. \n", - "\n", - "Using a dataset can either involve collecting custom data from various sources or using one of the many datasets freely available online. The custom data may better suit your use case, but a free dataset is often faster and easier to use. You can read our [dataset guide](http://docs.djl.ai/docs/dataset.html) to learn more about datasets.\n", - "\n", - "### MNIST\n", - "\n", - "The dataset we will be using is [MNIST](https://en.wikipedia.org/wiki/MNIST_database), a database of handwritten digits. Each image contains a black and white digit from 0 to 9 in a 28x28 image. It is commonly used when getting started with deep learning because it is small and fast to train.\n", - "\n", - "![Mnist Image](https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png)\n", - "\n", - "Once you understand your dataset, you should create an implementation of the [Dataset class](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Dataset.html). In this case, we provide the MNIST dataset built-in to make it easy for you to use it.\n", - "\n", - "## Multilayer Perceptron\n", - "\n", - "Now that we have our dataset, we can choose a model to train with it. For this tutorial, we will build one of the simplest and oldest deep learning networks: a Multilayer Perceptron (MLP).\n", - "\n", - "The MLP is organized into layers. The first layer is the input layer which contains your input data and the last layer is the output layer which produces the final result of the network. Between them are layers referred to as hidden layers. Having more hidden layers and larger hidden layers allows the MLP to represent more complex functions.\n", - "\n", - "The example below contains an input of size 3, a single hidden layer of size 3, and an output of size 2. The number and sizes of the hidden layers are usually determined through experimentation. Between each pair of layers is a linear operation (sometimes called a FullyConnected operation because each number in the input is connected to each number in the output by a matrix multiplication). Not pictured, there is also a non-linear activation function after each linear operation. For more information, see the [Multilayer Perceptron chapter of the D2l DJL book](https://d2l.djl.ai/chapter_multilayer-perceptrons/index.html).\n", - "\n", - "![MLP Image](https://upload.wikimedia.org/wikipedia/commons/c/c2/MultiLayerNeuralNetworkBigger_english.png)\n", - "\n", - "\n", - "## Step 2: Determine your input and output size\n", - "\n", - "The MLP model uses a one dimensional vector as the input and the output. You should determine the appropriate size of this vector based on your input data and what you will use the output of the model for.\n", - "\n", - "Our input vector will have size `28x28` because the MNIST input images have a height and width of 28 and it takes only a single number to represent each pixel. For a color image, you would need to further multiply this by `3` for the RGB channels.\n", - "\n", - "Our output vector has size `10` because there are `10` possible classes (0 to 9) for each image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "long inputSize = 28*28;\n", - "long outputSize = 10;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 3: Create a **SequentialBlock**\n", - "\n", - "### NDArray\n", - "\n", - "The core data type used for working with deep learning is the [NDArray](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/ndarray/NDArray.html). An NDArray represents a multidimensional, fixed-size homogeneous array. It has very similar behavior to the Numpy python package with the addition of efficient computing. We also have a helper class, the [NDList](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/ndarray/NDList.html) which is a list of NDArrays which can have different sizes and data types.\n", - "\n", - "### Block API\n", - "\n", - "In DJL, [Blocks](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/nn/Block.html) serve a purpose similar to functions that convert an input `NDList` to an output `NDList`. They can represent single operations, parts of a neural network, and even the whole neural network. What makes blocks special is that they contain a number of parameters that are used in their function and are trained during deep learning. As these parameters are trained, the function represented by the blocks get more and more accurate.\n", - "\n", - "When building these block functions, the easiest way is to use composition. Similar to how functions are built by calling other functions, blocks can be built by combining other blocks. We refer to the containing block as the parent and the sub-blocks as the children.\n", - "\n", - "\n", - "We provide several helpers to make it easy to build common block composition structures. For the MLP we will use the [SequentialBlock](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/nn/SequentialBlock.html), a container block whose children form a chain of blocks where each child block feeds its output to the next child block in a sequence.\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "SequentialBlock block = new SequentialBlock();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 4: Add blocks to SequentialBlock\n", - "\n", - "An MLP is organized into several layers. Each layer is composed of a [Linear Block](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/nn/core/Linear.html) and a non-linear activation function. If we just had two linear blocks in a row, it would be the same as a combined linear block ($f(x) = W_2(W_1x) = (W_2W_1)x = W_{combined}x$). An activation is used to intersperse between the linear blocks to allow them to represent non-linear functions. We will use the popular [ReLU](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/nn/Activation.html#reluBlock()) as our activation function.\n", - "\n", - "The first layer and last layers have fixed sizes depending on your desired input and output size. However, you are free to choose the number and sizes of the middle layers in the network. We will create a smaller MLP with two middle layers that gradually decrease the size. Typically, you would experiment with different values to see what works the best on your data set." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "block.add(Blocks.batchFlattenBlock(inputSize));\n", - "block.add(Linear.builder().setUnits(128).build());\n", - "block.add(Activation::relu);\n", - "block.add(Linear.builder().setUnits(64).build());\n", - "block.add(Activation::relu);\n", - "block.add(Linear.builder().setUnits(outputSize).build());\n", - "\n", - "block" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Now that you've successfully created your first neural network, you can use this network to train your model.\n", - "\n", - "Next chapter: [Train your first model](02_train_your_first_model.ipynb)\n", - "\n", - "You can find the complete source code for this tutorial in the [model zoo](https://github.com/deepjavalibrary/djl/blob/master/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/Mlp.java)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/tutorial/02_train_your_first_model.ipynb b/jupyter/tutorial/02_train_your_first_model.ipynb deleted file mode 100644 index 4905dadfbb5..00000000000 --- a/jupyter/tutorial/02_train_your_first_model.ipynb +++ /dev/null @@ -1,243 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Train your first model\n", - "\n", - "This is the second of our [beginner tutorial series](https://github.com/deepjavalibrary/djl/tree/master/jupyter/tutorial) that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to train an image classification model that can recognize handwritten digits.\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the [Jupyter README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Add the snapshot repository to get the DJL snapshot artifacts\n", - "%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "// Add the maven dependencies\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:basicdataset:0.23.0\n", - "%maven ai.djl:model-zoo:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.nio.file.*;\n", - "\n", - "import ai.djl.*;\n", - "import ai.djl.basicdataset.cv.classification.Mnist;\n", - "import ai.djl.ndarray.types.*;\n", - "import ai.djl.training.*;\n", - "import ai.djl.training.dataset.*;\n", - "import ai.djl.training.initializer.*;\n", - "import ai.djl.training.loss.*;\n", - "import ai.djl.training.listener.*;\n", - "import ai.djl.training.evaluator.*;\n", - "import ai.djl.training.optimizer.*;\n", - "import ai.djl.training.util.*;\n", - "import ai.djl.basicmodelzoo.cv.classification.*;\n", - "import ai.djl.basicmodelzoo.basic.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 1: Prepare MNIST dataset for training\n", - "\n", - "In order to train, you must create a [Dataset class](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Dataset.html) to contain your training data. A dataset is a collection of sample input/output pairs for the function represented by your neural network. Each single input/output is represented by a [Record](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Record.html). Each record could have multiple arrays of inputs or outputs such as an image question and answer dataset where the input is both an image and a question about the image while the output is the answer to the question.\n", - "\n", - "Because data learning is highly parallelizable, training is often done not with a single record at a time, but a [Batch](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Batch.html). This can lead to significant performance gains, especially when working with images\n", - "\n", - "## Sampler\n", - "\n", - "Then, we must decide the parameters for loading data from the dataset. The only parameter we need for MNIST is the choice of [Sampler](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Sampler.html). The sampler decides which and how many element from datasets are part of each batch when iterating through it. We will have it randomly shuffle the elements for the batch and use a batchSize of 32. The batchSize is usually the largest power of 2 that fits within memory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "int batchSize = 32;\n", - "Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();\n", - "mnist.prepare(new ProgressBar());" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 2: Create your Model\n", - "\n", - "Next we will build a model. A [Model](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/Model.html) contains a neural network [Block](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/nn/Block.html) along with additional artifacts used for the training process. It possesses additional information about the inputs, outputs, shapes, and data types you will use. Generally, you will use the Model once you have fully completed your Block.\n", - "\n", - "In this part of the tutorial, we will use the built-in Multilayer Perceptron Block from the Model Zoo. To learn how to build it from scratch, see the previous tutorial: [Create Your First Network](01_create_your_first_network.ipynb).\n", - "\n", - "Because images in the MNIST dataset are 28x28 grayscale images, we will create an MLP block with 28 x 28 input. The output will be 10 because there are 10 possible classes (0 to 9) each image could be. For the hidden layers, we have chosen `new int[] {128, 64}` by experimenting with different values." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Model model = Model.newInstance(\"mlp\");\n", - "model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 3: Create a Trainer\n", - "\n", - "Now, you can create a [`Trainer`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/Trainer.html) to train your model. The trainer is the main class to orchestrate the training process. Usually, they will be opened using a try-with-resources and closed after training is over.\n", - "\n", - "The trainer takes an existing model and attempts to optimize the parameters inside the model's Block to best match the dataset. Most optimization is based upon [Stochastic Gradient Descent](https://en.wikipedia.org/wiki/Stochastic_gradient_descent) (SGD).\n", - "\n", - "## Step 3.1: Setup your training configurations\n", - "\n", - "Before you create your trainer, we we will need a [training configuration](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/DefaultTrainingConfig.html) that describes how to train your model.\n", - "\n", - "The following are a few common items you may need to configure your training:\n", - "\n", - "* **REQUIRED** [`Loss`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/loss/Loss.html) function: A loss function is used to measure how well our model matches the dataset. Because the lower value of the function is better, it's called the \"loss\" function. The Loss is the only required argument to the model\n", - "* [`Evaluator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/evaluator/Evaluator.html) function: An evaluator function is also used to measure how well our model matches the dataset. Unlike the loss, they are only there for people to look at and are not used for optimizing the model. Since many losses are not as intuitive, adding other evaluators such as Accuracy can help to understand how your model is doing. If you know of any useful evaluators, we recommend adding them.\n", - "* [`Training Listeners`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/listener/TrainingListener.html): The training listener adds additional functionality to the training process through a listener interface. This can include showing training progress, stopping early if training becomes undefined, or recording performance metrics. We offer several easy sets of [default listeners](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/listener/TrainingListener.Defaults.html).\n", - "\n", - "You can also configure other options such as the Device, Initializer, and Optimizer. See [more details](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/TrainingConfig.html)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())\n", - " //softmaxCrossEntropyLoss is a standard loss for classification problems\n", - " .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is\n", - " .addTrainingListeners(TrainingListener.Defaults.logging());\n", - "\n", - "// Now that we have our training configuration, we should create a new trainer for our model\n", - "Trainer trainer = model.newTrainer(config);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 5: Initialize Training\n", - "\n", - "Before training your model, you have to initialize all of the parameters with starting values. You can use the trainer for this initialization by passing in the input shape.\n", - "\n", - "* The first axis of the input shape is the batch size. This won't impact the parameter initialization, so you can use 1 here.\n", - "* The second axis of the input shape of the MLP - the number of pixels in the input image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.initialize(new Shape(1, 28 * 28));" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 6: Train your model\n", - "\n", - "Now, we can train the model.\n", - "\n", - "When training, it is usually organized into epochs where each epoch trains the model on each item in the dataset once. It is slightly faster than training randomly.\n", - "\n", - "Then, we will use the EasyTrain to, as the name promises, make the training easy. If you want to see more details about how the training loop works, see [the EasyTrain class](https://github.com/deepjavalibrary/djl/blob/master/api/src/main/java/ai/djl/training/EasyTrain.java) or [read our Dive into Deep Learning book](https://d2l.djl.ai)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Deep learning is typically trained in epochs where each epoch trains the model on each item in the dataset once.\n", - "int epoch = 2;\n", - "\n", - "EasyTrain.fit(trainer, epoch, mnist, null);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Step 7: Save your model\n", - "\n", - "Once your model is trained, you should save it so that it can be reloaded later. You can also add metadata to it such as training accuracy, number of epochs trained, etc that can be used when loading the model or when examining it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Path modelDir = Paths.get(\"build/mlp\");\n", - "Files.createDirectories(modelDir);\n", - "\n", - "model.setProperty(\"Epoch\", String.valueOf(epoch));\n", - "\n", - "model.save(modelDir, \"mlp\");\n", - "\n", - "model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Summary\n", - "\n", - "Now, you've successfully trained a model that can recognize handwritten digits. You'll learn how to apply this model in the next chapter: [Run image classification with your model](03_image_classification_with_your_model.ipynb).\n", - "\n", - "You can find the complete source code for this tutorial in the [examples project](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/TrainMnist.java)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/tutorial/03_image_classification_with_your_model.ipynb b/jupyter/tutorial/03_image_classification_with_your_model.ipynb deleted file mode 100644 index f8d42d7972e..00000000000 --- a/jupyter/tutorial/03_image_classification_with_your_model.ipynb +++ /dev/null @@ -1,214 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Inference with your model\n", - "\n", - "This is the third and final tutorial of our [beginner tutorial series](https://github.com/deepjavalibrary/djl/tree/master/jupyter/tutorial) that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to execute your image classification model for a production system.\n", - "\n", - "In the [previous tutorial](02_train_your_first_model.ipynb), you successfully trained your model. Now, we will learn how to implement a `Translator` to convert between POJO and `NDArray` as well as a `Predictor` to run inference.\n", - "\n", - "\n", - "## Preparation\n", - "\n", - "This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the [Jupyter README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "// Add the snapshot repository to get the DJL snapshot artifacts\n", - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", - "\n", - "// Add the maven dependencies\n", - "%maven ai.djl:api:0.23.0\n", - "%maven ai.djl:model-zoo:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-engine:0.23.0\n", - "%maven ai.djl.mxnet:mxnet-model-zoo:0.23.0\n", - "%maven org.slf4j:slf4j-simple:1.7.32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import java.awt.image.*;\n", - "import java.nio.file.*;\n", - "import java.util.*;\n", - "import java.util.stream.*;\n", - "import ai.djl.*;\n", - "import ai.djl.basicmodelzoo.basic.*;\n", - "import ai.djl.ndarray.*;\n", - "import ai.djl.modality.*;\n", - "import ai.djl.modality.cv.*;\n", - "import ai.djl.modality.cv.util.NDImageUtils;\n", - "import ai.djl.translate.*;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 1: Load your handwritten digit image\n", - "\n", - "We will start by loading the image that we want to run our model to classify." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/0.png\");\n", - "img.getWrappedImage();" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 2: Load your model\n", - "\n", - "Next, we need to load the model to run inference with. This model should have been saved to the `build/mlp` directory when running the [previous tutorial](02_train_your_first_model.ipynb)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Path modelDir = Paths.get(\"build/mlp\");\n", - "Model model = Model.newInstance(\"mlp\");\n", - "model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));\n", - "model.load(modelDir);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In addition to loading a local model, you can also find pretrained models within our [model zoo](http://docs.djl.ai/docs/model-zoo.html). See more options in our [model loading documentation](http://docs.djl.ai/docs/load_model.html).\n", - "\n", - "## Step 3: Create a `Translator`\n", - "\n", - "The [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) is used to encapsulate the pre-processing and post-processing functionality of your application. The input to the processInput and processOutput should be single data items, not batches." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Translator translator = new Translator() {\n", - "\n", - " @Override\n", - " public NDList processInput(TranslatorContext ctx, Image input) {\n", - " // Convert Image to NDArray\n", - " NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);\n", - " return new NDList(NDImageUtils.toTensor(array));\n", - " }\n", - "\n", - " @Override\n", - " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", - " // Create a Classifications with the output probabilities\n", - " NDArray probabilities = list.singletonOrThrow().softmax(0);\n", - " List classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());\n", - " return new Classifications(classNames, probabilities);\n", - " }\n", - " \n", - " @Override\n", - " public Batchifier getBatchifier() {\n", - " // The Batchifier describes how to combine a batch together\n", - " // Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array\n", - " return Batchifier.STACK;\n", - " }\n", - "};" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 4: Create Predictor\n", - "\n", - "Using the translator, we will create a new [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html). The predictor is the main class to orchestrate the inference process. During inference, a trained model is used to predict values, often for production use cases. The predictor is NOT thread-safe, so if you want to do prediction in parallel, you should call newPredictor multiple times to create a predictor object for each thread." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var predictor = model.newPredictor(translator);" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Step 5: Run inference\n", - "\n", - "With our predictor, we can simply call the [predict](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html#predict(I)) method to run inference. For better performance, you can also call [batchPredict](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html#batchPredict(java.util.List)) with a list of input items. Afterwards, the same predictor should be used for further inference calls. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "var classifications = predictor.predict(img);\n", - "\n", - "classifications" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "Now, you've successfully built a model, trained it, and run inference. Congratulations on finishing the [beginner tutorial series](https://github.com/deepjavalibrary/djl/tree/master/jupyter/tutorial). After this, you should read our other [examples](https://github.com/deepjavalibrary/djl/tree/master/examples) and [jupyter notebooks](https://github.com/deepjavalibrary/djl/tree/master/jupyter) to learn more about DJL.\n", - "\n", - "You can find the complete source code for this tutorial in the [examples project](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Java", - "language": "java", - "name": "java" - }, - "language_info": { - "codemirror_mode": "java", - "file_extension": ".jshell", - "mimetype": "text/x-java-source", - "name": "Java", - "pygments_lexer": "java", - "version": "14.0.2+12" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/jupyter/tutorial/README.md b/jupyter/tutorial/README.md deleted file mode 100644 index 4c53b0f41e8..00000000000 --- a/jupyter/tutorial/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# DJL - Beginner Tutorial - -Our beginner tutorial takes you through creating your first network, training it, and using it in a real system. This is a good place to start if you are new to DJL or to deep learning. - -1. [Create your first neural network](01_create_your_first_network.ipynb) -2. [Train your first model](02_train_your_first_model.ipynb) -3. [Run image classification with your first model](03_image_classification_with_your_model.ipynb) diff --git a/model-zoo/README.md b/model-zoo/README.md index 11ae15c5505..b8f2a8fd124 100644 --- a/model-zoo/README.md +++ b/model-zoo/README.md @@ -33,7 +33,7 @@ You can pull the model zoo from the central Maven repository by including the fo ai.djl model-zoo - 0.23.0 + 0.26.0 ``` diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java index 7cdfc040c12..543ab5f1f21 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java @@ -43,8 +43,8 @@ public String getGroupId() { public Set getSupportedEngines() { Set set = new HashSet<>(); set.add("MXNet"); + set.add("PyTorch"); // TODO Currently WIP in supporting these two engines in the basic model zoo - // set.add("PyTorch"); // set.add("TensorFlow"); return set; } diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/Mlp.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/Mlp.java index 05f171c1ec2..2869e42f55d 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/Mlp.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/Mlp.java @@ -56,6 +56,7 @@ public Mlp(int input, int output, int[] hidden) { * @param hidden the sizes of all of the hidden layers * @param activation the activation function to use */ + @SuppressWarnings("this-escape") public Mlp(int input, int output, int[] hidden, Function activation) { add(Blocks.batchFlattenBlock(input)); for (int hiddenSize : hidden) { diff --git a/settings.gradle b/settings.gradle index 75a1f854ef8..ff6967fc308 100644 --- a/settings.gradle +++ b/settings.gradle @@ -2,6 +2,7 @@ rootProject.name = 'djl' include ':api' include ':basicdataset' include ':djl-zero' +include ':engines:llama' include ':engines:ml:xgboost' include ':engines:ml:lightgbm' include ':engines:mxnet:jnarator' @@ -34,7 +35,9 @@ include ':extensions:sentencepiece' include ':extensions:tokenizers' include ':extensions:tablesaw' include ':extensions:timeseries' -include ':extensions:spark' +if (JavaVersion.current() < JavaVersion.VERSION_21) { + include ':extensions:spark' +} include ':integration' include ':model-zoo' include ':testing' diff --git a/testing/src/main/java/ai/djl/testing/TestRequirements.java b/testing/src/main/java/ai/djl/testing/TestRequirements.java index bf57d64bd7c..32f242589b9 100644 --- a/testing/src/main/java/ai/djl/testing/TestRequirements.java +++ b/testing/src/main/java/ai/djl/testing/TestRequirements.java @@ -13,6 +13,7 @@ package ai.djl.testing; import ai.djl.engine.Engine; +import ai.djl.util.Utils; import org.testng.SkipException; @@ -45,7 +46,7 @@ public static void weekly() { /** Requires a test not be run in offline mode. */ public static void notOffline() { - if (Boolean.getBoolean("offline")) { + if (Utils.isOfflineMode()) { throw new SkipException("This test can not run while offline"); } } diff --git a/tools/gradle/publish.gradle b/tools/gradle/publish.gradle index 663f847e95b..0baa3d5a2c1 100644 --- a/tools/gradle/publish.gradle +++ b/tools/gradle/publish.gradle @@ -1,7 +1,8 @@ -configure([ +def projects = [ project(':api'), project(':basicdataset'), project(':djl-zero'), + project(':engines:llama'), project(':engines:ml:xgboost'), project(':engines:ml:lightgbm'), project(':engines:mxnet:mxnet-engine'), @@ -27,8 +28,13 @@ configure([ project(':extensions:tablesaw'), project(':extensions:timeseries'), project(':extensions:tokenizers'), - project(':extensions:spark'), - project(':model-zoo')]) { + project(':model-zoo') +] +if (JavaVersion.current() < JavaVersion.VERSION_21) { + projects.add(project(':extensions:spark')) +} + +configure(projects) { apply plugin: "maven-publish" apply plugin: "signing" diff --git a/tools/scripts/build_ft_deps.sh b/tools/scripts/build_ft_deps.sh deleted file mode 100755 index 4d3cb94a103..00000000000 --- a/tools/scripts/build_ft_deps.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env bash - -set -ex - -FT_VERSION=$1 -NVIDIA_TRITON_SERVER_VERSION=$2 -IS_LLAMA_BUILD=$3 - -apt-get update && apt-get install -y rapidjson-dev - -pushd /tmp - -git clone https://github.com/NVIDIA/FasterTransformer.git -b ${FT_VERSION} - -export FT_DIR=/tmp/FasterTransformer -mkdir -p /tmp/binaries - -# Build FasterTransformer Triton library -if [ "$IS_LLAMA_BUILD" = "false" ] ; then - git clone https://github.com/triton-inference-server/fastertransformer_backend.git -else - echo "cloning forked FT backend repo with llama support" - git clone https://github.com/rohithkrn/fastertransformer_backend.git -b llama_void_main -fi -mkdir -p fastertransformer_backend/build -cd fastertransformer_backend/build -cmake \ - -D CMAKE_EXPORT_COMPILE_COMMANDS=1 \ - -D CMAKE_BUILD_TYPE=Release \ - -D ENABLE_FP8=OFF \ - -D CMAKE_INSTALL_PREFIX=/opt/tritonserver \ - -D TRITON_COMMON_REPO_TAG="${NVIDIA_TRITON_SERVER_VERSION}" \ - -D TRITON_CORE_REPO_TAG="${NVIDIA_TRITON_SERVER_VERSION}" \ - -D TRITON_BACKEND_REPO_TAG="${NVIDIA_TRITON_SERVER_VERSION}" \ - .. -make -j$(nproc) install -cp /opt/tritonserver/backends/fastertransformer/*.so /tmp/binaries/ -cd ../../ - -# Build FasterTransformer TH Ops library -mkdir -p FasterTransformer/build -cd FasterTransformer/build -git submodule init && git submodule update -cmake -DCMAKE_BUILD_TYPE=Release -DSM=70,75,80,86 -DBUILD_PYT=ON -DBUILD_MULTI_GPU=ON .. -make -j$(nproc) -cp lib/libth_transformer.so /tmp/binaries/ -cd ../../ - -popd diff --git a/website/js/index.js b/website/js/index.js index 70b313acea3..605e1c04228 100644 --- a/website/js/index.js +++ b/website/js/index.js @@ -27,7 +27,7 @@ let app = new Vue({ }, { name: 'Tutorial', - url: 'https://docs.djl.ai/jupyter/tutorial/index.html' + url: 'https://docs.djl.ai/docs/demos/jupyter/tutorial/index.html' }, { name: 'Examples',

+ The Llama Engine module contains the Llama.cpp implementation of the DJL EngineProvider. + See here for more details. +