Skip to content

Speeding Up Your AI‐Powered Search with JAI Async

Richard Hightower edited this page Jul 13, 2023 · 1 revision

Speeding Up Your AI-powered Search with JAI Async

Introduction

Businesses today must deal with the massive amount of data constantly being generated. The data collected from various sources enables organizations to understand their customer's needs and preferences better. With so much data available, the search process can become time-consuming and slow down business operations. A more efficient search system is essential to support business success. Our previous developer notebook discussed how combining ChatGPT with retrieval and re-ranking methods can improve search accuracy. You can obtain a fast and efficient search function place by retrieving the most related content through cosine similarity to a hypothetical answer. By using async calls, you can even decrease the search time further, resulting in a 50% to 200% increase in speed! Steps 1 and 2 can be done in parallel with steps 3 and 4 using the async interface for Open AI API provided by JAI. In this blog, we will dive into how you can speed up your search with JAI Async methods.

Speeding things up with JAI Async

In the previous developer notebook, we discussed how to improve search accuracy by combining ChatGPT with retrieval and re-ranking methods. This technique can be implemented on top of any existing search systems, including Elasticsearch, Solr, or any custom search engine application.

To implement this approach, we did these steps:

  1. Ask the user a question and generate a list of potential queries based on the question.
  2. Execute the search queries and retrieve relevant articles.
  3. Create an ideal answer
  4. Create an ideal answer embedding
  5. Score each article based on the embeddings for each article compared to the hypothetical ideal answer using dot to calculate cosine similarities.
  6. Sort and filter the articles based on the similarity obtained from the embeddings of the articles vs. the embedding of the ideal answer.
  7. Generate an answer to the user's question, including references and links.

The most related content, as measured by cosine similarity to the hypothetical answer (HyDE), is retrieved using this approach. It is fast and can be added to a search function you already have without managing a vector database.

We can speed up this process by using async calls. The speed up is about 30% to 200% faster.

Steps 1 and 2 can be done in parallel with steps 3 and 4. Since JAI has an async interface for accessing Open AI API in Java, we can easily do this.

Digging in

Let’s shows an updated version of the main method within the WhoWonUFC290Async class. This code introduces asynchronous operations and uses CompletableFuture to handle asynchronous tasks. Here's a breakdown of the code:

