Skip to content

Commit

Permalink
AIException is replaced by SKException
Browse files Browse the repository at this point in the history
  • Loading branch information
SergeyMenshykh committed Jul 20, 2023
1 parent 7c9e131 commit e78c188
Show file tree
Hide file tree
Showing 16 changed files with 83 additions and 374 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextCompletion;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Reliability;
using Microsoft.SemanticKernel.Services;
Expand Down Expand Up @@ -197,7 +197,7 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
private static AsyncRetryPolicy GetPolicy(ILogger log)
{
return Policy
.Handle<AIException>(ex => ex.ErrorCode == AIException.ErrorCodes.Throttling)
.Handle<HttpOperationException>(ex => ex.StatusCode == System.Net.HttpStatusCode.TooManyRequests)
.WaitAndRetryAsync(new[]
{
TimeSpan.FromSeconds(2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Azure;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ChatCompletion;
using Microsoft.SemanticKernel.Diagnostics;
using RepoUtils;

// ReSharper disable once InconsistentNaming
Expand Down Expand Up @@ -80,8 +79,7 @@ string OutputExceptionDetail(Exception? exception)
{
return exception switch
{
RequestFailedException requestException => new { requestException.Status, requestException.Message }.AsJson(),
AIException aiException => new { ErrorCode = aiException.ErrorCode.ToString(), aiException.Message, aiException.Detail }.AsJson(),
HttpOperationException httpException => new { StatusCode = httpException.StatusCode.ToString(), httpException.Message, httpException.ResponseContent }.AsJson(),
{ } e => new { e.Message }.AsJson(),
_ => string.Empty
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Diagnostics;

Expand Down Expand Up @@ -119,7 +118,7 @@ private async Task<IReadOnlyList<ITextStreamingResult>> ExecuteGetCompletionsAsy

if (completionResponse is null)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Unexpected response from model")
throw new SKException("Unexpected response from model")
{
Data = { { "ResponseData", responseContent } },
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Diagnostics;

Expand Down Expand Up @@ -76,9 +75,7 @@ public HuggingFaceTextEmbeddingGeneration(string model, HttpClient httpClient, s

if (httpClient.BaseAddress == null && string.IsNullOrEmpty(endpoint))
{
throw new AIException(
AIException.ErrorCodes.InvalidConfiguration,
"The HttpClient BaseAddress and endpoint are both null or empty. Please ensure at least one is provided.");
throw new ArgumentException("The HttpClient BaseAddress and endpoint are both null or empty. Please ensure at least one is provided.");
}
}

Expand All @@ -96,33 +93,34 @@ public async Task<IList<Embedding<float>>> GenerateEmbeddingsAsync(IList<string>
/// <param name="data">Data to embed.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of generated embeddings.</returns>
/// <exception cref="AIException">Exception when backend didn't respond with generated embeddings.</exception>
/// <exception cref="HttpOperationException">Exception when backend didn't respond with generated embeddings.</exception>
private async Task<IList<Embedding<float>>> ExecuteEmbeddingRequestAsync(IList<string> data, CancellationToken cancellationToken)
{
try
var embeddingRequest = new TextEmbeddingRequest
{
var embeddingRequest = new TextEmbeddingRequest
{
Input = data
};
Input = data
};

using var httpRequestMessage = HttpRequest.CreatePostRequest(this.GetRequestUri(), embeddingRequest);
using var httpRequestMessage = HttpRequest.CreatePostRequest(this.GetRequestUri(), embeddingRequest);

httpRequestMessage.Headers.Add("User-Agent", HttpUserAgent);
httpRequestMessage.Headers.Add("User-Agent", HttpUserAgent);

var response = await this._httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);
var body = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
var response = await this._httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);

var embeddingResponse = JsonSerializer.Deserialize<TextEmbeddingResponse>(body);
var responseContent = await response.Content.ReadAsStringAsync().ConfigureAwait(false);

return embeddingResponse?.Embeddings?.Select(l => new Embedding<float>(l.Embedding!, transferOwnership: true)).ToList()!;
try
{
response.EnsureSuccessStatusCode();
}
catch (Exception e) when (e is not AIException && !e.IsCriticalException())
catch (HttpRequestException e)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
throw new HttpOperationException(response.StatusCode, responseContent, e.Message, e);
}

var embeddingResponse = JsonSerializer.Deserialize<TextEmbeddingResponse>(responseContent);

return embeddingResponse?.Embeddings?.Select(l => new Embedding<float>(l.Embedding!, transferOwnership: true)).ToList()!;
}

/// <summary>
Expand All @@ -145,7 +143,7 @@ private Uri GetRequestUri()
}
else
{
throw new AIException(AIException.ErrorCodes.InvalidConfiguration, "No endpoint or HTTP client base address has been provided");
throw new SKException("No endpoint or HTTP client base address has been provided");
}

return new Uri($"{baseUrl!.TrimEnd('/')}/{this._model}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Diagnostics;
Expand Down Expand Up @@ -37,7 +36,7 @@ public async Task<ChatMessageBase> GetChatMessageAsync(CancellationToken cancell

if (chatMessage is null)
{
throw new AIException(AIException.ErrorCodes.UnknownError, "Unable to get chat message from stream");
throw new SKException("Unable to get chat message from stream");
}

return new SKChatMessage(chatMessage);
Expand Down
19 changes: 8 additions & 11 deletions dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/ClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Threading.Tasks;
using Azure;
using Azure.AI.OpenAI;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.AI.TextCompletion;
Expand Down Expand Up @@ -59,14 +58,14 @@ private protected async Task<IReadOnlyList<ITextResult>> InternalGetTextResultsA

if (response == null)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Text completions null response");
throw new SKException("Text completions null response");
}

var responseData = response.Value;

if (responseData.Choices.Count == 0)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Text completions not found");
throw new SKException("Text completions not found");
}

return responseData.Choices.Select(choice => new TextResult(responseData, choice)).ToList();
Expand Down Expand Up @@ -119,12 +118,12 @@ private protected async Task<IList<Embedding<float>>> InternalGetEmbeddingsAsync

if (response == null)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Text embedding null response");
throw new SKException("Text embedding null response");
}

if (response.Value.Data.Count == 0)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Text embedding not found");
throw new SKException("Text embedding not found");
}

EmbeddingItem x = response.Value.Data[0];
Expand Down Expand Up @@ -158,12 +157,12 @@ private protected async Task<IReadOnlyList<IChatResult>> InternalGetChatResultsA

if (response == null)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Chat completions null response");
throw new SKException("Chat completions null response");
}

if (response.Value.Choices.Count == 0)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Chat completions not found");
throw new SKException("Chat completions not found");
}

