Skip to content

Commit

Permalink
[spark] Format python
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed May 2, 2023
1 parent a33d70a commit 7694e78
Show file tree
Hide file tree
Showing 23 changed files with 91 additions and 92 deletions.
1 change: 0 additions & 1 deletion extensions/spark/setup/djl_spark/task/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

"""DJL Spark Tasks Audio API."""

from . import (
Expand Down
14 changes: 7 additions & 7 deletions extensions/spark/setup/djl_spark/task/audio/speech_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ def recognize(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
recognizer = (
sc._jvm.ai.djl.spark.task.audio.SpeechRecognizer()
.setInputCol(self.input_col)
.setOutputCol(self.output_col)
recognizer = sc._jvm.ai.djl.spark.task.audio.SpeechRecognizer() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
)
if self.engine is not None:
recognizer = recognizer.setEngine(self.engine)
if self.batch_size is not None:
recognizer = recognizer.setBatchSize(self.batch_size)
if self.translator_factory is not None:
recognizer = recognizer.setTranslatorFactory(self.translator_factory)
recognizer = recognizer.setTranslatorFactory(
self.translator_factory)
if self.batchifier is not None:
recognizer = recognizer.setBatchifier(self.batchifier)
return DataFrame(recognizer.recognize(dataset._jdf), dataset.sparkSession)
return DataFrame(recognizer.recognize(dataset._jdf),
dataset.sparkSession)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from transformers import pipeline
from ...util import files_util, dependency_util


TASK = "automatic-speech-recognition"
APPLICATION = "audio/automatic_speech_recognition"
GROUP_ID = "ai/djl/huggingface/pytorch"
Expand Down Expand Up @@ -65,22 +64,31 @@ def recognize(self, dataset, generate_kwargs=None, **kwargs):
raise ValueError("Only PyTorch engine is supported.")

if self.model_url:
cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, self.model_url)
cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID,
self.model_url)
files_util.download_and_extract(self.model_url, cache_dir)
dependency_util.install(cache_dir)
model_id_or_path = cache_dir
elif self.hf_model_id:
model_id_or_path = self.hf_model_id
else:
raise ValueError("Either model_url or hf_model_id must be provided.")
raise ValueError(
"Either model_url or hf_model_id must be provided.")

@pandas_udf(StringType())
def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
pipe = pipeline(TASK, generate_kwargs=generate_kwargs, model=model_id_or_path,
batch_size=self.batch_size, chunk_length_s=30, **kwargs)
pipe = pipeline(TASK,
generate_kwargs=generate_kwargs,
model=model_id_or_path,
batch_size=self.batch_size,
chunk_length_s=30,
**kwargs)
for s in iterator:
# Model expects single channel, 16000 sample rate audio
batch = [librosa.load(io.BytesIO(d), mono=True, sr=16000)[0] for d in s]
batch = [
librosa.load(io.BytesIO(d), mono=True, sr=16000)[0]
for d in s
]
output = pipe(batch)
text = [o["text"] for o in output]
yield pd.Series(text)
Expand Down
1 change: 0 additions & 1 deletion extensions/spark/setup/djl_spark/task/binary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

"""DJL Spark Tasks Binary API."""

from . import binary_predictor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,10 @@ def predict(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
predictor = (
sc._jvm.ai.djl.spark.task.binary.BinaryPredictor()
.setInputCol(self.input_col)
.setOutputCol(self.output_col)
predictor = sc._jvm.ai.djl.spark.task.binary.BinaryPredictor() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
)
if self.engine is not None:
predictor = predictor.setEngine(self.engine)
if self.batch_size is not None:
Expand Down
1 change: 0 additions & 1 deletion extensions/spark/setup/djl_spark/task/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

"""DJL Spark Tasks Text API."""

from . import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,9 @@ def answer(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
answerer = (
sc._jvm.ai.djl.spark.task.text.QuestionAnswerer()
.setOutputCol(self.output_col)
answerer = sc._jvm.ai.djl.spark.task.text.QuestionAnswerer() \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
)
if self.input_cols is not None:
# Convert the input_cols to Java array
input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,23 @@ def generate(self, dataset, **kwargs):
raise ValueError("Only PyTorch engine is supported.")

if self.model_url:
cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, self.model_url)
cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID,
self.model_url)
files_util.download_and_extract(self.model_url, cache_dir)
dependency_util.install(cache_dir)
model_id_or_path = cache_dir
elif self.hf_model_id:
model_id_or_path = self.hf_model_id
else:
raise ValueError("Either model_url or hf_model_id must be provided.")
raise ValueError(
"Either model_url or hf_model_id must be provided.")

