Skip to content

Commit

Permalink
[ML] Data Frame Analytics: Fix feature importance cell value and deci…
Browse files Browse the repository at this point in the history
…sion path chart (#82011)

Fixes a regression that caused data grid cells for feature importance to be empty and clicking on the button to show the decision path chart popover to render the whole page empty.
  • Loading branch information
walterra authored Oct 30, 2020
1 parent e01fc2f commit b0a223e
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 10 deletions.
5 changes: 4 additions & 1 deletion x-pack/plugins/ml/common/types/feature_importance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ export interface ClassFeatureImportance {
class_name: string | boolean;
importance: number;
}

// TODO We should separate the interface because classes/importance
// isn't both optional but either/or.
export interface FeatureImportance {
feature_name: string;
importance?: number;
classes?: ClassFeatureImportance[];
importance?: number;
}

export interface TopClass {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { EuiDataGridSorting } from '@elastic/eui';

import { multiColumnSortFactory } from './common';

describe('Transform: Define Pivot Common', () => {
describe('Data Frame Analytics: Data Grid Common', () => {
test('multiColumnSortFactory()', () => {
const data = [
{ s: 'a', n: 1 },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import {
KBN_FIELD_TYPES,
} from '../../../../../../../src/plugins/data/public';

import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
import { extractErrorMessage } from '../../../../common/util/errors';
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';

import {
BASIC_NUMERICAL_TYPES,
Expand Down Expand Up @@ -158,6 +160,90 @@ export const getDataGridSchemaFromKibanaFieldType = (
return schema;
};

const getClassName = (className: string, isClassTypeBoolean: boolean) => {
if (isClassTypeBoolean) {
return className === 'true';
}

return className;
};
/**
* Helper to transform feature importance flattened fields with arrays back to object structure
*
* @param row - EUI data grid data row
* @param mlResultsField - Data frame analytics results field
* @returns nested object structure of feature importance values
*/
export const getFeatureImportance = (
row: Record<string, any>,
mlResultsField: string,
isClassTypeBoolean = false
): FeatureImportance[] => {
const featureNames: string[] | undefined =
row[`${mlResultsField}.feature_importance.feature_name`];
const classNames: string[] | undefined =
row[`${mlResultsField}.feature_importance.classes.class_name`];
const classImportance: number[] | undefined =
row[`${mlResultsField}.feature_importance.classes.importance`];

if (featureNames === undefined) {
return [];
}

// return object structure for classification job
if (classNames !== undefined && classImportance !== undefined) {
const overallClassNames = classNames?.slice(0, classNames.length / featureNames.length);

return featureNames.map((fName, index) => {
const offset = overallClassNames.length * index;
const featureClassImportance = classImportance.slice(
offset,
offset + overallClassNames.length
);
return {
feature_name: fName,
classes: overallClassNames.map((fClassName, fIndex) => {
return {
class_name: getClassName(fClassName, isClassTypeBoolean),
importance: featureClassImportance[fIndex],
};
}),
};
});
}

// return object structure for regression job
const importance: number[] = row[`${mlResultsField}.feature_importance.importance`];
return featureNames.map((fName, index) => ({
feature_name: fName,
importance: importance[index],
}));
};

/**
* Helper to transforms top classes flattened fields with arrays back to object structure
*
* @param row - EUI data grid data row
* @param mlResultsField - Data frame analytics results field
* @returns nested object structure of feature importance values
*/
export const getTopClasses = (row: Record<string, any>, mlResultsField: string): TopClasses => {
const classNames: string[] | undefined = row[`${mlResultsField}.top_classes.class_name`];
const classProbabilities: number[] | undefined =
row[`${mlResultsField}.top_classes.class_probability`];
const classScores: number[] | undefined = row[`${mlResultsField}.top_classes.class_score`];

if (classNames === undefined || classProbabilities === undefined || classScores === undefined) {
return [];
}

return classNames.map((className, index) => ({
class_name: className,
class_probability: classProbabilities[index],
class_score: classScores[index],
}));
};

export const useRenderCellValue = (
indexPattern: IndexPattern | undefined,
pagination: IndexPagination,
Expand Down Expand Up @@ -207,6 +293,15 @@ export const useRenderCellValue = (
return item[cId];
}

// For classification and regression results, we need to treat some fields with a custom transform.
if (cId === `${resultsField}.feature_importance`) {
return getFeatureImportance(fullItem, resultsField ?? DEFAULT_RESULTS_FIELD);
}

if (cId === `${resultsField}.top_classes`) {
return getTopClasses(fullItem, resultsField ?? DEFAULT_RESULTS_FIELD);
}

// Try if the field name is available as a nested field.
return getNestedProperty(tableItems[adjustedRowIndex], cId, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ import { DEFAULT_SAMPLER_SHARD_SIZE } from '../../../../common/constants/field_h

import { ANALYSIS_CONFIG_TYPE, INDEX_STATUS } from '../../data_frame_analytics/common';

import { euiDataGridStyle, euiDataGridToolbarSettings } from './common';
import {
euiDataGridStyle,
euiDataGridToolbarSettings,
getFeatureImportance,
getTopClasses,
} from './common';
import { UseIndexDataReturnType } from './types';
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
import { TopClasses } from '../../../../common/types/feature_importance';
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics';

Expand Down Expand Up @@ -118,18 +123,28 @@ export const DataGrid: FC<Props> = memo(
if (!row) return <div />;
// if resultsField for some reason is not available then use ml
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
const parsedFIArray = row[mlResultsField].feature_importance;
let predictedValue: string | number | undefined;
let topClasses: TopClasses = [];
if (
predictionFieldName !== undefined &&
row &&
row[mlResultsField][predictionFieldName] !== undefined
row[`${mlResultsField}.${predictionFieldName}`] !== undefined
) {
predictedValue = row[mlResultsField][predictionFieldName];
topClasses = row[mlResultsField].top_classes;
predictedValue = row[`${mlResultsField}.${predictionFieldName}`];
topClasses = getTopClasses(row, mlResultsField);
}

const isClassTypeBoolean = topClasses.reduce(
(p, c) => typeof c.class_name === 'boolean' || p,
false
);

const parsedFIArray: FeatureImportance[] = getFeatureImportance(
row,
mlResultsField,
isClassTypeBoolean
);

return (
<DecisionPathPopover
analysisType={analysisType}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ export const getDefaultFieldsFromJobCaps = (
name: `${resultsField}.${FEATURE_IMPORTANCE}`,
type: KBN_FIELD_TYPES.UNKNOWN,
});
// remove flattened feature importance fields
fields = fields.filter(
(field: any) => !field.name.includes(`${resultsField}.${FEATURE_IMPORTANCE}.`)
);
}

if ((numTopClasses ?? 0) > 0) {
Expand All @@ -221,6 +225,10 @@ export const getDefaultFieldsFromJobCaps = (
name: `${resultsField}.${TOP_CLASSES}`,
type: KBN_FIELD_TYPES.UNKNOWN,
});
// remove flattened top classes fields
fields = fields.filter(
(field: any) => !field.name.includes(`${resultsField}.${TOP_CLASSES}.`)
);
}

// Only need to add these fields if we didn't use dest index pattern to get the fields
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export const getIndexData = async (
index: jobConfig.dest.index,
body: {
fields: ['*'],
_source: [],
_source: false,
query: searchQuery,
from: pageIndex * pageSize,
size: pageSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ interface Props {
}

export const ExplorationResultsTable: FC<Props> = React.memo(
({ indexPattern, jobConfig, jobStatus, needsDestIndexPattern, searchQuery }) => {
({ indexPattern, jobConfig, needsDestIndexPattern, searchQuery }) => {
const {
services: {
mlServices: { mlApiServices },
Expand Down

0 comments on commit b0a223e

Please sign in to comment.