Skip to content

Commit

Permalink
Implemented YoloV5 as a 3rd model for inference. Works as the second …
Browse files Browse the repository at this point in the history
…option for object detection. You can import you own custom model in the public/yolov5 folder.

Disclaimer, I don't know javascript. Expect code not to meet the guidelines and might be buggy
  • Loading branch information
IuliuNovac committed Dec 16, 2021
1 parent dbccbfd commit 0bcb0b6
Show file tree
Hide file tree
Showing 14 changed files with 24,666 additions and 74 deletions.
24,576 changes: 24,519 additions & 57 deletions package-lock.json

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
"@tensorflow-models/coco-ssd": "^2.2.2",
"@tensorflow-models/posenet": "^2.2.2",
"@tensorflow/tfjs": "^3.8.0",
"@tensorflow/tfjs-backend-cpu": "^3.9.0",
"@tensorflow/tfjs-backend-webgl": "^3.9.0",
"@tensorflow/tfjs-core": "^3.9.0",
"@tensorflow/tfjs-node": "^3.9.0",
"@tensorflow/tfjs-backend-cpu": "^3.8.0",
"@tensorflow/tfjs-backend-webgl": "^3.8.0",
"@tensorflow/tfjs-core": "^3.8.0",
"@tensorflow/tfjs-node": "^3.8.0",
"@types/jest": "27.0.1",
"@types/node": "16.7.6",
"@types/react": "17.0.19",
Expand Down
Binary file added public/yolov5/group1-shard1of7.bin
Binary file not shown.
Binary file added public/yolov5/group1-shard2of7.bin
Binary file not shown.
Binary file added public/yolov5/group1-shard3of7.bin
Binary file not shown.
Binary file added public/yolov5/group1-shard4of7.bin
Binary file not shown.
Binary file added public/yolov5/group1-shard5of7.bin
Binary file not shown.
Binary file added public/yolov5/group1-shard6of7.bin
Binary file not shown.
Binary file added public/yolov5/group1-shard7of7.bin
Binary file not shown.
1 change: 1 addition & 0 deletions public/yolov5/model.json

Large diffs are not rendered by default.

90 changes: 90 additions & 0 deletions src/ai/ObjectDetectorYOLO.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import '@tensorflow/tfjs-backend-cpu';
import * as tf from '@tensorflow/tfjs';
import {store} from '../index';
import {updateObjectDetectorStatus} from '../store/ai/actionCreators';
import {LabelType} from '../data/enums/LabelType';
import {LabelsSelector} from '../store/selectors/LabelsSelector';
import {AIObjectDetectionActions} from '../logic/actions/AIObjectDetectionActions';
import {updateActiveLabelType} from '../store/labels/actionCreators';
import {DetectedObject, ObjectDetection} from '@tensorflow-models/coco-ssd';
import { AIModel } from '../data/enums/AIModel';

export class ObjectDetectorYolov5 {
private static model: tf.GraphModel;
private static width = 640;
private static height = 640;
public static AIModel = AIModel.OBJECT_DETECTION;
private static names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush'];
// private static predictions: DetectedObject[];
public static async loadModel(callback?: () => any) {
const path = '/yolov5/model.json';
await tf.loadGraphModel(path).then((model ) => {
ObjectDetectorYolov5.model = model;
ObjectDetectorYolov5.AIModel= AIModel.OBJECT_DETECTION_YOLOv5;
store.dispatch(updateObjectDetectorStatus(true));
store.dispatch(updateActiveLabelType(LabelType.RECT));
const activeLabelType: LabelType = LabelsSelector.getActiveLabelType();
if (activeLabelType === LabelType.RECT) {
AIObjectDetectionActions.detectRectsForActiveImage();
}
if (callback) {
callback();
}
}).catch((error) => {
store.dispatch(updateObjectDetectorStatus(false));
throw new Error(error as string);
});
}

public static imgToTensor(img) {
const imgTensor = tf.browser.fromPixels(img);
const originHeight = img.height;
const originWidth = img.width;
const inputTensor = tf.image
.resizeBilinear(imgTensor, [ObjectDetectorYolov5.height, ObjectDetectorYolov5.width])
.div(255.0)
.expandDims(0);
return [inputTensor, originHeight, originWidth];
}

public static async predict(image: HTMLImageElement, callback?: (predictions: DetectedObject[]) => any) {
if (!ObjectDetectorYolov5.model) return;
tf.engine().startScope();
const [input, originHeight, originWidth] = ObjectDetectorYolov5.imgToTensor(image);
const predictions: DetectedObject[] = [];
const results = await ObjectDetectorYolov5.model.executeAsync(input);
const boxes = await results[0].dataSync();
const scores = await results[1].dataSync();
const classes = await results[2].dataSync();
const validDetections = await results[3].dataSync();
for (let i = 0; i < validDetections; i++) {
let [x1, y1, x2, y2] = boxes.slice(i * 4, (i + 1) * 4);
x1 *= originWidth;
x2 *= originWidth;
y1 *= originHeight;
y2 *= originHeight;
const width = x2 - x1;
const height = y2 - y1;
const className = ObjectDetectorYolov5.names[classes[i]];
const score = scores[i];
predictions.push({
bbox: [x1, y1, width, height],
class: className,
score: score.toFixed(2)
});
}
tf.dispose(results);
tf.engine().endScope();
if (callback) {
callback(predictions);
}
}
}
3 changes: 2 additions & 1 deletion src/data/enums/AIModel.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export enum AIModel {
OBJECT_DETECTION = "OBJECT_DETECTION",
POSE_DETECTION = "POSE_DETECTION"
POSE_DETECTION = "POSE_DETECTION",
OBJECT_DETECTION_YOLOv5 = "OBJECT_DETECTION_YOLOv5"
}
51 changes: 39 additions & 12 deletions src/logic/actions/AIObjectDetectionActions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ import {PopupWindowType} from '../../data/enums/PopupWindowType';
import {updateActivePopupType} from '../../store/general/actionCreators';
import {AISelector} from '../../store/selectors/AISelector';
import {AIActions} from './AIActions';