public static void main(String... args) throws Exception {

        try {

            long startTime = System.currentTimeMillis();
            final CountDownLatch countDownLatch = new CountDownLatch(2);

            // Generating a hypothetical answer and its embedding
            final var hypotheticalAnswerEmbeddingFuture = hypotheticalAnswer()
                    .thenCompose(WhoWonUFC290Async::embeddingsAsync).thenApply(floats -> {
                        countDownLatch.countDown();
                        return floats;
                    });

            // Generate a list of queries and use them to look up articles. 
            final var queriesFuture = jsonGPT(QUERIES_INPUT.replace("{USER_QUESTION}", USER_QUESTION))
                    .thenApply(queriesJson ->
                            JsonParserBuilder.builder().build().parse(queriesJson)
                                    .getObjectNode().getArrayNode("queries")
                                    .filter(node -> node instanceof StringNode)
                                    .stream().map(Object::toString).collect(Collectors.toList())
                    ).thenCompose(WhoWonUFC290Async::getArticles
                    ).thenApply(objectNodes -> {
                        countDownLatch.countDown();
                        return objectNodes;
                    });

            if (!countDownLatch.await(30, TimeUnit.SECONDS))
                throw new TimeoutException("Timed out waiting for hypotheticalAnswerEmbedding and " +
                        " articles ");

            final var articles = queriesFuture.get();

            final var hypotheticalAnswerEmbedding = hypotheticalAnswerEmbeddingFuture.get();

            // Extracting article content and generating embeddings for each article
            final var articleContent = articles.stream().map(article ->
                            String.format("%s %s %s", article.getString("title"),
                                    article.getString("description"), article.getString("content").substring(0, 100)))
                    .collect(Collectors.toList());
            final var articleEmbeddings = embeddingsAsync(articleContent).get();

            // Calculating cosine similarities between the hypothetical answer embedding and article embeddings
           final var cosineSimilarities = articleEmbeddings.stream()
                    .map(articleEmbedding -> dot(hypotheticalAnswerEmbedding, articleEmbedding))
                    .collect(Collectors.toList());

            // Creating a set of scored articles based on cosine similarities
            final var articleSet = IntStream.range(0,
                            Math.min(cosineSimilarities.size(), articleContent.size()))
                    .mapToObj(i -> new ScoredArticle(articles.get(i), cosineSimilarities.get(i)))
                    .collect(Collectors.toSet());

            final var sortedArticles = new ArrayList<>(articleSet);
            sortedArticles.sort((o1, o2) -> Float.compare(o2.getScore(), o1.getScore()));

            // Printing the top 5 scored articles
            sortedArticles.subList(0, 5).forEach(System.out::println);

            // Formatting the top results as JSON strings
            final var formattedTopResults = String.join(",\n", sortedArticles.stream()
                    .map(ScoredArticle::getContent)
                    .map(article -> String.format(Json.niceJson("{'title':'%s', 'url':'%s'," +
                                    " 'description':'%s', 'content':'%s'}\n"),
                            article.getString("title"), article.getString("url"),
                            article.getString("description"),
                            getArticleContent(article))).collect(Collectors.toList()).subList(0, 10));

            System.out.println(formattedTopResults);

            // Generating the final answer with the formatted top results
            final var finalAnswer = jsonGPT(ANSWER_INPUT.replace("{USER_QUESTION}", USER_QUESTION)
                    .replace("{formatted_top_results}", formattedTopResults), "Output format is markdown").get();
            System.out.println(finalAnswer);

            long endTime = System.currentTimeMillis();

            System.out.println(endTime - startTime);
        } catch (Exception ex) {
            ex.printStackTrace();
        }

}
  1. Asynchronous Execution:
    • The code initializes a CountDownLatch with a count of 2. This will be used to synchronize the completion of two asynchronous tasks: generating a hypothetical answer embedding and obtaining the list of queries.
    • The hypotheticalAnswerEmbeddingFuture is created using the hypotheticalAnswer method, which returns a CompletableFuture. The embeddingsAsync method is then applied asynchronously to generate the hypothetical answer embedding. The thenApply method is used to handle the result and decrement the latch count.
    • The queriesFuture is created by first calling the jsonGPT method to generate the queries JSON, and then applying a series of asynchronous operations (getArticles and thenApply) to obtain the list of articles. The thenApply method is used to handle the result and decrement the latch count.
  2. Synchronization:
    • The code uses countDownLatch.await to wait for both asynchronous tasks to complete. If the tasks don't complete within 30 seconds, a TimeoutException is thrown.
  3. Obtaining Results:
    • The code retrieves the results of the asynchronous tasks by calling queriesFuture.get() and hypotheticalAnswerEmbeddingFuture.get().
  4. Extracting Article Content and Embeddings:
    • The code continues by extracting article content and generating embeddings for each article, similar to the previous example.
  5. Calculating Cosine Similarities:
    • The code calculates cosine similarities between the hypothetical answer embedding and article embeddings, similar to the previous example.
  6. Creating Scored Articles and Sorting:
    • The code creates a set of scored articles based on cosine similarities and sorts them in descending order based on the score, similar to the previous example.
  7. Printing Top Results:
    • The code prints the top 5 scored articles by iterating over a sublist of sorted articles and calling System.out.println for each article.
  8. Formatting Top Results as JSON Strings:
    • The code formats the top results as JSON strings using the stream function, map operation, and String.format, similar to the previous example.
  9. Generating the Final Answer:
  • The code generates the final answer by replacing the placeholders in the ANSWER_INPUT string with the actual values, using the jsonGPT method asynchronously. The get method is called on the returned CompletableFuture to obtain the final answer.
  1. Printing Timing Information:
  • The code calculates the elapsed time by subtracting the start time from the end time and prints it using System.out.println.

Let’s break it down and show the async method calls

Let’s cover the rest.

