Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1021381: Remove pkg_resource #2371

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
- Added support for `Session.stored_procedure_profiler`.

#### Improvements
- Removed the use of `pkg_resources`, which is deprecated. No more deprecation warnings about it. Used `packaging` and `importlib.metadata` to replace `pkg_resources`.


#### Bug Fixes

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"pyyaml",
"cloudpickle>=1.6.0,<=2.2.1,!=2.1.0,!=2.2.0;python_version<'3.11'",
"cloudpickle==2.2.1;python_version~='3.11'", # backend only supports cloudpickle 2.2.1 + python 3.11 at the moment
"packaging",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need an approval from legal team?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

]
REQUIRED_PYTHON_VERSION = ">=3.8, <3.12"

Expand Down
62 changes: 45 additions & 17 deletions src/snowflake/snowpark/_internal/packaging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
import subprocess
import sys
import zipfile
from importlib.metadata import PackageNotFoundError, distribution
from logging import getLogger
from pathlib import Path
from typing import AnyStr, Dict, List, Optional, Set, Tuple

import pkg_resources
import yaml
from pkg_resources import Requirement
from packaging.requirements import Requirement
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion

_logger = getLogger(__name__)
PIP_ENVIRONMENT_VARIABLE: str = "PIP_NAME"
Expand All @@ -34,6 +36,36 @@
}


def get_distribution_version(name: str) -> Optional[str]:
"""
Get the distribution of a package.

Args:
name (str): The name of the package.

Returns:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It looks duplicate with description, so may not needed

Optional[str]: The distribution of the package.
"""
return distribution(name).version


def contains_version(specifier: SpecifierSet, version: str) -> bool:
"""
Check if a requirement contains a specific version.

Args:
specifier (SpecifierSet): The requirement to check.
version (str): The version to check for.

Returns:
bool: True if the requirement contains the version, False otherwise.
"""
try:
return specifier.contains(version)
except InvalidVersion:
return False


