-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
97 lines (80 loc) · 3.29 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import base64
import urllib.parse
from typing import List
import os
import requests
from io import BytesIO
from PIL import Image
from cog import BasePredictor, Input, File, emit_metric, Path
from api_client import APIClient
class Predictor(BasePredictor):
client: APIClient
def setup(self, weights: str) -> None:
if not weights:
raise ValueError(
"API token must be provided. "
"Set COG_WEIGHTS environment variable to a "
"base64-encoded data URI containing the API key."
)
parsed_uri = urllib.parse.urlparse(weights)
if not parsed_uri.scheme == "data":
raise ValueError(
"Invalid data URI. Expected a data URI with a base64-encoded API key."
)
_, data = parsed_uri.path.split(",", 1)
try:
api_key = base64.b64decode(data).decode("utf-8")
except Exception as e:
raise ValueError(f"Failed to decode API key: {str(e)}") from e
self.client = APIClient(api_key)
def aspect_ratio_to_width_height(self, aspect_ratio: str):
aspect_ratios = {
"1:1": (1024, 1024),
"16:9": (1344, 768),
"3:2": (1216, 832),
"2:3": (832, 1216),
"4:5": (896, 1088),
"5:4": (1088, 896),
"9:16": (768, 1344),
}
return aspect_ratios.get(aspect_ratio)
async def predict(
self,
prompt: str = Input(description="Text prompt for image generation"),
aspect_ratio: str = Input(
description="Aspect ratio for the generated image",
choices=["1:1", "16:9", "2:3", "3:2", "4:5", "5:4", "9:16"],
default="1:1",
),
steps: int = Input(
description="Number of diffusion steps", ge=1, le=50, default=25
),
guidance: float = Input(description="Controls the balance between adherence to the text prompt and image quality/diversity. Higher values make the output more closely match the prompt but may reduce overall image quality. Lower values allow for more creative freedom but might produce results less relevant to the prompt.", ge=2, le=5, default=3),
seed: int = Input(description="Random seed. Set for reproducible generation", default=None)
) -> Path:
if not seed:
seed = int.from_bytes(os.urandom(2), "big")
self.log(f"Using seed: {seed}\n")
try:
self.log("Running prediction... \n")
width, height = self.aspect_ratio_to_width_height(aspect_ratio)
image_url = await self.client.predict(
prompt=prompt,
width=width,
height=height,
steps=steps,
guidance=guidance,
seed=seed,
log=self.log
)
req = requests.get(image_url)
img = Image.open(BytesIO(req.content))
img_path = "./output.jpg"
img.save(img_path)
emit_metric("width", width)
emit_metric("height", height)
emit_metric("steps", steps)
emit_metric("num_images", 1)
return Path(img_path)
except Exception as e:
raise ValueError(f"Error generating image: {str(e)}") from e