Skip to content

Commit

Permalink
[Fix] Fix multisports dataset detection (#2584)
Browse files Browse the repository at this point in the history
  • Loading branch information
cir7 authored Sep 6, 2023
1 parent 4fee8c2 commit 8ff889a
Show file tree
Hide file tree
Showing 32 changed files with 387 additions and 444 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
shared_head=dict(type='ACRNHead', in_channels=4608, out_channels=2304),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2304,
num_classes=81,
multilabel=True,
Expand Down Expand Up @@ -88,9 +89,6 @@
proposal_file_val = f'{anno_root}/ava_dense_proposals_val.FAIR.recall_93.9.pkl'

file_client_args = dict(io_backend='disk')
file_client_args = dict(
io_backend='petrel',
path_mapping=dict({'data/ava': 's254:s3://openmmlab/datasets/action/ava'}))
train_pipeline = [
dict(type='SampleAVAFrames', clip_len=32, frame_interval=2),
dict(type='RawFrameDecode', **file_client_args),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
shared_head=dict(type='ACRNHead', in_channels=4608, out_channels=2304),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2304,
num_classes=81,
multilabel=True,
Expand Down
1 change: 1 addition & 0 deletions configs/detection/lfb/slowonly-lfb-infer_r50_ava21-rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2048,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2560,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2048,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2304,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2304,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2304,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2048,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2048,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2048,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2048,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2048,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2048,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=False,
in_channels=2048,
num_classes=num_classes,
multilabel=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=2048,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=768,
num_classes=81,
multilabel=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
with_temporal_pool=True),
bbox_head=dict(
type='BBoxHeadAVA',
background_class=True,
in_channels=1024,
num_classes=81,
multilabel=True,
Expand Down
1 change: 0 additions & 1 deletion mmaction/datasets/ava_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def parse_img_record(self, img_records: List[dict]) -> tuple:

labels.append(label)
entity_ids.append(img_record['entity_id'])

bboxes = np.stack(bboxes)
labels = np.stack(labels)
entity_ids = np.stack(entity_ids)
Expand Down
25 changes: 13 additions & 12 deletions mmaction/evaluation/functional/multisports_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections import defaultdict

import numpy as np
from mmengine.logging import MMLogger
from rich.progress import track


Expand Down Expand Up @@ -314,7 +315,7 @@ def tubescore(tt):


def frameAP(GT, alldets, thr, print_info=True):

logger = MMLogger.get_current_instance()
vlist = GT['test_videos'][0]

results = {}
Expand All @@ -326,7 +327,7 @@ def frameAP(GT, alldets, thr, print_info=True):
'basketball save', 'basketball jump ball'
]:
if print_info:
print('do not evaluate {}'.format(label))
logger.info('do not evaluate {}'.format(label))
continue
# det format: <video_index><frame_number><label_index><score><x1><y1><x2><y2> # noqa: E501
detections = alldets[alldets[:, 2] == ilabel, :]
Expand Down Expand Up @@ -355,7 +356,7 @@ def frameAP(GT, alldets, thr, print_info=True):
gt_num = sum([g.shape[0] for g in gt.values()])
if gt_num == 0:
if print_info:
print('no such label', ilabel, label)
logger.info('no such label', ilabel, label)
continue
fp = 0 # false positives
tp = 0 # true positives
Expand Down Expand Up @@ -395,15 +396,15 @@ def frameAP(GT, alldets, thr, print_info=True):
class_result[label] = pr_to_ap_voc(results[label]) * 100
frameap_result = np.mean(ap)
if print_info:
print('frameAP_{}\n'.format(thr))
logger.info('frameAP_{}\n'.format(thr))
for label in class_result:
print('{:20s} {:8.2f}'.format(label, class_result[label]))
print('{:20s} {:8.2f}'.format('mAP', frameap_result))
logger.info('{:20s} {:8.2f}'.format(label, class_result[label]))
logger.info('{:20s} {:8.2f}'.format('mAP', frameap_result))
return frameap_result


def videoAP(GT, alldets, thr, print_info=True):

logger = MMLogger.get_current_instance()
vlist = GT['test_videos'][0]

res = {}
Expand All @@ -414,7 +415,7 @@ def videoAP(GT, alldets, thr, print_info=True):
'basketball save', 'basketball jump ball'
]:
if print_info:
print('do not evaluate{}'.format(GT['labels'][ilabel]))
logger.info('do not evaluate{}'.format(GT['labels'][ilabel]))
continue
detections = alldets[ilabel]
# load ground-truth
Expand All @@ -438,7 +439,7 @@ def videoAP(GT, alldets, thr, print_info=True):
tp = 0 # true positives
if gt_num == 0:
if print_info:
print('no such label', ilabel, GT['labels'][ilabel])
logger.info('no such label', ilabel, GT['labels'][ilabel])
continue
is_gt_box_detected = {}
for i, j in enumerate(
Expand Down Expand Up @@ -471,10 +472,10 @@ def videoAP(GT, alldets, thr, print_info=True):
for label in res:
class_result[label] = pr_to_ap_voc(res[label]) * 100
if print_info:
print('VideoAP_{}\n'.format(thr))
logger.info('VideoAP_{}\n'.format(thr))
for label in class_result:
print('{:20s} {:8.2f}'.format(label, class_result[label]))
print('{:20s} {:8.2f}'.format('mAP', videoap_result))
logger.info('{:20s} {:8.2f}'.format(label, class_result[label]))
logger.info('{:20s} {:8.2f}'.format('mAP', videoap_result))
return videoap_result


Expand Down
10 changes: 0 additions & 10 deletions mmaction/models/backbones/vit_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@
from mmaction.registry import MODELS
from mmaction.utils import ConfigType, OptConfigType

try:
from mmdet.registry import MODELS as MMDET_MODELS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
mmdet_imported = False


class Attention(BaseModule):
"""Multi-head Self-attention.
Expand Down Expand Up @@ -387,7 +381,3 @@ def forward(self, x: Tensor) -> Tensor:
return self.fc_norm(x.mean(1))

return x[:, 0]


if mmdet_imported:
MMDET_MODELS.register_module()(VisionTransformer)
31 changes: 22 additions & 9 deletions mmaction/models/roi_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bbox_heads import BBoxHeadAVA
from .roi_extractors import SingleRoIExtractor3D
from .roi_head import AVARoIHead
from .shared_heads import ACRNHead, FBOHead, LFBInferHead

__all__ = [
'AVARoIHead', 'BBoxHeadAVA', 'SingleRoIExtractor3D', 'ACRNHead', 'FBOHead',
'LFBInferHead'
]
try:
from mmdet.registry import MODELS as MMDET_MODELS

from .bbox_heads import BBoxHeadAVA
from .roi_extractors import SingleRoIExtractor3D
from .roi_head import AVARoIHead
from .shared_heads import ACRNHead, FBOHead, LFBInferHead

for module in [
AVARoIHead, BBoxHeadAVA, SingleRoIExtractor3D, ACRNHead, FBOHead,
LFBInferHead
]:

MMDET_MODELS.register_module()(module)

__all__ = [
'AVARoIHead', 'BBoxHeadAVA', 'SingleRoIExtractor3D', 'ACRNHead',
'FBOHead', 'LFBInferHead'
]

except (ImportError, ModuleNotFoundError):
pass
35 changes: 15 additions & 20 deletions mmaction/models/roi_heads/bbox_heads/bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,17 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.task_modules.samplers import SamplingResult
from mmengine.config import ConfigDict
from mmengine.structures import InstanceData
from torch import Tensor

from mmaction.structures.bbox import bbox_target
from mmaction.utils import InstanceList

try:
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.registry import MODELS as MMDET_MODELS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
from mmaction.utils import SamplingResult
mmdet_imported = False

# Resolve cross-entropy function to support multi-target in Torch < 1.10
# This is a very basic 'hack', with minimal functionality to support the
# procedure under prior torch versions
from packaging import version as pv
from torch import Tensor

from mmaction.structures.bbox import bbox_target
from mmaction.utils import InstanceList

if pv.parse(torch.__version__) < pv.parse('1.10'):

Expand All @@ -44,6 +36,8 @@ class BBoxHeadAVA(nn.Module):
"""Simplest RoI head, with only one fc layer for classification.
Args:
background_class (bool): Whether set class 0 as background class and
ignore it when calculate loss.
temporal_pool_type (str): The temporal pool type. Choices are ``avg``
or ``max``. Defaults to ``avg``.
spatial_pool_type (str): The spatial pool type. Choices are ``avg`` or
Expand All @@ -70,6 +64,7 @@ class BBoxHeadAVA(nn.Module):

def __init__(
self,
background_class: bool,
temporal_pool_type: str = 'avg',
spatial_pool_type: str = 'max',
in_channels: int = 2048,
Expand Down Expand Up @@ -98,6 +93,8 @@ def __init__(
self.focal_gamma = focal_gamma
self.focal_alpha = focal_alpha

self.background_class = background_class

if topk is None:
self.topk = ()
elif isinstance(topk, int):
Expand Down Expand Up @@ -251,9 +248,11 @@ def loss_and_target(self, cls_score: Tensor, rois: Tensor,
losses = dict()
# Only use the cls_score
if cls_score is not None:
labels = labels[:, 1:] # Get valid labels (ignore first one)
if self.background_class:
labels = labels[:, 1:] # Get valid labels (ignore first one)
cls_score = cls_score[:, 1:]
pos_inds = torch.sum(labels, dim=-1) > 0
cls_score = cls_score[pos_inds, 1:]
cls_score = cls_score[pos_inds]
labels = labels[pos_inds]

# Compute First Recall/Precisions
Expand All @@ -268,7 +267,7 @@ def loss_and_target(self, cls_score: Tensor, rois: Tensor,

# If Single-label, need to ensure that target labels sum to 1: ie
# that they are valid probabilities.
if not self.multilabel:
if not self.multilabel and self.background_class:
labels = labels / labels.sum(dim=1, keepdim=True)

# Select Loss function based on single/multi-label
Expand Down Expand Up @@ -414,7 +413,3 @@ def _bbox_crop_undo(bboxes, crop_quadruple):
results.scores = scores

return results


if mmdet_imported:
MMDET_MODELS.register_module()(BBoxHeadAVA)
Loading

0 comments on commit 8ff889a

Please sign in to comment.