Skip to content

Commit

Permalink
[api] Improve drawJoints behavior (#3305)
Browse files Browse the repository at this point in the history
* [api] Improve drawJoints behavior
  • Loading branch information
frankfliu authored Jul 9, 2024
1 parent 4142934 commit cf9e5cf
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 13 deletions.
38 changes: 31 additions & 7 deletions api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -347,18 +347,37 @@ public void drawJoints(Joints joints) {
convertIdNeeded();

Graphics2D g = (Graphics2D) image.getGraphics();
int stroke = 2;
g.setStroke(new BasicStroke(stroke));
g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);

int imageWidth = image.getWidth();
int imageHeight = image.getHeight();

for (Joints.Joint joint : joints.getJoints()) {
g.setPaint(randomColor().darker());
List<Joints.Joint> list = joints.getJoints();
if (list.size() == 17) {
g.setColor(new Color(224, 255, 37));
g.setStroke(new BasicStroke(3));
drawLine(g, list.get(5), list.get(7), imageWidth, imageHeight);
drawLine(g, list.get(7), list.get(9), imageWidth, imageHeight);
drawLine(g, list.get(6), list.get(8), imageWidth, imageHeight);
drawLine(g, list.get(8), list.get(10), imageWidth, imageHeight);
drawLine(g, list.get(11), list.get(13), imageWidth, imageHeight);
drawLine(g, list.get(12), list.get(14), imageWidth, imageHeight);
drawLine(g, list.get(13), list.get(15), imageWidth, imageHeight);
drawLine(g, list.get(14), list.get(16), imageWidth, imageHeight);
drawLine(g, list.get(5), list.get(6), imageWidth, imageHeight);
drawLine(g, list.get(11), list.get(12), imageWidth, imageHeight);
drawLine(g, list.get(5), list.get(11), imageWidth, imageHeight);
drawLine(g, list.get(6), list.get(12), imageWidth, imageHeight);
}

g.setColor(new Color(37, 150, 190));
g.setStroke(new BasicStroke(2));
for (Joints.Joint joint : list) {
int x = (int) (joint.getX() * imageWidth);
int y = (int) (joint.getY() * imageHeight);
g.fillOval(x, y, 10, 10);
g.fillOval(x - 6, y - 6, 12, 12);
}

g.dispose();
}

Expand All @@ -380,8 +399,13 @@ public void drawImage(Image overlay, boolean resize) {
image = target;
}

private Color randomColor() {
return new Color(RandomUtils.nextInt(255));
private void drawLine(
Graphics2D g, Joints.Joint from, Joints.Joint to, int width, int height) {
int x0 = (int) (from.getX() * width);
int y0 = (int) (from.getY() * height);
int x1 = (int) (to.getX() * width);
int y1 = (int) (to.getY() * height);
g.drawLine(x0, y0, x1, y1);
}

private void drawText(Graphics2D g, String text, int x, int y, int stroke, int padding) {
Expand Down
33 changes: 27 additions & 6 deletions extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,25 @@ public void drawJoints(Joints joints) {
int imageWidth = image.width();
int imageHeight = image.height();

Scalar color =
new Scalar(
RandomUtils.nextInt(178),
RandomUtils.nextInt(178),
RandomUtils.nextInt(178));
for (Joints.Joint joint : joints.getJoints()) {
List<Joints.Joint> list = joints.getJoints();
if (list.size() == 17) {
Scalar color = new Scalar(37, 255, 224);
drawLine(list.get(5), list.get(7), imageWidth, imageHeight, color);
drawLine(list.get(7), list.get(9), imageWidth, imageHeight, color);
drawLine(list.get(6), list.get(8), imageWidth, imageHeight, color);
drawLine(list.get(8), list.get(10), imageWidth, imageHeight, color);
drawLine(list.get(11), list.get(13), imageWidth, imageHeight, color);
drawLine(list.get(12), list.get(14), imageWidth, imageHeight, color);
drawLine(list.get(13), list.get(15), imageWidth, imageHeight, color);
drawLine(list.get(14), list.get(16), imageWidth, imageHeight, color);
drawLine(list.get(5), list.get(6), imageWidth, imageHeight, color);
drawLine(list.get(11), list.get(12), imageWidth, imageHeight, color);
drawLine(list.get(5), list.get(11), imageWidth, imageHeight, color);
drawLine(list.get(6), list.get(12), imageWidth, imageHeight, color);
}

Scalar color = new Scalar(190, 150, 37);
for (Joints.Joint joint : list) {
int x = (int) (joint.getX() * imageWidth);
int y = (int) (joint.getY() * imageHeight);
Point point = new Point(x, y);
Expand Down Expand Up @@ -340,6 +353,14 @@ public OpenCVImage normalize(float[] mean, float[] std) {
return new OpenCVImage(result);
}

private void drawLine(Joints.Joint from, Joints.Joint to, int width, int height, Scalar color) {
int x0 = (int) (from.getX() * width);
int y0 = (int) (from.getY() * height);
int x1 = (int) (to.getX() * width);
int y1 = (int) (to.getY() * height);
Imgproc.line(image, new Point(x0, y0), new Point(x1, y1), color, 2, Imgproc.LINE_AA);
}

private void drawLandmarks(BoundingBox box) {
Scalar color = new Scalar(0, 96, 246);
for (ai.djl.modality.cv.output.Point point : box.getPath()) {
Expand Down

0 comments on commit cf9e5cf

Please sign in to comment.