return response.Value.Choices.Select(chatChoice => new ChatResult(response.Value, chatChoice)).ToList();
Expand Down Expand Up @@ -193,7 +192,7 @@ private protected async IAsyncEnumerable<IChatStreamingResult> InternalGetChatSt

if (response is null)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Chat completions null response");
throw new SKException("Chat completions null response");
}

using StreamingChatCompletions streamingChatCompletions = response.Value;
Expand Down Expand Up @@ -351,9 +350,7 @@ private static void ValidateMaxTokens(int? maxTokens)
{
if (maxTokens.HasValue && maxTokens < 1)
{
throw new AIException(
AIException.ErrorCodes.InvalidRequest,
$"MaxTokens {maxTokens} is not valid, the value must be greater than zero");
throw new SKException($"MaxTokens {maxTokens} is not valid, the value must be greater than zero");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
Expand Down Expand Up @@ -45,7 +44,6 @@ private protected virtual void AddRequestHeaders(HttpRequestMessage request)
/// <param name="requestBody">Request payload</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of text embeddings</returns>
/// <exception cref="AIException">AIException thrown during the request.</exception>
private protected async Task<IList<Embedding<float>>> ExecuteTextEmbeddingRequestAsync(
string url,
string requestBody,
Expand All @@ -54,9 +52,7 @@ private protected async Task<IList<Embedding<float>>> ExecuteTextEmbeddingReques
var result = await this.ExecutePostRequestAsync<TextEmbeddingResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
if (result.Embeddings is not { Count: >= 1 })
{
throw new AIException(
AIException.ErrorCodes.InvalidResponseContent,
"Embeddings not found");
throw new SKException("Embeddings not found");
}

return result.Embeddings.Select(e => new Embedding<float>(e.Values, transferOwnership: true)).ToList();
Expand All @@ -70,7 +66,6 @@ private protected async Task<IList<Embedding<float>>> ExecuteTextEmbeddingReques
/// <param name="extractResponseFunc">Function to invoke to extract the desired portion of the image generation response.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of image URLs</returns>
/// <exception cref="AIException">AIException thrown during the request.</exception>
private protected async Task<IList<string>> ExecuteImageGenerationRequestAsync(
string url,
string requestBody,
Expand Down Expand Up @@ -110,7 +105,7 @@ private protected T JsonDeserialize<T>(string responseJson)
var result = Json.Deserialize<T>(responseJson);
if (result is null)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Response JSON parse error");
throw new SKException("Response JSON parse error");
}

return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.ImageGeneration;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.CustomClient;
using Microsoft.SemanticKernel.Diagnostics;
Expand Down Expand Up @@ -87,9 +86,7 @@ public AzureOpenAIImageGeneration(string apiKey, HttpClient httpClient, string?

if (httpClient.BaseAddress == null && string.IsNullOrEmpty(endpoint))
{
throw new AIException(
AIException.ErrorCodes.InvalidConfiguration,
"The HttpClient BaseAddress and endpoint are both null or empty. Please ensure at least one is provided.");
throw new ArgumentException("The HttpClient BaseAddress and endpoint are both null or empty. Please ensure at least one is provided.");
}

endpoint = !string.IsNullOrEmpty(endpoint) ? endpoint! : httpClient.BaseAddress!.AbsoluteUri;
Expand All @@ -109,12 +106,12 @@ public async Task<string> GenerateImageAsync(string description, int width, int

if (result.Result == null)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Azure Image Generation null response");
throw new SKException("Azure Image Generation null response");
}

if (result.Result.Images.Count == 0)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Azure Image Generation result not found");
throw new SKException("Azure Image Generation result not found");
}

return result.Result.Images.First().Url;
Expand Down Expand Up @@ -148,7 +145,7 @@ private async Task<string> StartImageGenerationAsync(string description, int wid

if (result == null || string.IsNullOrWhiteSpace(result.Id))
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Response not contains result");
throw new SKException("Response not contains result");
}

return result.Id;
Expand All @@ -165,42 +162,34 @@ private async Task<AzureImageGenerationResponse> GetImageGenerationResultAsync(s
var operationLocation = this.GetUri(GetImageOperation, operationId);

var retryCount = 0;
try

while (true)
{
while (true)
if (this._maxRetryCount == retryCount)
{
if (this._maxRetryCount == retryCount)
{
throw new AIException(AIException.ErrorCodes.RequestTimeout, "Reached maximum retry attempts");
}

using var response = await this.ExecuteRequestAsync(operationLocation, HttpMethod.Get, null, cancellationToken).ConfigureAwait(false);
var responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
var result = this.JsonDeserialize<AzureImageGenerationResponse>(responseJson);

if (result.Status.Equals(AzureImageOperationStatus.Succeeded, StringComparison.OrdinalIgnoreCase))
{
return result;
}
else if (this.IsFailedOrCancelled(result.Status))
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, $"Azure OpenAI image generation {result.Status}");
}

if (response.Headers.TryGetValues("retry-after", out var afterValues) && long.TryParse(afterValues.FirstOrDefault(), out var after))
{
await Task.Delay(TimeSpan.FromSeconds(after), cancellationToken).ConfigureAwait(false);
}

// increase retry count
retryCount++;
throw new SKException("Reached maximum retry attempts");
}
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);

using var response = await this.ExecuteRequestAsync(operationLocation, HttpMethod.Get, null, cancellationToken).ConfigureAwait(false);
var responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
var result = this.JsonDeserialize<AzureImageGenerationResponse>(responseJson);

if (result.Status.Equals(AzureImageOperationStatus.Succeeded, StringComparison.OrdinalIgnoreCase))
{
return result;
}
else if (this.IsFailedOrCancelled(result.Status))
{
throw new SKException($"Azure OpenAI image generation {result.Status}");
}

if (response.Headers.TryGetValues("retry-after", out var afterValues) && long.TryParse(afterValues.FirstOrDefault(), out var after))
{
await Task.Delay(TimeSpan.FromSeconds(after), cancellationToken).ConfigureAwait(false);
}

// increase retry count
retryCount++;
}
}

Expand Down
Loading

0 comments on commit e78c188

Please sign in to comment.