Recall that we are doing this steps

  1. Ask the user a question and generate a list of potential queries based on the question.
  2. Execute the search queries and retrieve relevant articles.
  3. Create an ideal answer
  4. Create an ideal answer embedding

Steps 1 and 2 can be done in parallel with steps 1 and 2. Since JAI has an async interface for accessing Open AI API in Java, we can easily do this. Let’s s how to run step 3 and 4 at the same time as 1 and 2. First we will run 3 and 4 by using async interface of JAI .

    public static CompletableFuture<String> jsonGPT(String input) {
        return jsonGPT(input, "All output shall be JSON");
    }

    public static CompletableFuture<String> jsonGPT(String input, String system) {

        final var client = OpenAIClient.builder()
                    .setApiKey(System.getenv("OPENAI_API_KEY")).build();

        final var chatRequest = ChatRequest.builder()
           .addMessage(Message.builder().role(Role.SYSTEM).content(system).build())
           .addMessage(Message.builder().role(Role.USER).content(input).build())
           .build();

        return client.chatAsync(chatRequest).thenApply(chat -> {
            if (chat.getResponse().isPresent()) {
                return chat.getResponse().get().getChoices().get(0).getMessage().getContent();
            } else {
                System.out.println(chat.getStatusCode().orElse(666) + " " + chat.getStatusMessage().orElse(""));
                throw new IllegalStateException();
            }
        });
    }

   public static CompletableFuture<String> hypotheticalAnswer() {
        final var input = HA_INPUT.replace("{USER_QUESTION}", USER_QUESTION);
        return jsonGPT(input).thenApply(response -> JsonParserBuilder.builder().build().parse(response).getObjectNode().getString("hypotheticalAnswer"));
    }

    public static CompletableFuture<float[]> embeddingsAsync(String input) {
        System.out.println("INPUT " + input);
        return embeddingsAsync(List.of(input)).thenApply(embeddings -> embeddings.get(0));
    }

    public static CompletableFuture<List<float[]>> embeddingsAsync(List<String> input) {
        System.out.println("INPUT " + input);

        if (input == null || input.size() == 0) {
            return CompletableFuture.completedFuture(Collections.singletonList(new float[0]));
        }

        final var client = OpenAIClient.builder().setApiKey(System.getenv("OPENAI_API_KEY")).build();
        return client.embeddingAsync(EmbeddingRequest.builder().model("text-embedding-ada-002").input(input).build()).thenApply(embedding -> {
            if (embedding.getResponse().isPresent()) {
                return embedding.getResponse().get().getData().stream().map(Embedding::getEmbedding).collect(Collectors.toList());
            } else {
                System.out.println(embedding.getStatusCode().orElse(666) + " " + embedding.getStatusMessage().orElse(""));
                throw new IllegalStateException(embedding.getStatusCode().orElse(666) + " " + embedding.getStatusMessage().orElse(""));
            }
        });

    }

