Skip to content

Commit

Permalink
tests and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
vinid committed Oct 6, 2024
1 parent 12c62c9 commit 4999f1f
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 39 deletions.
38 changes: 0 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,44 +139,6 @@ answer
> :white_check_mark: **answer: It will still take 1 hour to dry 30 shirts under the sun,**
> **assuming they are all laid out properly to receive equal sunlight.**
### Updates:

**29th Sept 2024**:

We are introducing a new engine based on [litellm](https:/BerriAI/litellm). This should allow
you to use any model you like, as long as it is supported by litellm. This means that now
**Bedrock, Together, Gemini and even more** are all supported by TextGrad!

In addition to this, with the new engines it should be easy to enable and disable caching.

We are in the process of testing these new engines and deprecating the old engines. If you have any issues, please let us know!

The new litellm engines can be loaded with the following code:

An example of loading a litellm engine:
```python
engine = get_engine("experimental:gpt-4o", cache=False)

# this also works with

set_backward_engine("experimental:gpt-4o", cache=False)
```

An example of forward pass:
```python

import httpx
from textgrad.engine_experimental.litellm import LiteLLMEngine

LiteLLMEngine("gpt-4o", cache=True).generate(content="hello, what's 3+4", system_prompt="you are an assistant")

image_url = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg"
image_data = httpx.get(image_url).content
```

In the examples folder you will find two new notebooks that show how to use the new engines.


We have many more examples around how TextGrad can optimize all kinds of variables -- code, solutions to problems, molecules, prompts, and all that!

### Tutorials
Expand Down
51 changes: 51 additions & 0 deletions tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,57 @@ def test_openai_engine():
os.environ['OPENAI_API_KEY'] = "fake_key"
engine = ChatOpenAI()


def test_set_backward_engine():
from textgrad.config import set_backward_engine, SingletonBackwardEngine
from textgrad.engine.openai import ChatOpenAI
from textgrad.engine_experimental.litellm import LiteLLMEngine

engine = ChatOpenAI()
set_backward_engine(engine, override=False)
assert SingletonBackwardEngine().get_engine() == engine

new_engine = LiteLLMEngine(model_string="gpt-3.5-turbo-0613")
set_backward_engine(new_engine, True)
assert SingletonBackwardEngine().get_engine() == new_engine

with pytest.raises(Exception):
set_backward_engine(engine, False)

def test_get_engine():
from textgrad.engine import get_engine
from textgrad.engine.openai import ChatOpenAI
from textgrad.engine_experimental.litellm import LiteLLMEngine

engine = get_engine("gpt-3.5-turbo-0613")
assert isinstance(engine, ChatOpenAI)

engine = get_engine("experimental:claude-3-opus-20240229")
assert isinstance(engine, LiteLLMEngine)

engine = get_engine("experimental:claude-3-opus-20240229", cache=True)
assert isinstance(engine, LiteLLMEngine)

engine = get_engine("experimental:claude-3-opus-20240229", cache=False)
assert isinstance(engine, LiteLLMEngine)

# get local diskcache
from diskcache import Cache
cache = Cache("./cache")

engine = get_engine("experimental:claude-3-opus-20240229", cache=cache)
assert isinstance(engine, LiteLLMEngine)

with pytest.raises(ValueError):
get_engine("invalid-engine")

with pytest.raises(ValueError):
get_engine("experimental:claude-3-opus-20240229", cache=[1,2,3])

with pytest.raises(ValueError):
get_engine("gpt-4o", cache=True)


# Test importing main components from textgrad
def test_import_main_components():
from textgrad import Variable, TextualGradientDescent, EngineLM
Expand Down
3 changes: 3 additions & 0 deletions textgrad/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM:
if "seed" in kwargs and "gpt-4" not in engine_name and "gpt-3.5" not in engine_name and "gpt-35" not in engine_name:
raise ValueError(f"Seed is currently supported only for OpenAI engines, not {engine_name}")

if "cache" in kwargs and "experimental" not in engine_name:
raise ValueError(f"Cache is currently supported only for LiteLLM engines, not {engine_name}")

# check if engine_name starts with "experimental:"
if engine_name.startswith("experimental:"):
engine_name = engine_name.split("experimental:")[1]
Expand Down
2 changes: 1 addition & 1 deletion textgrad/engine_experimental/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self,
model_string: str,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
is_multimodal: bool = True,
cache=Union[dc.Cache, bool]):
cache: Union[dc.Cache, bool] = False):

super().__init__(
model_string=model_string,
Expand Down

0 comments on commit 4999f1f

Please sign in to comment.