diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/JniUtils.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/JniUtils.java index aebd4910658..606676e246f 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/JniUtils.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/jni/JniUtils.java @@ -83,8 +83,15 @@ public static Pair inferenceMat( SWIGTYPE_p_p_void model, int iterations, LgbmNDArray a) { SWIGTYPE_p_long_long outLength = lightgbmlib.new_int64_tp(); SWIGTYPE_p_double outBuffer = null; + SWIGTYPE_p_int numClasses = lightgbmlib.new_intp(); try { - outBuffer = lightgbmlib.new_doubleArray(2L * a.getRows()); + int outFlag = + lightgbmlib.LGBM_BoosterGetNumClasses( + lightgbmlib.voidpp_value(model), numClasses); + checkCall(outFlag); + int classes = lightgbmlib.intp_value(numClasses); + + outBuffer = lightgbmlib.new_doubleArray((long) classes * a.getRows()); int result = lightgbmlib.LGBM_BoosterPredictForMat( lightgbmlib.voidpp_value(model), @@ -130,6 +137,7 @@ public static Pair inferenceMat( if (outBuffer != null) { lightgbmlib.delete_doubleArray(outBuffer); } + lightgbmlib.delete_intp(numClasses); } }