从零到一:基于mmFewShot框架的小样本目标检测实战指南 1. 环境配置与框架安装第一次接触mmFewShot时我被它开箱即用的特性惊艳到了。这个基于PyTorch的小样本学习框架完美继承了OpenMMLab系列工具链的优势。下面分享我在工业质检项目中的实际配置经验基础环境就像搭积木需要严格对齐版本号。我推荐使用conda创建隔离环境conda create -n mmfewshot python3.7 -y conda activate mmfewshot关键依赖的版本组合就像精密齿轮pip install torch1.7.0cu101 torchvision0.8.0cu101 -f https://download.pytorch.org/whl/torch_stable.html pip install mmcv-full1.4.0 -f https://download.mmcv.mmcv-full安装mmFewShot本体时有个小技巧——先克隆仓库再安装git clone https://github.com/open-mmlab/mmfewshot.git cd mmfewshot pip install -v -e . # 可编辑模式安装方便修改源码注意如果遇到mmdetection版本冲突可以尝试先卸载原有版本再安装指定版本。我在RTX 3090上测试时发现torch 1.7需要搭配CUDA 11.0才能发挥最佳性能验证安装是否成功时别只看import是否报错。我习惯跑一个微型测试from mmfewshot.detection import build_detector config configs/detection/meta_rcnn/voc/split1/meta-rcnn_r101_c4_8xb4_voc-split1_base-training.py model build_detector(config) print(model)2. 数据集准备与改造处理工业缺陷数据集时我走过最长的路就是数据格式转换。假设我们有个金属表面缺陷数据集包含划痕、凹陷等10类缺陷需要改造为VOC格式目录结构要像乐高积木般规整VOC2007/ ├── Annotations/ # 存放XML标注文件 ├── ImageSets/ │ └── Main/ # 存放划分文件 └── JPEGImages/ # 存放原始图片类别的拆分艺术决定了模型上限。我的经验是Base类选择常见缺陷如划痕、污渍等7类Novel类保留罕见缺陷如龟裂、氧化等3类确保Base类样本量是Novel类的20倍以上修改mmfewshot/detection/datasets/voc.py时这三个变量是核心ALL_CLASSES_SPLIT1 (scratch, stain, ..., crack) # 全部10类 BASE_CLASSES_SPLIT1 (scratch, stain, ...) # 7个基类 NOVEL_CLASSES_SPLIT1 (crack, oxidation) # 3个新类实战技巧可以用Python脚本自动生成ImageSets文件。我在处理5000张图片时写了这样的脚本import os import random from glob import glob image_ids [os.path.basename(x).split(.)[0] for x in glob(JPEGImages/*.jpg)] random.shuffle(image_ids) # 按7:2:1划分train/val/test with open(ImageSets/Main/trainval.txt,w) as f: f.writelines(id\n for id in image_ids[:int(len(image_ids)*0.7)])3. 两阶段训练实战3.1 基类训练阶段配置base-training.py文件时这几个参数是胜负手data dict( samples_per_gpu4, # 根据显存调整RTX 3090可设8 workers_per_gpu2, # 建议等于CPU核心数 traindict( ann_filedata/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt, img_prefixdata/VOCdevkit/VOC2007/), valdict(...), testdict(...))启动训练时有个隐藏技巧——预热学习率# 单卡训练 python tools/detection/train.py \ configs/detection/meta_rcnn/voc/split1/meta-rcnn_r101_c4_8xb4_voc-split1_base-training.py \ --work-dir work_dirs/base_train \ --seed 42 # 固定随机种子保证可复现 # 多卡训练2卡示例 bash tools/detection/dist_train.sh \ configs/detection/meta_rcnn/voc/split1/meta-rcnn_r101_c4_8xb4_voc-split1_base-training.py \ 2 \ --work-dir work_dirs/base_train_dist3.2 小样本微调阶段准备5-shot数据时我开发了自动化脚本import os import shutil from collections import defaultdict # 创建按类别的图片索引 class_images defaultdict(list) for xml_file in glob(Annotations/*.xml): cls parse_xml_get_class(xml_file) # 自定义解析函数 class_images[cls].append(xml_file.replace(.xml,)) # 为每个新类随机选取5张 for novel_cls in NOVEL_CLASSES_SPLIT1: selected random.sample(class_images[novel_cls], 5) with open(ffew_shot_ann/voc/benchmark_5shot/{novel_cls}.txt,w) as f: f.writelines(fJPEGImages/{x}.jpg\n for x in selected)微调配置的关键在于继承基类知识model dict( roi_headdict( bbox_headdict( num_classes10, # 基类新类总数 init_cfgdict( typePretrained, checkpointwork_dirs/base_train/latest.pth))))启动微调时建议降低学习率python tools/detection/train.py \ configs/detection/meta_rcnn/voc/split1/meta-rcnn_r101_c4_8xb4_voc-split1_5shot-fine-tuning.py \ --cfg-options optimizer.lr0.001 \ --work-dir work_dirs/fewshot_finetune4. 模型推理与效果优化4.1 可视化推理我改进了官方推理脚本支持批量处理def batch_inference(model, img_dir, output_dir): import cv2 os.makedirs(output_dir, exist_okTrue) for img_path in glob(f{img_dir}/*.jpg): result inference_detector(model, img_path) vis_img model.show_result( img_path, result, score_thr0.5) cv2.imwrite( f{output_dir}/{os.path.basename(img_path)}, vis_img)4.2 性能提升技巧困难样本挖掘能显著提升小样本效果。我的实现方案在第一阶段训练时保存预测结果筛选出高loss的样本加入训练集对这些样本进行数据增强跨域适应是另一个突破点。当新类数据不足时使用StyleGAN生成近似样本应用CutMix数据增强采用域随机化技术最后分享一个模型融合技巧——在5-shot训练时# 在配置文件中添加模型集成 model dict( test_cfgdict( nmsdict(typesoft_nms, iou_threshold0.5), score_thr0.05, ensembledict( typeweighted_box_fusion, iou_thr0.5, skip_box_thr0.01)))