From 447c7edaca6315540e2ac4d52743939481dde139 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Sun, 29 Sep 2024 21:49:06 -0700 Subject: [PATCH 1/7] Remove pkg_resource --- setup.py | 1 + .../snowpark/_internal/packaging_utils.py | 42 +++++++++++------ src/snowflake/snowpark/session.py | 42 ++++++++--------- tests/unit/test_packaging_utils.py | 46 +++++++++---------- 4 files changed, 73 insertions(+), 58 deletions(-) diff --git a/setup.py b/setup.py index 12857b9e3c..584343111f 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ "pyyaml", "cloudpickle>=1.6.0,<=2.2.1,!=2.1.0,!=2.2.0;python_version<'3.11'", "cloudpickle==2.2.1;python_version~='3.11'", # backend only supports cloudpickle 2.2.1 + python 3.11 at the moment + "packaging", ] REQUIRED_PYTHON_VERSION = ">=3.8, <3.12" diff --git a/src/snowflake/snowpark/_internal/packaging_utils.py b/src/snowflake/snowpark/_internal/packaging_utils.py index ace8355468..61021c2ced 100644 --- a/src/snowflake/snowpark/_internal/packaging_utils.py +++ b/src/snowflake/snowpark/_internal/packaging_utils.py @@ -15,9 +15,12 @@ from pathlib import Path from typing import AnyStr, Dict, List, Optional, Set, Tuple -import pkg_resources import yaml -from pkg_resources import Requirement + +from packaging.requirements import Requirement +from importlib.metadata import distribution, PackageNotFoundError + +from packaging.specifiers import SpecifierSet _logger = getLogger(__name__) PIP_ENVIRONMENT_VARIABLE: str = "PIP_NAME" @@ -34,6 +37,19 @@ } +def get_distribution_version(name: str) -> Optional[str]: + """ + Get the distribution of a package. + + Args: + name (str): The name of the package. + + Returns: + Optional[str]: The distribution of the package. + """ + return distribution(name).version + + def parse_requirements_text_file(file_path: str) -> Tuple[List[str], List[str]]: """ Parses a requirements.txt file to obtain a list of packages and file/folder imports. Returns a tuple of packages @@ -223,7 +239,7 @@ def map_python_packages_to_files_and_folders( # Create Requirement objects and store in map package_name_to_record_entries_map[ - Requirement.parse(package) + Requirement(package) ] = included_record_entries return package_name_to_record_entries_map @@ -264,20 +280,20 @@ def identify_supported_packages( for package in packages: package_name: str = package.name - package_version_required: Optional[str] = ( - package.specs[0][1] if package.specs else None + package_specifier: Optional[SpecifierSet] = ( + package.specifier if package.specifier else None ) version_text = ( - f"(version {package_version_required})" - if package_version_required is not None + f"(version {package_specifier})" + if package_specifier is not None else "" ) if package_name in valid_packages: # Detect supported packages if ( - package_version_required is None - or package_version_required in valid_packages[package_name] + package_specifier is None + or any(package_specifier.contains(x) for x in valid_packages[package_name]) ): supported_dependencies.append(package) _logger.info( @@ -291,7 +307,7 @@ def identify_supported_packages( f"Package {package_name}{version_text} contains native code, switching to latest available version " f"in Snowflake instead." ) - new_dependencies.append(Requirement.parse(package_name)) + new_dependencies.append(Requirement(package_name)) dropped_dependencies.append(package) else: @@ -488,9 +504,9 @@ def add_snowpark_package( if SNOWPARK_PACKAGE_NAME not in package_dict: package_dict[SNOWPARK_PACKAGE_NAME] = SNOWPARK_PACKAGE_NAME try: - package_client_version = pkg_resources.get_distribution( + package_client_version = get_distribution_version( SNOWPARK_PACKAGE_NAME - ).version + ) if package_client_version in valid_packages[SNOWPARK_PACKAGE_NAME]: package_dict[ SNOWPARK_PACKAGE_NAME @@ -501,7 +517,7 @@ def add_snowpark_package( f"{package_client_version}, which is not available in Snowflake. Your UDF might not work when " f"the package version is different between the server and your local environment." ) - except pkg_resources.DistributionNotFound: + except PackageNotFoundError: _logger.warning( f"Package '{SNOWPARK_PACKAGE_NAME}' is not installed in the local environment. " f"Your UDF might not work when the package is installed on the server " diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index bc85a6096c..b1b2fd5f19 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -33,8 +33,8 @@ ) import cloudpickle -import pkg_resources - +from importlib.metadata import distribution, distributions, PackageNotFoundError +from packaging.requirements import Requirement from snowflake.connector import ProgrammingError, SnowflakeConnection from snowflake.connector.options import installed_pandas, pandas from snowflake.connector.pandas_tools import write_pandas @@ -1245,7 +1245,7 @@ def remove_package(self, package: str) -> None: >>> len(session.get_packages()) 0 """ - package_name = pkg_resources.Requirement.parse(package).key + package_name = Requirement(package).name if package_name in self._packages: self._packages.pop(package_name) else: @@ -1365,38 +1365,38 @@ def replicate_local_environment( ignore_packages = {} if ignore_packages is None else ignore_packages packages = [] - for package in pkg_resources.working_set: - if package.key in ignore_packages: - _logger.info(f"{package.key} found in environment, ignoring...") + for package in distributions(): + if package.name in ignore_packages: + _logger.info(f"{package.name} found in environment, ignoring...") continue - if package.key in DEFAULT_PACKAGES: - _logger.info(f"{package.key} is available by default, ignoring...") + if package.name in DEFAULT_PACKAGES: + _logger.info(f"{package.name} is availna meble by default, ignoring...") continue version_text = ( - "==" + package.version if package.has_version() and not relax else "" + "==" + package.version if package.version and not relax else "" ) - packages.append(f"{package.key}{version_text}") + packages.append(f"{package.name}{version_text}") self.add_packages(packages) @staticmethod def _parse_packages( packages: List[Union[str, ModuleType]] - ) -> Dict[str, Tuple[str, bool, pkg_resources.Requirement]]: + ) -> Dict[str, Tuple[str, bool, Requirement]]: package_dict = dict() for package in packages: if isinstance(package, ModuleType): package_name = MODULE_NAME_TO_PACKAGE_NAME_MAP.get( package.__name__, package.__name__ ) - package = f"{package_name}=={pkg_resources.get_distribution(package_name).version}" + package = f"{package_name}=={distribution(package_name).version}" use_local_version = True else: package = package.strip().lower() if package.startswith("#"): continue use_local_version = False - package_req = pkg_resources.Requirement.parse(package) + package_req = Requirement(package) # get the standard package name if there is no underscore # underscores are discouraged in package names, but are still used in Anaconda channel # pkg_resources.Requirement.parse will convert all underscores to dashes @@ -1417,12 +1417,12 @@ def _parse_packages( def _get_dependency_packages( self, - package_dict: Dict[str, Tuple[str, bool, pkg_resources.Requirement]], + package_dict: Dict[str, Tuple[str, bool, Requirement]], validate_package: bool, package_table: str, current_packages: Dict[str, str], statement_params: Optional[Dict[str, str]] = None, - ) -> List[pkg_resources.Requirement]: + ) -> List[Requirement]: # Keep track of any package errors errors = [] @@ -1487,7 +1487,7 @@ def _get_dependency_packages( continue elif not use_local_version: try: - package_client_version = pkg_resources.get_distribution( + package_client_version = distribution( package_name ).version if package_client_version not in valid_packages[package_name]: @@ -1497,7 +1497,7 @@ def _get_dependency_packages( f"requirement '{package}'. Your UDF might not work when the package version " f"is different between the server and your local environment." ) - except pkg_resources.DistributionNotFound: + except PackageNotFoundError: _logger.warning( f"Package '{package_name}' is not installed in the local environment. " f"Your UDF might not work when the package is installed on the server " @@ -1527,7 +1527,7 @@ def _get_dependency_packages( elif len(errors) > 0: raise RuntimeError(errors) - dependency_packages: List[pkg_resources.Requirement] = [] + dependency_packages: List[Requirement] = [] if len(unsupported_packages) != 0: _logger.warning( f"The following packages are not available in Snowflake: {unsupported_packages}." @@ -1681,7 +1681,7 @@ def _upload_unsupported_packages( package_table: str, package_dict: Dict[str, str], custom_package_usage_config: Dict[str, Any], - ) -> List[pkg_resources.Requirement]: + ) -> List[Requirement]: """ Uploads a list of Pypi packages, which are unavailable in Snowflake, to session stage. @@ -1839,7 +1839,7 @@ def _is_anaconda_terms_acknowledged(self) -> bool: def _load_unsupported_packages_from_stage( self, environment_signature: str, cache_path: str - ) -> List[pkg_resources.Requirement]: + ) -> List[Requirement]: """ Uses specified stage path to auto-import a group of unsupported packages, along with its dependencies. This saves time spent on pip install, native package detection and zip upload to stage. @@ -1893,7 +1893,7 @@ def _load_unsupported_packages_from_stage( } dependency_packages = [ - pkg_resources.Requirement.parse(package) + Requirement(package) for package in metadata[environment_signature] ] _logger.info( diff --git a/tests/unit/test_packaging_utils.py b/tests/unit/test_packaging_utils.py index 98331194cb..942038c690 100644 --- a/tests/unit/test_packaging_utils.py +++ b/tests/unit/test_packaging_utils.py @@ -7,9 +7,9 @@ from subprocess import TimeoutExpired from unittest.mock import patch -import pkg_resources import pytest -from pkg_resources import Requirement +from importlib import metadata +from packaging.requirements import Requirement from snowflake.snowpark._internal.packaging_utils import ( SNOWPARK_PACKAGE_NAME, @@ -216,10 +216,10 @@ def test_identify_supported_packages_vanilla(): Assert that the most straightforward usage of identify_supported_packages() works """ packages = [ - Requirement.parse("package1==1.0.0"), - Requirement.parse("package2==2.0.0"), - Requirement.parse("package3"), - Requirement.parse("package4==2.1.2"), + Requirement("package1==1.0.0"), + Requirement("package2==2.0.0"), + Requirement("package3"), + Requirement("package4==2.1.2"), ] valid_packages = { "package1": ["1.0.0", "1.1.0"], @@ -238,7 +238,7 @@ def test_identify_supported_packages_vanilla(): assert len(dropped_deps) == 1 assert packages[3] in dropped_deps assert len(new_deps) == 1 - assert Requirement.parse("package4") in new_deps + assert Requirement("package4") in new_deps def test_identify_supported_packages_all_cases(): @@ -250,7 +250,7 @@ def test_identify_supported_packages_all_cases(): # Case 1: All packages supported native_packages = {"pandas"} - packages = [Requirement.parse("numpy==1.2"), Requirement.parse("pandas")] + packages = [Requirement("numpy==1.2"), Requirement("pandas")] supported, dropped, new = identify_supported_packages( packages, valid_packages, native_packages, {} ) @@ -261,29 +261,29 @@ def test_identify_supported_packages_all_cases(): # Case 2: One non-native package, version not supported native_packages = {"pandas"} - packages = [Requirement.parse("numpy==10.0"), Requirement.parse("pandas")] + packages = [Requirement("numpy==10.0"), Requirement("pandas")] supported, dropped, new = identify_supported_packages( packages, valid_packages, native_packages, {} ) - assert supported == [Requirement.parse("pandas")] + assert supported == [Requirement("pandas")] assert dropped == [] assert new == [] assert native_packages == set() # Case 3: Native package version not available, should switch to latest available version native_packages = {"numpy", "pandas"} - packages = [Requirement.parse("numpy==10.0"), Requirement.parse("pandas")] + packages = [Requirement("numpy==10.0"), Requirement("pandas")] supported, dropped, new = identify_supported_packages( packages, valid_packages, native_packages, {} ) - assert supported == [Requirement.parse("pandas")] - assert dropped == [Requirement.parse("numpy==10.0")] - assert new == [Requirement.parse("numpy")] + assert supported == [Requirement("pandas")] + assert dropped == [Requirement("numpy==10.0")] + assert new == [Requirement("numpy")] assert native_packages == set() # Case 4: Package not in valid_packages and not a native package either native_packages = {"numpy", "pandas"} - packages = [Requirement.parse("somepackage")] + packages = [Requirement("somepackage")] supported, dropped, new = identify_supported_packages( packages, valid_packages, native_packages, {} ) @@ -332,8 +332,8 @@ def test_no_pip(monkeypatch, temp_directory): def test_detect_native_dependencies(): target = "/path/to/target" downloaded_packages_dict = { - Requirement.parse("numpy"): ["numpy"], - Requirement.parse("pandas"): ["pandas"], + Requirement("numpy"): ["numpy"], + Requirement("pandas"): ["pandas"], } # Mock the glob.glob function to return specific paths @@ -355,21 +355,19 @@ def test_detect_native_dependencies(): def test_add_snowpark_package(): - version = "1.3.0" + version = metadata.distribution(SNOWPARK_PACKAGE_NAME).version valid_packages = {SNOWPARK_PACKAGE_NAME: [version]} result_dict = {} - with patch("pkg_resources.get_distribution") as mock_get_distribution: - mock_get_distribution.return_value.version = version - add_snowpark_package(result_dict, valid_packages) - assert result_dict == {SNOWPARK_PACKAGE_NAME: f"{SNOWPARK_PACKAGE_NAME}==1.3.0"} + add_snowpark_package(result_dict, valid_packages) + assert result_dict == {SNOWPARK_PACKAGE_NAME: f"{SNOWPARK_PACKAGE_NAME}=={version}"} def test_add_snowpark_package_if_missing(): version = "1.3.0" valid_packages = {SNOWPARK_PACKAGE_NAME: [version]} result_dict = {} - with patch("pkg_resources.get_distribution") as mock_get_distribution: - mock_get_distribution.side_effect = pkg_resources.DistributionNotFound( + with patch("importlib.metadata.distribution") as mock_get_distribution: + mock_get_distribution.side_effect = metadata.PackageNotFoundError( "Package not found" ) add_snowpark_package(result_dict, valid_packages) # Should not raise any error From 5d07a297086303478fb8c80b9efb6e1e5bcde165 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Sun, 29 Sep 2024 22:14:02 -0700 Subject: [PATCH 2/7] Remove pkg_resources in type hints --- src/snowflake/snowpark/_internal/packaging_utils.py | 2 +- src/snowflake/snowpark/session.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/_internal/packaging_utils.py b/src/snowflake/snowpark/_internal/packaging_utils.py index 61021c2ced..1dfb196b42 100644 --- a/src/snowflake/snowpark/_internal/packaging_utils.py +++ b/src/snowflake/snowpark/_internal/packaging_utils.py @@ -499,7 +499,7 @@ def add_snowpark_package( channel. Raises: - pkg_resources.DistributionNotFound: If the Snowpark Python Package is not installed in the local environment. + importlib.metadata.PackageNotFoundError: If the Snowpark Python Package is not installed in the local environment. """ if SNOWPARK_PACKAGE_NAME not in package_dict: package_dict[SNOWPARK_PACKAGE_NAME] = SNOWPARK_PACKAGE_NAME diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index b1b2fd5f19..426bb5759a 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -1407,7 +1407,7 @@ def _parse_packages( # It must start and end with a letter or number. # however, we don't validate the pkg name as this is done by pkg_resources.Requirement.parse # find the index of the first char which is not an valid package name character - package_name = package_req.key + package_name = package_req.name if not use_local_version and "_" in package: reg_match = re.search(r"[^0-9a-zA-Z\-_.]", package) package_name = package[: reg_match.start()] if reg_match else package @@ -1692,7 +1692,7 @@ def _upload_unsupported_packages( been added explicitly so far using add_packages() or other such methods. Returns: - List[pkg_resources.Requirement]: List of package dependencies (present in Snowflake) that would need to be added + List[packaging.requirements.Requirement]: List of package dependencies (present in Snowflake) that would need to be added to the package dictionary. Raises: @@ -1860,7 +1860,7 @@ def _load_unsupported_packages_from_stage( environment_signature (str): Unique hash signature for a set of unsupported packages, computed by hashing a sorted tuple of unsupported package requirements (package versioning included). Returns: - Optional[List[pkg_resources.Requirement]]: A list of package dependencies for the set of unsupported packages requested. + Optional[List[packaging.requirements.Requirement]]: A list of package dependencies for the set of unsupported packages requested. """ # Ensure that metadata file exists metadata_file = f"{ENVIRONMENT_METADATA_FILE_NAME}.txt" From 976d96e7ebc7ff324d943fb2138d221073bf5b4c Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Sun, 29 Sep 2024 22:20:12 -0700 Subject: [PATCH 3/7] Fix lint --- .../snowpark/_internal/packaging_utils.py | 17 +++++------------ src/snowflake/snowpark/session.py | 10 ++++------ tests/unit/test_packaging_utils.py | 2 +- 3 files changed, 10 insertions(+), 19 deletions(-) diff --git a/src/snowflake/snowpark/_internal/packaging_utils.py b/src/snowflake/snowpark/_internal/packaging_utils.py index 1dfb196b42..e9398a3b97 100644 --- a/src/snowflake/snowpark/_internal/packaging_utils.py +++ b/src/snowflake/snowpark/_internal/packaging_utils.py @@ -11,15 +11,13 @@ import subprocess import sys import zipfile +from importlib.metadata import PackageNotFoundError, distribution from logging import getLogger from pathlib import Path from typing import AnyStr, Dict, List, Optional, Set, Tuple import yaml - from packaging.requirements import Requirement -from importlib.metadata import distribution, PackageNotFoundError - from packaging.specifiers import SpecifierSet _logger = getLogger(__name__) @@ -284,16 +282,13 @@ def identify_supported_packages( package.specifier if package.specifier else None ) version_text = ( - f"(version {package_specifier})" - if package_specifier is not None - else "" + f"(version {package_specifier})" if package_specifier is not None else "" ) if package_name in valid_packages: # Detect supported packages - if ( - package_specifier is None - or any(package_specifier.contains(x) for x in valid_packages[package_name]) + if package_specifier is None or any( + package_specifier.contains(x) for x in valid_packages[package_name] ): supported_dependencies.append(package) _logger.info( @@ -504,9 +499,7 @@ def add_snowpark_package( if SNOWPARK_PACKAGE_NAME not in package_dict: package_dict[SNOWPARK_PACKAGE_NAME] = SNOWPARK_PACKAGE_NAME try: - package_client_version = get_distribution_version( - SNOWPARK_PACKAGE_NAME - ) + package_client_version = get_distribution_version(SNOWPARK_PACKAGE_NAME) if package_client_version in valid_packages[SNOWPARK_PACKAGE_NAME]: package_dict[ SNOWPARK_PACKAGE_NAME diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 426bb5759a..9c04665bd8 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -16,6 +16,7 @@ import warnings from array import array from functools import reduce +from importlib.metadata import PackageNotFoundError, distribution, distributions from logging import getLogger from threading import RLock from types import ModuleType @@ -33,8 +34,8 @@ ) import cloudpickle -from importlib.metadata import distribution, distributions, PackageNotFoundError from packaging.requirements import Requirement + from snowflake.connector import ProgrammingError, SnowflakeConnection from snowflake.connector.options import installed_pandas, pandas from snowflake.connector.pandas_tools import write_pandas @@ -1487,9 +1488,7 @@ def _get_dependency_packages( continue elif not use_local_version: try: - package_client_version = distribution( - package_name - ).version + package_client_version = distribution(package_name).version if package_client_version not in valid_packages[package_name]: _logger.warning( f"The version of package '{package_name}' in the local environment is " @@ -1893,8 +1892,7 @@ def _load_unsupported_packages_from_stage( } dependency_packages = [ - Requirement(package) - for package in metadata[environment_signature] + Requirement(package) for package in metadata[environment_signature] ] _logger.info( f"Loading dependency packages list - {metadata[environment_signature]}." diff --git a/tests/unit/test_packaging_utils.py b/tests/unit/test_packaging_utils.py index 942038c690..64e0481f0d 100644 --- a/tests/unit/test_packaging_utils.py +++ b/tests/unit/test_packaging_utils.py @@ -4,11 +4,11 @@ import os import zipfile +from importlib import metadata from subprocess import TimeoutExpired from unittest.mock import patch import pytest -from importlib import metadata from packaging.requirements import Requirement from snowflake.snowpark._internal.packaging_utils import ( From f3c55439ca63549a4dba4fa72e9c31684a8ac047 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Mon, 30 Sep 2024 00:16:16 -0700 Subject: [PATCH 4/7] Fix test errors --- .../snowpark/_internal/packaging_utils.py | 21 ++++++++++++++++++- src/snowflake/snowpark/session.py | 18 +++++++++------- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/_internal/packaging_utils.py b/src/snowflake/snowpark/_internal/packaging_utils.py index e9398a3b97..40412a942d 100644 --- a/src/snowflake/snowpark/_internal/packaging_utils.py +++ b/src/snowflake/snowpark/_internal/packaging_utils.py @@ -19,6 +19,7 @@ import yaml from packaging.requirements import Requirement from packaging.specifiers import SpecifierSet +from packaging.version import InvalidVersion _logger = getLogger(__name__) PIP_ENVIRONMENT_VARIABLE: str = "PIP_NAME" @@ -48,6 +49,23 @@ def get_distribution_version(name: str) -> Optional[str]: return distribution(name).version +def contains_version(specifier: SpecifierSet, version: str) -> bool: + """ + Check if a requirement contains a specific version. + + Args: + specifier (SpecifierSet): The requirement to check. + version (str): The version to check for. + + Returns: + bool: True if the requirement contains the version, False otherwise. + """ + try: + return specifier.contains(version) + except InvalidVersion: + return False + + def parse_requirements_text_file(file_path: str) -> Tuple[List[str], List[str]]: """ Parses a requirements.txt file to obtain a list of packages and file/folder imports. Returns a tuple of packages @@ -288,7 +306,8 @@ def identify_supported_packages( if package_name in valid_packages: # Detect supported packages if package_specifier is None or any( - package_specifier.contains(x) for x in valid_packages[package_name] + contains_version(package_specifier, x) + for x in valid_packages[package_name] ): supported_dependencies.append(package) _logger.info( diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 9c04665bd8..15d7a80963 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -65,6 +65,7 @@ DEFAULT_PACKAGES, ENVIRONMENT_METADATA_FILE_NAME, IMPLICIT_ZIP_FILE_NAME, + contains_version, delete_files_belonging_to_packages, detect_native_dependencies, get_signature, @@ -1439,16 +1440,19 @@ def _get_dependency_packages( unsupported_packages: List[str] = [] for package, package_info in package_dict.items(): package_name, use_local_version, package_req = package_info - package_version_req = package_req.specs[0][1] if package_req.specs else None + package_specifier = package_req.specifier if package_req.specifier else None if validate_package: if package_name not in valid_packages or ( - package_version_req - and not any(v in package_req for v in valid_packages[package_name]) + package_specifier + and not any( + contains_version(package_specifier, v) + for v in valid_packages[package_name] + ) ): version_text = ( - f"(version {package_version_req})" - if package_version_req is not None + f"(version {package_specifier})" + if package_specifier is not None else "" ) if is_in_stored_procedure(): # pragma: no cover @@ -1649,14 +1653,14 @@ def _resolve_packages( # Add dependency packages for package in dependency_packages: name = package.name - version = package.specs[0][1] if package.specs else None + version = package.specifier if package.specifier else None if name in result_dict: if version is not None: added_package_has_version = "==" in result_dict[name] if added_package_has_version and result_dict[name] != str(package): raise ValueError( - f"Cannot add dependency package '{name}=={version}' " + f"Cannot add dependency package '{name}{version}' " f"because {result_dict[name]} is already added." ) result_dict[name] = str(package) From 46f9032ad9f5e6cd48891524cd80ee866879ee91 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Wed, 2 Oct 2024 13:48:18 -0700 Subject: [PATCH 5/7] Update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4f6f3b601..55e3a3cf3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ - Added support for constructing `Series` and `DataFrame` objects with `index` and `column` values not present in `DataFrame`/`Series` `data`. #### Improvements +- Removed the use of `pkg_resources`, which is deprecated. No more deprecation warnings about it. Used `packaging` and `importlib.metadata` to replace `pkg_resources`. + #### Bug Fixes From 83250e66979ab8454331f304f4a855c1075f572c Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Wed, 2 Oct 2024 13:48:57 -0700 Subject: [PATCH 6/7] Remove code that converts _ to - --- src/snowflake/snowpark/session.py | 17 +---------------- tests/integ/test_stored_procedure.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 15d7a80963..39d74a8266 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -1399,22 +1399,7 @@ def _parse_packages( continue use_local_version = False package_req = Requirement(package) - # get the standard package name if there is no underscore - # underscores are discouraged in package names, but are still used in Anaconda channel - # pkg_resources.Requirement.parse will convert all underscores to dashes - # the regexp is to deal with case that "_" is in the package requirement as well as version restrictions - # we only extract the valid package name from the string by following: - # https://packaging.python.org/en/latest/specifications/name-normalization/ - # A valid name consists only of ASCII letters and numbers, period, underscore and hyphen. - # It must start and end with a letter or number. - # however, we don't validate the pkg name as this is done by pkg_resources.Requirement.parse - # find the index of the first char which is not an valid package name character - package_name = package_req.name - if not use_local_version and "_" in package: - reg_match = re.search(r"[^0-9a-zA-Z\-_.]", package) - package_name = package[: reg_match.start()] if reg_match else package - - package_dict[package] = (package_name, use_local_version, package_req) + package_dict[package] = (package_req.name, use_local_version, package_req) return package_dict def _get_dependency_packages( diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index 5cd8c42d93..93c74685c8 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -1991,3 +1991,18 @@ def test_register_sproc_after_switch_schema(session): Utils.drop_database(session, db) session.use_database(current_database) session.use_schema(current_schema) + +def test_register_sproc_with_package_underscore_name(session): + session.sproc.register( + lambda session_, x, y: session_.create_dataframe([[x + y]]).collect()[0][0], + return_type=IntegerType(), + input_types=[IntegerType(), IntegerType()], + packages=["huggingface_hub"] + ) + + session.sproc.register( + lambda session_, x, y: session_.create_dataframe([[x + y]]).collect()[0][0], + return_type=IntegerType(), + input_types=[IntegerType(), IntegerType()], + packages=["huggingface_hub>0.15.1"] + ) From 6c1e1e50ffbe03bb5a711b867c87f4e6636e5aca Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Mon, 7 Oct 2024 13:50:27 -0700 Subject: [PATCH 7/7] fix lint --- tests/integ/test_stored_procedure.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index 93c74685c8..4a4746d1a7 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -1992,17 +1992,21 @@ def test_register_sproc_after_switch_schema(session): session.use_database(current_database) session.use_schema(current_schema) + def test_register_sproc_with_package_underscore_name(session): + """Test registering a stored procedure with a package that has an underscore in its name. + It passes if no exception is thrown. + """ session.sproc.register( lambda session_, x, y: session_.create_dataframe([[x + y]]).collect()[0][0], return_type=IntegerType(), input_types=[IntegerType(), IntegerType()], - packages=["huggingface_hub"] + packages=["huggingface_hub"], ) session.sproc.register( lambda session_, x, y: session_.create_dataframe([[x + y]]).collect()[0][0], return_type=IntegerType(), input_types=[IntegerType(), IntegerType()], - packages=["huggingface_hub>0.15.1"] + packages=["huggingface_hub>0.15.1"], )