diff --git a/BotNet.Services/BotCommands/Art.cs b/BotNet.Services/BotCommands/Art.cs index 77f7bea..68b8740 100644 --- a/BotNet.Services/BotCommands/Art.cs +++ b/BotNet.Services/BotCommands/Art.cs @@ -51,19 +51,19 @@ await botClient.SendPhotoAsync( photo: new InputFileStream(imageStream, "art.png"), replyToMessageId: message.MessageId, cancellationToken: cancellationToken); - } catch (ContentFilteredException) { - await botClient.SendTextMessageAsync( + } catch (ContentFilteredException exc) { + await botClient.EditMessageTextAsync( chatId: message.Chat.Id, - text: "Content filtered", + messageId: busyMessage.MessageId, + text: $"{exc.Message ?? "Content filtered."}", parseMode: ParseMode.Html, - replyToMessageId: message.MessageId, cancellationToken: cancellationToken); } catch { - await botClient.SendTextMessageAsync( + await botClient.EditMessageTextAsync( chatId: message.Chat.Id, + messageId: busyMessage.MessageId, text: "Could not generate art", parseMode: ParseMode.Html, - replyToMessageId: message.MessageId, cancellationToken: cancellationToken); } } catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) { diff --git a/BotNet.Services/BotCommands/OpenAI.cs b/BotNet.Services/BotCommands/OpenAI.cs index 72569fd..0256f38 100644 --- a/BotNet.Services/BotCommands/OpenAI.cs +++ b/BotNet.Services/BotCommands/OpenAI.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Collections.Immutable; using System.IO; using System.Linq; @@ -13,7 +12,6 @@ using BotNet.Services.OpenAI.Skills; using BotNet.Services.RateLimit; using BotNet.Services.Stability.Models; -using BotNet.Services.Stability.Skills; using Microsoft.Extensions.DependencyInjection; using RG.Ninja; using SkiaSharp; @@ -873,11 +871,18 @@ await botClient.SendPhotoAsync( replyToMessageId: message.MessageId, cancellationToken: cancellationToken ); - } catch (ContentFilteredException) { + } catch (ContentFilteredException exc) { await botClient.EditMessageTextAsync( chatId: busyMessage.Chat.Id, messageId: busyMessage.MessageId, - text: "Content filtered.", + text: $"{exc.Message ?? "Content filtered."}", + parseMode: ParseMode.Html + ); + } catch { + await botClient.EditMessageTextAsync( + chatId: busyMessage.Chat.Id, + messageId: busyMessage.MessageId, + text: "Failed to generate image.", parseMode: ParseMode.Html ); } diff --git a/BotNet.Services/Stability/Models/ContentFilteredException.cs b/BotNet.Services/Stability/Models/ContentFilteredException.cs index 29cd174..701b284 100644 --- a/BotNet.Services/Stability/Models/ContentFilteredException.cs +++ b/BotNet.Services/Stability/Models/ContentFilteredException.cs @@ -1,5 +1,11 @@ using System; namespace BotNet.Services.Stability.Models { - public sealed class ContentFilteredException : Exception { } + public sealed class ContentFilteredException : Exception { + public ContentFilteredException() { + } + + public ContentFilteredException(string? message) : base(message) { + } + } } diff --git a/BotNet.Services/Stability/Models/ErrorResponse.cs b/BotNet.Services/Stability/Models/ErrorResponse.cs new file mode 100644 index 0000000..fe8d47f --- /dev/null +++ b/BotNet.Services/Stability/Models/ErrorResponse.cs @@ -0,0 +1,7 @@ +namespace BotNet.Services.Stability.Models { + public sealed record ErrorResponse( + string? Id, + string? Message, + string? Name + ); +} diff --git a/BotNet.Services/Stability/ServiceCollectionExtensions.cs b/BotNet.Services/Stability/ServiceCollectionExtensions.cs index 8c995c8..60f1223 100644 --- a/BotNet.Services/Stability/ServiceCollectionExtensions.cs +++ b/BotNet.Services/Stability/ServiceCollectionExtensions.cs @@ -6,6 +6,7 @@ public static class ServiceCollectionExtensions { public static IServiceCollection AddStabilityClient(this IServiceCollection services) { services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); return services; } } diff --git a/BotNet.Services/Stability/Skills/ImageVariationBot.cs b/BotNet.Services/Stability/Skills/ImageVariationBot.cs new file mode 100644 index 0000000..2ace9bb --- /dev/null +++ b/BotNet.Services/Stability/Skills/ImageVariationBot.cs @@ -0,0 +1,23 @@ +using System.Threading; +using System.Threading.Tasks; + +namespace BotNet.Services.Stability.Skills { + public sealed class ImageVariationBot( + StabilityClient stabilityClient + ) { + private readonly StabilityClient _stabilityClient = stabilityClient; + + public async Task ModifyImageAsync( + byte[] image, + string prompt, + CancellationToken cancellationToken + ) { + return await _stabilityClient.ModifyImageAsync( + engine: "stable-diffusion-xl-1024-v1-0", + promptImage: image, + promptText: prompt, + cancellationToken: cancellationToken + ); + } + } +} diff --git a/BotNet.Services/Stability/StabilityClient.cs b/BotNet.Services/Stability/StabilityClient.cs index f8730cc..76b400e 100644 --- a/BotNet.Services/Stability/StabilityClient.cs +++ b/BotNet.Services/Stability/StabilityClient.cs @@ -1,4 +1,5 @@ using System; +using System.Net; using System.Net.Http; using System.Net.Http.Json; using System.Text.Json; @@ -16,6 +17,8 @@ public sealed class StabilityClient( ILogger logger ) { private const string TEXT_TO_IMAGE_URL_TEMPLATE = "https://api.stability.ai/v1/generation/{0}/text-to-image"; + private const string IMAGE_TO_IMAGE_URL_TEMPLATE = "https://api.stability.ai/v1/generation/{0}/image-to-image"; + private static readonly JsonSerializerOptions SNAKE_CASE_SERIALIZER_OPTIONS = new() { PropertyNamingPolicy = new SnakeCaseNamingPolicy() }; @@ -57,6 +60,74 @@ CancellationToken cancellationToken options: SNAKE_CASE_SERIALIZER_OPTIONS ); using HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken); + if (!response.IsSuccessStatusCode) { + string error = await response.Content.ReadAsStringAsync(cancellationToken); + if (response.StatusCode == HttpStatusCode.BadRequest) { + ErrorResponse? errorResponse = JsonSerializer.Deserialize(error, SNAKE_CASE_SERIALIZER_OPTIONS); + throw new ContentFilteredException(errorResponse?.Message); + } + _logger.LogError("Unable to generate image: {0}, HTTP Status Code: {1}", error, (int)response.StatusCode); + response.EnsureSuccessStatusCode(); + } + + string responseJson = await response.Content.ReadAsStringAsync(cancellationToken); + + TextToImageResponse? responseData = JsonSerializer.Deserialize(responseJson, CAMEL_CASE_SERIALIZER_OPTIONS); + + if (responseData is { Artifacts: [Artifact { FinishReason: "CONTENT_FILTERED" }] }) { + throw new ContentFilteredException(); + } + + if (responseData is not { Artifacts: [Artifact { FinishReason: "SUCCESS", Base64: var base64 }] }) { + throw new HttpRequestException(); + } + + return Convert.FromBase64String(base64); + } + + public async Task ModifyImageAsync( + string engine, + byte[] promptImage, + string promptText, + CancellationToken cancellationToken + ) { + string url = string.Format(IMAGE_TO_IMAGE_URL_TEMPLATE, engine); + using HttpRequestMessage request = new(HttpMethod.Post, url); + request.Headers.Add("Authorization", $"Bearer {_apiKey}"); + request.Headers.Add("Accept", "application/json"); + using MultipartFormDataContent formData = new(); + using ByteArrayContent promptImageContent = new(promptImage); + formData.Add( + content: promptImageContent, + name: "init_image", + fileName: "init_image.png" + ); + using StringContent initImageMode = new("IMAGE_STRENGTH"); + using StringContent imageStrength = new("0.35"); + using StringContent steps = new("40"); + using StringContent width = new("1024"); + using StringContent height = new("1024"); + using StringContent seed = new("0"); + using StringContent cfgScale = new("5"); + using StringContent samples = new("1"); + using StringContent textPrompts0Text = new(promptText); + using StringContent textPrompts0Weight = new("1"); + using StringContent textPrompts1Text = new("blurry, bad, saturated, high contrast, watermark, signature, label, worst quality, normal quality, low quality, low res, extra digits, cropped, jpeg artifacts, username, error, duplicate, ugly, monochrome, mutation, disgusting, bad anatomy, bad hands, three hands, three legs, bad arms, missing legs, missing arms"); + using StringContent textPrompts1Weight = new("-1"); + formData.Add(initImageMode, "init_image_mode"); + formData.Add(imageStrength, "image_strength"); + formData.Add(steps, "steps"); + formData.Add(width, "width"); + formData.Add(height, "height"); + formData.Add(seed, "seed"); + formData.Add(cfgScale, "cfg_scale"); + formData.Add(samples, "samples"); + formData.Add(textPrompts0Text, "text_prompts[0][text]"); + formData.Add(textPrompts0Weight, "text_prompts[0][weight]"); + formData.Add(textPrompts1Text, "text_prompts[1][text]"); + formData.Add(textPrompts1Weight, "text_prompts[1][weight]"); + request.Content = formData; + using HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken); if (!response.IsSuccessStatusCode) { string error = await response.Content.ReadAsStringAsync(cancellationToken); _logger.LogError("Unable to generate image: {0}, HTTP Status Code: {1}", error, (int)response.StatusCode);