Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Add new include flag to GET inference/<model_id> API for model training metadata #61922

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
import org.elasticsearch.client.ml.GetOverallBucketsRequest;
import org.elasticsearch.client.ml.GetRecordsRequest;
import org.elasticsearch.client.ml.GetTrainedModelsMetadataRequest;
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
import org.elasticsearch.client.ml.MlInfoRequest;
Expand Down Expand Up @@ -819,6 +820,31 @@ static Request getTrainedModelsStats(GetTrainedModelsStatsRequest getTrainedMode
return request;
}

static Request getTrainedModelsMetadata(GetTrainedModelsMetadataRequest getTrainedModelsMetadataRequest) {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml", "inference")
.addPathPart(Strings.collectionToCommaDelimitedString(getTrainedModelsMetadataRequest.getIds()))
.addPathPart("_metadata")
.build();
RequestConverters.Params params = new RequestConverters.Params();
if (getTrainedModelsMetadataRequest.getPageParams() != null) {
PageParams pageParams = getTrainedModelsMetadataRequest.getPageParams();
if (pageParams.getFrom() != null) {
params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString());
}
if (pageParams.getSize() != null) {
params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString());
}
}
if (getTrainedModelsMetadataRequest.getAllowNoMatch() != null) {
params.putParam(GetTrainedModelsMetadataRequest.ALLOW_NO_MATCH,
Boolean.toString(getTrainedModelsMetadataRequest.getAllowNoMatch()));
}
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
request.addParameters(params.asMap());
return request;
}

static Request deleteTrainedModel(DeleteTrainedModelRequest deleteRequest) {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml", "inference")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
import org.elasticsearch.client.ml.GetOverallBucketsResponse;
import org.elasticsearch.client.ml.GetRecordsRequest;
import org.elasticsearch.client.ml.GetRecordsResponse;
import org.elasticsearch.client.ml.GetTrainedModelsMetadataRequest;
import org.elasticsearch.client.ml.GetTrainedModelsMetadataResponse;
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
import org.elasticsearch.client.ml.GetTrainedModelsResponse;
import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
Expand Down Expand Up @@ -2519,6 +2521,49 @@ public Cancellable getTrainedModelsStatsAsync(GetTrainedModelsStatsRequest reque
Collections.emptySet());
}

/**
* Gets trained model metadata
* <p>
* For additional info
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/get-inference-metadata.html">
* GET Trained Model Metadata documentation</a>
*
* @param request The {@link GetTrainedModelsMetadataRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @return {@link GetTrainedModelsMetadataResponse} response object
*/
public GetTrainedModelsMetadataResponse getTrainedModelsMetadata(GetTrainedModelsMetadataRequest request,
RequestOptions options) throws IOException {
return restHighLevelClient.performRequestAndParseEntity(request,
MLRequestConverters::getTrainedModelsMetadata,
options,
GetTrainedModelsMetadataResponse::fromXContent,
Collections.emptySet());
}

/**
* Gets trained model metadata asynchronously and notifies listener upon completion
* <p>
* For additional info
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/get-inference-metadata.html">
* GET Trained Model Metadata documentation</a>
*
* @param request The {@link GetTrainedModelsMetadataRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @param listener Listener to be notified upon request completion
* @return cancellable that may be used to cancel the request
*/
public Cancellable getTrainedModelsMetadataAsync(GetTrainedModelsMetadataRequest request,
RequestOptions options,
ActionListener<GetTrainedModelsMetadataResponse> listener) {
return restHighLevelClient.performRequestAsyncAndParseEntity(request,
MLRequestConverters::getTrainedModelsMetadata,
options,
GetTrainedModelsMetadataResponse::fromXContent,
listener,
Collections.emptySet());
}

/**
* Deletes the given Trained Model
* <p>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.client.ml;

import org.elasticsearch.client.Validatable;
import org.elasticsearch.client.ValidationException;
import org.elasticsearch.client.core.PageParams;
import org.elasticsearch.common.Nullable;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class GetTrainedModelsMetadataRequest implements Validatable {

public static final String ALLOW_NO_MATCH = "allow_no_match";

private final List<String> ids;
private Boolean allowNoMatch;
private PageParams pageParams;

public static GetTrainedModelsMetadataRequest getAllTrainedModelsMetadataRequest() {
return new GetTrainedModelsMetadataRequest("_all");
}

public GetTrainedModelsMetadataRequest(String... ids) {
this.ids = Arrays.asList(ids);
}

public List<String> getIds() {
return ids;
}

public Boolean getAllowNoMatch() {
return allowNoMatch;
}

/**
* Whether to ignore if a wildcard expression matches no trained models.
*
* @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all})
* does not match any trained models
*/
public GetTrainedModelsMetadataRequest setAllowNoMatch(boolean allowNoMatch) {
this.allowNoMatch = allowNoMatch;
return this;
}

public PageParams getPageParams() {
return pageParams;
}

public GetTrainedModelsMetadataRequest setPageParams(@Nullable PageParams pageParams) {
this.pageParams = pageParams;
return this;
}

@Override
public Optional<ValidationException> validate() {
if (ids == null || ids.isEmpty()) {
return Optional.of(ValidationException.withError("trained model id must not be null"));
}
return Optional.empty();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

GetTrainedModelsMetadataRequest other = (GetTrainedModelsMetadataRequest) o;
return Objects.equals(ids, other.ids)
&& Objects.equals(allowNoMatch, other.allowNoMatch)
&& Objects.equals(pageParams, other.pageParams);
}

@Override
public int hashCode() {
return Objects.hash(ids, allowNoMatch, pageParams);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.client.ml;

import org.elasticsearch.client.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;

import java.util.List;
import java.util.Objects;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;

public class GetTrainedModelsMetadataResponse {

public static final ParseField TRAINED_MODELS_METADATA = new ParseField("trained_models_metadata");
public static final ParseField COUNT = new ParseField("count");

@SuppressWarnings("unchecked")
static final ConstructingObjectParser<GetTrainedModelsMetadataResponse, Void> PARSER =
new ConstructingObjectParser<>(
"get_trained_models_metadata",
true,
args -> new GetTrainedModelsMetadataResponse((List<TrainedModelMetadata>) args[0], (Long) args[1]));

static {
PARSER.declareObjectArray(constructorArg(), (p, c) -> TrainedModelMetadata.fromXContent(p), TRAINED_MODELS_METADATA);
PARSER.declareLong(constructorArg(), COUNT);
}

public static GetTrainedModelsMetadataResponse fromXContent(final XContentParser parser) {
return PARSER.apply(parser, null);
}

private final List<TrainedModelMetadata> trainedModelsMetadata;
private final Long count;


public GetTrainedModelsMetadataResponse(List<TrainedModelMetadata> trainedModelsMetadata, Long count) {
this.trainedModelsMetadata = trainedModelsMetadata;
this.count = count;
}

public List<TrainedModelMetadata> getTrainedModelsMetadata() {
return trainedModelsMetadata;
}

/**
* @return The total count of the trained models that matched the ID pattern.
*/
public Long getCount() {
return count;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

GetTrainedModelsMetadataResponse other = (GetTrainedModelsMetadataResponse) o;
return Objects.equals(this.trainedModelsMetadata, other.trainedModelsMetadata) && Objects.equals(this.count, other.count);
}

@Override
public int hashCode() {
return Objects.hash(trainedModelsMetadata, count);
}
}
Loading