diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84acf2f471b..44c2ced28cc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,7 +45,6 @@ repos: ] additional_dependencies: [ - "types-requests", "pydantic", "overrides", "hypothesis", diff --git a/bin/windows_upgrade_sqlite.py b/bin/windows_upgrade_sqlite.py index 1b27011cd12..fde0841ab25 100644 --- a/bin/windows_upgrade_sqlite.py +++ b/bin/windows_upgrade_sqlite.py @@ -1,4 +1,4 @@ -import requests +import httpx import zipfile import io import os @@ -10,7 +10,7 @@ if __name__ == "__main__": # Download and extract the DLL - r = requests.get(DLL_URL) + r = httpx.get(DLL_URL) z = zipfile.ZipFile(io.BytesIO(r.content)) z.extractall(".") # Print current Python path diff --git a/chromadb/api/async_fastapi.py b/chromadb/api/async_fastapi.py index 349de99833c..7a742b5198d 100644 --- a/chromadb/api/async_fastapi.py +++ b/chromadb/api/async_fastapi.py @@ -2,11 +2,10 @@ from uuid import UUID import urllib.parse import orjson as json -from typing import Any, Optional, TypeVar, cast, Tuple, Sequence, Dict +from typing import Any, Optional, cast, Tuple, Sequence, Dict import logging import httpx from overrides import override -from chromadb import errors from chromadb.api import AsyncServerAPI from chromadb.api.base_http_client import BaseHTTPClient from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, Settings @@ -43,15 +42,6 @@ ) -# requests removes None values from the built query string, but httpx includes it as an empty value -T = TypeVar("T", bound=Dict[Any, Any]) - - -def clean_params(params: T) -> T: - """Remove None values from kwargs.""" - return {k: v for k, v in params.items() if v is not None} # type: ignore - - logger = logging.getLogger(__name__) @@ -130,8 +120,8 @@ async def _make_request( escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None) url = self._api_url + escaped_path - response = await self._get_client().request(method, url, **kwargs) - await raise_chroma_error(response) + response = await self._get_client().request(method, url, **cast(Any, kwargs)) + BaseHTTPClient._raise_chroma_error(response) return json.loads(response.text) @trace_method("AsyncFastAPI.heartbeat", OpenTelemetryGranularity.OPERATION) @@ -202,7 +192,7 @@ async def list_collections( resp_json = await self._make_request( "get", "/collections", - params=clean_params( + params=BaseHTTPClient._clean_params( { "tenant": tenant, "database": database, @@ -605,30 +595,3 @@ async def get_max_batch_size(self) -> int: resp_json = await self._make_request("get", "/pre-flight-checks") self._max_batch_size = cast(int, resp_json["max_batch_size"]) return self._max_batch_size - - -async def raise_chroma_error(resp: httpx.Response) -> Any: - """Raises an error if the response is not ok, using a ChromaError if possible.""" - try: - resp.raise_for_status() - return - except httpx.HTTPStatusError: - pass - - chroma_error = None - try: - body = json.loads(resp.text) - if "error" in body: - if body["error"] in errors.error_types: - chroma_error = errors.error_types[body["error"]](body["message"]) - - except BaseException: - pass - - if chroma_error: - raise chroma_error - - try: - resp.raise_for_status() - except httpx.HTTPStatusError: - raise (Exception(resp.text)) diff --git a/chromadb/api/base_http_client.py b/chromadb/api/base_http_client.py index 21590a6d1ac..eec8bc2e2e8 100644 --- a/chromadb/api/base_http_client.py +++ b/chromadb/api/base_http_client.py @@ -1,7 +1,10 @@ -from typing import Optional +from typing import Any, Dict, Optional, TypeVar from urllib.parse import quote, urlparse, urlunparse import logging +import orjson as json +import httpx +import chromadb.errors as errors from chromadb.config import Settings logger = logging.getLogger(__name__) @@ -57,3 +60,38 @@ def resolve_url( ) return full_url + + # requests removes None values from the built query string, but httpx includes it as an empty value + T = TypeVar("T", bound=Dict[Any, Any]) + + @staticmethod + def _clean_params(params: T) -> T: + """Remove None values from provided dict.""" + return {k: v for k, v in params.items() if v is not None} # type: ignore + + @staticmethod + def _raise_chroma_error(resp: httpx.Response) -> None: + """Raises an error if the response is not ok, using a ChromaError if possible.""" + try: + resp.raise_for_status() + return + except httpx.HTTPStatusError: + pass + + chroma_error = None + try: + body = json.loads(resp.text) + if "error" in body: + if body["error"] in errors.error_types: + chroma_error = errors.error_types[body["error"]](body["message"]) + + except BaseException: + pass + + if chroma_error: + raise chroma_error + + try: + resp.raise_for_status() + except httpx.HTTPStatusError: + raise (Exception(resp.text)) diff --git a/chromadb/api/client.py b/chromadb/api/client.py index 16a40cd5694..6d1d51f95ba 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -2,7 +2,7 @@ from uuid import UUID from overrides import override -import requests +import httpx from chromadb.api import AdminAPI, ClientAPI, ServerAPI from chromadb.api.shared_system_client import SharedSystemClient from chromadb.api.types import ( @@ -349,7 +349,7 @@ def set_database(self, database: str) -> None: def _validate_tenant_database(self, tenant: str, database: str) -> None: try: self._admin_client.get_tenant(name=tenant) - except requests.exceptions.ConnectionError: + except httpx.ConnectError: raise ValueError( "Could not connect to a Chroma server. Are you sure it is running?" ) @@ -363,7 +363,7 @@ def _validate_tenant_database(self, tenant: str, database: str) -> None: try: self._admin_client.get_database(name=database, tenant=tenant) - except requests.exceptions.ConnectionError: + except httpx.ConnectError: raise ValueError( "Could not connect to a Chroma server. Are you sure it is running?" ) diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index e2a31eed34d..18beec10a11 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -1,13 +1,13 @@ import orjson as json import logging -from typing import Optional, cast, Tuple +from typing import Any, Dict, Optional, cast, Tuple from typing import Sequence from uuid import UUID -import requests +import httpx +import urllib.parse from overrides import override from chromadb.api.base_http_client import BaseHTTPClient -import chromadb.errors as errors from chromadb.types import Database, Tenant import chromadb.utils.embedding_functions as ef from chromadb.api import ServerAPI @@ -62,13 +62,13 @@ def __init__(self, system: System): default_api_path=system.settings.chroma_server_api_default_path, ) - self._session = requests.Session() + self._session = httpx.Client(timeout=None) self._header = system.settings.chroma_server_headers if self._header is not None: self._session.headers.update(self._header) if self._settings.chroma_server_ssl_verify is not None: - self._session.verify = self._settings.chroma_server_ssl_verify + self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify) if system.settings.chroma_client_auth_provider: self._auth_provider = self.require(ClientAuthProvider) @@ -76,13 +76,21 @@ def __init__(self, system: System): for header, value in _headers.items(): self._session.headers[header] = value.get_secret_value() + def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any: + # Unlike requests, httpx does not automatically escape the path + escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None) + url = self._api_url + escaped_path + + response = self._session.request(method, url, **cast(Any, kwargs)) + BaseHTTPClient._raise_chroma_error(response) + return json.loads(response.text) + @trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION) @override def heartbeat(self) -> int: """Returns the current server time in nanoseconds to check if the server is alive""" - resp = self._session.get(self._api_url) - raise_chroma_error(resp) - return int(json.loads(resp.text)["nanosecond heartbeat"]) + resp_json = self._make_request("get", "/heartbeat") + return int(resp_json["nanosecond heartbeat"]) @trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION) @override @@ -92,12 +100,12 @@ def create_database( tenant: str = DEFAULT_TENANT, ) -> None: """Creates a database""" - resp = self._session.post( - self._api_url + "/databases", - data=json.dumps({"name": name}), + self._make_request( + "post", + "/databases", + json={"name": name}, params={"tenant": tenant}, ) - raise_chroma_error(resp) @trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION) @override @@ -107,12 +115,11 @@ def get_database( tenant: str = DEFAULT_TENANT, ) -> Database: """Returns a database""" - resp = self._session.get( - self._api_url + "/databases/" + name, + resp_json = self._make_request( + "get", + "/databases/" + name, params={"tenant": tenant}, ) - raise_chroma_error(resp) - resp_json = json.loads(resp.text) return Database( id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"] ) @@ -120,20 +127,12 @@ def get_database( @trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION) @override def create_tenant(self, name: str) -> None: - resp = self._session.post( - self._api_url + "/tenants", - data=json.dumps({"name": name}), - ) - raise_chroma_error(resp) + self._make_request("post", "/tenants", json={"name": name}) @trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION) @override def get_tenant(self, name: str) -> Tenant: - resp = self._session.get( - self._api_url + "/tenants/" + name, - ) - raise_chroma_error(resp) - resp_json = json.loads(resp.text) + resp_json = self._make_request("get", "/tenants/" + name) return Tenant(name=resp_json["name"]) @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) @@ -146,17 +145,18 @@ def list_collections( database: str = DEFAULT_DATABASE, ) -> Sequence[Collection]: """Returns a list of all collections""" - resp = self._session.get( - self._api_url + "/collections", - params={ - "tenant": tenant, - "database": database, - "limit": limit, - "offset": offset, - }, + json_collections = self._make_request( + "get", + "/collections", + params=BaseHTTPClient._clean_params( + { + "tenant": tenant, + "database": database, + "limit": limit, + "offset": offset, + } + ), ) - raise_chroma_error(resp) - json_collections = json.loads(resp.text) collections = [] for json_collection in json_collections: model = CollectionModel( @@ -178,12 +178,12 @@ def count_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE ) -> int: """Returns a count of collections""" - resp = self._session.get( - self._api_url + "/count_collections", + resp_json = self._make_request( + "get", + "/count_collections", params={"tenant": tenant, "database": database}, ) - raise_chroma_error(resp) - return cast(int, json.loads(resp.text)) + return cast(int, resp_json) @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION) @override @@ -200,19 +200,17 @@ def create_collection( database: str = DEFAULT_DATABASE, ) -> Collection: """Creates a collection""" - resp = self._session.post( - self._api_url + "/collections", - data=json.dumps( - { - "name": name, - "metadata": metadata, - "get_or_create": get_or_create, - } - ), + resp_json = self._make_request( + "post", + "/collections", + json={ + "name": name, + "metadata": metadata, + "get_or_create": get_or_create, + }, params={"tenant": tenant, "database": database}, ) - raise_chroma_error(resp) - resp_json = json.loads(resp.text) + model = CollectionModel( id=resp_json["id"], name=resp_json["name"], @@ -249,11 +247,13 @@ def get_collection( _params = {"tenant": tenant, "database": database} if id is not None: _params["type"] = str(id) - resp = self._session.get( - self._api_url + "/collections/" + name if name else str(id), params=_params + + resp_json = self._make_request( + "get", + "/collections/" + name if name else str(id), + params=_params, ) - raise_chroma_error(resp) - resp_json = json.loads(resp.text) + model = CollectionModel( id=resp_json["id"], name=resp_json["name"], @@ -285,17 +285,14 @@ def get_or_create_collection( tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: - return cast( - Collection, - self.create_collection( - name=name, - metadata=metadata, - embedding_function=embedding_function, - data_loader=data_loader, - get_or_create=True, - tenant=tenant, - database=database, - ), + return self.create_collection( + name=name, + metadata=metadata, + embedding_function=embedding_function, + data_loader=data_loader, + get_or_create=True, + tenant=tenant, + database=database, ) @trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION) @@ -307,11 +304,11 @@ def _modify( new_metadata: Optional[CollectionMetadata] = None, ) -> None: """Updates a collection""" - resp = self._session.put( - self._api_url + "/collections/" + str(id), - data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}), + self._make_request( + "put", + "/collections/" + str(id), + json={"new_metadata": new_metadata, "new_name": new_name}, ) - raise_chroma_error(resp) @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION) @override @@ -322,11 +319,11 @@ def delete_collection( database: str = DEFAULT_DATABASE, ) -> None: """Deletes a collection""" - resp = self._session.delete( - self._api_url + "/collections/" + name, + self._make_request( + "delete", + "/collections/" + name, params={"tenant": tenant, "database": database}, ) - raise_chroma_error(resp) @trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION) @override @@ -335,11 +332,11 @@ def _count( collection_id: UUID, ) -> int: """Returns the number of embeddings in the database""" - resp = self._session.get( - self._api_url + "/collections/" + str(collection_id) + "/count" + resp_json = self._make_request( + "get", + "/collections/" + str(collection_id) + "/count", ) - raise_chroma_error(resp) - return cast(int, json.loads(resp.text)) + return cast(int, resp_json) @trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION) @override @@ -376,31 +373,28 @@ def _get( offset = (page - 1) * page_size limit = page_size - resp = self._session.post( - self._api_url + "/collections/" + str(collection_id) + "/get", - data=json.dumps( - { - "ids": ids, - "where": where, - "sort": sort, - "limit": limit, - "offset": offset, - "where_document": where_document, - "include": include, - } - ), + resp_json = self._make_request( + "post", + "/collections/" + str(collection_id) + "/get", + json={ + "ids": ids, + "where": where, + "sort": sort, + "limit": limit, + "offset": offset, + "where_document": where_document, + "include": include, + }, ) - raise_chroma_error(resp) - body = json.loads(resp.text) return GetResult( - ids=body["ids"], - embeddings=body.get("embeddings", None), - metadatas=body.get("metadatas", None), - documents=body.get("documents", None), + ids=resp_json["ids"], + embeddings=resp_json.get("embeddings", None), + metadatas=resp_json.get("metadatas", None), + documents=resp_json.get("documents", None), data=None, - uris=body.get("uris", None), - included=body["included"], + uris=resp_json.get("uris", None), + included=resp_json["included"], ) @trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION) @@ -413,15 +407,16 @@ def _delete( where_document: Optional[WhereDocument] = {}, ) -> IDs: """Deletes embeddings from the database""" - resp = self._session.post( - self._api_url + "/collections/" + str(collection_id) + "/delete", - data=json.dumps( - {"where": where, "ids": ids, "where_document": where_document} - ), + resp_json = self._make_request( + "post", + "/collections/" + str(collection_id) + "/delete", + json={ + "ids": ids, + "where": where, + "where_document": where_document, + }, ) - - raise_chroma_error(resp) - return cast(IDs, json.loads(resp.text)) + return cast(IDs, resp_json) @trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL) def _submit_batch( @@ -434,23 +429,21 @@ def _submit_batch( Optional[URIs], ], url: str, - ) -> requests.Response: + ) -> None: """ Submits a batch of embeddings to the database """ - resp = self._session.post( - self._api_url + url, - data=json.dumps( - { - "ids": batch[0], - "embeddings": batch[1], - "metadatas": batch[2], - "documents": batch[3], - "uris": batch[4], - } - ), + self._make_request( + "post", + url, + json={ + "ids": batch[0], + "embeddings": batch[1], + "metadatas": batch[2], + "documents": batch[3], + "uris": batch[4], + }, ) - return resp @trace_method("FastAPI._add", OpenTelemetryGranularity.ALL) @override @@ -469,8 +462,7 @@ def _add( """ batch = (ids, embeddings, metadatas, documents, uris) validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) - resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add") - raise_chroma_error(resp) + self._submit_batch(batch, "/collections/" + str(collection_id) + "/add") return True @trace_method("FastAPI._update", OpenTelemetryGranularity.ALL) @@ -490,10 +482,7 @@ def _update( """ batch = (ids, embeddings, metadatas, documents, uris) validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) - resp = self._submit_batch( - batch, "/collections/" + str(collection_id) + "/update" - ) - raise_chroma_error(resp) + self._submit_batch(batch, "/collections/" + str(collection_id) + "/update") return True @trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL) @@ -513,10 +502,7 @@ def _upsert( """ batch = (ids, embeddings, metadatas, documents, uris) validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) - resp = self._submit_batch( - batch, "/collections/" + str(collection_id) + "/upsert" - ) - raise_chroma_error(resp) + self._submit_batch(batch, "/collections/" + str(collection_id) + "/upsert") return True @trace_method("FastAPI._query", OpenTelemetryGranularity.ALL) @@ -531,48 +517,42 @@ def _query( include: Include = ["metadatas", "documents", "distances"], ) -> QueryResult: """Gets the nearest neighbors of a single embedding""" - resp = self._session.post( - self._api_url + "/collections/" + str(collection_id) + "/query", - data=json.dumps( - { - "query_embeddings": query_embeddings, - "n_results": n_results, - "where": where, - "where_document": where_document, - "include": include, - } - ), + resp_json = self._make_request( + "post", + "/collections/" + str(collection_id) + "/query", + json={ + "query_embeddings": query_embeddings, + "n_results": n_results, + "where": where, + "where_document": where_document, + "include": include, + }, ) - raise_chroma_error(resp) - body = json.loads(resp.text) - return QueryResult( - ids=body["ids"], - distances=body.get("distances", None), - embeddings=body.get("embeddings", None), - metadatas=body.get("metadatas", None), - documents=body.get("documents", None), - uris=body.get("uris", None), + ids=resp_json["ids"], + distances=resp_json.get("distances", None), + embeddings=resp_json.get("embeddings", None), + metadatas=resp_json.get("metadatas", None), + documents=resp_json.get("documents", None), + uris=resp_json.get("uris", None), data=None, - included=body["included"], + included=resp_json["included"], ) @trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL) @override def reset(self) -> bool: """Resets the database""" - resp = self._session.post(self._api_url + "/reset") - raise_chroma_error(resp) - return cast(bool, json.loads(resp.text)) + resp_json = self._make_request("post", "/reset") + return cast(bool, resp_json) @trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION) @override def get_version(self) -> str: """Returns the version of the server""" - resp = self._session.get(self._api_url + "/version") - raise_chroma_error(resp) - return cast(str, json.loads(resp.text)) + resp_json = self._make_request("get", "/version") + return cast(str, resp_json) @override def get_settings(self) -> Settings: @@ -583,31 +563,6 @@ def get_settings(self) -> Settings: @override def get_max_batch_size(self) -> int: if self._max_batch_size == -1: - resp = self._session.get(self._api_url + "/pre-flight-checks") - raise_chroma_error(resp) - self._max_batch_size = cast(int, json.loads(resp.text)["max_batch_size"]) + resp_json = self._make_request("get", "/pre-flight-checks") + self._max_batch_size = cast(int, resp_json["max_batch_size"]) return self._max_batch_size - - -def raise_chroma_error(resp: requests.Response) -> None: - """Raises an error if the response is not ok, using a ChromaError if possible""" - if resp.ok: - return - - chroma_error = None - try: - body = json.loads(resp.text) - if "error" in body: - if body["error"] in errors.error_types: - chroma_error = errors.error_types[body["error"]](body["message"]) - - except BaseException: - pass - - if chroma_error: - raise chroma_error - - try: - resp.raise_for_status() - except requests.HTTPError: - raise (Exception(resp.text)) diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 3f5c7ec8468..4768f49b880 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -20,7 +20,6 @@ import hypothesis import pytest import uvicorn -from requests.exceptions import ConnectionError from httpx import ConnectError from typing_extensions import Protocol @@ -233,9 +232,7 @@ def _run_server( def _await_server(api: ServerAPI, attempts: int = 0) -> None: try: api.heartbeat() - # First error is from requests, second is from httpx - # todo: use httpx for both? - except (ConnectionError, ConnectError) as e: + except ConnectError as e: if attempts > 15: raise e else: diff --git a/chromadb/test/ef/test_ollama_ef.py b/chromadb/test/ef/test_ollama_ef.py index d44f1e8e6d1..413bf091bf7 100644 --- a/chromadb/test/ef/test_ollama_ef.py +++ b/chromadb/test/ef/test_ollama_ef.py @@ -1,9 +1,8 @@ import os import pytest -import requests -from requests import HTTPError -from requests.exceptions import ConnectionError +import httpx +from httpx import HTTPError, ConnectError from chromadb.utils.embedding_functions import OllamaEmbeddingFunction @@ -21,10 +20,10 @@ def test_ollama() -> None: "OLLAMA_SERVER_URL or OLLAMA_MODEL environment variable not set. Skipping test." ) try: - response = requests.get(os.environ.get("OLLAMA_SERVER_URL", "")) + response = httpx.get(os.environ.get("OLLAMA_SERVER_URL", "")) # If the response was successful, no Exception will be raised response.raise_for_status() - except (HTTPError, ConnectionError): + except (HTTPError, ConnectError): pytest.skip("Ollama server not running. Skipping test.") ef = OllamaEmbeddingFunction( model_name=os.environ.get("OLLAMA_MODEL") or "nomic-embed-text", diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 7331ce11bff..48574504796 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -1,7 +1,6 @@ # type: ignore import traceback -import requests -from urllib3.connectionpool import InsecureRequestWarning +import httpx import chromadb from chromadb.api.fastapi import FastAPI @@ -203,7 +202,7 @@ def test_pre_flight_checks(api): if not isinstance(api, FastAPI): pytest.skip("Not a FastAPI instance") - resp = requests.get(f"{api._api_url}/pre-flight-checks") + resp = httpx.get(f"{api._api_url}/pre-flight-checks") assert resp.status_code == 200 assert resp.json() is not None assert "max_batch_size" in resp.json().keys() @@ -1614,19 +1613,3 @@ def test_ssl_self_signed_without_ssl_verify(client_ssl): ) client_ssl.clear_system_cache() assert "CERTIFICATE_VERIFY_FAILED" in "".join(stack_trace) - - -def test_ssl_self_signed_with_verify_false(client_ssl): - if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY"): - pytest.skip("Skipping test for integration test") - client_ssl.heartbeat() - _port = client_ssl._server._settings.chroma_server_http_port - with pytest.warns(InsecureRequestWarning) as record: - client = chromadb.HttpClient( - ssl=True, - port=_port, - settings=chromadb.Settings(chroma_server_ssl_verify=False), - ) - client.heartbeat() - client_ssl.clear_system_cache() - assert "Unverified HTTPS request" in str(record[0].message) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index ef8421580dc..737a26909a2 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -20,7 +20,7 @@ from pathlib import Path import os import tarfile -import requests +import httpx from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union, cast import numpy as np import numpy.typing as npt @@ -265,7 +265,7 @@ def __init__( model_name (str, optional): The name of the model to use for text embeddings. Defaults to "sentence-transformers/all-MiniLM-L6-v2". """ self._api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}" - self._session = requests.Session() + self._session = httpx.Client() self._session.headers.update({"Authorization": f"Bearer {api_key}"}) def __call__(self, input: Documents) -> Embeddings: @@ -309,7 +309,7 @@ def __init__(self, api_key: str, model_name: str = "jina-embeddings-v2-base-en") """ self._model_name = model_name self._api_url = "https://api.jina.ai/v1/embeddings" - self._session = requests.Session() + self._session = httpx.Client() self._session.headers.update( {"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"} ) @@ -434,18 +434,18 @@ def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: retry=retry_if_exception(lambda e: "does not match expected SHA256" in str(e)), ) def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: - resp = requests.get(url, stream=True) - total = int(resp.headers.get("content-length", 0)) - with open(fname, "wb") as file, self.tqdm( - desc=str(fname), - total=total, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as bar: - for data in resp.iter_content(chunk_size=chunk_size): - size = file.write(data) - bar.update(size) + with httpx.stream("GET", url) as resp: + total = int(resp.headers.get("content-length", 0)) + with open(fname, "wb") as file, self.tqdm( + desc=str(fname), + total=total, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in resp.iter_bytes(chunk_size=chunk_size): + size = file.write(data) + bar.update(size) if not _verify_sha256(fname, self._MODEL_SHA256): # if the integrity of the file is not verified, remove it os.remove(fname) @@ -666,7 +666,7 @@ def __init__( region: str = "us-central1", ): self._api_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/goole/models/{model_name}:predict" - self._session = requests.Session() + self._session = httpx.Client() self._session.headers.update({"Authorization": f"Bearer {api_key}"}) def __call__(self, input: Documents) -> Embeddings: @@ -782,7 +782,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: }, } - res = requests.post( + res = httpx.post( f"{self._api_url}/clip/embed_image?api_key={self._api_key}", json=infer_clip_payload, ) @@ -796,7 +796,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: "text": input, } - res = requests.post( + res = httpx.post( f"{self._api_url}/clip/embed_text?api_key={self._api_key}", json=infer_clip_payload, ) @@ -869,13 +869,13 @@ def __init__(self, url: str): url (str): The URL of the HuggingFace Embedding Server. """ try: - import requests + import httpx except ImportError: raise ValueError( - "The requests python package is not installed. Please install it with `pip install requests`" + "The httpx python package is not installed. Please install it with `pip install httpx`" ) self._api_url = f"{url}" - self._session = requests.Session() + self._session = httpx.Client() def __call__(self, input: Documents) -> Embeddings: """ @@ -976,14 +976,14 @@ def __init__(self, url: str, model_name: str) -> None: model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see https://ollama.com/library for available models). """ try: - import requests + import httpx except ImportError: raise ValueError( - "The requests python package is not installed. Please install it with `pip install requests`" + "The httpx python package is not installed. Please install it with `pip install httpx`" ) self._api_url = f"{url}" self._model_name = model_name - self._session = requests.Session() + self._session = httpx.Client() def __call__(self, input: Documents) -> Embeddings: """ diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 028d8fd1824..c7e1486d0c1 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -3,16 +3,16 @@ name = "chromadb-client" dynamic = ["version"] authors = [ - { name="Jeff Huber", email="jeff@trychroma.com" }, - { name="Anton Troynikov", email="anton@trychroma.com" } + { name = "Jeff Huber", email = "jeff@trychroma.com" }, + { name = "Anton Troynikov", email = "anton@trychroma.com" }, ] description = "Chroma Client." readme = "README.md" requires-python = ">=3.8" classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", ] dependencies = [ 'numpy >= 1.22.5, < 2.0.0', @@ -22,16 +22,16 @@ dependencies = [ 'overrides >= 7.3.1', 'posthog >= 2.4.0', 'pydantic>=1.9', - 'requests >= 2.28', 'typing_extensions >= 4.5.0', 'tenacity>=8.2.3', 'PyYAML>=6.0.0', 'orjson>=3.9.12', + 'httpx>=0.27.0', ] [tool.black] line-length = 88 -required-version = "23.3.0" # Black will refuse to run if it's not this version. +required-version = "23.3.0" # Black will refuse to run if it's not this version. target-version = ['py38', 'py39', 'py310', 'py311'] [tool.pytest.ini_options] @@ -46,7 +46,7 @@ requires = ["setuptools>=61.0", "setuptools_scm[toml]>=6.2"] build-backend = "setuptools.build_meta" [tool.setuptools_scm] -local_scheme="no-local-version" +local_scheme = "no-local-version" [tool.setuptools] packages = ["chromadb"] diff --git a/clients/python/requirements.txt b/clients/python/requirements.txt index 057af2f963c..0c83eb72520 100644 --- a/clients/python/requirements.txt +++ b/clients/python/requirements.txt @@ -1,3 +1,4 @@ +httpx>=0.27.0 numpy >= 1.22.5, < 2.0.0 opentelemetry-api>=1.2.0 opentelemetry-exporter-otlp-proto-grpc>=1.2.0 @@ -7,6 +8,5 @@ overrides >= 7.3.1 posthog >= 2.4.0 pydantic>=1.9 PyYAML>=6.0.0 -requests >= 2.28 tenacity>=8.2.3 typing_extensions >= 4.5.0 diff --git a/docs/docs.trychroma.com/pages/integrations/hugging-face-server.md b/docs/docs.trychroma.com/pages/integrations/hugging-face-server.md index f43c16dfda3..e89b90fce26 100644 --- a/docs/docs.trychroma.com/pages/integrations/hugging-face-server.md +++ b/docs/docs.trychroma.com/pages/integrations/hugging-face-server.md @@ -34,8 +34,6 @@ The above docker command will run the server with the `BAAI/bge-small-en-v1.5` m {% tabs group="code-lang" hideTabs=true %} {% tab label="Python" %} -This embedding function relies on the `requests` python package, which you can install with `pip install requests`. - ```python from chromadb.utils.embedding_functions import HuggingFaceEmbeddingServer huggingface_ef = HuggingFaceEmbeddingServer(url="http://localhost:8001/embed") diff --git a/docs/docs.trychroma.com/pages/integrations/hugging-face.md b/docs/docs.trychroma.com/pages/integrations/hugging-face.md index 418913b1e79..37e15581a2c 100644 --- a/docs/docs.trychroma.com/pages/integrations/hugging-face.md +++ b/docs/docs.trychroma.com/pages/integrations/hugging-face.md @@ -14,8 +14,6 @@ Chroma also provides a convenient wrapper around HuggingFace's embedding API. Th {% tabs group="code-lang" hideTabs=true %} {% tab label="Python" %} -This embedding function relies on the `requests` python package, which you can install with `pip install requests`. - ```python import chromadb.utils.embedding_functions as embedding_functions huggingface_ef = embedding_functions.HuggingFaceEmbeddingFunction( diff --git a/docs/docs.trychroma.com/pages/integrations/jinaai.md b/docs/docs.trychroma.com/pages/integrations/jinaai.md index f2377a39718..cce47c7c81f 100644 --- a/docs/docs.trychroma.com/pages/integrations/jinaai.md +++ b/docs/docs.trychroma.com/pages/integrations/jinaai.md @@ -14,8 +14,6 @@ Chroma provides a convenient wrapper around JinaAI's embedding API. This embeddi {% tabs group="code-lang" hideTabs=true %} {% tab label="Python" %} -This embedding function relies on the `requests` python package, which you can install with `pip install requests`. - ```python import chromadb.utils.embedding_functions as embedding_functions jinaai_ef = embedding_functions.JinaEmbeddingFunction( diff --git a/pyproject.toml b/pyproject.toml index f4e35264e03..ffc6184c716 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ classifiers = [ ] dependencies = [ 'build >= 1.0.3', - 'requests >= 2.28', 'pydantic >= 1.9', 'chroma-hnswlib==0.7.3', 'fastapi >= 0.95.2', diff --git a/requirements.txt b/requirements.txt index 413ae4b3cbd..f4b98114b1c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,6 @@ posthog>=2.4.0 pydantic>=1.9 pypika>=0.48.9 PyYAML>=6.0.0 -requests>=2.28.1 tenacity>=8.2.3 tokenizers>=0.13.2 tqdm>=4.65.0 diff --git a/requirements_dev.txt b/requirements_dev.txt index c1612f7db7e..53d311409fe 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -11,4 +11,3 @@ pytest-asyncio pytest-xdist setuptools_scm types-protobuf -types-requests==2.30.0.0