public static void main(String... args)  {
        try {

            long startTime = System.currentTimeMillis();
            final CountDownLatch countDownLatch = new CountDownLatch(2);

            // Generating a hypothetical answer and its embedding
            final var hypotheticalAnswerEmbeddingFuture = hypotheticalAnswer()
                    .thenCompose(WhoWonUFC290Async::embeddingsAsync).thenApply(floats -> {
                        countDownLatch.countDown();
                        return floats;
                    });

...

The example adds some additional methods to support asynchronous operations and CompletableFuture usage. Here's a breakdown of the code:

  1. jsonGPT Method:
    • This method is overloaded with two versions. The first version takes only the input parameter and sets the system message to "All output shall be JSON". The second version takes both the input and system parameters to customize the system message.
    • The method creates a ChatRequest with a system message and user message using the provided input.
    • The client.chatAsync(chatRequest) method is invoked to perform the chat interaction asynchronously.
    • The returned CompletableFuture is then processed using thenApply to extract the desired response from the chat and return it.
  2. hypotheticalAnswer Method:
    • This method generates a hypothetical answer by replacing the {USER_QUESTION} placeholder in the HA_INPUT string with the actual user question.
    • It calls the jsonGPT method with the generated input and then processes the resulting CompletableFuture using thenApply.
    • In the thenApply function, the response is parsed using JsonParserBuilder to extract the hypotheticalAnswer field from the JSON object.
  3. embeddingsAsync Methods:
    • The embeddingsAsync methods handle the generation of embeddings asynchronously.
    • The single-input version of the method takes an input string, which represents the content for which embeddings need to be generated.
    • The multi-input version of the method takes a list of input strings.
    • Both methods first check if the input is empty or null. If so, a completed CompletableFuture is returned with an empty list of embeddings.
    • The methods then create an OpenAIClient and invoke the appropriate embedding operation (embeddingAsync) using the input parameter.
    • The resulting CompletableFuture is processed using thenApply, where the embeddings are extracted from the response and returned as a list.

Next, let’s run step 1 and 2, getting the list of queries and then loading the articles for query, at the same time.

public static void main(String... args)  {
        try {

            final CountDownLatch countDownLatch = new CountDownLatch(2);

            // Generating a hypothetical answer and its embedding
            ...

            // Generate a list of queries and use them to look up articles.
            final var queriesFuture = jsonGPT(QUERIES_INPUT.replace("{USER_QUESTION}", USER_QUESTION))
                    .thenApply(queriesJson ->
                            JsonParserBuilder.builder().build().parse(queriesJson)
                                    .getObjectNode().getArrayNode("queries")
                                    .filter(node -> node instanceof StringNode)
                                    .stream().map(Object::toString).collect(Collectors.toList())
                    ).thenCompose(WhoWonUFC290Async::getArticles
                    ).thenApply(objectNodes -> {
                        countDownLatch.countDown();
                        return objectNodes;
                    });

            if (!countDownLatch.await(30, TimeUnit.SECONDS))
                throw new TimeoutException("Timed out waiting for hypotheticalAnswerEmbedding and " +
                        " articles ");

            final var articles = queriesFuture.get();

            final var hypotheticalAnswerEmbedding = 
                              hypotheticalAnswerEmbeddingFuture.get();

            ...
  • The above example initializes a CountDownLatch with a count of 2 to synchronize the completion of two asynchronous tasks: generating a hypothetical answer embedding and obtaining the list of queries.
  • The hypotheticalAnswerEmbeddingFuture is created by calling the hypotheticalAnswer method, followed by the embeddingsAsync method to generate the embeddings asynchronously. The thenApply method is used to handle the result and decrement the latch count.
  • The queriesFuture is obtained using similar asynchronous operations as in the previous examples.
  • The code checks if both asynchronous tasks complete within 30 seconds using countDownLatch.await. If not, a TimeoutException is thrown.
  • The results of the asynchronous tasks are retrieved using the get method on the respective CompletableFuture objects.
  • Creation of queriesFuture:
    • The code uses the jsonGPT method to generate a list of queries based on the user question.
    • The QUERIES_INPUT string is modified by replacing the {USER_QUESTION} placeholder with the actual user question.
    • The resulting CompletableFuture of the JSON queries is processed using thenApply to parse the JSON and extract the array of queries.
    • The obtained objectNodes are then processed using thenCompose to call the getArticles method asynchronously. Notice this because we will cover this: WhoWonUFC290Async::getArticles.
    • After obtaining the articles, the countDownLatch is decremented by calling countDownLatch.countDown().
  1. Asynchronous Waiting:
    • The code uses countDownLatch.await to wait for the completion of both the hypothetical answer embedding and the retrieval of articles.
    • If the waiting time exceeds 30 seconds, a TimeoutException is thrown.
  2. Retrieval of Articles and Hypothetical Answer Embedding:
    • The articles are retrieved from the queriesFuture using queriesFuture.get().
    • The hypothetical answer embedding is obtained from the hypotheticalAnswerEmbeddingFuture using hypotheticalAnswerEmbeddingFuture.get().

Let’s take a look at the getArticles method.

WhoWonUFC290Async::getArticles.

private static CompletableFuture<List<ObjectNode>> getArticles(List<String> queries) {
        final CompletableFuture<List<ObjectNode>> completableFuture = new CompletableFuture<>();
        final CountDownLatch countDownLatch = new CountDownLatch(queries.size());
        final LinkedTransferQueue<ObjectNode> results = new LinkedTransferQueue<>();

        final List<CompletableFuture<ArrayNode>> queryFutures = queries.stream()
                .map(WhoWonUFC290Async::searchNews).collect(Collectors.toList());

        final ExecutorService executorService = Executors.newCachedThreadPool();

        executorService.submit(() -> {
            for (CompletableFuture<ArrayNode> future : queryFutures) {
                try {
                    ArrayNode arrayNode = future.get();
                    arrayNode.forEach(node -> results.add((ObjectNode) node));
                    countDownLatch.countDown();
                } catch (Exception e) {
                    e.printStackTrace();
                    countDownLatch.countDown();
                }
            }
        });

        executorService.submit(() -> {
            try {
                if (!countDownLatch.await(30, TimeUnit.SECONDS)) {
                    throw new TimeoutException("Timed out waiting for articles");
                }

                final var list = new ArrayList<ObjectNode>();
                if (list.addAll(results)) {
                    completableFuture.complete(list);
                } else {
                    completableFuture.complete(list);
                }

            } catch (Exception e) {
                completableFuture.completeExceptionally(e);

            } finally {
                executorService.shutdown();
            }
        });
        return completableFuture;
    }

The getArticles method is a utility method that retrieves articles for a given list of queries asynchronously using CompletableFuture. Let's break down the code:

  1. Initialization:
    • A CompletableFuture<List<ObjectNode>> named completableFuture is created to hold the final result.
    • A CountDownLatch named countDownLatch is initialized with the size of the queries list to keep track of query completions.
    • A LinkedTransferQueue<ObjectNode> named results is created to collect the intermediate results from each query.
  2. Query Execution:
    • The queries list is streamed, and for each query, the searchNews method is called asynchronously using CompletableFuture.
    • The resulting CompletableFuture<ArrayNode> instances are collected in the queryFutures list.
  3. ExecutorService and Results Collection:
    • An ExecutorService named executorService is created with a cached thread pool.
    • Two tasks are submitted to the executorService for concurrent execution.
      • The first task iterates through the queryFutures list and retrieves the ArrayNode results from each future.
      • Each ObjectNode from the ArrayNode is added to the results queue, and the countDownLatch is decremented.
      • Any exceptions during the retrieval process are printed, and the countDownLatch is still decremented even in case of errors.
      • The second task waits for the completion of all queries by calling countDownLatch.await.
      • If the waiting time exceeds 30 seconds, a TimeoutException is thrown.
      • After the waiting period, the results queue is converted to an ArrayList<ObjectNode> named list.
      • If the list is successfully populated, the completableFuture is completed with the list, otherwise it's completed with an empty list.
      • Any exceptions that occur during the process are caught, and the completableFuture is completed exceptionally.
      • Finally, the executorService is shut down.
  4. Return:
    • The completableFuture is returned, which will hold the final list of ObjectNode results when completed.

This method allows the asynchronous execution of multiple queries to retrieve articles, and the results are collected in a list of ObjectNode using CompletableFuture for concurrent processing and efficient handling of asynchronous tasks.

The example continues with the remaining steps of extracting article content, calculating cosine similarities, creating scored articles, sorting them, and using the selected articles as part of the final answer. The big difference is that looking up articles based on generated queries happens simultaneously when we generate an ideal answer and get the ideal answers embeddings.

Conclusion

The Java Open AI API Client Async methods can be used to speed up the process of generating a list of potential queries, executing search queries, creating an ideal answer, and creating an ideal answer embedding to measure and prioritize articles. The example code demonstrates the use of CompletableFuture to handle asynchronous operations with JAI, including generating a hypothetical answer and its embedding, obtaining a list of queries, and retrieving articles for each query. The examples, as before, also shows how to calculate cosine similarities, create scored articles, and sort them to select the best articles for the final answer.