Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better error handling of filtered content #89

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions BotNet.Services/BotCommands/Art.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ await botClient.SendPhotoAsync(
photo: new InputFileStream(imageStream, "art.png"),
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
} catch (ContentFilteredException) {
await botClient.SendTextMessageAsync(
} catch (ContentFilteredException exc) {
await botClient.EditMessageTextAsync(
chatId: message.Chat.Id,
text: "<code>Content filtered</code>",
messageId: busyMessage.MessageId,
text: $"<code>{exc.Message ?? "Content filtered."}</code>",
parseMode: ParseMode.Html,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
} catch {
await botClient.SendTextMessageAsync(
await botClient.EditMessageTextAsync(
chatId: message.Chat.Id,
messageId: busyMessage.MessageId,
text: "<code>Could not generate art</code>",
parseMode: ParseMode.Html,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
}
} catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) {
Expand Down
13 changes: 9 additions & 4 deletions BotNet.Services/BotCommands/OpenAI.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
Expand All @@ -13,7 +12,6 @@
using BotNet.Services.OpenAI.Skills;
using BotNet.Services.RateLimit;
using BotNet.Services.Stability.Models;
using BotNet.Services.Stability.Skills;
using Microsoft.Extensions.DependencyInjection;
using RG.Ninja;
using SkiaSharp;
Expand Down Expand Up @@ -873,11 +871,18 @@
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);
} catch (ContentFilteredException) {
} catch (ContentFilteredException exc) {
await botClient.EditMessageTextAsync(
chatId: busyMessage.Chat.Id,
messageId: busyMessage.MessageId,
text: "<code>Content filtered.</code>",
text: $"<code>{exc.Message ?? "Content filtered."}</code>",
parseMode: ParseMode.Html
);
} catch {
await botClient.EditMessageTextAsync(
chatId: busyMessage.Chat.Id,
messageId: busyMessage.MessageId,
text: "<code>Failed to generate image.</code>",
parseMode: ParseMode.Html
);
}
Expand Down Expand Up @@ -911,7 +916,7 @@
parseMode: ParseMode.Html,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
} catch (HttpRequestException exc) {

Check warning on line 919 in BotNet.Services/BotCommands/OpenAI.cs

View workflow job for this annotation

GitHub Actions / build

The variable 'exc' is declared but never used

Check warning on line 919 in BotNet.Services/BotCommands/OpenAI.cs

View workflow job for this annotation

GitHub Actions / build

The variable 'exc' is declared but never used
await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: "<code>Unknown error.</code>",
Expand Down
8 changes: 7 additions & 1 deletion BotNet.Services/Stability/Models/ContentFilteredException.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
using System;

namespace BotNet.Services.Stability.Models {
public sealed class ContentFilteredException : Exception { }
public sealed class ContentFilteredException : Exception {
public ContentFilteredException() {
}

public ContentFilteredException(string? message) : base(message) {
}
}
}
7 changes: 7 additions & 0 deletions BotNet.Services/Stability/Models/ErrorResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace BotNet.Services.Stability.Models {
public sealed record ErrorResponse(
string? Id,
string? Message,
string? Name
);
}
1 change: 1 addition & 0 deletions BotNet.Services/Stability/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ public static class ServiceCollectionExtensions {
public static IServiceCollection AddStabilityClient(this IServiceCollection services) {
services.AddSingleton<StabilityClient>();
services.AddSingleton<ImageGenerationBot>();
services.AddSingleton<ImageVariationBot>();
return services;
}
}
Expand Down
23 changes: 23 additions & 0 deletions BotNet.Services/Stability/Skills/ImageVariationBot.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using System.Threading;
using System.Threading.Tasks;

namespace BotNet.Services.Stability.Skills {
public sealed class ImageVariationBot(
StabilityClient stabilityClient
) {
private readonly StabilityClient _stabilityClient = stabilityClient;

public async Task<byte[]> ModifyImageAsync(
byte[] image,
string prompt,
CancellationToken cancellationToken
) {
return await _stabilityClient.ModifyImageAsync(
engine: "stable-diffusion-xl-1024-v1-0",
promptImage: image,
promptText: prompt,
cancellationToken: cancellationToken
);
}
}
}
71 changes: 71 additions & 0 deletions BotNet.Services/Stability/StabilityClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Net;
using System.Net.Http;
using System.Net.Http.Json;
using System.Text.Json;
Expand All @@ -16,6 +17,8 @@ public sealed class StabilityClient(
ILogger<StabilityClient> logger
) {
private const string TEXT_TO_IMAGE_URL_TEMPLATE = "https://api.stability.ai/v1/generation/{0}/text-to-image";
private const string IMAGE_TO_IMAGE_URL_TEMPLATE = "https://api.stability.ai/v1/generation/{0}/image-to-image";

private static readonly JsonSerializerOptions SNAKE_CASE_SERIALIZER_OPTIONS = new() {
PropertyNamingPolicy = new SnakeCaseNamingPolicy()
};
Expand Down Expand Up @@ -57,6 +60,74 @@ CancellationToken cancellationToken
options: SNAKE_CASE_SERIALIZER_OPTIONS
);
using HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
if (!response.IsSuccessStatusCode) {
string error = await response.Content.ReadAsStringAsync(cancellationToken);
if (response.StatusCode == HttpStatusCode.BadRequest) {
ErrorResponse? errorResponse = JsonSerializer.Deserialize<ErrorResponse>(error, SNAKE_CASE_SERIALIZER_OPTIONS);
throw new ContentFilteredException(errorResponse?.Message);
}
_logger.LogError("Unable to generate image: {0}, HTTP Status Code: {1}", error, (int)response.StatusCode);
response.EnsureSuccessStatusCode();
}

string responseJson = await response.Content.ReadAsStringAsync(cancellationToken);

TextToImageResponse? responseData = JsonSerializer.Deserialize<TextToImageResponse>(responseJson, CAMEL_CASE_SERIALIZER_OPTIONS);

if (responseData is { Artifacts: [Artifact { FinishReason: "CONTENT_FILTERED" }] }) {
throw new ContentFilteredException();
}

if (responseData is not { Artifacts: [Artifact { FinishReason: "SUCCESS", Base64: var base64 }] }) {
throw new HttpRequestException();
}

return Convert.FromBase64String(base64);
}

