From e3d0c2304f58c0d528482a8614589b1e911e1cd2 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 30 Sep 2024 11:07:04 +0100 Subject: [PATCH 01/11] Add sagemaker finetuning client --- poetry.lock | 790 +++++++++++++- pyproject.toml | 1 + .../cohere_aws/__init__.py | 3 + .../manually_maintained/cohere_aws/chat.py | 325 ++++++ .../cohere_aws/classification.py | 60 ++ .../manually_maintained/cohere_aws/client.py | 974 ++++++++++++++++++ .../cohere_aws/embeddings.py | 26 + .../manually_maintained/cohere_aws/error.py | 23 + .../cohere_aws/generation.py | 107 ++ .../manually_maintained/cohere_aws/mode.py | 6 + .../manually_maintained/cohere_aws/rerank.py | 66 ++ .../cohere_aws/response.py | 11 + .../manually_maintained/cohere_aws/summary.py | 16 + src/cohere/sagemaker_client.py | 6 +- 14 files changed, 2408 insertions(+), 6 deletions(-) create mode 100644 src/cohere/manually_maintained/cohere_aws/__init__.py create mode 100644 src/cohere/manually_maintained/cohere_aws/chat.py create mode 100644 src/cohere/manually_maintained/cohere_aws/classification.py create mode 100644 src/cohere/manually_maintained/cohere_aws/client.py create mode 100644 src/cohere/manually_maintained/cohere_aws/embeddings.py create mode 100644 src/cohere/manually_maintained/cohere_aws/error.py create mode 100644 src/cohere/manually_maintained/cohere_aws/generation.py create mode 100644 src/cohere/manually_maintained/cohere_aws/mode.py create mode 100644 src/cohere/manually_maintained/cohere_aws/rerank.py create mode 100644 src/cohere/manually_maintained/cohere_aws/response.py create mode 100644 src/cohere/manually_maintained/cohere_aws/summary.py diff --git a/poetry.lock b/poetry.lock index 4941469a3..47d0ddd20 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -36,6 +36,25 @@ doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.21.0b1)"] trio = ["trio (>=0.26.1)"] +[[package]] +name = "attrs" +version = "23.2.0" +description = "Classes Without Boilerplate" +optional = false +python-versions = ">=3.7" +files = [ + {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, + {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, +] + +[package.extras] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] +tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] + [[package]] name = "boto3" version = "1.35.27" @@ -70,8 +89,8 @@ files = [ jmespath = ">=0.7.1,<2.0.0" python-dateutil = ">=2.1,<3.0.0" urllib3 = [ - {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, ] [package.extras] @@ -187,6 +206,17 @@ files = [ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] +[[package]] +name = "cloudpickle" +version = "2.2.1" +description = "Extended pickling support for Python objects" +optional = false +python-versions = ">=3.6" +files = [ + {file = "cloudpickle-2.2.1-py3-none-any.whl", hash = "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f"}, + {file = "cloudpickle-2.2.1.tar.gz", hash = "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"}, +] + [[package]] name = "colorama" version = "0.4.6" @@ -198,6 +228,43 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "dill" +version = "0.3.9" +description = "serialize all of Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, + {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] + +[[package]] +name = "docker" +version = "7.1.0" +description = "A Python library for the Docker Engine API." +optional = false +python-versions = ">=3.8" +files = [ + {file = "docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0"}, + {file = "docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c"}, +] + +[package.dependencies] +pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""} +requests = ">=2.26.0" +urllib3 = ">=1.26.0" + +[package.extras] +dev = ["coverage (==7.2.7)", "pytest (==7.4.2)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.1.0)", "ruff (==0.1.8)"] +docs = ["myst-parser (==0.18.0)", "sphinx (==5.1.1)"] +ssh = ["paramiko (>=2.4.3)"] +websockets = ["websocket-client (>=1.3.0)"] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -313,6 +380,21 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe, test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] +[[package]] +name = "google-pasta" +version = "0.2.0" +description = "pasta is an AST-based Python refactoring library" +optional = false +python-versions = "*" +files = [ + {file = "google-pasta-0.2.0.tar.gz", hash = "sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e"}, + {file = "google_pasta-0.2.0-py2-none-any.whl", hash = "sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954"}, + {file = "google_pasta-0.2.0-py3-none-any.whl", hash = "sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed"}, +] + +[package.dependencies] +six = "*" + [[package]] name = "h11" version = "0.14.0" @@ -429,6 +511,47 @@ files = [ [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "importlib-metadata" +version = "6.11.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"}, + {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] + +[[package]] +name = "importlib-resources" +version = "6.4.5" +description = "Read resources from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_resources-6.4.5-py3-none-any.whl", hash = "sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717"}, + {file = "importlib_resources-6.4.5.tar.gz", hash = "sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065"}, +] + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"] +type = ["pytest-mypy"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -451,6 +574,119 @@ files = [ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, ] +[[package]] +name = "jsonschema" +version = "4.23.0" +description = "An implementation of JSON Schema validation for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, + {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""} +jsonschema-specifications = ">=2023.03.6" +pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""} +referencing = ">=0.28.4" +rpds-py = ">=0.7.1" + +[package.extras] +format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"] + +[[package]] +name = "jsonschema-specifications" +version = "2023.12.1" +description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c"}, + {file = "jsonschema_specifications-2023.12.1.tar.gz", hash = "sha256:48a76787b3e70f5ed53f1160d2b81f586e4ca6d1548c5de7085d1682674764cc"}, +] + +[package.dependencies] +importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""} +referencing = ">=0.31.0" + +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = false +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + +[[package]] +name = "mock" +version = "4.0.3" +description = "Rolling backport of unittest.mock for all Pythons" +optional = false +python-versions = ">=3.6" +files = [ + {file = "mock-4.0.3-py3-none-any.whl", hash = "sha256:122fcb64ee37cfad5b3f48d7a7d51875d7031aaf3d8be7c42e2bee25044eee62"}, + {file = "mock-4.0.3.tar.gz", hash = "sha256:7d3fbbde18228f4ff2f1f119a45cdffa458b4c0dee32eb4d2bb2f82554bac7bc"}, +] + +[package.extras] +build = ["blurb", "twine", "wheel"] +docs = ["sphinx"] +test = ["pytest (<5.4)", "pytest-cov"] + +[[package]] +name = "multiprocess" +version = "0.70.16" +description = "better multiprocessing and multithreading in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"}, + {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"}, + {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"}, + {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"}, + {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"}, + {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"}, + {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"}, +] + +[package.dependencies] +dill = ">=0.3.8" + [[package]] name = "mypy" version = "1.0.1" @@ -508,6 +744,43 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "numpy" +version = "1.24.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, + {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"}, + {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"}, + {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"}, + {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"}, + {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"}, + {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"}, + {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"}, + {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"}, + {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, + {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, +] + [[package]] name = "packaging" version = "24.1" @@ -519,6 +792,73 @@ files = [ {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] +[[package]] +name = "pandas" +version = "2.0.3" +description = "Powerful data structures for data analysis, time series, and statistics" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, + {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, + {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce0c6f76a0f1ba361551f3e6dceaff06bde7514a374aa43e33b588ec10420183"}, + {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba619e410a21d8c387a1ea6e8a0e49bb42216474436245718d7f2e88a2f8d7c0"}, + {file = "pandas-2.0.3-cp310-cp310-win32.whl", hash = "sha256:3ef285093b4fe5058eefd756100a367f27029913760773c8bf1d2d8bebe5d210"}, + {file = "pandas-2.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:9ee1a69328d5c36c98d8e74db06f4ad518a1840e8ccb94a4ba86920986bb617e"}, + {file = "pandas-2.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b084b91d8d66ab19f5bb3256cbd5ea661848338301940e17f4492b2ce0801fe8"}, + {file = "pandas-2.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:37673e3bdf1551b95bf5d4ce372b37770f9529743d2498032439371fc7b7eb26"}, + {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9cb1e14fdb546396b7e1b923ffaeeac24e4cedd14266c3497216dd4448e4f2d"}, + {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9cd88488cceb7635aebb84809d087468eb33551097d600c6dad13602029c2df"}, + {file = "pandas-2.0.3-cp311-cp311-win32.whl", hash = "sha256:694888a81198786f0e164ee3a581df7d505024fbb1f15202fc7db88a71d84ebd"}, + {file = "pandas-2.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:6a21ab5c89dcbd57f78d0ae16630b090eec626360085a4148693def5452d8a6b"}, + {file = "pandas-2.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9e4da0d45e7f34c069fe4d522359df7d23badf83abc1d1cef398895822d11061"}, + {file = "pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:32fca2ee1b0d93dd71d979726b12b61faa06aeb93cf77468776287f41ff8fdc5"}, + {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:258d3624b3ae734490e4d63c430256e716f488c4fcb7c8e9bde2d3aa46c29089"}, + {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eae3dc34fa1aa7772dd3fc60270d13ced7346fcbcfee017d3132ec625e23bb0"}, + {file = "pandas-2.0.3-cp38-cp38-win32.whl", hash = "sha256:f3421a7afb1a43f7e38e82e844e2bca9a6d793d66c1a7f9f0ff39a795bbc5e02"}, + {file = "pandas-2.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:69d7f3884c95da3a31ef82b7618af5710dba95bb885ffab339aad925c3e8ce78"}, + {file = "pandas-2.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5247fb1ba347c1261cbbf0fcfba4a3121fbb4029d95d9ef4dc45406620b25c8b"}, + {file = "pandas-2.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:81af086f4543c9d8bb128328b5d32e9986e0c84d3ee673a2ac6fb57fd14f755e"}, + {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1994c789bf12a7c5098277fb43836ce090f1073858c10f9220998ac74f37c69b"}, + {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ec591c48e29226bcbb316e0c1e9423622bc7a4eaf1ef7c3c9fa1a3981f89641"}, + {file = "pandas-2.0.3-cp39-cp39-win32.whl", hash = "sha256:04dbdbaf2e4d46ca8da896e1805bc04eb85caa9a82e259e8eed00254d5e0c682"}, + {file = "pandas-2.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:1168574b036cd8b93abc746171c9b4f1b83467438a5e45909fed645cf8692dbc"}, + {file = "pandas-2.0.3.tar.gz", hash = "sha256:c02f372a88e0d17f36d3093a644c73cfc1788e876a7c4bcb4020a77512e2043c"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.20.3", markers = "python_version < \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, +] +python-dateutil = ">=2.8.2" +pytz = ">=2020.1" +tzdata = ">=2022.1" + +[package.extras] +all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] +aws = ["s3fs (>=2021.08.0)"] +clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"] +compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"] +computation = ["scipy (>=1.7.1)", "xarray (>=0.21.0)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pyxlsb (>=1.0.8)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)"] +feather = ["pyarrow (>=7.0.0)"] +fss = ["fsspec (>=2021.07.0)"] +gcp = ["gcsfs (>=2021.07.0)", "pandas-gbq (>=0.15.0)"] +hdf5 = ["tables (>=3.6.1)"] +html = ["beautifulsoup4 (>=4.9.3)", "html5lib (>=1.1)", "lxml (>=4.6.3)"] +mysql = ["SQLAlchemy (>=1.4.16)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.0.0)", "tabulate (>=0.8.9)"] +parquet = ["pyarrow (>=7.0.0)"] +performance = ["bottleneck (>=1.3.2)", "numba (>=0.53.1)", "numexpr (>=2.7.1)"] +plot = ["matplotlib (>=3.6.1)"] +postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"] +spss = ["pyreadstat (>=1.1.2)"] +sql-other = ["SQLAlchemy (>=1.4.16)"] +test = ["hypothesis (>=6.34.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.6.3)"] + [[package]] name = "parameterized" version = "0.9.0" @@ -533,6 +873,50 @@ files = [ [package.extras] dev = ["jinja2"] +[[package]] +name = "pathos" +version = "0.3.2" +description = "parallel graph management and execution in heterogeneous computing" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathos-0.3.2-py3-none-any.whl", hash = "sha256:d669275e6eb4b3fbcd2846d7a6d1bba315fe23add0c614445ba1408d8b38bafe"}, + {file = "pathos-0.3.2.tar.gz", hash = "sha256:4f2a42bc1e10ccf0fe71961e7145fc1437018b6b21bd93b2446abc3983e49a7a"}, +] + +[package.dependencies] +dill = ">=0.3.8" +multiprocess = ">=0.70.16" +pox = ">=0.3.4" +ppft = ">=1.7.6.8" + +[[package]] +name = "pkgutil-resolve-name" +version = "1.3.10" +description = "Resolve a name to an object." +optional = false +python-versions = ">=3.6" +files = [ + {file = "pkgutil_resolve_name-1.3.10-py3-none-any.whl", hash = "sha256:ca27cc078d25c5ad71a9de0a7a330146c4e014c2462d9af19c6b828280649c5e"}, + {file = "pkgutil_resolve_name-1.3.10.tar.gz", hash = "sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174"}, +] + +[[package]] +name = "platformdirs" +version = "4.3.6" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." +optional = false +python-versions = ">=3.8" +files = [ + {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, + {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] +type = ["mypy (>=1.11.2)"] + [[package]] name = "pluggy" version = "1.5.0" @@ -548,6 +932,80 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pox" +version = "0.3.5" +description = "utilities for filesystem exploration and automated builds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pox-0.3.5-py3-none-any.whl", hash = "sha256:9e82bcc9e578b43e80a99cad80f0d8f44f4d424f0ee4ee8d4db27260a6aa365a"}, + {file = "pox-0.3.5.tar.gz", hash = "sha256:8120ee4c94e950e6e0483e050a4f0e56076e590ba0a9add19524c254bd23c2d1"}, +] + +[[package]] +name = "ppft" +version = "1.7.6.9" +description = "distributed and parallel Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ppft-1.7.6.9-py3-none-any.whl", hash = "sha256:dab36548db5ca3055067fbe6b1a17db5fee29f3c366c579a9a27cebb52ed96f0"}, + {file = "ppft-1.7.6.9.tar.gz", hash = "sha256:73161c67474ea9d81d04bcdad166d399cff3f084d5d2dc21ebdd46c075bbc265"}, +] + +[package.extras] +dill = ["dill (>=0.3.9)"] + +[[package]] +name = "protobuf" +version = "4.25.5" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "protobuf-4.25.5-cp310-abi3-win32.whl", hash = "sha256:5e61fd921603f58d2f5acb2806a929b4675f8874ff5f330b7d6f7e2e784bbcd8"}, + {file = "protobuf-4.25.5-cp310-abi3-win_amd64.whl", hash = "sha256:4be0571adcbe712b282a330c6e89eae24281344429ae95c6d85e79e84780f5ea"}, + {file = "protobuf-4.25.5-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:b2fde3d805354df675ea4c7c6338c1aecd254dfc9925e88c6d31a2bcb97eb173"}, + {file = "protobuf-4.25.5-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:919ad92d9b0310070f8356c24b855c98df2b8bd207ebc1c0c6fcc9ab1e007f3d"}, + {file = "protobuf-4.25.5-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:fe14e16c22be926d3abfcb500e60cab068baf10b542b8c858fa27e098123e331"}, + {file = "protobuf-4.25.5-cp38-cp38-win32.whl", hash = "sha256:98d8d8aa50de6a2747efd9cceba361c9034050ecce3e09136f90de37ddba66e1"}, + {file = "protobuf-4.25.5-cp38-cp38-win_amd64.whl", hash = "sha256:b0234dd5a03049e4ddd94b93400b67803c823cfc405689688f59b34e0742381a"}, + {file = "protobuf-4.25.5-cp39-cp39-win32.whl", hash = "sha256:abe32aad8561aa7cc94fc7ba4fdef646e576983edb94a73381b03c53728a626f"}, + {file = "protobuf-4.25.5-cp39-cp39-win_amd64.whl", hash = "sha256:7a183f592dc80aa7c8da7ad9e55091c4ffc9497b3054452d629bb85fa27c2a45"}, + {file = "protobuf-4.25.5-py3-none-any.whl", hash = "sha256:0aebecb809cae990f8129ada5ca273d9d670b76d9bfc9b1809f0a9c02b7dbf41"}, + {file = "protobuf-4.25.5.tar.gz", hash = "sha256:7f8249476b4a9473645db7f8ab42b02fe1488cbe5fb72fddd445e0665afd8584"}, +] + +[[package]] +name = "psutil" +version = "6.0.0" +description = "Cross-platform lib for process and system monitoring in Python." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "psutil-6.0.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c"}, + {file = "psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35"}, + {file = "psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1"}, + {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, + {file = "psutil-6.0.0-cp36-cp36m-win32.whl", hash = "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14"}, + {file = "psutil-6.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c"}, + {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, + {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, + {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, + {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, +] + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + [[package]] name = "pydantic" version = "2.9.2" @@ -563,8 +1021,8 @@ files = [ annotated-types = ">=0.6.0" pydantic-core = "2.23.4" typing-extensions = [ - {version = ">=4.6.1", markers = "python_version < \"3.13\""}, {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, + {version = ">=4.6.1", markers = "python_version < \"3.13\""}, ] [package.extras] @@ -672,6 +1130,20 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pygments" +version = "2.18.0" +description = "Pygments is a syntax highlighting package written in Python." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"}, + {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, +] + +[package.extras] +windows-terminal = ["colorama (>=0.4.6)"] + [[package]] name = "pytest" version = "7.4.4" @@ -726,6 +1198,40 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pytz" +version = "2024.2" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725"}, + {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"}, +] + +[[package]] +name = "pywin32" +version = "306" +description = "Python for Window Extensions" +optional = false +python-versions = "*" +files = [ + {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, + {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, + {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"}, + {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"}, + {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"}, + {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"}, + {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"}, + {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"}, + {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"}, + {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"}, + {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"}, + {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"}, + {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"}, + {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, +] + [[package]] name = "pyyaml" version = "6.0.2" @@ -788,6 +1294,21 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] +[[package]] +name = "referencing" +version = "0.35.1" +description = "JSON Referencing + Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de"}, + {file = "referencing-0.35.1.tar.gz", hash = "sha256:25b42124a6c8b632a425174f24087783efb348a6f1e0008e63cd4466fedf703c"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +rpds-py = ">=0.7.0" + [[package]] name = "requests" version = "2.32.3" @@ -809,6 +1330,137 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rich" +version = "13.8.1" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "rich-13.8.1-py3-none-any.whl", hash = "sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06"}, + {file = "rich-13.8.1.tar.gz", hash = "sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + +[[package]] +name = "rpds-py" +version = "0.20.0" +description = "Python bindings to Rust's persistent data structures (rpds)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "rpds_py-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3ad0fda1635f8439cde85c700f964b23ed5fc2d28016b32b9ee5fe30da5c84e2"}, + {file = "rpds_py-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9bb4a0d90fdb03437c109a17eade42dfbf6190408f29b2744114d11586611d6f"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6377e647bbfd0a0b159fe557f2c6c602c159fc752fa316572f012fc0bf67150"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb851b7df9dda52dc1415ebee12362047ce771fc36914586b2e9fcbd7d293b3e"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e0f80b739e5a8f54837be5d5c924483996b603d5502bfff79bf33da06164ee2"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a8c94dad2e45324fc74dce25e1645d4d14df9a4e54a30fa0ae8bad9a63928e3"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e604fe73ba048c06085beaf51147eaec7df856824bfe7b98657cf436623daf"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:df3de6b7726b52966edf29663e57306b23ef775faf0ac01a3e9f4012a24a4140"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf258ede5bc22a45c8e726b29835b9303c285ab46fc7c3a4cc770736b5304c9f"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:55fea87029cded5df854ca7e192ec7bdb7ecd1d9a3f63d5c4eb09148acf4a7ce"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ae94bd0b2f02c28e199e9bc51485d0c5601f58780636185660f86bf80c89af94"}, + {file = "rpds_py-0.20.0-cp310-none-win32.whl", hash = "sha256:28527c685f237c05445efec62426d285e47a58fb05ba0090a4340b73ecda6dee"}, + {file = "rpds_py-0.20.0-cp310-none-win_amd64.whl", hash = "sha256:238a2d5b1cad28cdc6ed15faf93a998336eb041c4e440dd7f902528b8891b399"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac2f4f7a98934c2ed6505aead07b979e6f999389f16b714448fb39bbaa86a489"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:220002c1b846db9afd83371d08d239fdc865e8f8c5795bbaec20916a76db3318"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d7919548df3f25374a1f5d01fbcd38dacab338ef5f33e044744b5c36729c8db"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:758406267907b3781beee0f0edfe4a179fbd97c0be2e9b1154d7f0a1279cf8e5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d61339e9f84a3f0767b1995adfb171a0d00a1185192718a17af6e124728e0f5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1259c7b3705ac0a0bd38197565a5d603218591d3f6cee6e614e380b6ba61c6f6"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1dc0f53856b9cc9a0ccca0a7cc61d3d20a7088201c0937f3f4048c1718a209"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7e60cb630f674a31f0368ed32b2a6b4331b8350d67de53c0359992444b116dd3"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbe982f38565bb50cb7fb061ebf762c2f254ca3d8c20d4006878766e84266272"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:514b3293b64187172bc77c8fb0cdae26981618021053b30d8371c3a902d4d5ad"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d0a26ffe9d4dd35e4dfdd1e71f46401cff0181c75ac174711ccff0459135fa58"}, + {file = "rpds_py-0.20.0-cp311-none-win32.whl", hash = "sha256:89c19a494bf3ad08c1da49445cc5d13d8fefc265f48ee7e7556839acdacf69d0"}, + {file = "rpds_py-0.20.0-cp311-none-win_amd64.whl", hash = "sha256:c638144ce971df84650d3ed0096e2ae7af8e62ecbbb7b201c8935c370df00a2c"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a84ab91cbe7aab97f7446652d0ed37d35b68a465aeef8fc41932a9d7eee2c1a6"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:56e27147a5a4c2c21633ff8475d185734c0e4befd1c989b5b95a5d0db699b21b"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2580b0c34583b85efec8c5c5ec9edf2dfe817330cc882ee972ae650e7b5ef739"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b80d4a7900cf6b66bb9cee5c352b2d708e29e5a37fe9bf784fa97fc11504bf6c"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50eccbf054e62a7b2209b28dc7a22d6254860209d6753e6b78cfaeb0075d7bee"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:49a8063ea4296b3a7e81a5dfb8f7b2d73f0b1c20c2af401fb0cdf22e14711a96"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea438162a9fcbee3ecf36c23e6c68237479f89f962f82dae83dc15feeceb37e4"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18d7585c463087bddcfa74c2ba267339f14f2515158ac4db30b1f9cbdb62c8ef"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4c7d1a051eeb39f5c9547e82ea27cbcc28338482242e3e0b7768033cb083821"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4df1e3b3bec320790f699890d41c59d250f6beda159ea3c44c3f5bac1976940"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2cf126d33a91ee6eedc7f3197b53e87a2acdac63602c0f03a02dd69e4b138174"}, + {file = "rpds_py-0.20.0-cp312-none-win32.whl", hash = "sha256:8bc7690f7caee50b04a79bf017a8d020c1f48c2a1077ffe172abec59870f1139"}, + {file = "rpds_py-0.20.0-cp312-none-win_amd64.whl", hash = "sha256:0e13e6952ef264c40587d510ad676a988df19adea20444c2b295e536457bc585"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:aa9a0521aeca7d4941499a73ad7d4f8ffa3d1affc50b9ea11d992cd7eff18a29"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a1f1d51eccb7e6c32ae89243cb352389228ea62f89cd80823ea7dd1b98e0b91"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a86a9b96070674fc88b6f9f71a97d2c1d3e5165574615d1f9168ecba4cecb24"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c8ef2ebf76df43f5750b46851ed1cdf8f109d7787ca40035fe19fbdc1acc5a7"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b25f024b421d5859d156750ea9a65651793d51b76a2e9238c05c9d5f203a9"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57eb94a8c16ab08fef6404301c38318e2c5a32216bf5de453e2714c964c125c8"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1940dae14e715e2e02dfd5b0f64a52e8374a517a1e531ad9412319dc3ac7879"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d20277fd62e1b992a50c43f13fbe13277a31f8c9f70d59759c88f644d66c619f"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:06db23d43f26478303e954c34c75182356ca9aa7797d22c5345b16871ab9c45c"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b2a5db5397d82fa847e4c624b0c98fe59d2d9b7cf0ce6de09e4d2e80f8f5b3f2"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a35df9f5548fd79cb2f52d27182108c3e6641a4feb0f39067911bf2adaa3e57"}, + {file = "rpds_py-0.20.0-cp313-none-win32.whl", hash = "sha256:fd2d84f40633bc475ef2d5490b9c19543fbf18596dcb1b291e3a12ea5d722f7a"}, + {file = "rpds_py-0.20.0-cp313-none-win_amd64.whl", hash = "sha256:9bc2d153989e3216b0559251b0c260cfd168ec78b1fac33dd485750a228db5a2"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f2fbf7db2012d4876fb0d66b5b9ba6591197b0f165db8d99371d976546472a24"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e5f3cd7397c8f86c8cc72d5a791071431c108edd79872cdd96e00abd8497d29"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce9845054c13696f7af7f2b353e6b4f676dab1b4b215d7fe5e05c6f8bb06f965"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c3e130fd0ec56cb76eb49ef52faead8ff09d13f4527e9b0c400307ff72b408e1"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b16aa0107ecb512b568244ef461f27697164d9a68d8b35090e9b0c1c8b27752"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7f429242aae2947246587d2964fad750b79e8c233a2367f71b554e9447949c"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af0fc424a5842a11e28956e69395fbbeab2c97c42253169d87e90aac2886d751"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8c00a3b1e70c1d3891f0db1b05292747f0dbcfb49c43f9244d04c70fbc40eb8"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:40ce74fc86ee4645d0a225498d091d8bc61f39b709ebef8204cb8b5a464d3c0e"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4fe84294c7019456e56d93e8ababdad5a329cd25975be749c3f5f558abb48253"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:338ca4539aad4ce70a656e5187a3a31c5204f261aef9f6ab50e50bcdffaf050a"}, + {file = "rpds_py-0.20.0-cp38-none-win32.whl", hash = "sha256:54b43a2b07db18314669092bb2de584524d1ef414588780261e31e85846c26a5"}, + {file = "rpds_py-0.20.0-cp38-none-win_amd64.whl", hash = "sha256:a1862d2d7ce1674cffa6d186d53ca95c6e17ed2b06b3f4c476173565c862d232"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3fde368e9140312b6e8b6c09fb9f8c8c2f00999d1823403ae90cc00480221b22"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9824fb430c9cf9af743cf7aaf6707bf14323fb51ee74425c380f4c846ea70789"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11ef6ce74616342888b69878d45e9f779b95d4bd48b382a229fe624a409b72c5"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c52d3f2f82b763a24ef52f5d24358553e8403ce05f893b5347098014f2d9eff2"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d35cef91e59ebbeaa45214861874bc6f19eb35de96db73e467a8358d701a96c"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d72278a30111e5b5525c1dd96120d9e958464316f55adb030433ea905866f4de"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c29cbbba378759ac5786730d1c3cb4ec6f8ababf5c42a9ce303dc4b3d08cda"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6632f2d04f15d1bd6fe0eedd3b86d9061b836ddca4c03d5cf5c7e9e6b7c14580"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d0b67d87bb45ed1cd020e8fbf2307d449b68abc45402fe1a4ac9e46c3c8b192b"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ec31a99ca63bf3cd7f1a5ac9fe95c5e2d060d3c768a09bc1d16e235840861420"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22e6c9976e38f4d8c4a63bd8a8edac5307dffd3ee7e6026d97f3cc3a2dc02a0b"}, + {file = "rpds_py-0.20.0-cp39-none-win32.whl", hash = "sha256:569b3ea770c2717b730b61998b6c54996adee3cef69fc28d444f3e7920313cf7"}, + {file = "rpds_py-0.20.0-cp39-none-win_amd64.whl", hash = "sha256:e6900ecdd50ce0facf703f7a00df12374b74bbc8ad9fe0f6559947fb20f82364"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:617c7357272c67696fd052811e352ac54ed1d9b49ab370261a80d3b6ce385045"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9426133526f69fcaba6e42146b4e12d6bc6c839b8b555097020e2b78ce908dcc"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deb62214c42a261cb3eb04d474f7155279c1a8a8c30ac89b7dcb1721d92c3c02"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcaeb7b57f1a1e071ebd748984359fef83ecb026325b9d4ca847c95bc7311c92"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d454b8749b4bd70dd0a79f428731ee263fa6995f83ccb8bada706e8d1d3ff89d"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d807dc2051abe041b6649681dce568f8e10668e3c1c6543ebae58f2d7e617855"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3c20f0ddeb6e29126d45f89206b8291352b8c5b44384e78a6499d68b52ae511"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b7f19250ceef892adf27f0399b9e5afad019288e9be756d6919cb58892129f51"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4f1ed4749a08379555cebf4650453f14452eaa9c43d0a95c49db50c18b7da075"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dcedf0b42bcb4cfff4101d7771a10532415a6106062f005ab97d1d0ab5681c60"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39ed0d010457a78f54090fafb5d108501b5aa5604cc22408fc1c0c77eac14344"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f918a1a130a6dfe1d7fe0f105064141342e7dd1611f2e6a21cd2f5c8cb1cfb3e"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f60012a73aa396be721558caa3a6fd49b3dd0033d1675c6d59c4502e870fcf0c"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d2b1ad682a3dfda2a4e8ad8572f3100f95fad98cb99faf37ff0ddfe9cbf9d03"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:614fdafe9f5f19c63ea02817fa4861c606a59a604a77c8cdef5aa01d28b97921"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa518bcd7600c584bf42e6617ee8132869e877db2f76bcdc281ec6a4113a53ab"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0475242f447cc6cb8a9dd486d68b2ef7fbee84427124c232bff5f63b1fe11e5"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f90a4cd061914a60bd51c68bcb4357086991bd0bb93d8aa66a6da7701370708f"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:def7400461c3a3f26e49078302e1c1b38f6752342c77e3cf72ce91ca69fb1bc1"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:65794e4048ee837494aea3c21a28ad5fc080994dfba5b036cf84de37f7ad5074"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:faefcc78f53a88f3076b7f8be0a8f8d35133a3ecf7f3770895c25f8813460f08"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:5b4f105deeffa28bbcdff6c49b34e74903139afa690e35d2d9e3c2c2fba18cec"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdfc3a892927458d98f3d55428ae46b921d1f7543b89382fdb483f5640daaec8"}, + {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, +] + [[package]] name = "ruff" version = "0.5.7" @@ -853,6 +1505,84 @@ botocore = ">=1.33.2,<2.0a.0" [package.extras] crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] +[[package]] +name = "sagemaker" +version = "2.232.1" +description = "Open source library for training and deploying models on Amazon SageMaker." +optional = false +python-versions = ">=3.8" +files = [ + {file = "sagemaker-2.232.1-py3-none-any.whl", hash = "sha256:64b92639918613b8042ddbc13f34cfac65145d916ff0c0001b249f9f33012cb1"}, + {file = "sagemaker-2.232.1.tar.gz", hash = "sha256:e59e1ac79bc31235b8d5c766abee5c6b2fd526814d1dde62b98e7dca9654503c"}, +] + +[package.dependencies] +attrs = ">=23.1.0,<24" +boto3 = ">=1.34.142,<2.0" +cloudpickle = "2.2.1" +docker = "*" +google-pasta = "*" +importlib-metadata = ">=1.4.0,<7.0" +jsonschema = "*" +numpy = ">=1.9.0,<2.0" +packaging = ">=20.0" +pandas = "*" +pathos = "*" +platformdirs = "*" +protobuf = ">=3.12,<5.0" +psutil = "*" +pyyaml = ">=6.0,<7.0" +requests = "*" +sagemaker-core = ">=1.0.0,<2.0.0" +schema = "*" +smdebug-rulesconfig = "1.0.1" +tblib = ">=1.7.0,<4" +tqdm = "*" +urllib3 = ">=1.26.8,<3.0.0" + +[package.extras] +all = ["accelerate (>=0.24.1,<=0.27.0)", "docker (>=5.0.2,<8.0.0)", "fastapi (>=0.111.0)", "nest-asyncio", "pyspark (==3.3.1)", "pyyaml (>=5.4.1,<7)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "scipy (==1.10.1)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)"] +feature-processor = ["pyspark (==3.3.1)", "sagemaker-feature-store-pyspark-3-3"] +huggingface = ["accelerate (>=0.24.1,<=0.27.0)", "fastapi (>=0.111.0)", "nest-asyncio", "sagemaker-schema-inference-artifacts (>=0.0.5)", "uvicorn (>=0.30.1)"] +local = ["docker (>=5.0.2,<8.0.0)", "pyyaml (>=5.4.1,<7)", "urllib3 (>=1.26.8,<3.0.0)"] +scipy = ["scipy (==1.10.1)"] +test = ["accelerate (>=0.24.1,<=0.27.0)", "apache-airflow (==2.9.3)", "apache-airflow-providers-amazon (==7.2.1)", "attrs (>=23.1.0,<24)", "awslogs (==0.14.0)", "black (==24.3.0)", "build[virtualenv] (==1.2.1)", "cloudpickle (==2.2.1)", "contextlib2 (==21.6.0)", "coverage (>=5.2,<6.2)", "docker (>=5.0.2,<8.0.0)", "fabric (==2.6.0)", "fastapi (>=0.111.0)", "flake8 (==4.0.1)", "huggingface-hub (>=0.23.4)", "jinja2 (==3.1.4)", "mlflow (>=2.12.2,<2.13)", "mock (==4.0.3)", "nbformat (>=5.9,<6)", "nest-asyncio", "numpy (>=1.24.0)", "onnx (>=1.15.0)", "pandas (>=1.3.5,<1.5)", "pillow (>=10.0.1,<=11)", "pyspark (==3.3.1)", "pytest (==6.2.5)", "pytest-cov (==3.0.0)", "pytest-rerunfailures (==10.2)", "pytest-timeout (==2.1.0)", "pytest-xdist (==2.4.0)", "pyvis (==0.2.1)", "pyyaml (==6.0)", "pyyaml (>=5.4.1,<7)", "requests (==2.32.2)", "sagemaker-experiments (==0.1.35)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "schema (==0.7.5)", "scikit-learn (==1.3.0)", "scipy (==1.10.1)", "stopit (==1.1.2)", "tensorflow (>=2.1,<=2.16)", "tox (==3.24.5)", "tritonclient[http] (<2.37.0)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)", "xgboost (>=1.6.2,<=1.7.6)"] + +[[package]] +name = "sagemaker-core" +version = "1.0.9" +description = "An python package for sagemaker core functionalities" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sagemaker_core-1.0.9-py3-none-any.whl", hash = "sha256:7a22c46cf93594f8d44e3523d4ba98407911f3530af68a8ffdde5082d3b26fa3"}, + {file = "sagemaker_core-1.0.9.tar.gz", hash = "sha256:664115faf797412553fb81b97a4777e78e51dfd4454c32edb2c8371bf203c535"}, +] + +[package.dependencies] +boto3 = ">=1.34.0,<2.0.0" +importlib-metadata = ">=1.4.0,<7.0" +jsonschema = "<5.0.0" +mock = ">4.0,<5.0" +platformdirs = ">=4.0.0,<5.0.0" +pydantic = ">=1.7.0,<3.0.0" +PyYAML = ">=6.0,<7.0" +rich = ">=13.0.0,<14.0.0" + +[package.extras] +codegen = ["black (>=24.3.0,<25.0.0)", "pandas (>=2.0.0,<3.0.0)", "pylint (>=3.0.0,<4.0.0)", "pytest (>=8.0.0,<9.0.0)"] + +[[package]] +name = "schema" +version = "0.7.7" +description = "Simple data validation library" +optional = false +python-versions = "*" +files = [ + {file = "schema-0.7.7-py2.py3-none-any.whl", hash = "sha256:5d976a5b50f36e74e2157b47097b60002bd4d42e65425fcc9c9befadb4255dde"}, + {file = "schema-0.7.7.tar.gz", hash = "sha256:7da553abd2958a19dc2547c388cde53398b39196175a9be59ea1caf5ab0a1807"}, +] + [[package]] name = "six" version = "1.16.0" @@ -864,6 +1594,17 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "smdebug-rulesconfig" +version = "1.0.1" +description = "SMDebug RulesConfig" +optional = false +python-versions = ">=2.7" +files = [ + {file = "smdebug_rulesconfig-1.0.1-py2.py3-none-any.whl", hash = "sha256:104da3e6931ecf879dfc687ca4bbb3bee5ea2bc27f4478e9dbb3ee3655f1ae61"}, + {file = "smdebug_rulesconfig-1.0.1.tar.gz", hash = "sha256:7a19e6eb2e6bcfefbc07e4a86ef7a88f32495001a038bf28c7d8e77ab793fcd6"}, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -875,6 +1616,17 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "tblib" +version = "3.0.0" +description = "Traceback serialization library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "tblib-3.0.0-py3-none-any.whl", hash = "sha256:80a6c77e59b55e83911e1e607c649836a69c103963c5f28a46cbeef44acf8129"}, + {file = "tblib-3.0.0.tar.gz", hash = "sha256:93622790a0a29e04f0346458face1e144dc4d32f493714c6c3dff82a4adb77e6"}, +] + [[package]] name = "tokenizers" version = "0.20.0" @@ -1084,6 +1836,17 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "tzdata" +version = "2024.2" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"}, + {file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"}, +] + [[package]] name = "urllib3" version = "1.26.20" @@ -1117,7 +1880,26 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "zipp" +version = "3.20.2" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"}, + {file = "zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +type = ["pytest-mypy"] + [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "64ad018aa9a30f1a565f720939b5f423557ddb112c5e2e39c9303d3ecc40a594" +content-hash = "a7846fa74e0faed852edf752d5fe80cb9861f040a1b8ce838bc2f6251ac22a1b" diff --git a/pyproject.toml b/pyproject.toml index fe2fc4191..f4370c6c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ requests = "^2.0.0" tokenizers = ">=0.15,<1" types-requests = "^2.0.0" typing_extensions = ">= 4.0.0" +sagemaker = "^2.232.1" [tool.poetry.dev-dependencies] mypy = "1.0.1" diff --git a/src/cohere/manually_maintained/cohere_aws/__init__.py b/src/cohere/manually_maintained/cohere_aws/__init__.py new file mode 100644 index 000000000..04899930a --- /dev/null +++ b/src/cohere/manually_maintained/cohere_aws/__init__.py @@ -0,0 +1,3 @@ +from .client import Client +from .error import CohereError +from .mode import Mode diff --git a/src/cohere/manually_maintained/cohere_aws/chat.py b/src/cohere/manually_maintained/cohere_aws/chat.py new file mode 100644 index 000000000..b16303442 --- /dev/null +++ b/src/cohere/manually_maintained/cohere_aws/chat.py @@ -0,0 +1,325 @@ +from cohere_aws.response import CohereObject +from cohere_aws.error import CohereError +from cohere_aws.mode import Mode +from typing import List, Optional, Generator, Dict, Any, Union +from enum import Enum +import json + +# Tools + +class ToolParameterDefinitionsValue(CohereObject, dict): + def __init__( + self, + type: str, + description: str, + required: Optional[bool] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.__dict__ = self + self.type = type + self.description = description + if required is not None: + self.required = required + + +class Tool(CohereObject, dict): + def __init__( + self, + name: str, + description: str, + parameter_definitions: Optional[Dict[str, ToolParameterDefinitionsValue]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.__dict__ = self + self.name = name + self.description = description + if parameter_definitions is not None: + self.parameter_definitions = parameter_definitions + + +class ToolCall(CohereObject, dict): + def __init__( + self, + name: str, + parameters: Dict[str, Any], + generation_id: str, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.__dict__ = self + self.name = name + self.parameters = parameters + self.generation_id = generation_id + + @classmethod + def from_dict(cls, tool_call_res: Dict[str, Any]) -> "ToolCall": + return cls( + name=tool_call_res.get("name"), + parameters=tool_call_res.get("parameters"), + generation_id=tool_call_res.get("generation_id"), + ) + + @classmethod + def from_list(cls, tool_calls_res: Optional[List[Dict[str, Any]]]) -> Optional[List["ToolCall"]]: + if tool_calls_res is None or not isinstance(tool_calls_res, list): + return None + + return [ToolCall.from_dict(tc) for tc in tool_calls_res] + +# Chat + +class Chat(CohereObject): + def __init__( + self, + response_id: str, + generation_id: str, + text: str, + chat_history: Optional[List[Dict[str, Any]]] = None, + preamble: Optional[str] = None, + finish_reason: Optional[str] = None, + token_count: Optional[Dict[str, int]] = None, + tool_calls: Optional[List[ToolCall]] = None, + citations: Optional[List[Dict[str, Any]]] = None, + documents: Optional[List[Dict[str, Any]]] = None, + search_results: Optional[List[Dict[str, Any]]] = None, + search_queries: Optional[List[Dict[str, Any]]] = None, + is_search_required: Optional[bool] = None, + ) -> None: + self.response_id = response_id + self.generation_id = generation_id + self.text = text + self.chat_history = chat_history + self.preamble = preamble + self.finish_reason = finish_reason + self.token_count = token_count + self.tool_calls = tool_calls + self.citations = citations + self.documents = documents + self.search_results = search_results + self.search_queries = search_queries + self.is_search_required = is_search_required + + @classmethod + def from_dict(cls, response: Dict[str, Any]) -> "Chat": + return cls( + response_id=response["response_id"], + generation_id=response.get("generation_id"), # optional + text=response.get("text"), + chat_history=response.get("chat_history"), # optional + preamble=response.get("preamble"), # optional + token_count=response.get("token_count"), + is_search_required=response.get("is_search_required"), # optional + citations=response.get("citations"), # optional + documents=response.get("documents"), # optional + search_results=response.get("search_results"), # optional + search_queries=response.get("search_queries"), # optional + finish_reason=response.get("finish_reason"), + tool_calls=ToolCall.from_list(response.get("tool_calls")), # optional + ) + +# ---------------| +# Steaming event | +# ---------------| + +class StreamEvent(str, Enum): + STREAM_START = "stream-start" + SEARCH_QUERIES_GENERATION = "search-queries-generation" + SEARCH_RESULTS = "search-results" + TEXT_GENERATION = "text-generation" + TOOL_CALLS_GENERATION = "tool-calls-generation" + CITATION_GENERATION = "citation-generation" + STREAM_END = "stream-end" + +class StreamResponse(CohereObject): + def __init__( + self, + is_finished: bool, + event_type: Union[StreamEvent, str], + index: Optional[int], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.is_finished = is_finished + self.index = index + self.event_type = event_type + + +class StreamStart(StreamResponse): + def __init__( + self, + generation_id: str, + conversation_id: Optional[str], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.generation_id = generation_id + self.conversation_id = conversation_id + + +class StreamTextGeneration(StreamResponse): + def __init__( + self, + text: str, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.text = text + + +class StreamCitationGeneration(StreamResponse): + def __init__( + self, + citations: Optional[List[Dict[str, Any]]], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.citations = citations + + +class StreamQueryGeneration(StreamResponse): + def __init__( + self, + search_queries: Optional[List[Dict[str, Any]]], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.search_queries = search_queries + + +class StreamSearchResults(StreamResponse): + def __init__( + self, + search_results: Optional[List[Dict[str, Any]]], + documents: Optional[List[Dict[str, Any]]], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.search_results = search_results + self.documents = documents + + +class StreamEnd(StreamResponse): + def __init__( + self, + finish_reason: str, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.finish_reason = finish_reason + + +class ChatToolCallsGenerationEvent(StreamResponse): + def __init__( + self, + tool_calls: Optional[List[ToolCall]], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.tool_calls = tool_calls + +class StreamingChat(CohereObject): + def __init__(self, stream_response, mode): + self.stream_response = stream_response + self.text = None + self.response_id = None + self.generation_id = None + self.preamble = None + self.prompt = None + self.chat_history = None + self.finish_reason = None + self.token_count = None + self.is_search_required = None + self.citations = None + self.documents = None + self.search_results = None + self.search_queries = None + self.tool_calls = None + + self.bytes = bytearray() + if mode == Mode.SAGEMAKER: + self.payload_key = "PayloadPart" + self.bytes_key = "Bytes" + elif mode == Mode.BEDROCK: + self.payload_key = "chunk" + self.bytes_key = "bytes" + + def _make_response_item(self, index, streaming_item) -> Any: + event_type = streaming_item.get("event_type") + + if event_type == StreamEvent.STREAM_START: + self.conversation_id = streaming_item.get("conversation_id") + self.generation_id = streaming_item.get("generation_id") + return StreamStart( + conversation_id=self.conversation_id, + generation_id=self.generation_id, + is_finished=False, + event_type=event_type, + index=index, + ) + elif event_type == StreamEvent.SEARCH_QUERIES_GENERATION: + search_queries = streaming_item.get("search_queries") + return StreamQueryGeneration( + search_queries=search_queries, is_finished=False, event_type=event_type, index=index + ) + elif event_type == StreamEvent.SEARCH_RESULTS: + search_results = streaming_item.get("search_results") + documents = streaming_item.get("documents") + return StreamSearchResults( + search_results=search_results, + documents=documents, + is_finished=False, + event_type=event_type, + index=index, + ) + elif event_type == StreamEvent.TEXT_GENERATION: + text = streaming_item.get("text") + return StreamTextGeneration(text=text, is_finished=False, event_type=event_type, index=index) + elif event_type == StreamEvent.CITATION_GENERATION: + citations = streaming_item.get("citations") + return StreamCitationGeneration(citations=citations, is_finished=False, event_type=event_type, index=index) + elif event_type == StreamEvent.TOOL_CALLS_GENERATION: + tool_calls = ToolCall.from_list(streaming_item.get("tool_calls")) + return ChatToolCallsGenerationEvent( + tool_calls=tool_calls, is_finished=False, event_type=event_type, index=index + ) + elif event_type == StreamEvent.STREAM_END: + response = streaming_item.get("response") + finish_reason = streaming_item.get("finish_reason") + self.finish_reason = finish_reason + + if response is None: + return None + + self.response_id = response.get("response_id") + self.conversation_id = response.get("conversation_id") + self.text = response.get("text") + self.generation_id = response.get("generation_id") + self.preamble = response.get("preamble") + self.prompt = response.get("prompt") + self.chat_history = response.get("chat_history") + self.token_count = response.get("token_count") + self.is_search_required = response.get("is_search_required") # optional + self.citations = response.get("citations") # optional + self.documents = response.get("documents") # optional + self.search_results = response.get("search_results") # optional + self.search_queries = response.get("search_queries") # optional + self.tool_calls = ToolCall.from_list(response.get("tool_calls")) # optional + return StreamEnd(finish_reason=finish_reason, is_finished=True, event_type=event_type, index=index) + return None + + def __iter__(self) -> Generator[StreamResponse, None, None]: + index = 0 + for payload in self.stream_response: + self.bytes.extend(payload[self.payload_key][self.bytes_key]) + try: + item = self._make_response_item(index, json.loads(self.bytes)) + except json.decoder.JSONDecodeError: + # payload contained only a partion JSON object + continue + + self.bytes = bytearray() + if item is not None: + index += 1 + yield item diff --git a/src/cohere/manually_maintained/cohere_aws/classification.py b/src/cohere/manually_maintained/cohere_aws/classification.py new file mode 100644 index 000000000..a10371532 --- /dev/null +++ b/src/cohere/manually_maintained/cohere_aws/classification.py @@ -0,0 +1,60 @@ +from cohere_aws.response import CohereObject +from typing import Any, Dict, Iterator, List, Literal, Union + +Prediction = Union[str, int, List[str], List[int]] +ClassificationDict = Dict[Literal["prediction", "confidence", "text"], Any] + + +class Classification(CohereObject): + def __init__(self, classification: Union[Prediction, ClassificationDict]) -> None: + # Prediction is the old format (version 1 of classification-finetuning) + # ClassificationDict is the new format (version 2 of classification-finetuning). + # It also contains the original text and the labels' confidence scores of the prediction + self.classification = classification + + def is_multilabel(self) -> bool: + if isinstance(self.classification, list): + return True + elif isinstance(self.classification, (int, str)): + return False + return isinstance(self.classification["prediction"], list) + + @property + def prediction(self) -> Prediction: + if isinstance(self.classification, (list, int, str)): + return self.classification + return self.classification["prediction"] + + @property + def confidence(self) -> List[float]: + if isinstance(self.classification, (list, int, str)): + raise ValueError( + "Confidence scores are not available for version prior to 2.0 of Cohere Classification Finetuning AWS package" + ) + return self.classification["confidence"] + + @property + def text(self) -> str: + if isinstance(self.classification, (list, int, str)): + raise ValueError( + "Original text is not available for version prior to 2.0 of Cohere Classification Finetuning AWS package" + ) + return self.classification["text"] + + +class Classifications(CohereObject): + def __init__(self, classifications: List[Classification]) -> None: + self.classifications = classifications + if len(self.classifications) > 0: + assert all( + [c.is_multilabel() == self.is_multilabel() for c in self.classifications] + ), "All classifications must be of the same type (single-label or multi-label)" + + def __iter__(self) -> Iterator: + return iter(self.classifications) + + def __len__(self) -> int: + return len(self.classifications) + + def is_multilabel(self) -> bool: + return len(self.classifications) > 0 and self.classifications[0].is_multilabel() diff --git a/src/cohere/manually_maintained/cohere_aws/client.py b/src/cohere/manually_maintained/cohere_aws/client.py new file mode 100644 index 000000000..1e74a8b2d --- /dev/null +++ b/src/cohere/manually_maintained/cohere_aws/client.py @@ -0,0 +1,974 @@ +import json +import os +import tarfile +import tempfile +import time +from typing import Any, Dict, List, Optional, Tuple, Union + +import boto3 +import sagemaker as sage +from botocore.exceptions import (ClientError, EndpointConnectionError, + ParamValidationError) +from sagemaker.s3 import S3Downloader, S3Uploader, parse_s3_url + +from cohere_aws.classification import Classification, Classifications +from cohere_aws.embeddings import Embeddings +from cohere_aws.error import CohereError +from cohere_aws.generation import (Generation, Generations, + StreamingGenerations, + TokenLikelihood) +from cohere_aws.chat import Chat, StreamingChat +from cohere_aws.rerank import Reranking +from cohere_aws.summary import Summary +from cohere_aws.mode import Mode + + +class Client: + def __init__(self, endpoint_name: Optional[str] = None, + region_name: Optional[str] = None, + mode: Optional[Mode] = Mode.SAGEMAKER): + """ + By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with + `aws configure set region us-west-2` or override it with `region_name` parameter. + """ + self._endpoint_name = endpoint_name # deprecated, should use self.connect_to_endpoint() instead + + if mode == Mode.SAGEMAKER: + self._client = boto3.client("sagemaker-runtime", region_name=region_name) + self._service_client = boto3.client("sagemaker", region_name=region_name) + self._sess = sage.Session(sagemaker_client=self._service_client) + elif mode == Mode.BEDROCK: + if not region_name: + region_name = boto3.Session().region_name + self._client = boto3.client( + service_name="bedrock-runtime", + region_name=region_name, + ) + self._service_client = boto3.client("bedrock", region_name=region_name) + else: + raise CohereError("Unsupported mode") + self.mode = mode + + + def _does_endpoint_exist(self, endpoint_name: str) -> bool: + try: + self._service_client.describe_endpoint(EndpointName=endpoint_name) + except ClientError: + return False + return True + + def connect_to_endpoint(self, endpoint_name: str) -> None: + """Connects to an existing SageMaker endpoint. + + Args: + endpoint_name (str): The name of the endpoint. + + Raises: + CohereError: Connection to the endpoint failed. + """ + if not self._does_endpoint_exist(endpoint_name): + raise CohereError(f"Endpoint {endpoint_name} does not exist.") + self._endpoint_name = endpoint_name + + def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str: + """ + Compress an S3 folder which contains one or several fine-tuned models to a tar file. + If the S3 folder contains only one fine-tuned model, it simply returns the path to that model. + If the S3 folder contains several fine-tuned models, it download all models, aggregates them into a single + tar.gz file. + + Args: + s3_models_dir (str): S3 URI pointing to a folder + + Returns: + str: S3 URI pointing to the `models.tar.gz` file + """ + + s3_models_dir = s3_models_dir.rstrip("/") + "/" + + # Links of all fine-tuned models in s3_models_dir. Their format should be .tar.gz + s3_tar_models = [ + s3_path + for s3_path in S3Downloader.list(s3_models_dir, sagemaker_session=self._sess) + if ( + s3_path.endswith(".tar.gz") # only .tar.gz files + and (s3_path.split("/")[-1] != "models.tar.gz") # exclude the .tar.gz file we are creating + and (s3_path.rsplit("/", 1)[0] == s3_models_dir[:-1]) # only files at the root of s3_models_dir + ) + ] + + if len(s3_tar_models) == 0: + raise CohereError(f"No fine-tuned models found in {s3_models_dir}") + elif len(s3_tar_models) == 1: + print(f"Found one fine-tuned model: {s3_tar_models[0]}") + return s3_tar_models[0] + + # More than one fine-tuned model found, need to aggregate them into a single .tar.gz file + with tempfile.TemporaryDirectory() as tmpdir: + local_tar_models_dir = os.path.join(tmpdir, "tar") + local_models_dir = os.path.join(tmpdir, "models") + + # Download and extract all fine-tuned models + for s3_tar_model in s3_tar_models: + print(f"Adding fine-tuned model: {s3_tar_model}") + S3Downloader.download(s3_tar_model, local_tar_models_dir, sagemaker_session=self._sess) + with tarfile.open(os.path.join(local_tar_models_dir, s3_tar_model.split("/")[-1])) as tar: + tar.extractall(local_models_dir) + + # Compress local_models_dir to a tar.gz file + model_tar = os.path.join(tmpdir, "models.tar.gz") + with tarfile.open(model_tar, "w:gz") as tar: + tar.add(local_models_dir, arcname=".") + + # Upload the new tarfile containing all models to s3 + # Very important to remove the trailing slash from s3_models_dir otherwise it just doesn't upload + model_tar_s3 = S3Uploader.upload(model_tar, s3_models_dir[:-1], sagemaker_session=self._sess) + + # sanity check + assert s3_models_dir + "models.tar.gz" in S3Downloader.list(s3_models_dir, sagemaker_session=self._sess) + + return model_tar_s3 + + def create_endpoint( + self, + arn: str, + endpoint_name: str, + s3_models_dir: Optional[str] = None, + instance_type: str = "ml.g4dn.xlarge", + n_instances: int = 1, + recreate: bool = False, + role: Optional[str] = None, + ) -> None: + """Creates and deploys a SageMaker endpoint. + + Args: + arn (str): The product ARN. Refers to a ready-to-use model (model package) or a fine-tuned model + (algorithm). + endpoint_name (str): The name of the endpoint. + s3_models_dir (str, optional): S3 URI pointing to the folder containing fine-tuned models. Defaults to None. + instance_type (str, optional): The EC2 instance type to deploy the endpoint to. Defaults to "ml.g4dn.xlarge". + n_instances (int, optional): Number of endpoint instances. Defaults to 1. + recreate (bool, optional): Force re-creation of endpoint if it already exists. Defaults to False. + rool (str, optional): The IAM role to use for the endpoint. If not provided, sagemaker.get_execution_role() + will be used to get the role. This should work when one uses the client inside SageMaker. If this errors + out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker. + """ + # First, check if endpoint already exists + if self._does_endpoint_exist(endpoint_name): + if recreate: + self.connect_to_endpoint(endpoint_name) + self.delete_endpoint() + else: + raise CohereError(f"Endpoint {endpoint_name} already exists and recreate={recreate}.") + + kwargs = {} + model_data = None + validation_params = dict() + if s3_models_dir is not None: + # If s3_models_dir is given, we assume to have custom fine-tuned models -> Algorithm + kwargs["algorithm_arn"] = arn + model_data = self._s3_models_dir_to_tarfile(s3_models_dir) + else: + # If no s3_models_dir is given, we assume to use a pre-trained model -> ModelPackage + kwargs["model_package_arn"] = arn + + # For now only non-finetuned models can use these timeouts + validation_params = dict( + model_data_download_timeout=2400, + container_startup_health_check_timeout=2400 + ) + + # Out of precaution, check if there is an endpoint config and delete it if that's the case + # Otherwise it might block deployment + try: + self._service_client.delete_endpoint_config(EndpointConfigName=endpoint_name) + except ClientError: + pass + + if role is None: + try: + role = sage.get_execution_role() + except ValueError: + print("Using default role: 'ServiceRoleSagemaker'.") + role = "ServiceRoleSagemaker" + + model = sage.ModelPackage( + role=role, + model_data=model_data, + sagemaker_session=self._sess, # makes sure the right region is used + **kwargs + ) + + try: + model.deploy( + n_instances, + instance_type, + endpoint_name=endpoint_name, + **validation_params + ) + except ParamValidationError: + # For at least some versions of python 3.6, SageMaker SDK does not support the validation_params + model.deploy(n_instances, instance_type, endpoint_name=endpoint_name) + self.connect_to_endpoint(endpoint_name) + + def chat( + self, + message: str, + stream: Optional[bool] = False, + preamble: Optional[str] = None, + chat_history: Optional[List[Dict[str, Any]]] = None, + # should only be passed for stacked finetune deployment + model: Optional[str] = None, + # should only be passed for Bedrock mode; ignored otherwise + model_id: Optional[str] = None, + temperature: Optional[float] = None, + p: Optional[float] = None, + k: Optional[float] = None, + max_tokens: Optional[int] = None, + search_queries_only: Optional[bool] = None, + documents: Optional[List[Dict[str, Any]]] = None, + prompt_truncation: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_results: Optional[List[Dict[str, Any]]] = None, + raw_prompting: Optional[bool] = False, + return_prompt: Optional[bool] = False, + variant: Optional[str] = None, + ) -> Union[Chat, StreamingChat]: + """Returns a Chat object with the query reply. + + Args: + message (str): The message to send to the chatbot. + + stream (bool): Return streaming tokens. + + preamble (str): (Optional) A string to override the preamble. + chat_history (List[Dict[str, str]]): (Optional) A list of entries used to construct the conversation. If provided, these messages will be used to build the prompt and the conversation_id will be ignored so no data will be stored to maintain state. + + model (str): (Optional) The model to use for generating the response. Should only be passed for stacked finetune deployment. + model_id (str): (Optional) The model to use for generating the response. Should only be passed for Bedrock mode; ignored otherwise. + temperature (float): (Optional) The temperature to use for the response. The higher the temperature, the more random the response. + p (float): (Optional) The nucleus sampling probability. + k (float): (Optional) The top-k sampling probability. + max_tokens (int): (Optional) The max tokens generated for the next reply. + + search_queries_only (bool): (Optional) When true, the response will only contain a list of generated `search_queries`, no reply from the model to the user's message will be generated. + documents (List[Dict[str, str]]): (Optional) Documents to use to generate grounded response with citations. Example: + documents=[ + { + "id": "national_geographic_everest", + "title": "Height of Mount Everest", + "snippet": "The height of Mount Everest is 29,035 feet", + "url": "https://education.nationalgeographic.org/resource/mount-everest/", + }, + { + "id": "national_geographic_mariana", + "title": "Depth of the Mariana Trench", + "snippet": "The depth of the Mariana Trench is 36,070 feet", + "url": "https://www.nationalgeographic.org/activity/mariana-trench-deepest-place-earth", + }, + ], + prompt_truncation (str) (Optional): Defaults to `OFF`. Dictates how the prompt will be constructed. With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be raised. + Returns: + a Chat object if stream=False, or a StreamingChat object if stream=True + + Examples: + A simple chat message: + >>> res = co.chat(message="Hey! How are you doing today?") + >>> print(res.text) + Streaming chat: + >>> res = co.chat( + >>> message="Hey! How are you doing today?", + >>> stream=True) + >>> for token in res: + >>> print(token) + Stateless chat with chat history: + >>> res = co.chat( + >>> chat_history=[ + >>> {'role': 'User', message': 'Hey! How are you doing today?'}, + >>> {'role': 'Chatbot', message': 'I am doing great! How can I help you?'}, + >>> message="Tell me a joke!", + >>> ]) + >>> print(res.text) + Chat message with documents to use to generate the response: + >>> res = co.chat( + >>> "How deep in the Mariana Trench", + >>> documents=[ + >>> { + >>> "id": "national_geographic_everest", + >>> "title": "Height of Mount Everest", + >>> "snippet": "The height of Mount Everest is 29,035 feet", + >>> "url": "https://education.nationalgeographic.org/resource/mount-everest/", + >>> }, + >>> { + >>> "id": "national_geographic_mariana", + >>> "title": "Depth of the Mariana Trench", + >>> "snippet": "The depth of the Mariana Trench is 36,070 feet", + >>> "url": "https://www.nationalgeographic.org/activity/mariana-trench-deepest-place-earth", + >>> }, + >>> ]) + >>> print(res.text) + >>> print(res.citations) + >>> print(res.documents) + Generate search queries for fetching documents to use in chat: + >>> res = co.chat( + >>> "What is the height of Mount Everest?", + >>> search_queries_only=True) + >>> if res.is_search_required: + >>> print(res.search_queries) + """ + + if self.mode == Mode.SAGEMAKER and self._endpoint_name is None: + raise CohereError("No endpoint connected. " + "Run connect_to_endpoint() first.") + json_params = { + "model": model, + "message": message, + "chat_history": chat_history, + "preamble": preamble, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": stream, + "p": p, + "k": k, + "tools": tools, + "tool_results": tool_results, + "search_queries_only": search_queries_only, + "documents": documents, + "raw_prompting": raw_prompting, + "return_prompt": return_prompt, + "prompt_truncation": prompt_truncation + } + + for key, value in list(json_params.items()): + if value is None: + del json_params[key] + + if self.mode == Mode.SAGEMAKER: + return self._sagemaker_chat(json_params, variant) + elif self.mode == Mode.BEDROCK: + return self._bedrock_chat(json_params, model_id) + else: + raise CohereError("Unsupported mode") + + def _sagemaker_chat(self, json_params: Dict[str, Any], variant: str) : + json_body = json.dumps(json_params) + params = { + 'EndpointName': self._endpoint_name, + 'ContentType': 'application/json', + 'Body': json_body, + } + if variant: + params['TargetVariant'] = variant + + try: + if json_params['stream']: + result = self._client.invoke_endpoint_with_response_stream( + **params) + return StreamingChat(result['Body'], self.mode) + else: + result = self._client.invoke_endpoint(**params) + return Chat.from_dict(json.loads(result['Body'].read().decode())) + except EndpointConnectionError as e: + raise CohereError(str(e)) + except Exception as e: + # TODO should be client error - distinct type from CohereError? + # ValidationError, e.g. when variant is bad + raise CohereError(str(e)) + + def _bedrock_chat(self, json_params: Dict[str, Any], model_id: str) : + if not model_id: + raise CohereError("must supply model_id arg when calling bedrock") + if json_params['stream']: + stream = json_params['stream'] + else: + stream = False + # Bedrock does not expect the stream key to be present in the body, use invoke_model_with_response_stream to indicate stream mode + del json_params['stream'] + + json_body = json.dumps(json_params) + params = { + 'body': json_body, + 'modelId': model_id, + } + + try: + if stream: + result = self._client.invoke_model_with_response_stream( + **params) + return StreamingChat(result['body'], self.mode) + else: + result = self._client.invoke_model(**params) + return Chat.from_dict( + json.loads(result['body'].read().decode())) + except EndpointConnectionError as e: + raise CohereError(str(e)) + except Exception as e: + # TODO should be client error - distinct type from CohereError? + # ValidationError, e.g. when variant is bad + raise CohereError(str(e)) + + def generate( + self, + prompt: str, + # should only be passed for stacked finetune deployment + model: Optional[str] = None, + # should only be passed for Bedrock mode; ignored otherwise + model_id: Optional[str] = None, + # requires DB with presets + # preset: str = None, + num_generations: int = 1, + max_tokens: int = 400, + temperature: float = 1.0, + k: int = 0, + p: float = 0.75, + stop_sequences: Optional[List[str]] = None, + return_likelihoods: Optional[str] = None, + truncate: Optional[str] = None, + variant: Optional[str] = None, + stream: Optional[bool] = True, + ) -> Union[Generations, StreamingGenerations]: + if self.mode == Mode.SAGEMAKER and self._endpoint_name is None: + raise CohereError("No endpoint connected. " + "Run connect_to_endpoint() first.") + + json_params = { + 'model': model, + 'prompt': prompt, + 'max_tokens': max_tokens, + 'temperature': temperature, + 'k': k, + 'p': p, + 'stop_sequences': stop_sequences, + 'return_likelihoods': return_likelihoods, + 'truncate': truncate, + 'stream': stream, + } + for key, value in list(json_params.items()): + if value is None: + del json_params[key] + + if self.mode == Mode.SAGEMAKER: + # TODO: Bedrock should support this param too + json_params['num_generations'] = num_generations + return self._sagemaker_generations(json_params, variant) + elif self.mode == Mode.BEDROCK: + return self._bedrock_generations(json_params, model_id) + else: + raise CohereError("Unsupported mode") + + def _sagemaker_generations(self, json_params: Dict[str, Any], variant: str) : + json_body = json.dumps(json_params) + params = { + 'EndpointName': self._endpoint_name, + 'ContentType': 'application/json', + 'Body': json_body, + } + if variant: + params['TargetVariant'] = variant + + try: + if json_params['stream']: + result = self._client.invoke_endpoint_with_response_stream( + **params) + return StreamingGenerations(result['Body'], self.mode) + else: + result = self._client.invoke_endpoint(**params) + return Generations( + json.loads(result['Body'].read().decode())['generations']) + except EndpointConnectionError as e: + raise CohereError(str(e)) + except Exception as e: + # TODO should be client error - distinct type from CohereError? + # ValidationError, e.g. when variant is bad + raise CohereError(str(e)) + + def _bedrock_generations(self, json_params: Dict[str, Any], model_id: str) : + if not model_id: + raise CohereError("must supply model_id arg when calling bedrock") + json_body = json.dumps(json_params) + params = { + 'body': json_body, + 'modelId': model_id, + } + + try: + if json_params['stream']: + result = self._client.invoke_model_with_response_stream( + **params) + return StreamingGenerations(result['body'], self.mode) + else: + result = self._client.invoke_model(**params) + return Generations( + json.loads(result['body'].read().decode())['generations']) + except EndpointConnectionError as e: + raise CohereError(str(e)) + except Exception as e: + # TODO should be client error - distinct type from CohereError? + # ValidationError, e.g. when variant is bad + raise CohereError(str(e)) + + def embed( + self, + texts: List[str], + truncate: Optional[str] = None, + variant: Optional[str] = None, + input_type: Optional[str] = None, + model_id: Optional[str] = None, + ) -> Embeddings: + json_params = { + 'texts': texts, + 'truncate': truncate, + "input_type": input_type + } + for key, value in list(json_params.items()): + if value is None: + del json_params[key] + + if self.mode == Mode.SAGEMAKER: + return self._sagemaker_embed(json_params, variant) + elif self.mode == Mode.BEDROCK: + return self._bedrock_embed(json_params, model_id) + else: + raise CohereError("Unsupported mode") + + def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str): + if self._endpoint_name is None: + raise CohereError("No endpoint connected. " + "Run connect_to_endpoint() first.") + + json_body = json.dumps(json_params) + params = { + 'EndpointName': self._endpoint_name, + 'ContentType': 'application/json', + 'Body': json_body, + } + if variant: + params['TargetVariant'] = variant + + try: + result = self._client.invoke_endpoint(**params) + response = json.loads(result['Body'].read().decode()) + except EndpointConnectionError as e: + raise CohereError(str(e)) + except Exception as e: + # TODO should be client error - distinct type from CohereError? + # ValidationError, e.g. when variant is bad + raise CohereError(str(e)) + + return Embeddings(response['embeddings']) + + def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str): + if not model_id: + raise CohereError("must supply model_id arg when calling bedrock") + json_body = json.dumps(json_params) + params = { + 'body': json_body, + 'modelId': model_id, + } + + try: + result = self._client.invoke_model(**params) + response = json.loads(result['body'].read().decode()) + except EndpointConnectionError as e: + raise CohereError(str(e)) + except Exception as e: + # TODO should be client error - distinct type from CohereError? + # ValidationError, e.g. when variant is bad + raise CohereError(str(e)) + + return Embeddings(response['embeddings']) + + + def rerank(self, + query: str, + documents: Union[List[str], List[Dict[str, Any]]], + top_n: Optional[int] = None, + variant: Optional[str] = None, + max_chunks_per_doc: Optional[int] = None, + rank_fields: Optional[List[str]] = None) -> Reranking: + """Returns an ordered list of documents oridered by their relevance to the provided query + Args: + query (str): The search query + documents (list[str], list[dict]): The documents to rerank + top_n (int): (optional) The number of results to return, defaults to return all results + max_chunks_per_doc (int): (optional) The maximum number of chunks derived from a document + rank_fields (list[str]): (optional) The fields used for reranking. This parameter is only supported for rerank v3 models + """ + + if self._endpoint_name is None: + raise CohereError("No endpoint connected. " + "Run connect_to_endpoint() first.") + + parsed_docs = [] + for doc in documents: + if isinstance(doc, str): + parsed_docs.append({'text': doc}) + elif isinstance(doc, dict): + parsed_docs.append(doc) + else: + raise CohereError( + message='invalid format for documents, must be a list of strings or dicts') + + json_params = { + "query": query, + "documents": parsed_docs, + "top_n": top_n, + "return_documents": False, + "max_chunks_per_doc" : max_chunks_per_doc, + "rank_fields": rank_fields + } + json_body = json.dumps(json_params) + + params = { + 'EndpointName': self._endpoint_name, + 'ContentType': 'application/json', + 'Body': json_body, + } + if variant is not None: + params['TargetVariant'] = variant + + try: + result = self._client.invoke_endpoint(**params) + response = json.loads(result['Body'].read().decode()) + reranking = Reranking(response) + for rank in reranking.results: + rank.document = parsed_docs[rank.index] + except EndpointConnectionError as e: + raise CohereError(str(e)) + except Exception as e: + # TODO should be client error - distinct type from CohereError? + # ValidationError, e.g. when variant is bad + raise CohereError(str(e)) + + return reranking + + def classify(self, input: List[str], name: str) -> Classifications: + + if self._endpoint_name is None: + raise CohereError("No endpoint connected. " + "Run connect_to_endpoint() first.") + + json_params = {"texts": input, "model_id": name} + json_body = json.dumps(json_params) + + params = { + "EndpointName": self._endpoint_name, + "ContentType": "application/json", + "Body": json_body, + } + + try: + result = self._client.invoke_endpoint(**params) + response = json.loads(result["Body"].read().decode()) + except EndpointConnectionError as e: + raise CohereError(str(e)) + except Exception as e: + # TODO should be client error - distinct type from CohereError? + # ValidationError, e.g. when variant is bad + raise CohereError(str(e)) + + return Classifications([Classification(classification) for classification in response]) + + def create_finetune( + self, + name: str, + train_data: str, + s3_models_dir: str, + arn: Optional[str] = None, + eval_data: Optional[str] = None, + instance_type: str = "ml.g4dn.xlarge", + training_parameters: Dict[str, Any] = {}, # Optional, training algorithm specific hyper-parameters + role: Optional[str] = None, + base_model_id: Optional[str] = None, + ) -> Optional[str]: + """Creates a fine-tuning job and returns an optional fintune job ID. + + Args: + name (str): The name to give to the fine-tuned model. + train_data (str): An S3 path pointing to the training data. + s3_models_dir (str): An S3 path pointing to the directory where the fine-tuned model will be saved. + arn (str, optional): The product ARN of the fine-tuning package. Required in Sagemaker mode and ignored otherwise + eval_data (str, optional): An S3 path pointing to the eval data. Defaults to None. + instance_type (str, optional): The EC2 instance type to use for training. Defaults to "ml.g4dn.xlarge". + training_parameters (Dict[str, Any], optional): Additional training parameters. Defaults to {}. + role (str, optional): The IAM role to use for the endpoint. + In Bedrock this mode is required and is used to access s3 input and output data. + If not provided in sagemaker, sagemaker.get_execution_role()will be used to get the role. + This should work when one uses the client inside SageMaker. If this errors + out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker. + base_model_id (str, optional): The ID of the Bedrock base model to finetune with. Required in Bedrock mode and ignored otherwise. + """ + assert name != "model", "name cannot be 'model'" + + if self.mode == Mode.BEDROCK: + return self._bedrock_create_finetune(name=name, train_data=train_data, s3_models_dir=s3_models_dir, base_model=base_model_id, eval_data=eval_data, training_parameters=training_parameters, role=role) + + s3_models_dir = s3_models_dir.rstrip("/") + "/" + + if role is None: + try: + role = sage.get_execution_role() + except ValueError: + print("Using default role: 'ServiceRoleSagemaker'.") + role = "ServiceRoleSagemaker" + + training_parameters.update({"name": name}) + estimator = sage.algorithm.AlgorithmEstimator( + algorithm_arn=arn, + role=role, + instance_count=1, + instance_type=instance_type, + sagemaker_session=self._sess, + output_path=s3_models_dir, + hyperparameters=training_parameters, + ) + + inputs = {} + if not train_data.startswith("s3:"): + raise ValueError("train_data must point to an S3 location.") + inputs["training"] = train_data + if eval_data is not None: + if not eval_data.startswith("s3:"): + raise ValueError("eval_data must point to an S3 location.") + inputs["evaluation"] = eval_data + estimator.fit(inputs=inputs) + job_name = estimator.latest_training_job.name + + current_filepath = f"{s3_models_dir}{job_name}/output/model.tar.gz" + + s3_resource = boto3.resource("s3") + + # Copy new model to root of output_model_dir + bucket, old_key = parse_s3_url(current_filepath) + _, new_key = parse_s3_url(f"{s3_models_dir}{name}.tar.gz") + s3_resource.Object(bucket, new_key).copy(CopySource={"Bucket": bucket, "Key": old_key}) + + # Delete old dir + bucket, old_short_key = parse_s3_url(s3_models_dir + job_name) + s3_resource.Bucket(bucket).objects.filter(Prefix=old_short_key).delete() + + def export_finetune( + self, + name: str, + s3_checkpoint_dir: str, + s3_output_dir: str, + arn: str, + instance_type: str = "ml.p4de.24xlarge", + role: Optional[str] = None, + ) -> None: + """Export the merged weights to the TensorRT-LLM inference engine. + + Args: + name (str): The name used while writing the exported model to the output directory. + s3_checkpoint_dir (str): An S3 path pointing to the directory of the model checkpoint (merged weights). + s3_output_dir (str): An S3 path pointing to the directory where the TensorRT-LLM engine will be saved. + arn (str): The product ARN of the bring your own finetuning algorithm. + instance_type (str, optional): The EC2 instance type to use for export. Defaults to "ml.p4de.24xlarge". + role (str, optional): The IAM role to use for export. + If not provided, sagemaker.get_execution_role() will be used to get the role. + This should work when one uses the client inside SageMaker. If this errors out, + the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker. + """ + if name == "model": + raise ValueError("name cannot be 'model'") + + s3_output_dir = s3_output_dir.rstrip("/") + "/" + + if role is None: + try: + role = sage.get_execution_role() + except ValueError: + print("Using default role: 'ServiceRoleSagemaker'.") + role = "ServiceRoleSagemaker" + + export_parameters = {"name": name} + + estimator = sage.algorithm.AlgorithmEstimator( + algorithm_arn=arn, + role=role, + instance_count=1, + instance_type=instance_type, + sagemaker_session=self._sess, + output_path=s3_output_dir, + hyperparameters=export_parameters, + ) + + if not s3_checkpoint_dir.startswith("s3:"): + raise ValueError("s3_checkpoint_dir must point to an S3 location.") + inputs = {"checkpoint": s3_checkpoint_dir} + + estimator.fit(inputs=inputs) + + job_name = estimator.latest_training_job.name + current_filepath = f"{s3_output_dir}{job_name}/output/model.tar.gz" + + s3_resource = boto3.resource("s3") + + # Copy the exported TensorRT-LLM engine to the root of s3_output_dir + bucket, old_key = parse_s3_url(current_filepath) + _, new_key = parse_s3_url(f"{s3_output_dir}{name}.tar.gz") + s3_resource.Object(bucket, new_key).copy(CopySource={"Bucket": bucket, "Key": old_key}) + + # Delete the old S3 directory + bucket, old_short_key = parse_s3_url(f"{s3_output_dir}{job_name}") + s3_resource.Bucket(bucket).objects.filter(Prefix=old_short_key).delete() + + def wait_for_finetune_job(self, job_id: str, timeout: int = 2*60*60) -> str: + """Waits for a finetune job to complete and returns a model arn if complete. Throws an exception if timeout occurs or if job does not complete successfully + Args: + job_id (str): The arn of the model customization job + timeout(int, optional): Timeout in seconds + """ + end = time.time() + timeout + while True: + customization_job = self._service_client.get_model_customization_job(jobIdentifier=job_id) + job_status = customization_job["status"] + if job_status in ["Completed", "Failed", "Stopped"]: + break + if time.time() > end: + raise CohereError("could not complete finetune within timeout") + time.sleep(10) + + if job_status != "Completed": + raise CohereError(f"finetune did not finish successfuly, ended with {job_status} status") + return customization_job["outputModelArn"] + + def provision_throughput( + self, + model_id: str, + name: str, + model_units: int, + commitment_duration: Optional[str] = None + ) -> str: + """Returns the provisined model arn + Args: + model_id (str): The ID or ARN of the model to provision + name (str): Name of the provisioned throughput model + model_units (int): Number of units to provision + commitment_duration (str, optional): Commitment duration, one of ("OneMonth", "SixMonths"), defaults to no commitment if unspecified + """ + if self.mode != Mode.BEDROCK: + raise ValueError("can only provision throughput in bedrock") + kwargs = {} + if commitment_duration: + kwargs["commitmentDuration"] = commitment_duration + + response = self._service_client.create_provisioned_model_throughput( + provisionedModelName=name, + modelId=model_id, + modelUnits=model_units, + **kwargs + ) + return response["provisionedModelArn"] + + def _bedrock_create_finetune( + self, + name: str, + train_data: str, + s3_models_dir: str, + base_model: str, + eval_data: Optional[str] = None, + training_parameters: Dict[str, Any] = {}, # Optional, training algorithm specific hyper-parameters + role: Optional[str] = None, + ) -> None: + if not name: + raise ValueError("name must not be empty") + if not role: + raise ValueError("must provide a role ARN for bedrock finetuning (https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-iam-role.html)") + if not train_data.startswith("s3:"): + raise ValueError("train_data must point to an S3 location.") + if eval_data: + if not eval_data.startswith("s3:"): + raise ValueError("eval_data must point to an S3 location.") + validationDataConfig = { + "validators": [{ + "s3Uri": eval_data + }] + } + + job_name = f"{name}-job" + customization_job = self._service_client.create_model_customization_job( + jobName=job_name, + customModelName=name, + roleArn=role, + baseModelIdentifier=base_model, + trainingDataConfig={"s3Uri": train_data}, + validationDataConfig=validationDataConfig, + outputDataConfig={"s3Uri": s3_models_dir}, + hyperParameters=training_parameters + ) + return customization_job["jobArn"] + + + def summarize( + self, + text: str, + length: Optional[str] = "auto", + format_: Optional[str] = "auto", + # Only summarize-xlarge is supported on Sagemaker + # model: Optional[str] = "summarize-xlarge", + extractiveness: Optional[str] = "auto", + temperature: Optional[float] = 0.3, + additional_command: Optional[str] = "", + variant: Optional[str] = None + ) -> Summary: + + if self._endpoint_name is None: + raise CohereError("No endpoint connected. " + "Run connect_to_endpoint() first.") + + json_params = { + 'text': text, + 'length': length, + 'format': format_, + 'extractiveness': extractiveness, + 'temperature': temperature, + 'additional_command': additional_command, + } + for key, value in list(json_params.items()): + if value is None: + del json_params[key] + json_body = json.dumps(json_params) + + params = { + 'EndpointName': self._endpoint_name, + 'ContentType': 'application/json', + 'Body': json_body, + } + if variant is not None: + params['TargetVariant'] = variant + + try: + result = self._client.invoke_endpoint(**params) + response = json.loads(result['Body'].read().decode()) + summary = Summary(response) + except EndpointConnectionError as e: + raise CohereError(str(e)) + except Exception as e: + # TODO should be client error - distinct type from CohereError? + # ValidationError, e.g. when variant is bad + raise CohereError(str(e)) + + return summary + + + def delete_endpoint(self) -> None: + if self._endpoint_name is None: + raise CohereError("No endpoint connected.") + try: + self._service_client.delete_endpoint(EndpointName=self._endpoint_name) + except: + print("Endpoint not found, skipping deletion.") + + try: + self._service_client.delete_endpoint_config(EndpointConfigName=self._endpoint_name) + except: + print("Endpoint config not found, skipping deletion.") + + def close(self) -> None: + try: + self._client.close() + self._service_client.close() + except AttributeError: + print("SageMaker client could not be closed. This might be because you are using an old version of SageMaker.") + raise diff --git a/src/cohere/manually_maintained/cohere_aws/embeddings.py b/src/cohere/manually_maintained/cohere_aws/embeddings.py new file mode 100644 index 000000000..40de3e947 --- /dev/null +++ b/src/cohere/manually_maintained/cohere_aws/embeddings.py @@ -0,0 +1,26 @@ +from cohere_aws.response import CohereObject +from typing import Iterator, List + + +class Embedding(CohereObject): + + def __init__(self, embedding: List[float]) -> None: + self.embedding = embedding + + def __iter__(self) -> Iterator: + return iter(self.embedding) + + def __len__(self) -> int: + return len(self.embedding) + + +class Embeddings(CohereObject): + + def __init__(self, embeddings: List[Embedding]) -> None: + self.embeddings = embeddings + + def __iter__(self) -> Iterator: + return iter(self.embeddings) + + def __len__(self) -> int: + return len(self.embeddings) diff --git a/src/cohere/manually_maintained/cohere_aws/error.py b/src/cohere/manually_maintained/cohere_aws/error.py new file mode 100644 index 000000000..19b784c94 --- /dev/null +++ b/src/cohere/manually_maintained/cohere_aws/error.py @@ -0,0 +1,23 @@ +class CohereError(Exception): + def __init__( + self, + message=None, + http_status=None, + headers=None, + ) -> None: + super(CohereError, self).__init__(message) + + self.message = message + self.http_status = http_status + self.headers = headers or {} + + def __str__(self) -> str: + msg = self.message or '' + return msg + + def __repr__(self) -> str: + return '%s(message=%r, http_status=%r)' % ( + self.__class__.__name__, + self.message, + self.http_status, + ) diff --git a/src/cohere/manually_maintained/cohere_aws/generation.py b/src/cohere/manually_maintained/cohere_aws/generation.py new file mode 100644 index 000000000..0ddabdbcd --- /dev/null +++ b/src/cohere/manually_maintained/cohere_aws/generation.py @@ -0,0 +1,107 @@ +from cohere_aws.response import CohereObject +from cohere_aws.mode import Mode +from typing import List, Optional, NamedTuple, Generator, Dict, Any +import json + + +class TokenLikelihood(CohereObject): + def __init__(self, token: str, likelihood: float) -> None: + self.token = token + self.likelihood = likelihood + + +class Generation(CohereObject): + def __init__(self, + text: str, + token_likelihoods: List[TokenLikelihood]) -> None: + self.text = text + self.token_likelihoods = token_likelihoods + + +class Generations(CohereObject): + def __init__(self, + generations: List[Generation]) -> None: + self.generations = generations + self.iterator = iter(generations) + + @classmethod + def from_dict(cls, response: Dict[str, Any]) -> List[Generation]: + generations: List[Generation] = [] + for gen in response['generations']: + token_likelihoods = None + + if 'token_likelihoods' in gen: + token_likelihoods = [] + for likelihoods in gen['token_likelihoods']: + if 'likelihood' in likelihoods: + token_likelihood = likelihoods['likelihood'] + else: + token_likelihood = None + token_likelihoods.append(TokenLikelihood( + likelihoods['token'], token_likelihood)) + generations.append(Generation(gen['text'], token_likelihoods)) + return cls(generations) + + def __iter__(self) -> iter: + return self.iterator + + def __next__(self) -> next: + return next(self.iterator) + + +StreamingText = NamedTuple("StreamingText", + [("index", Optional[int]), + ("text", str), + ("is_finished", bool)]) + + +class StreamingGenerations(CohereObject): + def __init__(self, stream, mode): + self.stream = stream + self.id = None + self.generations = None + self.finish_reason = None + self.bytes = bytearray() + + if mode == Mode.SAGEMAKER: + self.payload_key = "PayloadPart" + self.bytes_key = "Bytes" + elif mode == Mode.BEDROCK: + self.payload_key = "chunk" + self.bytes_key = "bytes" + else: + raise CohereError("Unsupported mode") + + def _make_response_item(self, streaming_item) -> Optional[StreamingText]: + is_finished = streaming_item.get("is_finished") + + if not is_finished: + index = streaming_item.get("index", 0) + text = streaming_item.get("text") + if text is None: + return None + return StreamingText( + text=text, is_finished=is_finished, index=index) + + self.finish_reason = streaming_item.get("finish_reason") + generation_response = streaming_item.get("response") + + if generation_response is None: + return None + + self.id = generation_response.get("id") + self.generations = Generations.from_dict(generation_response) + return None + + def __iter__(self) -> Generator[StreamingText, None, None]: + for payload in self.stream: + self.bytes.extend(payload[self.payload_key][self.bytes_key]) + try: + item = self._make_response_item(json.loads(self.bytes)) + except json.decoder.JSONDecodeError: + # payload contained only a partion JSON object + continue + + self.bytes = bytearray() + if item is not None: + yield item diff --git a/src/cohere/manually_maintained/cohere_aws/mode.py b/src/cohere/manually_maintained/cohere_aws/mode.py new file mode 100644 index 000000000..8a6fe4749 --- /dev/null +++ b/src/cohere/manually_maintained/cohere_aws/mode.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class Mode(Enum): + SAGEMAKER = 1 + BEDROCK = 2 diff --git a/src/cohere/manually_maintained/cohere_aws/rerank.py b/src/cohere/manually_maintained/cohere_aws/rerank.py new file mode 100644 index 000000000..153ef21ae --- /dev/null +++ b/src/cohere/manually_maintained/cohere_aws/rerank.py @@ -0,0 +1,66 @@ +from typing import Any, Dict, Iterator, List, NamedTuple, Optional + +from cohere_aws.response import CohereObject + +RerankDocument = NamedTuple("Document", [("text", str)]) +RerankDocument.__doc__ = """ +Returned by co.rerank, +dict which always contains text but can also contain aribitrary fields +""" + + +class RerankResult(CohereObject): + + def __init__(self, + document: Dict[str, Any] = None, + index: int = None, + relevance_score: float = None, + *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.document = document + self.index = index + self.relevance_score = relevance_score + + def __repr__(self) -> str: + score = self.relevance_score + index = self.index + if self.document is None: + return f"RerankResult" + elif 'text' in self.document: + text = self.document['text'] + return f"RerankResult" + else: + return f"RerankResult" + + +class Reranking(CohereObject): + + def __init__(self, + response: Optional[Dict[str, Any]] = None, + **kwargs) -> None: + super().__init__(**kwargs) + assert response is not None + self.results = self._results(response) + + def _results(self, response: Dict[str, Any]) -> List[RerankResult]: + results = [] + for res in response['results']: + if 'document' in res.keys(): + results.append( + RerankResult(res['document'], res['index'], res['relevance_score'])) + else: + results.append( + RerankResult(index=res['index'], relevance_score=res['relevance_score'])) + return results + + def __str__(self) -> str: + return str(self.results) + + def __repr__(self) -> str: + return self.results.__repr__() + + def __iter__(self) -> Iterator: + return iter(self.results) + + def __getitem__(self, index) -> RerankResult: + return self.results[index] diff --git a/src/cohere/manually_maintained/cohere_aws/response.py b/src/cohere/manually_maintained/cohere_aws/response.py new file mode 100644 index 000000000..a46125f64 --- /dev/null +++ b/src/cohere/manually_maintained/cohere_aws/response.py @@ -0,0 +1,11 @@ +class CohereObject(): + def __repr__(self) -> str: + contents = '' + exclude_list = ['iterator'] + + for k in self.__dict__.keys(): + if k not in exclude_list: + contents += f'\t{k}: {self.__dict__[k]}\n' + + output = f'cohere.{type(self).__name__} {{\n{contents}}}' + return output diff --git a/src/cohere/manually_maintained/cohere_aws/summary.py b/src/cohere/manually_maintained/cohere_aws/summary.py new file mode 100644 index 000000000..f982214bc --- /dev/null +++ b/src/cohere/manually_maintained/cohere_aws/summary.py @@ -0,0 +1,16 @@ +from cohere_aws.error import CohereError +from cohere_aws.response import CohereObject +from typing import Any, Dict, Optional + + +class Summary(CohereObject): + def __init__(self, + response: Optional[Dict[str, Any]] = None) -> None: + assert response is not None + if not response["summary"]: + raise CohereError("Response lacks a summary") + + self.result = response["summary"] + + def __str__(self) -> str: + return self.result diff --git a/src/cohere/sagemaker_client.py b/src/cohere/sagemaker_client.py index 8e80439ed..41ebd1de7 100644 --- a/src/cohere/sagemaker_client.py +++ b/src/cohere/sagemaker_client.py @@ -1,11 +1,12 @@ import typing -from tokenizers import Tokenizer # type: ignore - from .aws_client import AwsClient +from .manually_maintained.cohere_aws.client import Client class SagemakerClient(AwsClient): + finetuning: Client + def __init__( self, *, @@ -24,3 +25,4 @@ def __init__( aws_region=aws_region, timeout=timeout, ) + self.finetuning = Client(region_name=self._aws_region) \ No newline at end of file From 89b8799b6700be31eb8f68acba5447fcd39bb0ec Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 30 Sep 2024 11:09:21 +0100 Subject: [PATCH 02/11] Bump version --- pyproject.toml | 2 +- src/cohere/core/client_wrapper.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f4370c6c2..46ed9405b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cohere" -version = "5.10.0" +version = "5.11.0a0" description = "" readme = "README.md" authors = [] diff --git a/src/cohere/core/client_wrapper.py b/src/cohere/core/client_wrapper.py index bad0ef86f..fd0ba72b0 100644 --- a/src/cohere/core/client_wrapper.py +++ b/src/cohere/core/client_wrapper.py @@ -24,7 +24,7 @@ def get_headers(self) -> typing.Dict[str, str]: headers: typing.Dict[str, str] = { "X-Fern-Language": "Python", "X-Fern-SDK-Name": "cohere", - "X-Fern-SDK-Version": "5.10.0", + "X-Fern-SDK-Version": "5.11.0a0", } if self._client_name is not None: headers["X-Client-Name"] = self._client_name From 83d55dc43877dfbac5965ab8898b68cde8caaea9 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 30 Sep 2024 11:12:06 +0100 Subject: [PATCH 03/11] Fixes --- mypy.ini | 2 ++ src/cohere/sagemaker_client.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) create mode 100644 mypy.ini diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..1330b2c18 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +exclude = src/cohere/manually_maintained/cohere_aws diff --git a/src/cohere/sagemaker_client.py b/src/cohere/sagemaker_client.py index 41ebd1de7..53ad4a8ef 100644 --- a/src/cohere/sagemaker_client.py +++ b/src/cohere/sagemaker_client.py @@ -5,7 +5,7 @@ class SagemakerClient(AwsClient): - finetuning: Client + sagemaker_finetuning: Client def __init__( self, @@ -25,4 +25,4 @@ def __init__( aws_region=aws_region, timeout=timeout, ) - self.finetuning = Client(region_name=self._aws_region) \ No newline at end of file + self.sagemaker_finetuning = Client(region_name=aws_region) \ No newline at end of file From c7965952ed2098fa66d6b7bd2ce801ec38a71016 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 30 Sep 2024 11:31:13 +0100 Subject: [PATCH 04/11] Fix imports --- .../manually_maintained/cohere_aws/chat.py | 6 +++--- .../cohere_aws/classification.py | 2 +- .../manually_maintained/cohere_aws/client.py | 16 ++++++++-------- .../manually_maintained/cohere_aws/embeddings.py | 2 +- .../manually_maintained/cohere_aws/generation.py | 4 ++-- .../manually_maintained/cohere_aws/rerank.py | 2 +- .../manually_maintained/cohere_aws/summary.py | 4 ++-- 7 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/cohere/manually_maintained/cohere_aws/chat.py b/src/cohere/manually_maintained/cohere_aws/chat.py index b16303442..a15bca43a 100644 --- a/src/cohere/manually_maintained/cohere_aws/chat.py +++ b/src/cohere/manually_maintained/cohere_aws/chat.py @@ -1,6 +1,6 @@ -from cohere_aws.response import CohereObject -from cohere_aws.error import CohereError -from cohere_aws.mode import Mode +from .response import CohereObject +from .error import CohereError +from .mode import Mode from typing import List, Optional, Generator, Dict, Any, Union from enum import Enum import json diff --git a/src/cohere/manually_maintained/cohere_aws/classification.py b/src/cohere/manually_maintained/cohere_aws/classification.py index a10371532..a090fdaa0 100644 --- a/src/cohere/manually_maintained/cohere_aws/classification.py +++ b/src/cohere/manually_maintained/cohere_aws/classification.py @@ -1,4 +1,4 @@ -from cohere_aws.response import CohereObject +from .response import CohereObject from typing import Any, Dict, Iterator, List, Literal, Union Prediction = Union[str, int, List[str], List[int]] diff --git a/src/cohere/manually_maintained/cohere_aws/client.py b/src/cohere/manually_maintained/cohere_aws/client.py index 1e74a8b2d..870cf6240 100644 --- a/src/cohere/manually_maintained/cohere_aws/client.py +++ b/src/cohere/manually_maintained/cohere_aws/client.py @@ -11,16 +11,16 @@ ParamValidationError) from sagemaker.s3 import S3Downloader, S3Uploader, parse_s3_url -from cohere_aws.classification import Classification, Classifications -from cohere_aws.embeddings import Embeddings -from cohere_aws.error import CohereError -from cohere_aws.generation import (Generation, Generations, +from .classification import Classification, Classifications +from .embeddings import Embeddings +from .error import CohereError +from .generation import (Generation, Generations, StreamingGenerations, TokenLikelihood) -from cohere_aws.chat import Chat, StreamingChat -from cohere_aws.rerank import Reranking -from cohere_aws.summary import Summary -from cohere_aws.mode import Mode +from .chat import Chat, StreamingChat +from .rerank import Reranking +from .summary import Summary +from .mode import Mode class Client: diff --git a/src/cohere/manually_maintained/cohere_aws/embeddings.py b/src/cohere/manually_maintained/cohere_aws/embeddings.py index 40de3e947..86b2043ff 100644 --- a/src/cohere/manually_maintained/cohere_aws/embeddings.py +++ b/src/cohere/manually_maintained/cohere_aws/embeddings.py @@ -1,4 +1,4 @@ -from cohere_aws.response import CohereObject +from .response import CohereObject from typing import Iterator, List diff --git a/src/cohere/manually_maintained/cohere_aws/generation.py b/src/cohere/manually_maintained/cohere_aws/generation.py index 0ddabdbcd..145877d0a 100644 --- a/src/cohere/manually_maintained/cohere_aws/generation.py +++ b/src/cohere/manually_maintained/cohere_aws/generation.py @@ -1,5 +1,5 @@ -from cohere_aws.response import CohereObject -from cohere_aws.mode import Mode +from .response import CohereObject +from .mode import Mode from typing import List, Optional, NamedTuple, Generator, Dict, Any import json diff --git a/src/cohere/manually_maintained/cohere_aws/rerank.py b/src/cohere/manually_maintained/cohere_aws/rerank.py index 153ef21ae..55a00c205 100644 --- a/src/cohere/manually_maintained/cohere_aws/rerank.py +++ b/src/cohere/manually_maintained/cohere_aws/rerank.py @@ -1,6 +1,6 @@ from typing import Any, Dict, Iterator, List, NamedTuple, Optional -from cohere_aws.response import CohereObject +from .response import CohereObject RerankDocument = NamedTuple("Document", [("text", str)]) RerankDocument.__doc__ = """ diff --git a/src/cohere/manually_maintained/cohere_aws/summary.py b/src/cohere/manually_maintained/cohere_aws/summary.py index f982214bc..bfacc722f 100644 --- a/src/cohere/manually_maintained/cohere_aws/summary.py +++ b/src/cohere/manually_maintained/cohere_aws/summary.py @@ -1,5 +1,5 @@ -from cohere_aws.error import CohereError -from cohere_aws.response import CohereObject +from .error import CohereError +from .response import CohereObject from typing import Any, Dict, Optional From 91d7cc1abfd2fd7fe6377f4a258caa55195cdde0 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 30 Sep 2024 11:41:37 +0100 Subject: [PATCH 05/11] Fix imports --- src/cohere/sagemaker_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cohere/sagemaker_client.py b/src/cohere/sagemaker_client.py index 53ad4a8ef..0d97740e2 100644 --- a/src/cohere/sagemaker_client.py +++ b/src/cohere/sagemaker_client.py @@ -2,6 +2,7 @@ from .aws_client import AwsClient from .manually_maintained.cohere_aws.client import Client +from .manually_maintained.cohere_aws.mode import Mode class SagemakerClient(AwsClient): @@ -25,4 +26,4 @@ def __init__( aws_region=aws_region, timeout=timeout, ) - self.sagemaker_finetuning = Client(region_name=aws_region) \ No newline at end of file + self.sagemaker_finetuning = Client(region_name=aws_region, mode=Mode.SAGEMAKER) \ No newline at end of file From 50aa0909ce2b51fc470f338daea5eff6b55689ca Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 30 Sep 2024 11:45:38 +0100 Subject: [PATCH 06/11] Fix region --- tests/test_aws_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_aws_client.py b/tests/test_aws_client.py index a6a60b4df..81789a679 100644 --- a/tests/test_aws_client.py +++ b/tests/test_aws_client.py @@ -40,7 +40,7 @@ "platform": "sagemaker", "client": cohere.SagemakerClient( timeout=10000, - aws_region="us-east-1", + aws_region="us-east-2", aws_access_key="...", aws_secret_key="...", aws_session_token="...", From 33c8b912eaa81724879becfc01fecbb50aa40eea Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 30 Sep 2024 11:57:50 +0100 Subject: [PATCH 07/11] Fixes --- .../manually_maintained/cohere_aws/client.py | 33 +++++++------------ src/cohere/sagemaker_client.py | 2 +- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/src/cohere/manually_maintained/cohere_aws/client.py b/src/cohere/manually_maintained/cohere_aws/client.py index 870cf6240..d759937db 100644 --- a/src/cohere/manually_maintained/cohere_aws/client.py +++ b/src/cohere/manually_maintained/cohere_aws/client.py @@ -21,33 +21,22 @@ from .rerank import Reranking from .summary import Summary from .mode import Mode +import typing - -class Client: - def __init__(self, endpoint_name: Optional[str] = None, - region_name: Optional[str] = None, - mode: Optional[Mode] = Mode.SAGEMAKER): +class Client: + def __init__( + self, + aws_region: typing.Optional[str] = None, + ): """ By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with `aws configure set region us-west-2` or override it with `region_name` parameter. """ - self._endpoint_name = endpoint_name # deprecated, should use self.connect_to_endpoint() instead - - if mode == Mode.SAGEMAKER: - self._client = boto3.client("sagemaker-runtime", region_name=region_name) - self._service_client = boto3.client("sagemaker", region_name=region_name) - self._sess = sage.Session(sagemaker_client=self._service_client) - elif mode == Mode.BEDROCK: - if not region_name: - region_name = boto3.Session().region_name - self._client = boto3.client( - service_name="bedrock-runtime", - region_name=region_name, - ) - self._service_client = boto3.client("bedrock", region_name=region_name) - else: - raise CohereError("Unsupported mode") - self.mode = mode + self._client = boto3.client("sagemaker-runtime", region_name=aws_region) + self._service_client = boto3.client("sagemaker", region_name=aws_region) + self._sess = sage.Session(sagemaker_client=self._service_client) + self.mode = Mode.SAGEMAKER + def _does_endpoint_exist(self, endpoint_name: str) -> bool: diff --git a/src/cohere/sagemaker_client.py b/src/cohere/sagemaker_client.py index 0d97740e2..6d4236d53 100644 --- a/src/cohere/sagemaker_client.py +++ b/src/cohere/sagemaker_client.py @@ -26,4 +26,4 @@ def __init__( aws_region=aws_region, timeout=timeout, ) - self.sagemaker_finetuning = Client(region_name=aws_region, mode=Mode.SAGEMAKER) \ No newline at end of file + self.sagemaker_finetuning = Client(aws_region=aws_region) \ No newline at end of file From 1b71abee8c216eb0478783ffeba1f5492bca97ef Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 30 Sep 2024 12:01:21 +0100 Subject: [PATCH 08/11] Fixes --- src/cohere/manually_maintained/cohere_aws/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cohere/manually_maintained/cohere_aws/client.py b/src/cohere/manually_maintained/cohere_aws/client.py index d759937db..8cd355576 100644 --- a/src/cohere/manually_maintained/cohere_aws/client.py +++ b/src/cohere/manually_maintained/cohere_aws/client.py @@ -34,7 +34,7 @@ def __init__( """ self._client = boto3.client("sagemaker-runtime", region_name=aws_region) self._service_client = boto3.client("sagemaker", region_name=aws_region) - self._sess = sage.Session(sagemaker_client=self._service_client) + self._sess = sage.Session(sagemaker_client=self._service_client, region_name=aws_region) self.mode = Mode.SAGEMAKER From 05e38410e9e962b66795951cebf58cabc442067e Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 30 Sep 2024 12:07:07 +0100 Subject: [PATCH 09/11] Sesh --- src/cohere/manually_maintained/cohere_aws/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cohere/manually_maintained/cohere_aws/client.py b/src/cohere/manually_maintained/cohere_aws/client.py index 8cd355576..31a9b2ec8 100644 --- a/src/cohere/manually_maintained/cohere_aws/client.py +++ b/src/cohere/manually_maintained/cohere_aws/client.py @@ -34,7 +34,7 @@ def __init__( """ self._client = boto3.client("sagemaker-runtime", region_name=aws_region) self._service_client = boto3.client("sagemaker", region_name=aws_region) - self._sess = sage.Session(sagemaker_client=self._service_client, region_name=aws_region) + self._sess = sage.Session(boto3.session.Session()) self.mode = Mode.SAGEMAKER From e89f5435f58603b353ebedaed7c6958bc726fdb4 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 30 Sep 2024 12:12:34 +0100 Subject: [PATCH 10/11] Patch AWS_DEFAULT_REGION --- src/cohere/manually_maintained/cohere_aws/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/cohere/manually_maintained/cohere_aws/client.py b/src/cohere/manually_maintained/cohere_aws/client.py index 31a9b2ec8..bacb3a97e 100644 --- a/src/cohere/manually_maintained/cohere_aws/client.py +++ b/src/cohere/manually_maintained/cohere_aws/client.py @@ -34,7 +34,9 @@ def __init__( """ self._client = boto3.client("sagemaker-runtime", region_name=aws_region) self._service_client = boto3.client("sagemaker", region_name=aws_region) - self._sess = sage.Session(boto3.session.Session()) + if os.environ.get('AWS_DEFAULT_REGION') is None: + os.environ['AWS_DEFAULT_REGION'] = aws_region + self._sess = sage.Session(sagemaker_client=self._service_client) self.mode = Mode.SAGEMAKER From 6bebeac6b40b29dc75b3f5905effa6fc0dd1cf49 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Mon, 30 Sep 2024 18:06:43 +0100 Subject: [PATCH 11/11] 5.11.0 --- pyproject.toml | 2 +- src/cohere/core/client_wrapper.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 46ed9405b..fa8e05378 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cohere" -version = "5.11.0a0" +version = "5.11.0" description = "" readme = "README.md" authors = [] diff --git a/src/cohere/core/client_wrapper.py b/src/cohere/core/client_wrapper.py index fd0ba72b0..c92dd1d42 100644 --- a/src/cohere/core/client_wrapper.py +++ b/src/cohere/core/client_wrapper.py @@ -24,7 +24,7 @@ def get_headers(self) -> typing.Dict[str, str]: headers: typing.Dict[str, str] = { "X-Fern-Language": "Python", "X-Fern-SDK-Name": "cohere", - "X-Fern-SDK-Version": "5.11.0a0", + "X-Fern-SDK-Version": "5.11.0", } if self._client_name is not None: headers["X-Client-Name"] = self._client_name