import {ObjectDetectorYolov5} from "../../ai/ObjectDetectorYOLO";
import { AIModel } from '../../data/enums/AIModel';
export class AIObjectDetectionActions {

public static detectRectsForActiveImage(): void {
const activeImageData: ImageData = LabelsSelector.getActiveImageData();
AIObjectDetectionActions.detectRects(activeImageData.id, ImageRepository.getById(activeImageData.id))
Expand All @@ -25,18 +27,43 @@ export class AIObjectDetectionActions {
return;

store.dispatch(updateActivePopupType(PopupWindowType.LOADER));
ObjectDetector.predict(image, (predictions: DetectedObject[]) => {
const suggestedLabelNames = AIObjectDetectionActions.extractNewSuggestedLabelNames(LabelsSelector.getLabelNames(), predictions);
const rejectedLabelNames = AISelector.getRejectedSuggestedLabelList();
const newlySuggestedNames = AIActions.excludeRejectedLabelNames(suggestedLabelNames, rejectedLabelNames);
if (newlySuggestedNames.length > 0) {
store.dispatch(updateSuggestedLabelList(newlySuggestedNames));
store.dispatch(updateActivePopupType(PopupWindowType.SUGGEST_LABEL_NAMES));
} else {
store.dispatch(updateActivePopupType(null));
switch (ObjectDetectorYolov5.AIModel) {
case AIModel.OBJECT_DETECTION_YOLOv5: {
ObjectDetectorYolov5.predict(image, (predictions: DetectedObject[]) => {
const suggestedLabelNames = AIObjectDetectionActions.extractNewSuggestedLabelNames(LabelsSelector.getLabelNames(), predictions);
const rejectedLabelNames = AISelector.getRejectedSuggestedLabelList();
const newlySuggestedNames = AIActions.excludeRejectedLabelNames(suggestedLabelNames, rejectedLabelNames);
if (newlySuggestedNames.length > 0) {
store.dispatch(updateSuggestedLabelList(newlySuggestedNames));
store.dispatch(updateActivePopupType(PopupWindowType.SUGGEST_LABEL_NAMES));
} else {
store.dispatch(updateActivePopupType(null));
}
AIObjectDetectionActions.saveRectPredictions(imageId, predictions);
})
break;
}
AIObjectDetectionActions.saveRectPredictions(imageId, predictions);
})
case AIModel.OBJECT_DETECTION: {
ObjectDetectorYolov5.predict(image, (predictions: DetectedObject[]) => {
const suggestedLabelNames = AIObjectDetectionActions.extractNewSuggestedLabelNames(LabelsSelector.getLabelNames(), predictions);
const rejectedLabelNames = AISelector.getRejectedSuggestedLabelList();
const newlySuggestedNames = AIActions.excludeRejectedLabelNames(suggestedLabelNames, rejectedLabelNames);
if (newlySuggestedNames.length > 0) {
store.dispatch(updateSuggestedLabelList(newlySuggestedNames));
store.dispatch(updateActivePopupType(PopupWindowType.SUGGEST_LABEL_NAMES));
} else {
store.dispatch(updateActivePopupType(null));
}
AIObjectDetectionActions.saveRectPredictions(imageId, predictions);
})
break;
}
default: {
throw new Error('Unknown model');
break;
}
}

}

public static saveRectPredictions(imageId: string, predictions: DetectedObject[]) {
Expand Down
11 changes: 11 additions & 0 deletions src/views/PopupView/LoadModelPopup/LoadModelPopup.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import React, { useState } from "react";
import { PopupActions } from "../../../logic/actions/PopupActions";
import { GenericYesNoPopup } from "../GenericYesNoPopup/GenericYesNoPopup";
import { ObjectDetector } from "../../../ai/ObjectDetector";
import { ObjectDetectorYolov5 } from "../../../ai/ObjectDetectorYOLO";
import './LoadModelPopup.scss'
import { ClipLoader } from "react-spinners";
import { AIModel } from "../../../data/enums/AIModel";
Expand All @@ -21,6 +22,11 @@ const models: SelectableModel[] = [
name: "COCO SSD - object detection using rectangles",
flag: false
},
{
model: AIModel.OBJECT_DETECTION_YOLOv5,
name: "YOLO v5 - object detection using bounding boxes",
flag: false
},
{
model: AIModel.POSE_DETECTION,
name: "POSE-NET - pose estimation using points",
Expand All @@ -40,6 +46,11 @@ export const LoadModelPopup: React.FC = () => {
PopupActions.close();
});
break;
case AIModel.OBJECT_DETECTION_YOLOv5:
ObjectDetectorYolov5.loadModel(() => {
PopupActions.close();
});
break;
case AIModel.OBJECT_DETECTION:
ObjectDetector.loadModel(() => {
PopupActions.close();
Expand Down

0 comments on commit 0bcb0b6

Please sign in to comment.