Skip to content

Commit

Permalink
Check that the Java / Scala package is installed when needed (#250)
Browse files Browse the repository at this point in the history
Provides a meaningful error message when user accesses a spark extension
function in Python that requires the Java / Scala package:

RuntimeError: Java / Scala package not found! You need to add the Maven
spark-extension package to your PySpark environment:
https:/G-Research/spark-extension#python

Before, the error was:

    TypeError: 'JavaPackage' object is not callable

Improves #242
Supersedes #244
  • Loading branch information
EnricoMi authored Aug 16, 2024
1 parent 4f4838d commit e9e508d
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 12 deletions.
50 changes: 40 additions & 10 deletions python/gresearch/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,56 @@
if TYPE_CHECKING:
from pyspark.sql._typing import ColumnOrName

_java_pkg_is_installed: Optional[bool] = None


def _check_java_pkg_is_installed(jvm: JVMView) -> bool:
"""Check that the Java / Scala package is installed."""
try:
jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").VersionString()
return True
except TypeError as e:
print(e.args)
return False
except:
# any other exception indicate some problem, be safe and do not fail fast here
return True


def _get_jvm(obj: Any) -> JVMView:
"""
Provides easy access to the JVMView provided by Spark, and raises meaningful error message if that is not available.
Also checks that the Java / Scala package is accessible via this JVMView.
"""
if obj is None:
if SparkContext._active_spark_context is None:
raise RuntimeError("This method must be called inside an active Spark session")
else:
raise ValueError("Cannot provide access to JVM from None")

# helper method to assert the JVM is accessible and provide a useful error message
if has_connect and isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession)):
raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. https:/G-Research/spark-extension#spark-connect-server')
raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. '
'https:/G-Research/spark-extension#spark-connect-server')

if isinstance(obj, DataFrame):
return _get_jvm(obj._sc)
if isinstance(obj, DataFrameReader):
return _get_jvm(obj._spark)
if isinstance(obj, SparkSession):
return _get_jvm(obj.sparkContext)
if isinstance(obj, (SparkContext, SQLContext)):
return obj._jvm
raise RuntimeError(f'Unsupported class: {type(obj)}')
jvm = _get_jvm(obj._sc)
elif isinstance(obj, DataFrameReader):
jvm = _get_jvm(obj._spark)
elif isinstance(obj, SparkSession):
jvm = _get_jvm(obj.sparkContext)
elif isinstance(obj, (SparkContext, SQLContext)):
jvm = obj._jvm
else:
raise RuntimeError(f'Unsupported class: {type(obj)}')

global _java_pkg_is_installed
if _java_pkg_is_installed is None:
_java_pkg_is_installed = _check_java_pkg_is_installed(jvm)
if not _java_pkg_is_installed:
raise RuntimeError("Java / Scala package not found! You need to add the Maven spark-extension package "
"to your PySpark environment: https:/G-Research/spark-extension#python")

return jvm


def _to_seq(jvm: JVMView, list: List[Any]) -> JavaObject:
Expand Down
21 changes: 19 additions & 2 deletions python/test/test_jvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

from pyspark.sql.functions import sum

from gresearch.spark import _get_jvm, dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \
timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, histogram, job_description, append_description
from gresearch.spark import _get_jvm, \
dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \
timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, \
histogram, job_description, append_description
from gresearch.spark.diff import *
from gresearch.spark.parquet import *
from spark_common import SparkTest
Expand Down Expand Up @@ -58,6 +60,21 @@ def test_get_jvm_connect(self):
_get_jvm(object())
self.assertEqual(("Unsupported class: <class 'object'>", ), e.exception.args)

@skipIf(SparkTest.is_spark_connect, "Spark classic client tests")
def test_get_jvm_check_java_pkg_is_installed(self):
from gresearch import spark

is_installed = spark._java_pkg_is_installed

try:
spark._java_pkg_is_installed = False
with self.assertRaises(RuntimeError) as e:
_get_jvm(self.spark)
self.assertEqual(("Java / Scala package not found! You need to add the Maven spark-extension package "
"to your PySpark environment: https:/G-Research/spark-extension#python", ), e.exception.args)
finally:
spark._java_pkg_is_installed = is_installed

@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_diff(self):
for label, func in {
Expand Down

0 comments on commit e9e508d

Please sign in to comment.