diff --git a/README.md b/README.md index b77e89ddf..4e54c0126 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ Supported Methods - [x] [FGFA](configs/vid/fgfa) (ICCV 2017) - [x] [SELSA](configs/vid/selsa) (ICCV 2019) - [x] [Temporal RoI Align](configs/vid/temporal_roi_align) (AAAI 2021) +- [x] [TF-Blender](configs/vid/tf_blender) (ICCV 2021) Supported Datasets diff --git a/configs/vid/tf_blender/README.md b/configs/vid/tf_blender/README.md new file mode 100644 index 000000000..d0834ac1f --- /dev/null +++ b/configs/vid/tf_blender/README.md @@ -0,0 +1,24 @@ +# TF-Blender: Temporal Feature Blender for Video Object Detection + +## Abstract + + + +Video objection detection is a challenging task because isolated video frames may encounter appearance deterioration, which introduces great confusion for detection. One of the popular solutions is to exploit the temporal information and enhance per-frame representation through aggregating features from neighboring frames. Despite achieving improvements in detection, existing methods focus on the selection of higher-level video frames for aggregation rather than modeling lower-level temporal relations to increase the feature representation. To address this limitation, we propose a novel solution named TF-Blender, which includes three modules: 1) Temporal relation models the relations between the current frame and its neighboring frames to preserve spatial information. 2). Feature adjustment enriches the representation of every neighboring feature map; 3) Feature blender combines outputs from the first two modules and produces stronger features for the later detection tasks. For its simplicity, TFBlender can be effortlessly plugged into any detection network to improve detection behavior. Extensive evaluations on ImageNet VID and YouTube-VIS benchmarks indicate the performance guarantees of using TF-Blender on recent state-of-the-art methods. + + + +## Citation + + + +```latex +@inproceedings{cui2021tf, + title={Tf-blender: Temporal feature blender for video object detection}, + author={Cui, Yiming and Yan, Liqi and Cao, Zhiwen and Liu, Dongfang}, + booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, + pages={8138--8147}, + year={2021} +} +``` + diff --git a/configs/vid/tf_blender/fgfa_tfblender_faster_rcnn_r101_dc5_7e_imagenetvid.py b/configs/vid/tf_blender/fgfa_tfblender_faster_rcnn_r101_dc5_7e_imagenetvid.py new file mode 100644 index 000000000..59b0c4d18 --- /dev/null +++ b/configs/vid/tf_blender/fgfa_tfblender_faster_rcnn_r101_dc5_7e_imagenetvid.py @@ -0,0 +1,7 @@ +_base_ = ['./fgfa_tfblender_faster_rcnn_r50_dc5_7e_imagenetvid.py'] +model = dict( + detector=dict( + backbone=dict( + depth=101, + init_cfg=dict( + type='Pretrained', checkpoint='torchvision://resnet101')))) diff --git a/configs/vid/tf_blender/fgfa_tfblender_faster_rcnn_r50_dc5_7e_imagenetvid.py b/configs/vid/tf_blender/fgfa_tfblender_faster_rcnn_r50_dc5_7e_imagenetvid.py new file mode 100644 index 000000000..b6a253191 --- /dev/null +++ b/configs/vid/tf_blender/fgfa_tfblender_faster_rcnn_r50_dc5_7e_imagenetvid.py @@ -0,0 +1,34 @@ +_base_ = [ + '../../_base_/models/faster_rcnn_r50_dc5.py', + '../../_base_/datasets/imagenet_vid_fgfa_style.py', + '../../_base_/default_runtime.py' +] +model = dict( + type='FGFA', + motion=dict( + type='FlowNetSimple', + img_scale_factor=0.5, + init_cfg=dict( + type='Pretrained', + checkpoint= # noqa: E251 + 'https://download.openmmlab.com/mmtracking/pretrained_weights/flownet_simple.pth' # noqa: E501 + )), + aggregator=dict( + type='TFBlenderAggregator', num_convs=1, channels=512, kernel_size=3)) + +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict( + _delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) + +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[2, 5]) + +# runtime settings +total_epochs = 7 +evaluation = dict(metric=['bbox'], interval=7) diff --git a/configs/vid/tf_blender/fgfa_tfblender_faster_rcnn_x101_dc5_7e_imagenetvid.py b/configs/vid/tf_blender/fgfa_tfblender_faster_rcnn_x101_dc5_7e_imagenetvid.py new file mode 100644 index 000000000..d6575501b --- /dev/null +++ b/configs/vid/tf_blender/fgfa_tfblender_faster_rcnn_x101_dc5_7e_imagenetvid.py @@ -0,0 +1,11 @@ +_base_ = ['./fgfa_tfblender_faster_rcnn_r50_dc5_7e_imagenetvid.py'] +model = dict( + detector=dict( + backbone=dict( + type='ResNeXt', + depth=101, + groups=64, + base_width=4, + init_cfg=dict( + type='Pretrained', + checkpoint='open-mmlab://resnext101_64x4d')))) diff --git a/mmtrack/models/aggregators/__init__.py b/mmtrack/models/aggregators/__init__.py index 08c6f99fa..c904bbf6a 100644 --- a/mmtrack/models/aggregators/__init__.py +++ b/mmtrack/models/aggregators/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .embed_aggregator import EmbedAggregator from .selsa_aggregator import SelsaAggregator +from .tfblender_aggregator import TFBlenderAggregator -__all__ = ['EmbedAggregator', 'SelsaAggregator'] +__all__ = ['EmbedAggregator', 'SelsaAggregator', 'TFBlenderAggregator'] diff --git a/mmtrack/models/aggregators/selsa_aggregator.py b/mmtrack/models/aggregators/selsa_aggregator.py index b8ca29ce5..ae85adbc4 100644 --- a/mmtrack/models/aggregators/selsa_aggregator.py +++ b/mmtrack/models/aggregators/selsa_aggregator.py @@ -14,10 +14,13 @@ class SelsaAggregator(BaseModule): Object Detection". `SELSA `_. Args: - in_channels (int): The number of channels of the features of - proposal. - num_attention_blocks (int): The number of attention blocks used in - selsa aggregator module. Defaults to 16. + num_convs (int): Number of embedding convs. + channels (int): Channels of embedding convs. Defaults to 256. + kernel_size (int): Kernel size of embedding convs, Defaults to 3. + norm_cfg (dict): Configuration of normlization method after each + conv. Defaults to None. + act_cfg (dict): Configuration of activation method after each + conv. Defaults to dict(type='ReLU'). init_cfg (dict or list[dict], optional): Initialization config dict. Defaults to None. """ diff --git a/mmtrack/models/aggregators/tfblender_aggregator.py b/mmtrack/models/aggregators/tfblender_aggregator.py new file mode 100644 index 000000000..9c1ce5c51 --- /dev/null +++ b/mmtrack/models/aggregators/tfblender_aggregator.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +from mmcv.cnn.bricks import ConvModule + +from ..builder import AGGREGATORS + + +@AGGREGATORS.register_module() +class TFBlenderAggregator(nn.Module): + """TF-Blender aggregator module. + + This module is proposed in "TF-Blender: Temporal Feature Blender for Video + Object Detection". `TF-Blender `_. + + Args: + num_convs (int): Number of embedding convs. + channels (int): Channels of embedding convs. Defaults to 256. + kernel_size (int): Kernel size of embedding convs, Defaults to 3. + norm_cfg (dict): Configuration of normlization method after each + conv. Defaults to None. + act_cfg (dict): Configuration of activation method after each + conv. Defaults to dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + + def __init__(self, + num_convs=1, + channels=256, + kernel_size=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')): + super(TFBlenderAggregator, self).__init__() + assert num_convs > 0, 'The number of convs must be bigger than 1.' + self.embed_convs = nn.ModuleList() + for i in range(num_convs): + if i == num_convs - 1: + new_norm_cfg = None + new_act_cfg = None + else: + new_norm_cfg = norm_cfg + new_act_cfg = act_cfg + self.embed_convs.append( + ConvModule( + in_channels=channels, + out_channels=channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + norm_cfg=new_norm_cfg, + act_cfg=new_act_cfg)) + + self.tf_blenders = nn.ModuleList() + + new_norm_cfg = norm_cfg + new_act_cfg = act_cfg + self.tf_blenders.append( + ConvModule( + in_channels=channels * 8, + out_channels=channels * 4, + kernel_size=1, + padding=0, + norm_cfg=new_norm_cfg, + act_cfg=new_act_cfg)) + self.tf_blenders.append( + ConvModule( + in_channels=channels * 4, + out_channels=channels * 2, + kernel_size=3, + padding=1, + norm_cfg=new_norm_cfg, + act_cfg=new_act_cfg)) + self.tf_blenders.append( + ConvModule( + in_channels=channels * 2, + out_channels=channels, + kernel_size=1, + padding=0, + norm_cfg=None, + act_cfg=None)) + + def forward(self, x, ref_x): + """Aggregate reference feature maps `ref_x`. + + The aggregation mainly contains two steps: + 1. Building an aggregated tensor from `x`, `x_embed` ,`ref_x`, + and 'ref_x_embed' of shape [N, C*8, H, W] + 2. Compute weights through passing Temporal Relation, Feature Adjustment, + and Feature Blender modules. + 3. Use the normlized (i.e. softmax) cos similarity to weightedly sum + `ref_x`. + + Args: + x (Tensor): of shape [1, C, H, W] + ref_x (Tensor): of shape [N, C, H, W]. N is the number of reference + feature maps. + + Returns: + Tensor: The aggregated feature map with shape [1, C, H, W]. + """ + # assert len(x.shape) == 4 and len(x) == 1, \ + # "Only support 'batch_size == 1' for x" + x_embed = x + for embed_conv in self.embed_convs: + x_embed = embed_conv(x_embed) + x_embed = x_embed / x_embed.norm(p=2, dim=1, keepdim=True) + + ref_x_embed = ref_x + for embed_conv in self.embed_convs: + ref_x_embed = embed_conv(ref_x_embed) + ref_x_embed = ref_x_embed / ref_x_embed.norm(p=2, dim=1, keepdim=True) + + tf_weight = torch.cat((x_embed.repeat(ref_x_embed.shape[0],1,1,1), \ + ref_x_embed, \ + x_embed.repeat(ref_x_embed.shape[0],1,1,1) - ref_x_embed, \ + x.repeat(ref_x_embed.shape[0],1,1,1), \ + ref_x, \ + x.repeat(ref_x_embed.shape[0],1,1,1) - ref_x, \ + - x_embed.repeat(ref_x_embed.shape[0],1,1,1) + ref_x_embed, \ + - x.repeat(ref_x_embed.shape[0],1,1,1) + ref_x \ + ), dim=1) + + for tf_blender in self.tf_blenders: + tf_weight = tf_blender(tf_weight) + + ada_weights = tf_weight + + ada_weights = ada_weights.softmax(dim=0) + agg_x = torch.sum(ref_x * ada_weights, dim=0, keepdim=True) + return agg_x \ No newline at end of file