def parse_requirements_text_file(file_path: str) -> Tuple[List[str], List[str]]:
"""
Parses a requirements.txt file to obtain a list of packages and file/folder imports. Returns a tuple of packages
Expand Down Expand Up @@ -223,7 +255,7 @@ def map_python_packages_to_files_and_folders(

# Create Requirement objects and store in map
package_name_to_record_entries_map[
Requirement.parse(package)
Requirement(package)
] = included_record_entries

return package_name_to_record_entries_map
Expand Down Expand Up @@ -264,20 +296,18 @@ def identify_supported_packages(

for package in packages:
package_name: str = package.name
package_version_required: Optional[str] = (
package.specs[0][1] if package.specs else None
package_specifier: Optional[SpecifierSet] = (
package.specifier if package.specifier else None
)
version_text = (
f"(version {package_version_required})"
if package_version_required is not None
else ""
f"(version {package_specifier})" if package_specifier is not None else ""
)

if package_name in valid_packages:
# Detect supported packages
if (
package_version_required is None
or package_version_required in valid_packages[package_name]
if package_specifier is None or any(
contains_version(package_specifier, x)
for x in valid_packages[package_name]
):
supported_dependencies.append(package)
_logger.info(
Expand All @@ -291,7 +321,7 @@ def identify_supported_packages(
f"Package {package_name}{version_text} contains native code, switching to latest available version "
f"in Snowflake instead."
)
new_dependencies.append(Requirement.parse(package_name))
new_dependencies.append(Requirement(package_name))
dropped_dependencies.append(package)

else:
Expand Down Expand Up @@ -483,14 +513,12 @@ def add_snowpark_package(
channel.

Raises:
pkg_resources.DistributionNotFound: If the Snowpark Python Package is not installed in the local environment.
importlib.metadata.PackageNotFoundError: If the Snowpark Python Package is not installed in the local environment.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need this in the change log for this error behavior change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an internal API. But we do need to check if any public APIs expose it.

"""
if SNOWPARK_PACKAGE_NAME not in package_dict:
package_dict[SNOWPARK_PACKAGE_NAME] = SNOWPARK_PACKAGE_NAME
try:
package_client_version = pkg_resources.get_distribution(
SNOWPARK_PACKAGE_NAME
).version
package_client_version = get_distribution_version(SNOWPARK_PACKAGE_NAME)
if package_client_version in valid_packages[SNOWPARK_PACKAGE_NAME]:
package_dict[
SNOWPARK_PACKAGE_NAME
Expand All @@ -501,7 +529,7 @@ def add_snowpark_package(
f"{package_client_version}, which is not available in Snowflake. Your UDF might not work when "
f"the package version is different between the server and your local environment."
)
except pkg_resources.DistributionNotFound:
except PackageNotFoundError:
_logger.warning(
f"Package '{SNOWPARK_PACKAGE_NAME}' is not installed in the local environment. "
f"Your UDF might not work when the package is installed on the server "
Expand Down
83 changes: 35 additions & 48 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import warnings
from array import array
from functools import reduce
from importlib.metadata import PackageNotFoundError, distribution, distributions
from logging import getLogger
from threading import RLock
from types import ModuleType
Expand All @@ -33,7 +34,7 @@
)

import cloudpickle
import pkg_resources
from packaging.requirements import Requirement

from snowflake.connector import ProgrammingError, SnowflakeConnection
from snowflake.connector.options import installed_pandas, pandas
Expand Down Expand Up @@ -64,6 +65,7 @@
DEFAULT_PACKAGES,
ENVIRONMENT_METADATA_FILE_NAME,
IMPLICIT_ZIP_FILE_NAME,
contains_version,
delete_files_belonging_to_packages,
detect_native_dependencies,
get_signature,
Expand Down Expand Up @@ -1249,7 +1251,7 @@ def remove_package(self, package: str) -> None:
>>> len(session.get_packages())
0
"""
package_name = pkg_resources.Requirement.parse(package).key
package_name = Requirement(package).name
if package_name in self._packages:
self._packages.pop(package_name)
else:
Expand Down Expand Up @@ -1369,64 +1371,49 @@ def replicate_local_environment(
ignore_packages = {} if ignore_packages is None else ignore_packages

packages = []
for package in pkg_resources.working_set:
if package.key in ignore_packages:
_logger.info(f"{package.key} found in environment, ignoring...")
for package in distributions():
if package.name in ignore_packages:
_logger.info(f"{package.name} found in environment, ignoring...")
continue
if package.key in DEFAULT_PACKAGES:
_logger.info(f"{package.key} is available by default, ignoring...")
if package.name in DEFAULT_PACKAGES:
_logger.info(f"{package.name} is availna meble by default, ignoring...")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
_logger.info(f"{package.name} is availna meble by default, ignoring...")
_logger.info(f"{package.name} is available by default, ignoring...")

continue
version_text = (
"==" + package.version if package.has_version() and not relax else ""
"==" + package.version if package.version and not relax else ""
)
packages.append(f"{package.key}{version_text}")
packages.append(f"{package.name}{version_text}")

self.add_packages(packages)

@staticmethod
def _parse_packages(
packages: List[Union[str, ModuleType]]
) -> Dict[str, Tuple[str, bool, pkg_resources.Requirement]]:
) -> Dict[str, Tuple[str, bool, Requirement]]:
package_dict = dict()
for package in packages:
if isinstance(package, ModuleType):
package_name = MODULE_NAME_TO_PACKAGE_NAME_MAP.get(
package.__name__, package.__name__
)
package = f"{package_name}=={pkg_resources.get_distribution(package_name).version}"
package = f"{package_name}=={distribution(package_name).version}"
use_local_version = True
else:
package = package.strip().lower()
if package.startswith("#"):
continue
use_local_version = False
package_req = pkg_resources.Requirement.parse(package)
# get the standard package name if there is no underscore
# underscores are discouraged in package names, but are still used in Anaconda channel
# pkg_resources.Requirement.parse will convert all underscores to dashes
# the regexp is to deal with case that "_" is in the package requirement as well as version restrictions
# we only extract the valid package name from the string by following:
# https://packaging.python.org/en/latest/specifications/name-normalization/
# A valid name consists only of ASCII letters and numbers, period, underscore and hyphen.
# It must start and end with a letter or number.
# however, we don't validate the pkg name as this is done by pkg_resources.Requirement.parse
# find the index of the first char which is not an valid package name character
package_name = package_req.key
if not use_local_version and "_" in package:
reg_match = re.search(r"[^0-9a-zA-Z\-_.]", package)
package_name = package[: reg_match.start()] if reg_match else package

package_dict[package] = (package_name, use_local_version, package_req)
package_req = Requirement(package)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it's directly handled?

package_dict[package] = (package_req.name, use_local_version, package_req)
return package_dict

def _get_dependency_packages(
self,
package_dict: Dict[str, Tuple[str, bool, pkg_resources.Requirement]],
package_dict: Dict[str, Tuple[str, bool, Requirement]],
validate_package: bool,
package_table: str,
current_packages: Dict[str, str],
statement_params: Optional[Dict[str, str]] = None,
) -> List[pkg_resources.Requirement]:
) -> List[Requirement]:
# Keep track of any package errors
errors = []

Expand All @@ -1442,16 +1429,19 @@ def _get_dependency_packages(
unsupported_packages: List[str] = []
for package, package_info in package_dict.items():
package_name, use_local_version, package_req = package_info
package_version_req = package_req.specs[0][1] if package_req.specs else None
package_specifier = package_req.specifier if package_req.specifier else None

if validate_package:
if package_name not in valid_packages or (
package_version_req
and not any(v in package_req for v in valid_packages[package_name])
package_specifier
and not any(
contains_version(package_specifier, v)
for v in valid_packages[package_name]
)
):
version_text = (
f"(version {package_version_req})"
if package_version_req is not None
f"(version {package_specifier})"
if package_specifier is not None
else ""
)
if is_in_stored_procedure(): # pragma: no cover
Expand Down Expand Up @@ -1491,17 +1481,15 @@ def _get_dependency_packages(
continue
elif not use_local_version:
try:
package_client_version = pkg_resources.get_distribution(
package_name
).version
package_client_version = distribution(package_name).version
if package_client_version not in valid_packages[package_name]:
_logger.warning(
f"The version of package '{package_name}' in the local environment is "
f"{package_client_version}, which does not fit the criteria for the "
f"requirement '{package}'. Your UDF might not work when the package version "
f"is different between the server and your local environment."
)
except pkg_resources.DistributionNotFound:
except PackageNotFoundError:
_logger.warning(
f"Package '{package_name}' is not installed in the local environment. "
f"Your UDF might not work when the package is installed on the server "
Expand Down Expand Up @@ -1531,7 +1519,7 @@ def _get_dependency_packages(
elif len(errors) > 0:
raise RuntimeError(errors)

dependency_packages: List[pkg_resources.Requirement] = []
dependency_packages: List[Requirement] = []
if len(unsupported_packages) != 0:
_logger.warning(
f"The following packages are not available in Snowflake: {unsupported_packages}."
Expand Down Expand Up @@ -1654,14 +1642,14 @@ def _resolve_packages(
# Add dependency packages
for package in dependency_packages:
name = package.name
version = package.specs[0][1] if package.specs else None
version = package.specifier if package.specifier else None

if name in result_dict:
if version is not None:
added_package_has_version = "==" in result_dict[name]
if added_package_has_version and result_dict[name] != str(package):
raise ValueError(
f"Cannot add dependency package '{name}=={version}' "
f"Cannot add dependency package '{name}{version}' "
f"because {result_dict[name]} is already added."
)
result_dict[name] = str(package)
Expand All @@ -1685,7 +1673,7 @@ def _upload_unsupported_packages(
package_table: str,
package_dict: Dict[str, str],
custom_package_usage_config: Dict[str, Any],
) -> List[pkg_resources.Requirement]:
) -> List[Requirement]:
"""
Uploads a list of Pypi packages, which are unavailable in Snowflake, to session stage.

Expand All @@ -1696,7 +1684,7 @@ def _upload_unsupported_packages(
been added explicitly so far using add_packages() or other such methods.

Returns:
List[pkg_resources.Requirement]: List of package dependencies (present in Snowflake) that would need to be added
List[packaging.requirements.Requirement]: List of package dependencies (present in Snowflake) that would need to be added
to the package dictionary.

Raises:
Expand Down Expand Up @@ -1843,7 +1831,7 @@ def _is_anaconda_terms_acknowledged(self) -> bool:

def _load_unsupported_packages_from_stage(
self, environment_signature: str, cache_path: str
) -> List[pkg_resources.Requirement]:
) -> List[Requirement]:
"""
Uses specified stage path to auto-import a group of unsupported packages, along with its dependencies. This
saves time spent on pip install, native package detection and zip upload to stage.
Expand All @@ -1864,7 +1852,7 @@ def _load_unsupported_packages_from_stage(
environment_signature (str): Unique hash signature for a set of unsupported packages, computed by hashing
a sorted tuple of unsupported package requirements (package versioning included).
Returns:
Optional[List[pkg_resources.Requirement]]: A list of package dependencies for the set of unsupported packages requested.
Optional[List[packaging.requirements.Requirement]]: A list of package dependencies for the set of unsupported packages requested.
"""
# Ensure that metadata file exists
metadata_file = f"{ENVIRONMENT_METADATA_FILE_NAME}.txt"
Expand Down Expand Up @@ -1897,8 +1885,7 @@ def _load_unsupported_packages_from_stage(
}

dependency_packages = [
pkg_resources.Requirement.parse(package)
for package in metadata[environment_signature]
Requirement(package) for package in metadata[environment_signature]
]
_logger.info(
f"Loading dependency packages list - {metadata[environment_signature]}."
Expand Down
19 changes: 19 additions & 0 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,3 +1991,22 @@ def test_register_sproc_after_switch_schema(session):
Utils.drop_database(session, db)
session.use_database(current_database)
session.use_schema(current_schema)


def test_register_sproc_with_package_underscore_name(session):
"""Test registering a stored procedure with a package that has an underscore in its name.
It passes if no exception is thrown.
"""
session.sproc.register(
lambda session_, x, y: session_.create_dataframe([[x + y]]).collect()[0][0],
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
packages=["huggingface_hub"],
)

session.sproc.register(
lambda session_, x, y: session_.create_dataframe([[x + y]]).collect()[0][0],
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
packages=["huggingface_hub>0.15.1"],
)
Loading
Loading