From dfbc7ebce293c4ea27f0c6e40d3e9d789b93b39d Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Thu, 11 May 2023 18:32:39 -0700 Subject: [PATCH] [spark] Change implicit conversions to asScala --- .../scala/ai/djl/spark/task/audio/SpeechRecognizer.scala | 5 ++--- .../scala/ai/djl/spark/task/binary/BinaryPredictor.scala | 5 ++--- .../main/scala/ai/djl/spark/task/text/TextClassifier.scala | 6 +++--- .../scala/ai/djl/spark/task/vision/ImageClassifier.scala | 7 +++---- .../scala/ai/djl/spark/task/vision/ImageEmbedder.scala | 5 ++--- .../scala/ai/djl/spark/task/vision/InstanceSegmenter.scala | 7 +++---- .../scala/ai/djl/spark/task/vision/ObjectDetector.scala | 7 +++---- .../scala/ai/djl/spark/task/vision/SemanticSegmenter.scala | 5 ++--- 8 files changed, 20 insertions(+), 27 deletions(-) diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/audio/SpeechRecognizer.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/audio/SpeechRecognizer.scala index 63ee0e73b84..c28a34f1e67 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/audio/SpeechRecognizer.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/audio/SpeechRecognizer.scala @@ -21,8 +21,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.types.{BinaryType, StringType, StructField, StructType} import java.io.ByteArrayInputStream -import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` -import scala.jdk.CollectionConverters.seqAsJavaListConverter +import scala.jdk.CollectionConverters.{collectionAsScalaIterableConverter, seqAsJavaListConverter} /** * SpeechRecognizer performs speech recognition on audio. @@ -120,7 +119,7 @@ class SpeechRecognizer(override val uid: String) extends BaseAudioPredictor[Stri }.asJava // Batch predict - val output = predictor.batchPredict(inputs) + val output = predictor.batchPredict(inputs).asScala batch.zip(output).map { case (row, out) => Row.fromSeq(row.toSeq :+ out) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/binary/BinaryPredictor.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/binary/BinaryPredictor.scala index 8ac641cb628..4301ed99cff 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/binary/BinaryPredictor.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/binary/BinaryPredictor.scala @@ -19,8 +19,7 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{BinaryType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} -import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` -import scala.jdk.CollectionConverters.seqAsJavaListConverter +import scala.jdk.CollectionConverters.{collectionAsScalaIterableConverter, seqAsJavaListConverter} /** * BinaryPredictor performs prediction on binary input. @@ -74,7 +73,7 @@ class BinaryPredictor(override val uid: String) extends BasePredictor[Array[Byte val predictor = model.newPredictor() iter.grouped($(batchSize)).flatMap { batch => val inputs = batch.map(_.getAs[Array[Byte]](inputColIndex)).asJava - val output = predictor.batchPredict(inputs) + val output = predictor.batchPredict(inputs).asScala batch.zip(output).map { case (row, out) => Row.fromSeq(row.toSeq :+ out) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextClassifier.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextClassifier.scala index 8d01294e7e5..61d92a9bd76 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextClassifier.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextClassifier.scala @@ -17,10 +17,10 @@ import ai.djl.modality.Classifications import org.apache.spark.ml.param.Param import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.types.{ArrayType, DoubleType, MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} -import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` +import scala.jdk.CollectionConverters.collectionAsScalaIterableConverter /** * TextClassifier performs text classification on text. @@ -88,7 +88,7 @@ class TextClassifier(override val uid: String) extends BaseTextPredictor[Array[S val output = predictor.predict(inputs) batch.zip(output).map { case (row, out) => Row.fromSeq(row.toSeq :+ Row(out.getClassNames.toArray(), out.getProbabilities.toArray(), - out.topK[Classifications.Classification]().map(_.toString))) + out.topK[Classifications.Classification]().asScala.map(_.toString))) } } } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageClassifier.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageClassifier.scala index 1b80681de3f..1eae397e35a 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageClassifier.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageClassifier.scala @@ -22,8 +22,7 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} -import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` -import scala.jdk.CollectionConverters.seqAsJavaListConverter +import scala.jdk.CollectionConverters.{collectionAsScalaIterableConverter, seqAsJavaListConverter} /** * ImageClassifier performs image classification on images. @@ -91,10 +90,10 @@ class ImageClassifier(override val uid: String) extends BaseImagePredictor[Class val inputs = batch.map(row => ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), ImageSchema.getWidth(row), ImageSchema.getHeight(row))).asJava - val output = predictor.batchPredict(inputs) + val output = predictor.batchPredict(inputs).asScala batch.zip(output).map { case (row, out) => Row.fromSeq(row.toSeq :+ Row(out.getClassNames.toArray(), out.getProbabilities.toArray(), - out.topK[Classifications.Classification]().map(_.toString))) + out.topK[Classifications.Classification]().asScala.map(_.toString))) } } } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageEmbedder.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageEmbedder.scala index 8c787fe48a1..a0cea701901 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageEmbedder.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageEmbedder.scala @@ -20,8 +20,7 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, ByteType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} -import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` -import scala.jdk.CollectionConverters.seqAsJavaListConverter +import scala.jdk.CollectionConverters.{collectionAsScalaIterableConverter, seqAsJavaListConverter} /** * ImageEmbedder performs image embedding on images. @@ -60,7 +59,7 @@ class ImageEmbedder(override val uid: String) extends BaseImagePredictor[Array[B val inputs = batch.map(row => ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), ImageSchema.getWidth(row), ImageSchema.getHeight(row))).asJava - val output = predictor.batchPredict(inputs) + val output = predictor.batchPredict(inputs).asScala batch.zip(output).map { case (row, out) => Row.fromSeq(row.toSeq :+ out) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/InstanceSegmenter.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/InstanceSegmenter.scala index 6d3e47cd838..b895bfc9340 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/InstanceSegmenter.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/InstanceSegmenter.scala @@ -22,8 +22,7 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} -import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` -import scala.jdk.CollectionConverters.seqAsJavaListConverter +import scala.jdk.CollectionConverters.{collectionAsScalaIterableConverter, seqAsJavaListConverter} /** * InstanceSegmenter performs instance segmentation on images. @@ -72,10 +71,10 @@ class InstanceSegmenter(override val uid: String) extends BaseImagePredictor[Det val inputs = batch.map(row => ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), ImageSchema.getWidth(row), ImageSchema.getHeight(row))).asJava - val output = predictor.batchPredict(inputs) + val output = predictor.batchPredict(inputs).asScala batch.zip(output).map { case (row, out) => Row.fromSeq(row.toSeq :+ Row(out.getClassNames.toArray(), out.getProbabilities.toArray(), - out.items[DetectedObject]().map(_.getBoundingBox.toString))) + out.items[DetectedObject]().asScala.map(_.getBoundingBox.toString))) } } } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ObjectDetector.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ObjectDetector.scala index 4f4184449bf..d831623529b 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ObjectDetector.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ObjectDetector.scala @@ -22,8 +22,7 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} -import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` -import scala.jdk.CollectionConverters.seqAsJavaListConverter +import scala.jdk.CollectionConverters.{collectionAsScalaIterableConverter, seqAsJavaListConverter} /** * ObjectDetector performs object detection on images. @@ -62,10 +61,10 @@ class ObjectDetector(override val uid: String) extends BaseImagePredictor[Detect val inputs = batch.map(row => ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), ImageSchema.getWidth(row), ImageSchema.getHeight(row))).asJava - val output = predictor.batchPredict(inputs) + val output = predictor.batchPredict(inputs).asScala batch.zip(output).map { case (row, out) => Row.fromSeq(row.toSeq :+ Row(out.getClassNames.toArray(), out.getProbabilities.toArray(), - out.items[DetectedObject]().map(_.getBoundingBox.toString))) + out.items[DetectedObject]().asScala.map(_.getBoundingBox.toString))) } } } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/SemanticSegmenter.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/SemanticSegmenter.scala index af2347ad306..3a9fb71bf0c 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/SemanticSegmenter.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/SemanticSegmenter.scala @@ -21,8 +21,7 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} -import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` -import scala.jdk.CollectionConverters.seqAsJavaListConverter +import scala.jdk.CollectionConverters.{collectionAsScalaIterableConverter, seqAsJavaListConverter} /** * SemanticSegmenter performs semantic segmentation on images. @@ -71,7 +70,7 @@ class SemanticSegmenter(override val uid: String) extends BaseImagePredictor[Cat val inputs = batch.map(row => ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), ImageSchema.getWidth(row), ImageSchema.getHeight(row))).asJava - val output = predictor.batchPredict(inputs) + val output = predictor.batchPredict(inputs).asScala batch.zip(output).map { case (row, out) => Row.fromSeq(row.toSeq :+ Row(out.getClasses.toArray, out.getMask)) }