diff --git a/README.md b/README.md index ed585b1..11e1a57 100644 --- a/README.md +++ b/README.md @@ -139,44 +139,6 @@ answer > :white_check_mark: **answer: It will still take 1 hour to dry 30 shirts under the sun,** > **assuming they are all laid out properly to receive equal sunlight.** -### Updates: - -**29th Sept 2024**: - -We are introducing a new engine based on [litellm](https://github.com/BerriAI/litellm). This should allow -you to use any model you like, as long as it is supported by litellm. This means that now -**Bedrock, Together, Gemini and even more** are all supported by TextGrad! - -In addition to this, with the new engines it should be easy to enable and disable caching. - -We are in the process of testing these new engines and deprecating the old engines. If you have any issues, please let us know! - -The new litellm engines can be loaded with the following code: - -An example of loading a litellm engine: -```python -engine = get_engine("experimental:gpt-4o", cache=False) - -# this also works with - -set_backward_engine("experimental:gpt-4o", cache=False) -``` - -An example of forward pass: -```python - -import httpx -from textgrad.engine_experimental.litellm import LiteLLMEngine - -LiteLLMEngine("gpt-4o", cache=True).generate(content="hello, what's 3+4", system_prompt="you are an assistant") - -image_url = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg" -image_data = httpx.get(image_url).content -``` - -In the examples folder you will find two new notebooks that show how to use the new engines. - - We have many more examples around how TextGrad can optimize all kinds of variables -- code, solutions to problems, molecules, prompts, and all that! ### Tutorials diff --git a/tests/test_basics.py b/tests/test_basics.py index dba4fa6..0f06b8c 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -68,6 +68,57 @@ def test_openai_engine(): os.environ['OPENAI_API_KEY'] = "fake_key" engine = ChatOpenAI() + +def test_set_backward_engine(): + from textgrad.config import set_backward_engine, SingletonBackwardEngine + from textgrad.engine.openai import ChatOpenAI + from textgrad.engine_experimental.litellm import LiteLLMEngine + + engine = ChatOpenAI() + set_backward_engine(engine, override=False) + assert SingletonBackwardEngine().get_engine() == engine + + new_engine = LiteLLMEngine(model_string="gpt-3.5-turbo-0613") + set_backward_engine(new_engine, True) + assert SingletonBackwardEngine().get_engine() == new_engine + + with pytest.raises(Exception): + set_backward_engine(engine, False) + +def test_get_engine(): + from textgrad.engine import get_engine + from textgrad.engine.openai import ChatOpenAI + from textgrad.engine_experimental.litellm import LiteLLMEngine + + engine = get_engine("gpt-3.5-turbo-0613") + assert isinstance(engine, ChatOpenAI) + + engine = get_engine("experimental:claude-3-opus-20240229") + assert isinstance(engine, LiteLLMEngine) + + engine = get_engine("experimental:claude-3-opus-20240229", cache=True) + assert isinstance(engine, LiteLLMEngine) + + engine = get_engine("experimental:claude-3-opus-20240229", cache=False) + assert isinstance(engine, LiteLLMEngine) + + # get local diskcache + from diskcache import Cache + cache = Cache("./cache") + + engine = get_engine("experimental:claude-3-opus-20240229", cache=cache) + assert isinstance(engine, LiteLLMEngine) + + with pytest.raises(ValueError): + get_engine("invalid-engine") + + with pytest.raises(ValueError): + get_engine("experimental:claude-3-opus-20240229", cache=[1,2,3]) + + with pytest.raises(ValueError): + get_engine("gpt-4o", cache=True) + + # Test importing main components from textgrad def test_import_main_components(): from textgrad import Variable, TextualGradientDescent, EngineLM diff --git a/textgrad/engine/__init__.py b/textgrad/engine/__init__.py index ea7c6c1..2697faf 100644 --- a/textgrad/engine/__init__.py +++ b/textgrad/engine/__init__.py @@ -35,6 +35,9 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM: if "seed" in kwargs and "gpt-4" not in engine_name and "gpt-3.5" not in engine_name and "gpt-35" not in engine_name: raise ValueError(f"Seed is currently supported only for OpenAI engines, not {engine_name}") + if "cache" in kwargs and "experimental" not in engine_name: + raise ValueError(f"Cache is currently supported only for LiteLLM engines, not {engine_name}") + # check if engine_name starts with "experimental:" if engine_name.startswith("experimental:"): engine_name = engine_name.split("experimental:")[1] diff --git a/textgrad/engine_experimental/litellm.py b/textgrad/engine_experimental/litellm.py index 48b4c10..e28845c 100644 --- a/textgrad/engine_experimental/litellm.py +++ b/textgrad/engine_experimental/litellm.py @@ -25,7 +25,7 @@ def __init__(self, model_string: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT, is_multimodal: bool = True, - cache=Union[dc.Cache, bool]): + cache: Union[dc.Cache, bool] = False): super().__init__( model_string=model_string,