public async Task<byte[]> ModifyImageAsync(
string engine,
byte[] promptImage,
string promptText,
CancellationToken cancellationToken
) {
string url = string.Format(IMAGE_TO_IMAGE_URL_TEMPLATE, engine);
using HttpRequestMessage request = new(HttpMethod.Post, url);
request.Headers.Add("Authorization", $"Bearer {_apiKey}");
request.Headers.Add("Accept", "application/json");
using MultipartFormDataContent formData = new();
using ByteArrayContent promptImageContent = new(promptImage);
formData.Add(
content: promptImageContent,
name: "init_image",
fileName: "init_image.png"
);
using StringContent initImageMode = new("IMAGE_STRENGTH");
using StringContent imageStrength = new("0.35");
using StringContent steps = new("40");
using StringContent width = new("1024");
using StringContent height = new("1024");
using StringContent seed = new("0");
using StringContent cfgScale = new("5");
using StringContent samples = new("1");
using StringContent textPrompts0Text = new(promptText);
using StringContent textPrompts0Weight = new("1");
using StringContent textPrompts1Text = new("blurry, bad, saturated, high contrast, watermark, signature, label, worst quality, normal quality, low quality, low res, extra digits, cropped, jpeg artifacts, username, error, duplicate, ugly, monochrome, mutation, disgusting, bad anatomy, bad hands, three hands, three legs, bad arms, missing legs, missing arms");
using StringContent textPrompts1Weight = new("-1");
formData.Add(initImageMode, "init_image_mode");
formData.Add(imageStrength, "image_strength");
formData.Add(steps, "steps");
formData.Add(width, "width");
formData.Add(height, "height");
formData.Add(seed, "seed");
formData.Add(cfgScale, "cfg_scale");
formData.Add(samples, "samples");
formData.Add(textPrompts0Text, "text_prompts[0][text]");
formData.Add(textPrompts0Weight, "text_prompts[0][weight]");
formData.Add(textPrompts1Text, "text_prompts[1][text]");
formData.Add(textPrompts1Weight, "text_prompts[1][weight]");
request.Content = formData;
using HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
if (!response.IsSuccessStatusCode) {
string error = await response.Content.ReadAsStringAsync(cancellationToken);
_logger.LogError("Unable to generate image: {0}, HTTP Status Code: {1}", error, (int)response.StatusCode);
Expand Down
Loading