Skip to content

Commit

Permalink
[ML] Update to new client & more strict conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
qn895 committed Sep 3, 2020
1 parent 9f435f6 commit a954cc6
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ import React, { FC, useMemo, useState } from 'react';
import { i18n } from '@kbn/i18n';
import { EuiHealth, EuiSpacer, EuiSuperSelect, EuiTitle } from '@elastic/eui';
import d3 from 'd3';
import { isDecisionPathData, useDecisionPathData } from './use_classification_path_data';
import {
isDecisionPathData,
useDecisionPathData,
getStringBasedClassName,
} from './use_classification_path_data';
import { FeatureImportance, TopClasses } from '../../../../../common/types/feature_importance';
import { DecisionPathChart } from './decision_path_chart';
import { MissingDecisionPathCallout } from './missing_decision_path_callout';
Expand All @@ -20,27 +24,25 @@ interface ClassificationDecisionPathProps {
topClasses: TopClasses;
}

// cast to 'True' | 'False' | value to match Eui display
const getStr = (v: string | boolean): string =>
typeof v === 'boolean' ? (v ? 'True' : 'False') : v;

export const ClassificationDecisionPath: FC<ClassificationDecisionPathProps> = ({
featureImportance,
predictedValue,
topClasses,
predictionFieldName,
}) => {
const [currentClass, setCurrentClass] = useState<string>(getStr(topClasses[0].class_name));
const [currentClass, setCurrentClass] = useState<string>(
getStringBasedClassName(topClasses[0].class_name)
);
const { decisionPathData } = useDecisionPathData({
featureImportance,
predictedValue: currentClass,
});
const options = useMemo(() => {
const predictionValueStr = getStr(predictedValue);
const predictionValueStr = getStringBasedClassName(predictedValue);

return Array.isArray(topClasses)
? topClasses.map((c) => {
const className = getStr(c.class_name);
const className = getStringBasedClassName(c.class_name);
return {
value: className,
inputDisplay:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ export const isDecisionPathData = (decisionPathData: any): boolean => {
decisionPathData[0].length === 3
);
};

// cast to 'True' | 'False' | value to match Eui display
export const getStringBasedClassName = (v: string | boolean | undefined | number): string => {
if (v === undefined) {
return '';
}
if (typeof v === 'boolean') {
return v ? 'True' : 'False';
}
if (typeof v === 'number') {
return v.toString();
}
return v;
};

export const useDecisionPathData = ({
baseline,
featureImportance,
Expand Down Expand Up @@ -139,7 +154,7 @@ export const buildClassificationDecisionPathData = ({
ExtendedFeatureImportance | undefined
> = featureImportance.map((feature) => {
const classFeatureImportance = Array.isArray(feature.classes)
? feature.classes.find((c) => c.class_name === currentClass)
? feature.classes.find((c) => getStringBasedClassName(c.class_name) === currentClass)
: feature;
if (classFeatureImportance && typeof classFeatureImportance[FEATURE_IMPORTANCE] === 'number') {
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
* you may not use this file except in compliance with the Elastic License.
*/

import { ILegacyScopedClusterClient } from 'kibana/server';
import { IScopedClusterClient } from 'kibana/server';
import { getPredictionFieldName, isRegressionAnalysis } from '../../../common/util/analytics_utils';

// Obtains data for the data frame analytics feature importance functionalities
// such as baseline, decision paths, or importance summary.
export function analyticsFeatureImportanceProvider({
callAsCurrentUser,
callAsInternalUser,
}: ILegacyScopedClusterClient) {
asInternalUser,
asCurrentUser,
}: IScopedClusterClient) {
async function getRegressionAnalyticsBaseline(analyticsId: string): Promise<number | undefined> {
const results = await callAsInternalUser('ml.getDataFrameAnalytics', {
analyticsId,
const { body } = await asInternalUser.ml.getDataFrameAnalytics({
id: analyticsId,
});
const jobConfig = results.data_frame_analytics[0];
const jobConfig = body.data_frame_analytics[0];
if (!isRegressionAnalysis) return undefined;
const destinationIndex = jobConfig.dest.index;
const predictionField = getPredictionFieldName(jobConfig.analysis);
Expand Down Expand Up @@ -46,7 +46,7 @@ export function analyticsFeatureImportanceProvider({
},
};
let baseline;
const aggregationResult = await callAsCurrentUser('search', params);
const { body: aggregationResult } = await asCurrentUser.search(params);
if (aggregationResult) {
baseline = aggregationResult.aggregations.featureImportanceBaseline.value;
}
Expand Down

0 comments on commit a954cc6

Please sign in to comment.