@pandas_udf(StringType())
def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
pipe = pipeline(TASK, model=model_id_or_path, batch_size=self.batch_size, **kwargs)
pipe = pipeline(TASK,
model=model_id_or_path,
batch_size=self.batch_size,
**kwargs)
for s in iterator:
output = pipe(s.tolist())
text = [o["generated_text"] for o in output]
Expand Down
14 changes: 7 additions & 7 deletions extensions/spark/setup/djl_spark/task/text/text_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ def classify(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
classifier = (
sc._jvm.ai.djl.spark.task.text.TextClassifier()
.setInputCol(self.input_col)
.setOutputCol(self.output_col)
classifier = sc._jvm.ai.djl.spark.task.text.TextClassifier() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
)
if self.engine is not None:
classifier = classifier.setEngine(self.engine)
if self.batch_size is not None:
classifier = classifier.setBatchSize(self.batch_size)
if self.translator_factory is not None:
classifier = classifier.setTranslatorFactory(self.translator_factory)
classifier = classifier.setTranslatorFactory(
self.translator_factory)
if self.batchifier is not None:
classifier = classifier.setBatchifier(self.batchifier)
return DataFrame(classifier.classify(dataset._jdf), dataset.sparkSession)
return DataFrame(classifier.classify(dataset._jdf),
dataset.sparkSession)
8 changes: 3 additions & 5 deletions extensions/spark/setup/djl_spark/task/text/text_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ def decode(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
decoder = (
sc._jvm.ai.djl.spark.task.text.TextDecoder()
.setInputCol(self.input_col)
.setOutputCol(self.output_col)
decoder = sc._jvm.ai.djl.spark.task.text.TextDecoder() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setHfModelId(self.hf_model_id)
)
if self.batch_size is not None:
decoder = decoder.setBatchSize(self.batch_size)
return DataFrame(decoder.decode(dataset._jdf), dataset.sparkSession)
8 changes: 3 additions & 5 deletions extensions/spark/setup/djl_spark/task/text/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,10 @@ def embed(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
embedder = (
sc._jvm.ai.djl.spark.task.text.TextEmbedder()
.setInputCol(self.input_col)
.setOutputCol(self.output_col)
embedder = sc._jvm.ai.djl.spark.task.text.TextEmbedder() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
)
if self.engine is not None:
embedder = embedder.setEngine(self.engine)
if self.batch_size is not None:
Expand Down
8 changes: 3 additions & 5 deletions extensions/spark/setup/djl_spark/task/text/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ def encode(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
encoder = (
sc._jvm.ai.djl.spark.task.text.TextEncoder()
.setInputCol(self.input_col)
.setOutputCol(self.output_col)
encoder = sc._jvm.ai.djl.spark.task.text.TextEncoder() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setHfModelId(self.hf_model_id)
)
if self.batch_size is not None:
encoder = encoder.setBatchSize(self.batch_size)
return DataFrame(encoder.encode(dataset._jdf), dataset.sparkSession)
12 changes: 8 additions & 4 deletions extensions/spark/setup/djl_spark/task/text/text_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from transformers import pipeline
from ...util import files_util, dependency_util


TASK = "text-generation"
APPLICATION = "nlp/text_generation"
GROUP_ID = "ai/djl/huggingface/pytorch"
Expand Down Expand Up @@ -61,18 +60,23 @@ def generate(self, dataset, **kwargs):
raise ValueError("Only PyTorch engine is supported.")

if self.model_url:
cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, self.model_url)
cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID,
self.model_url)
files_util.download_and_extract(self.model_url, cache_dir)
dependency_util.install(cache_dir)
model_id_or_path = cache_dir
elif self.hf_model_id:
model_id_or_path = self.hf_model_id
else:
raise ValueError("Either model_url or hf_model_id must be provided.")
raise ValueError(
"Either model_url or hf_model_id must be provided.")

@pandas_udf(StringType())
def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
pipe = pipeline(TASK, model=model_id_or_path, batch_size=self.batch_size, **kwargs)
pipe = pipeline(TASK,
model=model_id_or_path,
batch_size=self.batch_size,
**kwargs)
for s in iterator:
output = pipe(s.tolist())
text = [o[0]["generated_text"] for o in output]
Expand Down
16 changes: 6 additions & 10 deletions extensions/spark/setup/djl_spark/task/text/text_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@

class TextTokenizer:

def __init__(self,
input_col: str,
output_col: str,
hf_model_id: str):
def __init__(self, input_col: str, output_col: str, hf_model_id: str):
"""
Initializes the TextTokenizer.
Expand All @@ -40,10 +37,9 @@ def tokenize(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
tokenizer = (
sc._jvm.ai.djl.spark.task.text.TextTokenizer()
.setInputCol(self.input_col)
.setOutputCol(self.output_col)
tokenizer = sc._jvm.ai.djl.spark.task.text.TextTokenizer() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setHfModelId(self.hf_model_id)
)
return DataFrame(tokenizer.tokenize(dataset._jdf), dataset.sparkSession)
return DataFrame(tokenizer.tokenize(dataset._jdf),
dataset.sparkSession)
1 change: 0 additions & 1 deletion extensions/spark/setup/djl_spark/task/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

"""DJL Spark Tasks Vision API."""

from . import (
Expand Down
12 changes: 6 additions & 6 deletions extensions/spark/setup/djl_spark/task/vision/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ def classify(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
classifier = (
sc._jvm.ai.djl.spark.task.vision.ImageClassifier()
.setOutputCol(self.output_col)
classifier = sc._jvm.ai.djl.spark.task.vision.ImageClassifier() \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
)
if self.input_cols is not None:
# Convert the input_cols to Java array
input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String,
Expand All @@ -79,11 +77,13 @@ def classify(self, dataset):
if self.batch_size is not None:
classifier = classifier.setBatchSize(self.batch_size)
if self.translator_factory is not None:
classifier = classifier.setTranslatorFactory(self.translator_factory)
classifier = classifier.setTranslatorFactory(
self.translator_factory)
if self.batchifier is not None:
classifier = classifier.setBatchifier(self.batchifier)
if self.apply_softmax is not None:
classifier = classifier.setApplySoftmax(self.apply_softmax)
if self.top_k is not None:
classifier = classifier.setTopK(self.top_k)
return DataFrame(classifier.classify(dataset._jdf), dataset.sparkSession)
return DataFrame(classifier.classify(dataset._jdf),
dataset.sparkSession)
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,9 @@ def embed(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
embedder = (
sc._jvm.ai.djl.spark.task.vision.ImageEmbedder()
.setOutputCol(self.output_col)
embedder = sc._jvm.ai.djl.spark.task.vision.ImageEmbedder() \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
)
if self.input_cols is not None:
# Convert the input_cols to Java array
input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ def segment(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
segmenter = (
sc._jvm.ai.djl.spark.task.vision.InstanceSegmenter()
.setOutputCol(self.output_col)
segmenter = sc._jvm.ai.djl.spark.task.vision.InstanceSegmenter() \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
)
if self.input_cols is not None:
# Convert the input_cols to Java array
input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,9 @@ def detect(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
detector = (
sc._jvm.ai.djl.spark.task.vision.ObjectDetector()
.setOutputCol(self.output_col)
detector = sc._jvm.ai.djl.spark.task.vision.ObjectDetector() \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
)
if self.input_cols is not None:
# Convert the input_cols to Java array
input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ def segment(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
segmenter = (
sc._jvm.ai.djl.spark.task.vision.SemanticSegmenter()
.setOutputCol(self.output_col)
segmenter = sc._jvm.ai.djl.spark.task.vision.SemanticSegmenter() \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
)
if self.input_cols is not None:
# Convert the input_cols to Java array
input_cols_arr = sc._gateway.new_array(sc._jvm.java.lang.String,
Expand Down
Loading

0 comments on commit 7694e78

Please sign in to comment.