diff --git a/docker/spark/Dockerfile b/docker/spark/Dockerfile index 54edbaa6681..da9c861378f 100644 --- a/docker/spark/Dockerfile +++ b/docker/spark/Dockerfile @@ -65,4 +65,5 @@ RUN echo "export HUGGINGFACE_HUB_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh RUN echo "export TRANSFORMERS_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh RUN echo "spark.yarn.appMasterEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf RUN echo "spark.executorEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf +RUN echo "spark.sql.execution.arrow.maxRecordsPerBatch 500" >> /opt/hadoop-config/spark-defaults.conf RUN echo "spark.hadoop.fs.s3a.connection.maximum 1000" >> /opt/hadoop-config/spark-defaults.conf diff --git a/extensions/spark/setup/djl_spark/task/audio/__init__.py b/extensions/spark/setup/djl_spark/task/audio/__init__.py index bed43fbd52f..6187f794a47 100644 --- a/extensions/spark/setup/djl_spark/task/audio/__init__.py +++ b/extensions/spark/setup/djl_spark/task/audio/__init__.py @@ -10,12 +10,18 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. +"""DJL Spark Tasks Audio API.""" -"""DJL Spark Tasks Text API.""" - -from . import whisper_speech_recognizer +from . import ( + speech_recognizer, + whisper_speech_recognizer, +) +SpeechRecognizer = speech_recognizer.SpeechRecognizer WhisperSpeechRecognizer = whisper_speech_recognizer.WhisperSpeechRecognizer # Remove unnecessary modules to avoid duplication in API. -del whisper_speech_recognizer \ No newline at end of file +del ( + speech_recognizer, + whisper_speech_recognizer, +) diff --git a/extensions/spark/setup/djl_spark/task/audio/speech_recognizer.py b/extensions/spark/setup/djl_spark/task/audio/speech_recognizer.py new file mode 100644 index 00000000000..cd7071e2483 --- /dev/null +++ b/extensions/spark/setup/djl_spark/task/audio/speech_recognizer.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from typing import Optional + + +class SpeechRecognizer: + + def __init__(self, + input_col: str, + output_col: str, + model_url: str, + engine: Optional[str] = None, + batch_size: Optional[int] = None, + translator_factory=None, + batchifier: Optional[str] = None, + channels: Optional[int] = None, + sample_rate: Optional[int] = None, + sample_format: Optional[int] = None): + """ + Initializes the SpeechRecognizer. + + :param input_col: The input column + :param output_col: The output column + :param model_url: The model URL + :param engine (optional): The engine + :param batch_size (optional): The batch size + :param translator_factory (optional): The translator factory. + Default is SpeechRecognitionTranslatorFactory. + :param batchifier (optional): The batchifier. Valid values include "none" (default), + "stack", and "padding". + :param channels (optional): The number of channels + :param sample_rate (optional): The audio sample rate + :param sample_format (optional): The audio sample format + """ + self.input_col = input_col + self.output_col = output_col + self.model_url = model_url + self.engine = engine + self.batch_size = batch_size + self.translator_factory = translator_factory + self.batchifier = batchifier + self.channels = channels + self.sample_rate = sample_rate + self.sample_format = sample_format + + def recognize(self, dataset): + """ + Performs speech recognition on the provided dataset. + + :param dataset: input dataset + :return: output dataset + """ + sc = SparkContext._active_spark_context + recognizer = sc._jvm.ai.djl.spark.task.audio.SpeechRecognizer() \ + .setInputCol(self.input_col) \ + .setOutputCol(self.output_col) \ + .setModelUrl(self.model_url) + if self.engine is not None: + recognizer = recognizer.setEngine(self.engine) + if self.batch_size is not None: + recognizer = recognizer.setBatchSize(self.batch_size) + if self.translator_factory is not None: + recognizer = recognizer.setTranslatorFactory( + self.translator_factory) + if self.batchifier is not None: + recognizer = recognizer.setBatchifier(self.batchifier) + return DataFrame(recognizer.recognize(dataset._jdf), + dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/audio/whisper_speech_recognizer.py b/extensions/spark/setup/djl_spark/task/audio/whisper_speech_recognizer.py index 15dce039ccb..dd56010bc70 100644 --- a/extensions/spark/setup/djl_spark/task/audio/whisper_speech_recognizer.py +++ b/extensions/spark/setup/djl_spark/task/audio/whisper_speech_recognizer.py @@ -16,11 +16,10 @@ import io import librosa import pandas as pd -from typing import Iterator +from typing import Iterator, Optional from transformers import pipeline from ...util import files_util, dependency_util - TASK = "automatic-speech-recognition" APPLICATION = "audio/automatic_speech_recognition" GROUP_ID = "ai/djl/huggingface/pytorch" @@ -28,7 +27,13 @@ class WhisperSpeechRecognizer: - def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engine="PyTorch"): + def __init__(self, + input_col: str, + output_col: str, + model_url: Optional[str] = None, + hf_model_id: Optional[str] = None, + engine: Optional[str] = "PyTorch", + batch_size: Optional[int] = 10): """ Initializes the WhisperSpeechRecognizer. @@ -37,12 +42,14 @@ def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engi :param model_url: The model URL :param hf_model_id: The Huggingface model ID :param engine: The engine. Currently only PyTorch is supported. + :param batch_size: The batch size """ self.input_col = input_col self.output_col = output_col self.model_url = model_url self.hf_model_id = hf_model_id self.engine = engine + self.batch_size = batch_size def recognize(self, dataset, generate_kwargs=None, **kwargs): """ @@ -57,24 +64,33 @@ def recognize(self, dataset, generate_kwargs=None, **kwargs): raise ValueError("Only PyTorch engine is supported.") if self.model_url: - cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, self.model_url) + cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, + self.model_url) files_util.download_and_extract(self.model_url, cache_dir) dependency_util.install(cache_dir) model_id_or_path = cache_dir elif self.hf_model_id: model_id_or_path = self.hf_model_id else: - raise ValueError("Either model_url or hf_model_id must be provided.") + raise ValueError( + "Either model_url or hf_model_id must be provided.") @pandas_udf(StringType()) def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: - pipe = pipeline(TASK, generate_kwargs=generate_kwargs, - model=model_id_or_path, chunk_length_s=30, **kwargs) + pipe = pipeline(TASK, + generate_kwargs=generate_kwargs, + model=model_id_or_path, + batch_size=self.batch_size, + chunk_length_s=30, + **kwargs) for s in iterator: # Model expects single channel, 16000 sample rate audio - batch = [librosa.load(io.BytesIO(d), mono=True, sr=16000)[0] for d in s] + batch = [ + librosa.load(io.BytesIO(d), mono=True, sr=16000)[0] + for d in s + ] output = pipe(batch) - text = map(lambda x: x["text"], output) + text = [o["text"] for o in output] yield pd.Series(text) - return dataset.withColumn(self.output_col, predict_udf(self.input_col)) \ No newline at end of file + return dataset.withColumn(self.output_col, predict_udf(self.input_col)) diff --git a/extensions/spark/setup/djl_spark/task/binary/__init__.py b/extensions/spark/setup/djl_spark/task/binary/__init__.py index 16bebea74ba..1d0828f8d9d 100644 --- a/extensions/spark/setup/djl_spark/task/binary/__init__.py +++ b/extensions/spark/setup/djl_spark/task/binary/__init__.py @@ -10,9 +10,11 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. +"""DJL Spark Tasks Binary API.""" + from . import binary_predictor BinaryPredictor = binary_predictor.BinaryPredictor # Remove unnecessary modules to avoid duplication in API. -del binary_predictor \ No newline at end of file +del binary_predictor diff --git a/extensions/spark/setup/djl_spark/task/binary/binary_predictor.py b/extensions/spark/setup/djl_spark/task/binary/binary_predictor.py index 593ffe95c28..9651f3f8bde 100644 --- a/extensions/spark/setup/djl_spark/task/binary/binary_predictor.py +++ b/extensions/spark/setup/djl_spark/task/binary/binary_predictor.py @@ -12,36 +12,47 @@ # the specific language governing permissions and limitations under the License. from pyspark import SparkContext -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import DataFrame +from typing import Optional class BinaryPredictor: """BinaryPredictor performs prediction on binary input. """ - def __init__(self, input_col, output_col, model_url, engine=None, - input_class=None, output_class=None, translator=None, - batchifier="none"): + def __init__(self, + input_col: str, + output_col: str, + model_url: str, + engine: Optional[str] = None, + batch_size: Optional[int] = None, + input_class=None, + output_class=None, + translator_factory=None, + batchifier: Optional[str] = None): """ Initializes the BinaryPredictor. - :param input_col: The input column. - :param output_col: The output column. - :param model_url: The model URL. - :param engine (optional): The engine. + :param input_col: The input column + :param output_col: The output column + :param model_url: The model URL + :param engine (optional): The engine + :param batch_size (optional): The batch size :param input_class (optional): The input class. Default is byte array. :param output_class (optional): The output class. Default is byte array. - :param translator (optional): The translator. Default is NpBinaryTranslator. - :param batchifier (optional): The batchifier. Valid values include none (default), - stack, and padding. + :param translator_factory (optional): The translator factory. + Default is NpBinaryTranslatorFactory. + :param batchifier (optional): The batchifier. Valid values include "none" (default), + "stack", and "padding". """ self.input_col = input_col self.output_col = output_col self.model_url = model_url self.engine = engine + self.batch_size = batch_size self.input_class = input_class self.output_class = output_class - self.translator = translator + self.translator_factory = translator_factory self.batchifier = batchifier def predict(self, dataset): @@ -52,18 +63,20 @@ def predict(self, dataset): :return: output dataset """ sc = SparkContext._active_spark_context - - predictor = sc._jvm.ai.djl.spark.task.binary.BinaryPredictor() + predictor = sc._jvm.ai.djl.spark.task.binary.BinaryPredictor() \ + .setInputCol(self.input_col) \ + .setOutputCol(self.output_col) \ + .setModelUrl(self.model_url) + if self.engine is not None: + predictor = predictor.setEngine(self.engine) + if self.batch_size is not None: + predictor = predictor.setBatchSize(self.batch_size) if self.input_class is not None: predictor = predictor.setinputClass(self.input_class) if self.output_class is not None: predictor = predictor.setOutputClass(self.output_class) - if self.translator is not None: - self.translator = predictor.setTranslator(self.translator) - predictor = predictor.setInputCol(self.input_col) \ - .setOutputCol(self.output_col) \ - .setModelUrl(self.model_url) \ - .setEngine(self.engine) \ - .setBatchifier(self.batchifier) - return DataFrame(predictor.predict(dataset._jdf), - dataset.sparkSession) + if self.translator_factory is not None: + predictor = predictor.setTranslatorFactory(self.translator_factory) + if self.batchifier is not None: + predictor = predictor.setBatchifier(self.batchifier) + return DataFrame(predictor.predict(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/text/__init__.py b/extensions/spark/setup/djl_spark/task/text/__init__.py index f8aaba072f4..54ad80dd9d5 100644 --- a/extensions/spark/setup/djl_spark/task/text/__init__.py +++ b/extensions/spark/setup/djl_spark/task/text/__init__.py @@ -10,22 +10,36 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. - """DJL Spark Tasks Text API.""" -from . import text_decoder, text_encoder, text_tokenizer, text_embedder, text2text_generator, text_generator +from . import ( + question_answerer, + text2text_generator, + text_classifier, + text_decoder, + text_embedder, + text_encoder, + text_generator, + text_tokenizer, +) +QuestionAnswerer = question_answerer.QuestionAnswerer +Text2TextGenerator = text2text_generator.Text2TextGenerator +TextClassifier = text_classifier.TextClassifier TextDecoder = text_decoder.TextDecoder -TextEncoder = text_encoder.TextEncoder -TextTokenizer = text_tokenizer.TextTokenizer TextEmbedder = text_embedder.TextEmbedder -Text2TextGenerator = text2text_generator.Text2TextGenerator +TextEncoder = text_encoder.TextEncoder TextGenerator = text_generator.TextGenerator +TextTokenizer = text_tokenizer.TextTokenizer # Remove unnecessary modules to avoid duplication in API. -del text_decoder -del text_encoder -del text_tokenizer -del text_embedder -del text2text_generator -del text_generator \ No newline at end of file +del ( + question_answerer, + text2text_generator, + text_classifier, + text_decoder, + text_embedder, + text_encoder, + text_generator, + text_tokenizer, +) diff --git a/extensions/spark/setup/djl_spark/task/text/question_answerer.py b/extensions/spark/setup/djl_spark/task/text/question_answerer.py new file mode 100644 index 00000000000..b3c48e941fe --- /dev/null +++ b/extensions/spark/setup/djl_spark/task/text/question_answerer.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from typing import Optional + + +class QuestionAnswerer: + + def __init__(self, + input_cols: list[str], + output_col: str, + model_url: str, + engine: Optional[str] = None, + batch_size: Optional[int] = None, + translator_factory=None, + batchifier: Optional[str] = None): + """ + Initializes the QuestionAnswerer. + + :param input_cols: The input columns + :param output_col: The output column + :param model_url: The model URL + :param engine (optional): The engine + :param batch_size (optional): The batch size + :param translator_factory (optional): The translator factory. + Default is QuestionAnsweringTranslatorFactory. + :param batchifier (optional): The batchifier. Valid values include "none" (default), + "stack", and "padding". + """ + self.input_cols = input_cols + self.output_col = output_col + self.model_url = model_url + self.engine = engine + self.batch_size = batch_size + self.translator_factory = translator_factory + self.batchifier = batchifier + + def answer(self, dataset): + """ + Performs question answering on the provided dataset. + + :param dataset: input dataset + :return: output dataset + """ + sc = SparkContext._active_spark_context + answerer = sc._jvm.ai.djl.spark.task.text.QuestionAnswerer() \ + .setOutputCol(self.output_col) \ + .setModelUrl(self.model_url) + if self.input_cols is not None: + # Convert the input_cols to Java array + input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String, + len(self.input_cols)) + input_cols_arr[:] = [col for col in self.input_cols] + answerer = answerer.setInputCols(input_cols_arr) + if self.engine is not None: + answerer = answerer.setEngine(self.engine) + if self.batch_size is not None: + answerer = answerer.setBatchSize(self.batch_size) + if self.translator_factory is not None: + answerer = answerer.setTranslatorFactory(self.translator_factory) + if self.batchifier is not None: + answerer = answerer.setBatchifier(self.batchifier) + return DataFrame(answerer.answer(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/text/text2text_generator.py b/extensions/spark/setup/djl_spark/task/text/text2text_generator.py index 95ff094760a..88a15168e9a 100644 --- a/extensions/spark/setup/djl_spark/task/text/text2text_generator.py +++ b/extensions/spark/setup/djl_spark/task/text/text2text_generator.py @@ -14,7 +14,7 @@ import pandas as pd from pyspark.sql.functions import pandas_udf from pyspark.sql.types import StringType -from typing import Iterator +from typing import Iterator, Optional from transformers import pipeline from ...util import files_util, dependency_util @@ -25,7 +25,13 @@ class Text2TextGenerator: - def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engine="PyTorch"): + def __init__(self, + input_col: str, + output_col: str, + model_url: Optional[str] = None, + hf_model_id: Optional[str] = None, + engine: Optional[str] = "PyTorch", + batch_size: Optional[str] = 100): """ Initializes the Text2TextGenerator. @@ -34,12 +40,14 @@ def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engi :param model_url: The model URL :param hf_model_id: The Huggingface model ID :param engine: The engine. Currently only PyTorch is supported. + :param batch_size: The batch size. """ self.input_col = input_col self.output_col = output_col self.model_url = model_url self.hf_model_id = hf_model_id self.engine = engine + self.batch_size = batch_size def generate(self, dataset, **kwargs): """ @@ -52,21 +60,26 @@ def generate(self, dataset, **kwargs): raise ValueError("Only PyTorch engine is supported.") if self.model_url: - cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, self.model_url) + cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, + self.model_url) files_util.download_and_extract(self.model_url, cache_dir) dependency_util.install(cache_dir) model_id_or_path = cache_dir elif self.hf_model_id: model_id_or_path = self.hf_model_id else: - raise ValueError("Either model_url or hf_model_id must be provided.") + raise ValueError( + "Either model_url or hf_model_id must be provided.") @pandas_udf(StringType()) def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: - pipe = pipeline(TASK, model=model_id_or_path, **kwargs) + pipe = pipeline(TASK, + model=model_id_or_path, + batch_size=self.batch_size, + **kwargs) for s in iterator: output = pipe(s.tolist()) - text = map(lambda x: x["generated_text"], output) + text = [o["generated_text"] for o in output] yield pd.Series(text) return dataset.withColumn(self.output_col, predict_udf(self.input_col)) diff --git a/extensions/spark/setup/djl_spark/task/text/text_classifier.py b/extensions/spark/setup/djl_spark/task/text/text_classifier.py new file mode 100644 index 00000000000..97e99b36170 --- /dev/null +++ b/extensions/spark/setup/djl_spark/task/text/text_classifier.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from typing import Optional + + +class TextClassifier: + + def __init__(self, + input_col: str, + output_col: str, + model_url: str, + engine: Optional[str] = None, + batch_size: Optional[int] = None, + translator_factory=None, + batchifier: Optional[str] = None): + """ + Initializes the TextClassifier. + + :param input_col: The input column + :param output_col: The output column + :param model_url: The model URL + :param engine (optional): The engine + :param batch_size (optional): The batch size + :param translator_factory (optional): The translator factory. + Default is TextClassificationTranslatorFactory. + :param batchifier (optional): The batchifier. Valid values include "none" (default), + "stack", and "padding". + """ + self.input_col = input_col + self.output_col = output_col + self.model_url = model_url + self.engine = engine + self.batch_size = batch_size + self.translator_factory = translator_factory + self.batchifier = batchifier + + def classify(self, dataset): + """ + Performs text classification on the provided dataset. + + :param dataset: input dataset + :return: output dataset + """ + sc = SparkContext._active_spark_context + classifier = sc._jvm.ai.djl.spark.task.text.TextClassifier() \ + .setInputCol(self.input_col) \ + .setOutputCol(self.output_col) \ + .setModelUrl(self.model_url) + if self.engine is not None: + classifier = classifier.setEngine(self.engine) + if self.batch_size is not None: + classifier = classifier.setBatchSize(self.batch_size) + if self.translator_factory is not None: + classifier = classifier.setTranslatorFactory( + self.translator_factory) + if self.batchifier is not None: + classifier = classifier.setBatchifier(self.batchifier) + return DataFrame(classifier.classify(dataset._jdf), + dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/text/text_decoder.py b/extensions/spark/setup/djl_spark/task/text/text_decoder.py index 8112b6cff1d..096408c2b83 100644 --- a/extensions/spark/setup/djl_spark/task/text/text_decoder.py +++ b/extensions/spark/setup/djl_spark/task/text/text_decoder.py @@ -13,21 +13,28 @@ from pyspark import SparkContext from pyspark.sql import DataFrame +from typing import Optional class TextDecoder: - def __init__(self, input_col, output_col, hf_model_id): + def __init__(self, + input_col: str, + output_col: str, + hf_model_id: str, + batch_size: Optional[int] = None): """ Initializes the TextDecoder. :param input_col: The input column :param output_col: The output column :param hf_model_id: The Huggingface model ID + :param batch_size (optional): The batch size """ self.input_col = input_col self.output_col = output_col self.hf_model_id = hf_model_id + self.batch_size = batch_size def decode(self, dataset): """ @@ -41,5 +48,6 @@ def decode(self, dataset): .setInputCol(self.input_col) \ .setOutputCol(self.output_col) \ .setHfModelId(self.hf_model_id) - return DataFrame(decoder.decode(dataset._jdf), - dataset.sparkSession) + if self.batch_size is not None: + decoder = decoder.setBatchSize(self.batch_size) + return DataFrame(decoder.decode(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/text/text_embedder.py b/extensions/spark/setup/djl_spark/task/text/text_embedder.py index f0593f9b9f4..7b283151a29 100644 --- a/extensions/spark/setup/djl_spark/task/text/text_embedder.py +++ b/extensions/spark/setup/djl_spark/task/text/text_embedder.py @@ -13,12 +13,19 @@ from pyspark import SparkContext from pyspark.sql import DataFrame +from typing import Optional class TextEmbedder: - def __init__(self, input_col, output_col, model_url, engine=None, - output_class=None, translator_factory=None): + def __init__(self, + input_col: str, + output_col: str, + model_url: str, + engine: Optional[str] = None, + batch_size: Optional[int] = None, + translator_factory=None, + batchifier: Optional[str] = None): """ Initializes the TextEmbedder. @@ -26,15 +33,19 @@ def __init__(self, input_col, output_col, model_url, engine=None, :param output_col: The output column :param model_url: The model URL :param engine (optional): The engine - :param output_class (optional): The output class - :param translator_factory (optional): The translator factory. Default is TextEmbeddingTranslatorFactory. + :param batch_size (optional): The batch size + :param translator_factory (optional): The translator factory. + Default is TextEmbeddingTranslatorFactory. + :param batchifier (optional): The batchifier. Valid values include "none" (default), + "stack", and "padding". """ self.input_col = input_col self.output_col = output_col - self.engine = engine self.model_url = model_url - self.output_class = output_class + self.engine = engine + self.batch_size = batch_size self.translator_factory = translator_factory + self.batchifier = batchifier def embed(self, dataset): """ @@ -44,14 +55,16 @@ def embed(self, dataset): :return: output dataset """ sc = SparkContext._active_spark_context - embedder = sc._jvm.ai.djl.spark.task.text.TextEmbedder() - if self.output_class is not None: - embedder = embedder.setOutputClass(self.output_class) - if self.translator_factory is not None: - embedder = embedder.setTranslatorFactory(self.translator_factory) - embedder = embedder.setInputCol(self.input_col) \ + embedder = sc._jvm.ai.djl.spark.task.text.TextEmbedder() \ + .setInputCol(self.input_col) \ .setOutputCol(self.output_col) \ - .setEngine(self.engine) \ .setModelUrl(self.model_url) - return DataFrame(embedder.embed(dataset._jdf), - dataset.sparkSession) + if self.engine is not None: + embedder = embedder.setEngine(self.engine) + if self.batch_size is not None: + embedder = embedder.setBatchSize(self.batch_size) + if self.translator_factory is not None: + embedder = embedder.setTranslatorFactory(self.translator_factory) + if self.batchifier is not None: + embedder = embedder.setBatchifier(self.batchifier) + return DataFrame(embedder.embed(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/text/text_encoder.py b/extensions/spark/setup/djl_spark/task/text/text_encoder.py index 44378e5930a..5bfcb8305ad 100644 --- a/extensions/spark/setup/djl_spark/task/text/text_encoder.py +++ b/extensions/spark/setup/djl_spark/task/text/text_encoder.py @@ -13,21 +13,28 @@ from pyspark import SparkContext from pyspark.sql import DataFrame +from typing import Optional class TextEncoder: - def __init__(self, input_col, output_col, hf_model_id): + def __init__(self, + input_col: str, + output_col: str, + hf_model_id: str, + batch_size: Optional[int] = None): """ Initializes the TextEncoder. :param input_col: The input column :param output_col: The output column :param hf_model_id: The Huggingface model ID + :param batch_size (optional): The batch size """ self.input_col = input_col self.output_col = output_col self.hf_model_id = hf_model_id + self.batch_size = batch_size def encode(self, dataset): """ @@ -41,5 +48,6 @@ def encode(self, dataset): .setInputCol(self.input_col) \ .setOutputCol(self.output_col) \ .setHfModelId(self.hf_model_id) - return DataFrame(encoder.encode(dataset._jdf), - dataset.sparkSession) + if self.batch_size is not None: + encoder = encoder.setBatchSize(self.batch_size) + return DataFrame(encoder.encode(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/text/text_generator.py b/extensions/spark/setup/djl_spark/task/text/text_generator.py index 0f0dfafae41..0d74a8bbe9b 100644 --- a/extensions/spark/setup/djl_spark/task/text/text_generator.py +++ b/extensions/spark/setup/djl_spark/task/text/text_generator.py @@ -14,11 +14,10 @@ import pandas as pd from pyspark.sql.functions import pandas_udf from pyspark.sql.types import StringType -from typing import Iterator +from typing import Iterator, Optional from transformers import pipeline from ...util import files_util, dependency_util - TASK = "text-generation" APPLICATION = "nlp/text_generation" GROUP_ID = "ai/djl/huggingface/pytorch" @@ -26,7 +25,13 @@ class TextGenerator: - def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engine="PyTorch"): + def __init__(self, + input_col: str, + output_col: str, + model_url: Optional[str] = None, + hf_model_id: Optional[str] = None, + engine: Optional[str] = "PyTorch", + batch_size: Optional[str] = 100): """ Initializes the TextGenerator. @@ -35,12 +40,14 @@ def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engi :param model_url: The model URL :param hf_model_id: The Huggingface model ID :param engine: The engine. Currently only PyTorch is supported. + :param batch_size: The batch size. """ self.input_col = input_col self.output_col = output_col self.model_url = model_url self.hf_model_id = hf_model_id self.engine = engine + self.batch_size = batch_size def generate(self, dataset, **kwargs): """ @@ -53,21 +60,26 @@ def generate(self, dataset, **kwargs): raise ValueError("Only PyTorch engine is supported.") if self.model_url: - cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, self.model_url) + cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, + self.model_url) files_util.download_and_extract(self.model_url, cache_dir) dependency_util.install(cache_dir) model_id_or_path = cache_dir elif self.hf_model_id: model_id_or_path = self.hf_model_id else: - raise ValueError("Either model_url or hf_model_id must be provided.") + raise ValueError( + "Either model_url or hf_model_id must be provided.") @pandas_udf(StringType()) def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: - pipe = pipeline(TASK, model=model_id_or_path, **kwargs) + pipe = pipeline(TASK, + model=model_id_or_path, + batch_size=self.batch_size, + **kwargs) for s in iterator: output = pipe(s.tolist()) - text = map(lambda x: x["generated_text"], output[0]) + text = [o[0]["generated_text"] for o in output] yield pd.Series(text) return dataset.withColumn(self.output_col, predict_udf(self.input_col)) diff --git a/extensions/spark/setup/djl_spark/task/text/text_tokenizer.py b/extensions/spark/setup/djl_spark/task/text/text_tokenizer.py index 7a8e9ce129c..0a0b1b0726c 100644 --- a/extensions/spark/setup/djl_spark/task/text/text_tokenizer.py +++ b/extensions/spark/setup/djl_spark/task/text/text_tokenizer.py @@ -17,7 +17,7 @@ class TextTokenizer: - def __init__(self, input_col, output_col, hf_model_id): + def __init__(self, input_col: str, output_col: str, hf_model_id: str): """ Initializes the TextTokenizer. diff --git a/extensions/spark/setup/djl_spark/task/vision/__init__.py b/extensions/spark/setup/djl_spark/task/vision/__init__.py index ed671ab556a..dc41b172149 100644 --- a/extensions/spark/setup/djl_spark/task/vision/__init__.py +++ b/extensions/spark/setup/djl_spark/task/vision/__init__.py @@ -10,12 +10,27 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. - """DJL Spark Tasks Vision API.""" -from . import image_classifier +from . import ( + image_classifier, + image_embedder, + instance_segmenter, + object_detector, + semantic_segmenter, +) ImageClassifier = image_classifier.ImageClassifier +ImageEmbedder = image_embedder.ImageEmbedder +InstanceSegmenter = instance_segmenter.InstanceSegmenter +ObjectDetector = object_detector.ObjectDetector +SemanticSegmenter = semantic_segmenter.SemanticSegmenter # Remove unnecessary modules to avoid duplication in API. -del image_classifier \ No newline at end of file +del ( + image_classifier, + image_embedder, + instance_segmenter, + object_detector, + semantic_segmenter, +) diff --git a/extensions/spark/setup/djl_spark/task/vision/image_classifier.py b/extensions/spark/setup/djl_spark/task/vision/image_classifier.py index 2495290510e..3b9a484de23 100644 --- a/extensions/spark/setup/djl_spark/task/vision/image_classifier.py +++ b/extensions/spark/setup/djl_spark/task/vision/image_classifier.py @@ -13,32 +13,47 @@ from pyspark import SparkContext from pyspark.sql import DataFrame +from typing import Optional class ImageClassifier: """ImageClassifier performs image classification on images. """ - def __init__(self, input_cols, output_col, engine, model_url, - output_class=None, translator=None, topK=5): + def __init__(self, + input_cols: list[str], + output_col: str, + model_url: str, + engine: Optional[str] = None, + batch_size: Optional[int] = None, + translator_factory=None, + batchifier: Optional[str] = None, + apply_softmax: Optional[bool] = None, + top_k: Optional[int] = None): """ Initializes the ImageClassifier. :param input_cols: The input columns :param output_col: The output column - :param engine (optional): The engine :param model_url: The model URL - :param output_class (optional): The output class - :param translator (optional): The translator. Default is ImageClassificationTranslator. - :param topK (optional): The number of classes to return. Default is 5. + :param engine (optional): The engine + :param batch_size (optional): The batch size + :param translator_factory (optional): The translator factory. + Default is ImageClassificationTranslatorFactory. + :param batchifier (optional): The batchifier. Valid values include "none" (default), + "stack", and "padding". + :param apply_softmax (optional): Whether to apply softmax when processing output. + :param top_k (optional): The number of classes to return. """ self.input_cols = input_cols self.output_col = output_col - self.engine = engine self.model_url = model_url - self.output_class = output_class - self.translator = translator - self.topK = topK + self.engine = engine + self.batch_size = batch_size + self.translator_factory = translator_factory + self.batchifier = batchifier + self.apply_softmax = apply_softmax + self.top_k = top_k def classify(self, dataset): """ @@ -48,25 +63,27 @@ def classify(self, dataset): :return: output dataset """ sc = SparkContext._active_spark_context - - # Convert the input_cols to Java array - input_cols_arr = None + classifier = sc._jvm.ai.djl.spark.task.vision.ImageClassifier() \ + .setOutputCol(self.output_col) \ + .setModelUrl(self.model_url) if self.input_cols is not None: + # Convert the input_cols to Java array input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String, len(self.input_cols)) - for i in range(len(self.input_cols)): - input_cols_arr[i] = self.input_cols[i] - - classifier = sc._jvm.ai.djl.spark.task.vision.ImageClassifier() - if input_cols_arr is not None: + input_cols_arr[:] = [col for col in self.input_cols] classifier = classifier.setInputCols(input_cols_arr) - if self.output_class is not None: - classifier = classifier.setOutputClass(self.output_class) - if self.translator is not None: - classifier = classifier.setTranslator(self.translator) - classifier = classifier.setOutputCol(self.output_col) \ - .setEngine(self.engine) \ - .setModelUrl(self.model_url) \ - .setTopK(self.topK) + if self.engine is not None: + classifier = classifier.setEngine(self.engine) + if self.batch_size is not None: + classifier = classifier.setBatchSize(self.batch_size) + if self.translator_factory is not None: + classifier = classifier.setTranslatorFactory( + self.translator_factory) + if self.batchifier is not None: + classifier = classifier.setBatchifier(self.batchifier) + if self.apply_softmax is not None: + classifier = classifier.setApplySoftmax(self.apply_softmax) + if self.top_k is not None: + classifier = classifier.setTopK(self.top_k) return DataFrame(classifier.classify(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/vision/image_embedder.py b/extensions/spark/setup/djl_spark/task/vision/image_embedder.py new file mode 100644 index 00000000000..ba46882712c --- /dev/null +++ b/extensions/spark/setup/djl_spark/task/vision/image_embedder.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from typing import Optional + + +class ImageEmbedder: + """ImageEmbedder performs image embedding on images. + """ + + def __init__(self, + input_cols: list[str], + output_col: str, + model_url: str, + engine: Optional[str] = None, + batch_size: Optional[int] = None, + translator_factory=None, + batchifier: Optional[str] = None): + """ + Initializes the ImageEmbedder. + + :param input_cols: The input columns + :param output_col: The output column + :param model_url: The model URL + :param engine (optional): The engine + :param batch_size (optional): The batch size + :param translator_factory (optional): The translator factory. + Default is ImageClassificationTranslatorFactory. + :param batchifier (optional): The batchifier. Valid values include "none" (default), + "stack", and "padding". + """ + self.input_cols = input_cols + self.output_col = output_col + self.model_url = model_url + self.engine = engine + self.batch_size = batch_size + self.translator_factory = translator_factory + self.batchifier = batchifier + + def embed(self, dataset): + """ + Performs image classification on the provided dataset. + + :param dataset: input dataset + :return: output dataset + """ + sc = SparkContext._active_spark_context + embedder = sc._jvm.ai.djl.spark.task.vision.ImageEmbedder() \ + .setOutputCol(self.output_col) \ + .setModelUrl(self.model_url) + if self.input_cols is not None: + # Convert the input_cols to Java array + input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String, + len(self.input_cols)) + input_cols_arr[:] = [col for col in self.input_cols] + embedder = embedder.setInputCols(input_cols_arr) + if self.engine is not None: + embedder = embedder.setEngine(self.engine) + if self.batch_size is not None: + embedder = embedder.setBatchSize(self.batch_size) + if self.translator_factory is not None: + embedder = embedder.setTranslatorFactory(self.translator_factory) + if self.batchifier is not None: + embedder = embedder.setBatchifier(self.batchifier) + return DataFrame(embedder.embed(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/vision/instance_segmenter.py b/extensions/spark/setup/djl_spark/task/vision/instance_segmenter.py new file mode 100644 index 00000000000..e03873689c1 --- /dev/null +++ b/extensions/spark/setup/djl_spark/task/vision/instance_segmenter.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from typing import Optional + + +class InstanceSegmenter: + """InstanceSegmenter performs instance segmentation on images. + """ + + def __init__(self, + input_cols: list[str], + output_col: str, + model_url: str, + engine: Optional[str] = None, + batch_size: Optional[int] = None, + translator_factory=None, + batchifier: Optional[str] = None): + """ + Initializes the InstanceSegmenter. + + :param input_cols: The input columns + :param output_col: The output column + :param model_url: The model URL + :param engine (optional): The engine + :param batch_size (optional): The batch size. Note that to enable batch predict + by setting batch size greater than 1, + we expect the input images to have the same size. + :param translator_factory (optional): The translator factory. + Default is ImageClassificationTranslatorFactory. + :param batchifier (optional): The batchifier. Valid values include "none" (default), + "stack", and "padding". + """ + self.input_cols = input_cols + self.output_col = output_col + self.model_url = model_url + self.engine = engine + self.batch_size = batch_size + self.translator_factory = translator_factory + self.batchifier = batchifier + + def segment(self, dataset): + """ + Performs instance segmentation on the provided dataset. + + :param dataset: input dataset + :return: output dataset + """ + sc = SparkContext._active_spark_context + segmenter = sc._jvm.ai.djl.spark.task.vision.InstanceSegmenter() \ + .setOutputCol(self.output_col) \ + .setModelUrl(self.model_url) + if self.input_cols is not None: + # Convert the input_cols to Java array + input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String, + len(self.input_cols)) + input_cols_arr[:] = [col for col in self.input_cols] + segmenter = segmenter.setInputCols(input_cols_arr) + if self.engine is not None: + segmenter = segmenter.setEngine(self.engine) + if self.batch_size is not None: + segmenter = segmenter.setBatchSize(self.batch_size) + if self.translator_factory is not None: + segmenter = segmenter.setTranslatorFactory(self.translator_factory) + if self.batchifier is not None: + segmenter = segmenter.setBatchifier(self.batchifier) + return DataFrame(segmenter.segment(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/vision/object_detector.py b/extensions/spark/setup/djl_spark/task/vision/object_detector.py new file mode 100644 index 00000000000..511560c26bd --- /dev/null +++ b/extensions/spark/setup/djl_spark/task/vision/object_detector.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from typing import Optional + + +class ObjectDetector: + """ObjectDetector performs object detection on images. + """ + + def __init__(self, + input_cols: list[str], + output_col: str, + model_url: str, + engine: Optional[str] = None, + batch_size: Optional[int] = None, + translator_factory=None, + batchifier: Optional[str] = None): + """ + Initializes the ObjectDetector. + + :param input_cols: The input columns + :param output_col: The output column + :param model_url: The model URL + :param engine (optional): The engine + :param batch_size (optional): The batch size + :param translator_factory (optional): The translator factory. + Default is ImageClassificationTranslatorFactory. + :param batchifier (optional): The batchifier. Valid values include "none" (default), + "stack", and "padding". + """ + self.input_cols = input_cols + self.output_col = output_col + self.model_url = model_url + self.engine = engine + self.batch_size = batch_size + self.translator_factory = translator_factory + self.batchifier = batchifier + + def detect(self, dataset): + """ + Performs object detection on the provided dataset. + + :param dataset: input dataset + :return: output dataset + """ + sc = SparkContext._active_spark_context + detector = sc._jvm.ai.djl.spark.task.vision.ObjectDetector() \ + .setOutputCol(self.output_col) \ + .setModelUrl(self.model_url) + if self.input_cols is not None: + # Convert the input_cols to Java array + input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String, + len(self.input_cols)) + input_cols_arr[:] = [col for col in self.input_cols] + detector = detector.setInputCols(input_cols_arr) + if self.engine is not None: + detector = detector.setEngine(self.engine) + if self.batch_size is not None: + detector = detector.setBatchSize(self.batch_size) + if self.translator_factory is not None: + detector = detector.setTranslatorFactory(self.translator_factory) + if self.batchifier is not None: + detector = detector.setBatchifier(self.batchifier) + return DataFrame(detector.detect(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/task/vision/semantic_segmenter.py b/extensions/spark/setup/djl_spark/task/vision/semantic_segmenter.py new file mode 100644 index 00000000000..ba23f3aa9f9 --- /dev/null +++ b/extensions/spark/setup/djl_spark/task/vision/semantic_segmenter.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from typing import Optional + + +class SemanticSegmenter: + """SemanticSegmenter performs semantic segmentation on images. + """ + + def __init__(self, + input_cols: list[str], + output_col: str, + model_url: str, + engine: Optional[str] = None, + batch_size: Optional[int] = None, + translator_factory=None, + batchifier: Optional[str] = None): + """ + Initializes the SemanticSegmenter. + + :param input_cols: The input columns + :param output_col: The output column + :param model_url: The model URL + :param engine (optional): The engine + :param batch_size (optional): The batch size. Note that to enable batch predict + by setting batch size greater than 1, + we expect the input images to have the same size. + :param translator_factory (optional): The translator factory. + Default is ImageClassificationTranslatorFactory. + :param batchifier (optional): The batchifier. Valid values include "none" (default), + "stack", and "padding". + """ + self.input_cols = input_cols + self.output_col = output_col + self.model_url = model_url + self.engine = engine + self.batch_size = batch_size + self.translator_factory = translator_factory + self.batchifier = batchifier + + def segment(self, dataset): + """ + Performs semantic segmentation on the provided dataset. + + :param dataset: input dataset + :return: output dataset + """ + sc = SparkContext._active_spark_context + segmenter = sc._jvm.ai.djl.spark.task.vision.SemanticSegmenter() \ + .setOutputCol(self.output_col) \ + .setModelUrl(self.model_url) + if self.input_cols is not None: + # Convert the input_cols to Java array + input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String, + len(self.input_cols)) + input_cols_arr[:] = [col for col in self.input_cols] + segmenter = segmenter.setInputCols(input_cols_arr) + if self.engine is not None: + segmenter = segmenter.setEngine(self.engine) + if self.batch_size is not None: + segmenter = segmenter.setBatchSize(self.batch_size) + if self.translator_factory is not None: + segmenter = segmenter.setTranslatorFactory(self.translator_factory) + if self.batchifier is not None: + segmenter = segmenter.setBatchifier(self.batchifier) + return DataFrame(segmenter.segment(dataset._jdf), dataset.sparkSession) diff --git a/extensions/spark/setup/djl_spark/util/dependency_util.py b/extensions/spark/setup/djl_spark/util/dependency_util.py index ea7fc787fab..be81961cada 100644 --- a/extensions/spark/setup/djl_spark/util/dependency_util.py +++ b/extensions/spark/setup/djl_spark/util/dependency_util.py @@ -21,7 +21,10 @@ def install(path): :param path: The path to find the requirements.txt. """ if os.path.exists(os.path.join(path, "requirements.txt")): - cmd = [python_executable(), "-m", "pip", "install", "-r", os.path.join(path, "requirements.txt")] + cmd = [ + python_executable(), "-m", "pip", "install", "-r", + os.path.join(path, "requirements.txt") + ] try: subprocess.run(cmd, stderr=subprocess.STDOUT, check=True) except subprocess.CalledProcessError as e: @@ -34,5 +37,6 @@ def python_executable(): :return: The path of the Python executable. """ if not sys.executable: - raise RuntimeError("Failed to retrieve the path of the Python executable.") + raise RuntimeError( + "Failed to retrieve the path of the Python executable.") return sys.executable diff --git a/extensions/spark/setup/djl_spark/util/files_util.py b/extensions/spark/setup/djl_spark/util/files_util.py index 9938a377540..5e31fc9e177 100644 --- a/extensions/spark/setup/djl_spark/util/files_util.py +++ b/extensions/spark/setup/djl_spark/util/files_util.py @@ -28,7 +28,8 @@ def get_cache_dir(application, group_id, url): :param group_id: The group ID. :param url: The url of the file to store to the cache. """ - base_dir = os.environ.get("DJL_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".djl.ai")) + base_dir = os.environ.get("DJL_CACHE_DIR", + os.path.join(os.path.expanduser("~"), ".djl.ai")) h = hashlib.sha256(url.encode('UTF-8')).hexdigest()[:40] return os.path.join(base_dir, "cache/repo/model", application, group_id, h) @@ -55,7 +56,8 @@ def s3_download(url, path): url = urlparse(url) if url.scheme != "s3": - raise ValueError("Expecting 's3' scheme, got: %s in %s" % (url.scheme, url)) + raise ValueError("Expecting 's3' scheme, got: %s in %s" % + (url.scheme, url)) bucket, key = url.netloc, url.path.lstrip("/") s3 = boto3.client("s3") diff --git a/extensions/spark/setup/setup.py b/extensions/spark/setup/setup.py index 7ff242b93ca..a0050588779 100755 --- a/extensions/spark/setup/setup.py +++ b/extensions/spark/setup/setup.py @@ -36,9 +36,11 @@ def run(self): if __name__ == '__main__': version = detect_version() - requirements = ['packaging', 'wheel'] + requirements = [ + 'packaging', 'wheel', 'pillow', 'pandas', 'numpy', 'pyarrow' + ] - test_requirements = ['numpy', 'requests', 'Pillow'] + test_requirements = ['numpy', 'requests'] setup(name='djl_spark', version=version, diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/BasePredictor.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/BasePredictor.scala index 55713b8ce88..9ced6482453 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/BasePredictor.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/BasePredictor.scala @@ -15,7 +15,7 @@ package ai.djl.spark.task import ai.djl.spark.ModelLoader import ai.djl.translate.TranslatorFactory import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.types.{DataType, StructField, StructType} @@ -30,18 +30,26 @@ abstract class BasePredictor[A, B](override val uid: String) extends Transformer def this() = this(Identifiable.randomUID("BasePredictor")) - final val engine = new Param[String](this, "engine", "The engine") final val modelUrl = new Param[String](this, "modelUrl", "The model URL") + final val engine = new Param[String](this, "engine", "The engine") + final val batchSize = new IntParam(this, "batchSize", "The batch size") final val inputClass = new Param[Class[A]](this, "inputClass", "The input class") final val outputClass = new Param[Class[B]](this, "outputClass", "The output class") + final val translatorFactory = new Param[TranslatorFactory](this, "translatorFactory", "The translator factory") final val batchifier = new Param[String](this, "batchifier", "The batchifier. Valid values include none (default), stack, and padding.") - final val translatorFactory = new Param[TranslatorFactory](this, "translatorFactory", "The translator factory") protected var model: ModelLoader[A, B] = _ protected var arguments: java.util.Map[String, AnyRef] = new java.util.HashMap[String, AnyRef] protected var outputSchema: StructType = _ + /** + * Sets the modelUrl parameter. + * + * @param value the value of the parameter + */ + def setModelUrl(value: String): this.type = set(modelUrl, value) + /** * Sets the engine parameter. * @@ -50,11 +58,11 @@ abstract class BasePredictor[A, B](override val uid: String) extends Transformer def setEngine(value: String): this.type = set(engine, value) /** - * Sets the modelUrl parameter. + * Sets the batchSize parameter. * * @param value the value of the parameter */ - def setModelUrl(value: String): this.type = set(modelUrl, value) + def setBatchSize(value: Int): this.type = set(batchSize, value) /** * Sets the input class. @@ -71,29 +79,30 @@ abstract class BasePredictor[A, B](override val uid: String) extends Transformer def setOutputClass(value: Class[B]): this.type = set(outputClass, value) /** - * Sets the batchifier parameter. + * Sets the translatorFactory parameter. * * @param value the value of the parameter */ - def setBatchifier(value: String): this.type = set(batchifier, value) + def setTranslatorFactory(value: TranslatorFactory): this.type = set(translatorFactory, value) /** - * Sets the translatorFactory parameter. + * Sets the batchifier parameter. * * @param value the value of the parameter */ - def setTranslatorFactory(value: TranslatorFactory): this.type = set(translatorFactory, value) + def setBatchifier(value: String): this.type = set(batchifier, value) - setDefault(engine, null) setDefault(modelUrl, null) + setDefault(engine, null) + setDefault(batchSize, 10) /** @inheritdoc */ override def transform(dataset: Dataset[_]): DataFrame = { if (isDefined(batchifier)) { arguments.put("batchifier", $(batchifier)) } - model = new ModelLoader[A, B]($(engine), $(modelUrl), $(inputClass), $(outputClass), $(translatorFactory), - arguments) + model = new ModelLoader[A, B]($(engine), $(modelUrl), $(inputClass), + $(outputClass), $(translatorFactory), arguments) validateInputType(dataset.schema) outputSchema = transformSchema(dataset.schema) val outputDf = dataset.toDF() @@ -117,7 +126,7 @@ abstract class BasePredictor[A, B](override val uid: String) extends Transformer * * @param schema the schema to validate */ - def validateInputType(schema: StructType): Unit + protected def validateInputType(schema: StructType): Unit /** * Validate data type. diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/audio/BaseAudioPredictor.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/audio/BaseAudioPredictor.scala index 143defe28fd..67e18aacd9a 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/audio/BaseAudioPredictor.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/audio/BaseAudioPredictor.scala @@ -25,5 +25,6 @@ abstract class BaseAudioPredictor[B](override val uid: String) extends BasePredi def this() = this(Identifiable.randomUID("BaseAudioPredictor")) + setDefault(batchSize, 10) setDefault(inputClass, classOf[Audio]) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/audio/SpeechRecognizer.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/audio/SpeechRecognizer.scala index 7c0b8dbbc1f..63ee0e73b84 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/audio/SpeechRecognizer.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/audio/SpeechRecognizer.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.types.{BinaryType, StringType, StructField, StructType} import java.io.ByteArrayInputStream +import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` +import scala.jdk.CollectionConverters.seqAsJavaListConverter /** * SpeechRecognizer performs speech recognition on audio. @@ -92,33 +94,41 @@ class SpeechRecognizer(override val uid: String) extends BaseAudioPredictor[Stri super.transform(dataset) } - /** - * Transforms the rows. - * - * @param iter the rows to transform - * @return the transformed rows - */ + /** @inheritdoc */ override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val predictor = model.newPredictor() - iter.map(row => { - val data = row.getAs[Array[Byte]](inputColIndex) - val audioFactory = AudioFactory.newInstance - if (isDefined(channels)) { - audioFactory.setChannels($(channels)) - } - if (isDefined(sampleRate)) { - audioFactory.setSampleRate($(sampleRate)) - } - if (isDefined(sampleFormat)) { - audioFactory.setSampleFormat($(sampleFormat)) + val audioFactory = AudioFactory.newInstance() + if (isDefined(channels)) { + audioFactory.setChannels($(channels)) + } + if (isDefined(sampleRate)) { + audioFactory.setSampleRate($(sampleRate)) + } + if (isDefined(sampleFormat)) { + audioFactory.setSampleFormat($(sampleFormat)) + } + iter.grouped($(batchSize)).flatMap { batch => + // Read inputs + val inputs = batch.map { row => + val data = row.getAs[Array[Byte]](inputColIndex) + val is = new ByteArrayInputStream(data) + try { + audioFactory.fromInputStream(is) + } finally { + is.close() + } + }.asJava + + // Batch predict + val output = predictor.batchPredict(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ out) } - val audio = audioFactory.fromInputStream(new ByteArrayInputStream(data)) - Row.fromSeq(row.toSeq :+ predictor.predict(audio)) - }) + } } /** @inheritdoc */ - def validateInputType(schema: StructType): Unit = { + override protected def validateInputType(schema: StructType): Unit = { validateType(schema($(inputCol)), BinaryType) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/binary/BinaryPredictor.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/binary/BinaryPredictor.scala index 200d12e9ae4..8ac641cb628 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/binary/BinaryPredictor.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/binary/BinaryPredictor.scala @@ -19,6 +19,9 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{BinaryType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} +import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` +import scala.jdk.CollectionConverters.seqAsJavaListConverter + /** * BinaryPredictor performs prediction on binary input. * @@ -45,6 +48,7 @@ class BinaryPredictor(override val uid: String) extends BasePredictor[Array[Byte */ def setOutputCol(value: String): this.type = set(outputCol, value) + setDefault(batchSize, 10) setDefault(inputClass, classOf[Array[Byte]]) setDefault(outputClass, classOf[Array[Byte]]) setDefault(translatorFactory, new NpBinaryTranslatorFactory()) @@ -61,7 +65,6 @@ class BinaryPredictor(override val uid: String) extends BasePredictor[Array[Byte /** @inheritdoc */ override def transform(dataset: Dataset[_]): DataFrame = { - arguments.put("batchifier", $(batchifier)) inputColIndex = dataset.schema.fieldIndex($(inputCol)) super.transform(dataset) } @@ -69,13 +72,17 @@ class BinaryPredictor(override val uid: String) extends BasePredictor[Array[Byte /** @inheritdoc */ override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val predictor = model.newPredictor() - iter.map(row => { - Row.fromSeq(row.toSeq :+ predictor.predict(row.getAs[Array[Byte]](inputColIndex))) - }) + iter.grouped($(batchSize)).flatMap { batch => + val inputs = batch.map(_.getAs[Array[Byte]](inputColIndex)).asJava + val output = predictor.batchPredict(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ out) + } + } } /** @inheritdoc */ - def validateInputType(schema: StructType): Unit = { + override protected def validateInputType(schema: StructType): Unit = { validateType(schema($(inputCol)), BinaryType) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/text/BaseTextPredictor.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/text/BaseTextPredictor.scala index da468342b75..c50508ec8b3 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/text/BaseTextPredictor.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/text/BaseTextPredictor.scala @@ -23,4 +23,6 @@ import org.apache.spark.ml.util.Identifiable abstract class BaseTextPredictor[A, B](override val uid: String) extends BasePredictor[A, B] { def this() = this(Identifiable.randomUID("BaseTextPredictor")) + + setDefault(batchSize, 100) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/text/QuestionAnswerer.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/text/QuestionAnswerer.scala index aefe4d6557b..30cbccb6485 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/text/QuestionAnswerer.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/text/QuestionAnswerer.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} * * @param uid An immutable unique ID for the object and its derivatives. */ -class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[QAInput, String] +class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[Array[QAInput], Array[String]] with HasInputCols with HasOutputCol { def this() = this(Identifiable.randomUID("QuestionAnswerer")) @@ -46,8 +46,8 @@ class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[QAInp */ def setOutputCol(value: String): this.type = set(outputCol, value) - setDefault(inputClass, classOf[QAInput]) - setDefault(outputClass, classOf[String]) + setDefault(inputClass, classOf[Array[QAInput]]) + setDefault(outputClass, classOf[Array[String]]) setDefault(translatorFactory, new QuestionAnsweringTranslatorFactory()) /** @@ -70,14 +70,18 @@ class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[QAInp /** @inheritdoc */ override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val predictor = model.newPredictor() - iter.map(row => { - Row.fromSeq(row.toSeq :+ predictor.predict(new QAInput(row.getString(inputColIndices(0)), - row.getString(inputColIndices(1))))) - }) + iter.grouped($(batchSize)).flatMap { batch => + val inputs = batch.map(row => new QAInput(row.getString(inputColIndices(0)), + row.getString(inputColIndices(1)))).toArray + val output = predictor.predict(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ out) + } + } } /** @inheritdoc */ - def validateInputType(schema: StructType): Unit = { + override protected def validateInputType(schema: StructType): Unit = { assert($(inputCols).length == 2, "inputCols must have 2 columns") validateType(schema($(inputCols)(0)), StringType) validateType(schema($(inputCols)(1)), StringType) diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextClassifier.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextClassifier.scala index 36c44acdc4f..8d01294e7e5 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextClassifier.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextClassifier.scala @@ -14,21 +14,20 @@ package ai.djl.spark.task.text import ai.djl.huggingface.translator.TextClassificationTranslatorFactory import ai.djl.modality.Classifications -import ai.djl.modality.Classifications.Classification import org.apache.spark.ml.param.Param import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, DoubleType, MapType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} -import scala.collection.mutable +import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` /** * TextClassifier performs text classification on text. * * @param uid An immutable unique ID for the object and its derivatives. */ -class TextClassifier(override val uid: String) extends BaseTextPredictor[String, Classifications] +class TextClassifier(override val uid: String) extends BaseTextPredictor[Array[String], Array[Classifications]] with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("TextClassifier")) @@ -58,9 +57,8 @@ class TextClassifier(override val uid: String) extends BaseTextPredictor[String, */ def setTopK(value: Int): this.type = set(topK, value) - setDefault(inputClass, classOf[String]) - setDefault(outputClass, classOf[Classifications]) - setDefault(topK, 3) + setDefault(inputClass, classOf[Array[String]]) + setDefault(outputClass, classOf[Array[Classifications]]) setDefault(translatorFactory, new TextClassificationTranslatorFactory()) /** @@ -75,8 +73,9 @@ class TextClassifier(override val uid: String) extends BaseTextPredictor[String, /** @inheritdoc */ override def transform(dataset: Dataset[_]): DataFrame = { - arguments.put("batchifier", $(batchifier)) - arguments.put("topK", $(topK).toString) + if (isDefined(topK)) { + arguments.put("topK", $(topK).toString) + } inputColIndex = dataset.schema.fieldIndex($(inputCol)) super.transform(dataset) } @@ -84,21 +83,18 @@ class TextClassifier(override val uid: String) extends BaseTextPredictor[String, /** @inheritdoc */ override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val predictor = model.newPredictor() - iter.map(row => { - val prediction: Classifications = predictor.predict(row.getString(inputColIndex)) - val top = mutable.LinkedHashMap[String, Double]() - val it: java.util.Iterator[Classification] = prediction.topK($(topK)).iterator() - while (it.hasNext) { - val t = it.next() - top += (t.getClassName -> t.getProbability) + iter.grouped($(batchSize)).flatMap { batch => + val inputs = batch.map(_.getString(inputColIndex)).toArray + val output = predictor.predict(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ Row(out.getClassNames.toArray(), out.getProbabilities.toArray(), + out.topK[Classifications.Classification]().map(_.toString))) } - Row.fromSeq(row.toSeq :+ Row(prediction.getClassNames.toArray, - prediction.getProbabilities.toArray, top)) - }) + } } /** @inheritdoc */ - def validateInputType(schema: StructType): Unit = { + override protected def validateInputType(schema: StructType): Unit = { validateType(schema($(inputCol)), StringType) } @@ -107,7 +103,7 @@ class TextClassifier(override val uid: String) extends BaseTextPredictor[String, val outputSchema = StructType(schema.fields :+ StructField($(outputCol), StructType(Seq(StructField("class_names", ArrayType(StringType)), StructField("probabilities", ArrayType(DoubleType)), - StructField("topK", MapType(StringType, DoubleType)))))) + StructField("top_k", ArrayType(StringType)))))) outputSchema } } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextDecoder.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextDecoder.scala index 405c69e581c..01175eb30f2 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextDecoder.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextDecoder.scala @@ -75,15 +75,19 @@ class TextDecoder(override val uid: String) extends BaseTextPredictor[Array[Long } /** @inheritdoc */ - override def transformRows(iter: Iterator[Row]): Iterator[Row] = { + override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val tokenizer = HuggingFaceTokenizer.newInstance($(hfModelId)) - iter.map(row => { - Row.fromSeq(row.toSeq :+ tokenizer.decode(row.getAs[Seq[Long]]($(inputCol)).toArray)) - }) + iter.grouped($(batchSize)).flatMap { batch => + val inputs = batch.map(_.getAs[Seq[Long]](inputColIndex)).map(_.toArray).toArray + val output = tokenizer.batchDecode(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ out) + } + } } /** @inheritdoc */ - def validateInputType(schema: StructType): Unit = { + override protected def validateInputType(schema: StructType): Unit = { validateType(schema($(inputCol)), ArrayType(LongType)) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextEmbedder.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextEmbedder.scala index 000cfb9b433..075eb3d8444 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextEmbedder.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextEmbedder.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} * * @param uid An immutable unique ID for the object and its derivatives. */ -class TextEmbedder(override val uid: String) extends BaseTextPredictor[String, Array[Float]] +class TextEmbedder(override val uid: String) extends BaseTextPredictor[Array[String], Array[Array[Float]]] with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("TextEmbedder")) @@ -44,8 +44,8 @@ class TextEmbedder(override val uid: String) extends BaseTextPredictor[String, A */ def setOutputCol(value: String): this.type = set(outputCol, value) - setDefault(inputClass, classOf[String]) - setDefault(outputClass, classOf[Array[Float]]) + setDefault(inputClass, classOf[Array[String]]) + setDefault(outputClass, classOf[Array[Array[Float]]]) setDefault(translatorFactory, new TextEmbeddingTranslatorFactory()) /** @@ -67,13 +67,17 @@ class TextEmbedder(override val uid: String) extends BaseTextPredictor[String, A /** @inheritdoc */ override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val predictor = model.newPredictor() - iter.map(row => { - Row.fromSeq(row.toSeq :+ predictor.predict(row.getString(inputColIndex))) - }) + iter.grouped($(batchSize)).flatMap { batch => + val inputs = batch.map(_.getString(inputColIndex)).toArray + val output = predictor.predict(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ out) + } + } } /** @inheritdoc */ - def validateInputType(schema: StructType): Unit = { + override protected def validateInputType(schema: StructType): Unit = { validateType(schema($(inputCol)), StringType) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextEncoder.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextEncoder.scala index f02e8fbe640..a9849f37851 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextEncoder.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextEncoder.scala @@ -75,16 +75,19 @@ class TextEncoder(override val uid: String) extends BaseTextPredictor[String, En } /** @inheritdoc */ - override def transformRows(iter: Iterator[Row]): Iterator[Row] = { + override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val tokenizer = HuggingFaceTokenizer.newInstance($(hfModelId)) - iter.map(row => { - val encoding = tokenizer.encode(row.getString(inputColIndex)) - Row.fromSeq(row.toSeq :+ Row(encoding.getIds, encoding.getTypeIds, encoding.getAttentionMask)) - }) + iter.grouped($(batchSize)).flatMap { batch => + val inputs = batch.map(_.getString(inputColIndex)).toArray + val output = tokenizer.batchEncode(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ Row(out.getIds, out.getTypeIds, out.getAttentionMask)) + } + } } /** @inheritdoc */ - def validateInputType(schema: StructType): Unit = { + override protected def validateInputType(schema: StructType): Unit = { validateType(schema($(inputCol)), StringType) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextTokenizer.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextTokenizer.scala index 9bc8361ff1b..e288cf03c59 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextTokenizer.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/text/TextTokenizer.scala @@ -83,7 +83,7 @@ class TextTokenizer(override val uid: String) extends BaseTextPredictor[String, } /** @inheritdoc */ - def validateInputType(schema: StructType): Unit = { + override protected def validateInputType(schema: StructType): Unit = { validateType(schema($(inputCol)), StringType) } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/BaseImagePredictor.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/BaseImagePredictor.scala index 2578da9eb2f..18001c1bc1f 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/BaseImagePredictor.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/BaseImagePredictor.scala @@ -35,11 +35,11 @@ abstract class BaseImagePredictor[B](override val uid: String) extends BasePredi */ def setInputCols(value: Array[String]): this.type = set(inputCols, value) + setDefault(batchSize, 10) setDefault(inputClass, classOf[Image]) - setDefault(batchifier, "stack") /** @inheritdoc */ - def validateInputType(schema: StructType): Unit = { + override protected def validateInputType(schema: StructType): Unit = { assert($(inputCols).length == 6, "inputCols must have 6 columns") validateType(schema($(inputCols)(0)), StringType) validateType(schema($(inputCols)(1)), IntegerType) diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageClassifier.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageClassifier.scala index 49648927650..1b80681de3f 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageClassifier.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageClassifier.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructFiel import org.apache.spark.sql.{DataFrame, Dataset, Row} import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` +import scala.jdk.CollectionConverters.seqAsJavaListConverter /** * ImageClassifier performs image classification on images. @@ -61,8 +62,6 @@ class ImageClassifier(override val uid: String) extends BaseImagePredictor[Class setDefault(outputClass, classOf[Classifications]) setDefault(translatorFactory, new ImageClassificationTranslatorFactory()) - setDefault(applySoftmax, true) - setDefault(topK, 5) /** * Performs image classification on the provided dataset. @@ -76,22 +75,28 @@ class ImageClassifier(override val uid: String) extends BaseImagePredictor[Class /** @inheritdoc */ override def transform(dataset: Dataset[_]): DataFrame = { - arguments.put("applySoftmax", $(applySoftmax).toString) - arguments.put("topK", $(topK).toString) + if (isDefined(applySoftmax)) { + arguments.put("applySoftmax", $(applySoftmax).toString) + } + if (isDefined(topK)) { + arguments.put("topK", $(topK).toString) + } super.transform(dataset) } /** @inheritdoc */ override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val predictor = model.newPredictor() - iter.map(row => { - val image = ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), - ImageSchema.getWidth(row), ImageSchema.getHeight(row)) - val prediction = predictor.predict(image) - val top = prediction.topK[Classifications.Classification]($(topK)).map(item => item.toString) - Row.fromSeq(row.toSeq :+ Row(prediction.getClassNames.toArray, - prediction.getProbabilities.toArray, top)) - }) + iter.grouped($(batchSize)).flatMap { batch => + val inputs = batch.map(row => + ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), + ImageSchema.getWidth(row), ImageSchema.getHeight(row))).asJava + val output = predictor.batchPredict(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ Row(out.getClassNames.toArray(), out.getProbabilities.toArray(), + out.topK[Classifications.Classification]().map(_.toString))) + } + } } /** @inheritdoc */ @@ -99,7 +104,7 @@ class ImageClassifier(override val uid: String) extends BaseImagePredictor[Class val outputSchema = StructType(schema.fields :+ StructField($(outputCol), StructType(Seq(StructField("class_names", ArrayType(StringType)), StructField("probabilities", ArrayType(DoubleType)), - StructField("topK", ArrayType(StringType)))))) + StructField("top_k", ArrayType(StringType)))))) outputSchema } } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageEmbedder.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageEmbedder.scala index 8b472524bbb..8c787fe48a1 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageEmbedder.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ImageEmbedder.scala @@ -20,6 +20,9 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, ByteType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} +import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` +import scala.jdk.CollectionConverters.seqAsJavaListConverter + /** * ImageEmbedder performs image embedding on images. * @@ -53,11 +56,15 @@ class ImageEmbedder(override val uid: String) extends BaseImagePredictor[Array[B /** @inheritdoc */ override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val predictor = model.newPredictor() - iter.map(row => { - val image = ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), - ImageSchema.getWidth(row), ImageSchema.getHeight(row)) - Row.fromSeq(row.toSeq :+ predictor.predict(image)) - }) + iter.grouped($(batchSize)).flatMap { batch => + val inputs = batch.map(row => + ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), + ImageSchema.getWidth(row), ImageSchema.getHeight(row))).asJava + val output = predictor.batchPredict(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ out) + } + } } /** @inheritdoc */ diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/InstanceSegmenter.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/InstanceSegmenter.scala index 3e1c482f501..6d3e47cd838 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/InstanceSegmenter.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/InstanceSegmenter.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructFiel import org.apache.spark.sql.{DataFrame, Dataset, Row} import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` +import scala.jdk.CollectionConverters.seqAsJavaListConverter /** * InstanceSegmenter performs instance segmentation on images. @@ -41,6 +42,16 @@ class InstanceSegmenter(override val uid: String) extends BaseImagePredictor[Det */ def setOutputCol(value: String): this.type = set(outputCol, value) + /** + * Sets the batchSize parameter. Note that to enable batch predict by + * setting batch size greater than 1, we expect the input images to + * have the same size. + * + * @param value the value of the parameter + */ + override def setBatchSize(value: Int): this.type = set(batchSize, value) + + setDefault(batchSize, 1) setDefault(outputClass, classOf[DetectedObjects]) setDefault(translatorFactory, new InstanceSegmentationTranslatorFactory()) @@ -57,14 +68,16 @@ class InstanceSegmenter(override val uid: String) extends BaseImagePredictor[Det /** @inheritdoc */ override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val predictor = model.newPredictor() - iter.map(row => { - val image = ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), - ImageSchema.getWidth(row), ImageSchema.getHeight(row)) - val prediction = predictor.predict(image) - val boundingBoxes = prediction.items[DetectedObject].map(item => item.getBoundingBox.toString) - Row.fromSeq(row.toSeq :+ Row(prediction.getClassNames.toArray, - prediction.getProbabilities.toArray, boundingBoxes)) - }) + iter.grouped($(batchSize)).flatMap { batch => + val inputs = batch.map(row => + ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), + ImageSchema.getWidth(row), ImageSchema.getHeight(row))).asJava + val output = predictor.batchPredict(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ Row(out.getClassNames.toArray(), out.getProbabilities.toArray(), + out.items[DetectedObject]().map(_.getBoundingBox.toString))) + } + } } /** @inheritdoc */ @@ -72,7 +85,7 @@ class InstanceSegmenter(override val uid: String) extends BaseImagePredictor[Det val outputSchema = StructType(schema.fields :+ StructField($(outputCol), StructType(Seq(StructField("class_names", ArrayType(StringType)), StructField("probabilities", ArrayType(DoubleType)), - StructField("boundingBoxes", ArrayType(StringType)))))) + StructField("bounding_boxes", ArrayType(StringType)))))) outputSchema } } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ObjectDetector.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ObjectDetector.scala index 809de7ce74d..4f4184449bf 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ObjectDetector.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/ObjectDetector.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructFiel import org.apache.spark.sql.{DataFrame, Dataset, Row} import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` +import scala.jdk.CollectionConverters.seqAsJavaListConverter /** * ObjectDetector performs object detection on images. @@ -57,14 +58,16 @@ class ObjectDetector(override val uid: String) extends BaseImagePredictor[Detect /** @inheritdoc */ override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val predictor = model.newPredictor() - iter.map(row => { - val image = ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), - ImageSchema.getWidth(row), ImageSchema.getHeight(row)) - val prediction = predictor.predict(image) - val boundingBoxes = prediction.items[DetectedObject].map(item => item.getBoundingBox.toString) - Row.fromSeq(row.toSeq :+ Row(prediction.getClassNames.toArray, - prediction.getProbabilities.toArray, boundingBoxes)) - }) + iter.grouped($(batchSize)).flatMap { batch => + val inputs = batch.map(row => + ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), + ImageSchema.getWidth(row), ImageSchema.getHeight(row))).asJava + val output = predictor.batchPredict(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ Row(out.getClassNames.toArray(), out.getProbabilities.toArray(), + out.items[DetectedObject]().map(_.getBoundingBox.toString))) + } + } } /** @inheritdoc */ @@ -72,7 +75,7 @@ class ObjectDetector(override val uid: String) extends BaseImagePredictor[Detect val outputSchema = StructType(schema.fields :+ StructField($(outputCol), StructType(Seq(StructField("class_names", ArrayType(StringType)), StructField("probabilities", ArrayType(DoubleType)), - StructField("boundingBoxes", ArrayType(StringType)))))) + StructField("bounding_boxes", ArrayType(StringType)))))) outputSchema } } diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/SemanticSegmenter.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/SemanticSegmenter.scala index 0b90fc19b62..af2347ad306 100644 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/vision/SemanticSegmenter.scala +++ b/extensions/spark/src/main/scala/ai/djl/spark/task/vision/SemanticSegmenter.scala @@ -21,6 +21,9 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row} +import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` +import scala.jdk.CollectionConverters.seqAsJavaListConverter + /** * SemanticSegmenter performs semantic segmentation on images. * @@ -38,6 +41,16 @@ class SemanticSegmenter(override val uid: String) extends BaseImagePredictor[Cat */ def setOutputCol(value: String): this.type = set(outputCol, value) + /** + * Sets the batchSize parameter. Note that to enable batch predict by + * setting batch size greater than 1, we expect the input images to + * have the same size. + * + * @param value the value of the parameter + */ + override def setBatchSize(value: Int): this.type = set(batchSize, value) + + setDefault(batchSize, 1) setDefault(outputClass, classOf[CategoryMask]) setDefault(translatorFactory, new SemanticSegmentationTranslatorFactory()) @@ -54,12 +67,15 @@ class SemanticSegmenter(override val uid: String) extends BaseImagePredictor[Cat /** @inheritdoc */ override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { val predictor = model.newPredictor() - iter.map(row => { - val image = ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), - ImageSchema.getWidth(row), ImageSchema.getHeight(row)) - val prediction = predictor.predict(image) - Row.fromSeq(row.toSeq :+ Row(prediction.getClasses.toArray, prediction.getMask)) - }) + iter.grouped($(batchSize)).flatMap { batch => + val inputs = batch.map(row => + ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)), + ImageSchema.getWidth(row), ImageSchema.getHeight(row))).asJava + val output = predictor.batchPredict(inputs) + batch.zip(output).map { case (row, out) => + Row.fromSeq(row.toSeq :+ Row(out.getClasses.toArray, out.getMask)) + } + } } /** @inheritdoc */