mirror of
https://gitee.com/rock_kim/Myolotrain.git
synced 2025-12-06 11:39:07 +08:00
新增transformer追踪和昇腾支持
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -16,6 +16,8 @@ logs/*
|
||||
*.log
|
||||
packages/*
|
||||
yolov8n/*
|
||||
dl/*
|
||||
README.md
|
||||
|
||||
# 数据库
|
||||
postgres_data/
|
||||
@@ -31,4 +33,4 @@ postgres_data/
|
||||
|
||||
# 系统文件
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
Thumbs.db
|
||||
|
||||
53
CHANGELOG.md
53
CHANGELOG.md
@@ -1,53 +0,0 @@
|
||||
# Myolotrain 系统更新日志
|
||||
|
||||
## [0.1.0] - 2025-04-21
|
||||
|
||||
### 新增
|
||||
- 创建系统更新日志,记录系统更新进展
|
||||
|
||||
- 添加OpenCV集成功能
|
||||
- 创建`opencv_service.py`服务类,提供图像处理和计算机视觉功能
|
||||
- 添加API端点,支持图像预处理、图像分析和数据增强
|
||||
- 创建前端界面,提供直观的用户交互
|
||||
- 实现图像预处理功能:调整大小、去噪、亮度和对比度调整、锐化
|
||||
- 实现图像分析功能:模糊检测、曝光检测
|
||||
- 实现数据增强功能:翻转、旋转、添加噪声、亮度和对比度变化、透视变换
|
||||
|
||||
### 修改
|
||||
- 修改`import_local_dataset`函数,支持直接引用`datasets_import`目录中的原始数据集
|
||||
- 当用户选择"服务器数据集"时,系统不再移动数据集目录,而是直接引用原始目录
|
||||
- 在数据库中标记为外部数据集(`is_external=True`)
|
||||
- 在原始目录创建必要的目录结构和配置文件
|
||||
|
||||
- 修改`import_external_dataset`函数,与`import_local_dataset`保持一致
|
||||
- 直接引用原始目录,不创建别名目录
|
||||
- 验证或创建dataset.yaml文件,确保训练过程能正确找到数据集
|
||||
|
||||
- 修改`delete_dataset`函数,确保不会删除外部数据集的目录
|
||||
- 当删除数据集时,检查是否为外部数据集(`is_external=True`)
|
||||
- 如果是外部数据集,只删除数据库记录,不删除目录
|
||||
|
||||
- 修复训练脚本模板中的格式化问题
|
||||
- 在`train_script_template.py`中使用双花括号转义字典定义,避免与格式化占位符冲突
|
||||
- 在`training_service.py`中处理这些双花括号,确保脚本正确生成
|
||||
|
||||
- 修复数据集路径问题
|
||||
- 在`dataset.yaml`文件中使用绝对路径而不是相对路径
|
||||
- 为训练、验证和测试集指定完整的绝对路径,避免路径解析错误
|
||||
|
||||
- 修复Ultralytics数据集目录设置
|
||||
- 在训练脚本中设置Ultralytics的数据集目录为当前项目目录
|
||||
- 避免使用默认的`D:\AI\zljc\yolov5-master\datasets`目录
|
||||
|
||||
- 禁用自动混合精度(AMP)检查
|
||||
- 在训练脚本中设置`amp=False`参数,避免下载额外模型
|
||||
- 设置环境变量`ULTRALYTICS_SKIP_AMP_CHECK=1`,禁用AMP检查
|
||||
- 修复恢复训练时的参数设置,确保一致性
|
||||
|
||||
### 计划中的更改
|
||||
- 进一步优化外部数据集的处理,提高训练效率
|
||||
- 扩展OpenCV功能
|
||||
- 添加视频处理功能,支持从视频中提取帧作为训练数据
|
||||
- 添加半自动标注功能,使用轮廓检测、边缘检测辅助用户快速标注
|
||||
- 添加实时目标跟踪功能,支持多目标跟踪
|
||||
- 添加模型可视化功能,展示模型各层的特征图和注意力图
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import datasets, models, training, detection, opencv, video
|
||||
from app.api.endpoints import datasets, models, training, detection, opencv, video, tracking
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(datasets.router, prefix="/datasets", tags=["datasets"])
|
||||
@@ -9,3 +9,4 @@ api_router.include_router(training.router, prefix="/training", tags=["training"]
|
||||
api_router.include_router(detection.router, prefix="/detection", tags=["detection"])
|
||||
api_router.include_router(opencv.router, prefix="/opencv", tags=["opencv"])
|
||||
api_router.include_router(video.router, prefix="/video", tags=["video"])
|
||||
api_router.include_router(tracking.router, prefix="/tracking", tags=["tracking"])
|
||||
|
||||
@@ -8,10 +8,10 @@ from app.services import detection_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=DetectionTask)
|
||||
@router.post("/")
|
||||
async def create_detection_task(
|
||||
model_id: str = Form(...),
|
||||
file: UploadFile = File(...),
|
||||
model_id: str = Form(None),
|
||||
conf_thres: float = Form(0.25),
|
||||
iou_thres: float = Form(0.45),
|
||||
db: Session = Depends(get_db)
|
||||
@@ -24,13 +24,29 @@ async def create_detection_task(
|
||||
"iou_thres": iou_thres
|
||||
}
|
||||
|
||||
return await detection_service.create_detection_task(
|
||||
task = await detection_service.create_detection_task(
|
||||
db=db,
|
||||
model_id=model_id,
|
||||
file=file,
|
||||
parameters=parameters
|
||||
)
|
||||
|
||||
# 获取检测结果
|
||||
result = detection_service.get_detection_result(db, task_id=task.id)
|
||||
|
||||
# 如果是前端追踪请求,直接返回检测结果
|
||||
if file.filename.endswith('.jpg') and 'frame' in file.filename:
|
||||
# 提取所有检测结果
|
||||
detections = []
|
||||
if result and result.get("results") and isinstance(result["results"], list):
|
||||
for res in result["results"]:
|
||||
if res.get("detections") and isinstance(res["detections"], list):
|
||||
detections.extend(res["detections"])
|
||||
return {"detections": detections}
|
||||
|
||||
# 否则返回任务信息
|
||||
return task
|
||||
|
||||
@router.get("/", response_model=List[DetectionTask])
|
||||
def read_detection_tasks(
|
||||
skip: int = 0,
|
||||
|
||||
@@ -3,6 +3,7 @@ OpenCV API 端点 - 提供图像处理和计算机视觉功能的 API
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
@@ -28,6 +29,10 @@ os.makedirs(PROCESSED_DIR, exist_ok=True)
|
||||
AUGMENTED_DIR = Path(settings.STATIC_DIR) / "augmented_datasets"
|
||||
os.makedirs(AUGMENTED_DIR, exist_ok=True)
|
||||
|
||||
# 高级数据增强目录
|
||||
ADVANCED_AUG_DIR = Path(settings.STATIC_DIR) / "advanced_augmentations"
|
||||
os.makedirs(ADVANCED_AUG_DIR, exist_ok=True)
|
||||
|
||||
|
||||
@router.post("/preprocess", response_class=FileResponse)
|
||||
async def preprocess_image(
|
||||
@@ -37,7 +42,7 @@ async def preprocess_image(
|
||||
):
|
||||
"""
|
||||
预处理图像
|
||||
|
||||
|
||||
操作格式:
|
||||
[
|
||||
{"name": "resize_image", "params": {"width": 640, "height": 480}},
|
||||
@@ -45,42 +50,42 @@ async def preprocess_image(
|
||||
]
|
||||
"""
|
||||
import json
|
||||
|
||||
|
||||
try:
|
||||
# 解析操作
|
||||
operations_list = json.loads(operations)
|
||||
|
||||
|
||||
# 保存上传的图像
|
||||
temp_file = await save_upload_file_temp(image, TEMP_DIR)
|
||||
|
||||
|
||||
# 读取图像
|
||||
img = opencv_service.read_image(temp_file)
|
||||
|
||||
|
||||
# 应用操作
|
||||
for operation in operations_list:
|
||||
op_name = operation["name"]
|
||||
op_params = operation.get("params", {})
|
||||
|
||||
|
||||
# 获取操作方法
|
||||
op_method = getattr(opencv_service, op_name, None)
|
||||
if op_method is None:
|
||||
raise HTTPException(status_code=400, detail=f"未知的操作: {op_name}")
|
||||
|
||||
|
||||
# 应用操作
|
||||
img = op_method(img, **op_params)
|
||||
|
||||
|
||||
# 保存处理后的图像
|
||||
output_filename = f"processed_{Path(image.filename).stem}.jpg"
|
||||
output_path = PROCESSED_DIR / output_filename
|
||||
opencv_service.save_image(img, output_path)
|
||||
|
||||
|
||||
# 返回处理后的图像
|
||||
return FileResponse(
|
||||
path=output_path,
|
||||
filename=output_filename,
|
||||
media_type="image/jpeg"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"图像处理失败: {str(e)}")
|
||||
finally:
|
||||
@@ -100,7 +105,7 @@ async def batch_process_images(
|
||||
):
|
||||
"""
|
||||
批量处理图像
|
||||
|
||||
|
||||
操作格式:
|
||||
[
|
||||
{"name": "resize_image", "params": {"width": 640, "height": 480}},
|
||||
@@ -109,31 +114,31 @@ async def batch_process_images(
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
|
||||
|
||||
try:
|
||||
# 解析操作
|
||||
operations_list = json.loads(operations)
|
||||
|
||||
|
||||
# 创建批处理目录
|
||||
batch_id = str(uuid.uuid4())
|
||||
batch_dir = TEMP_DIR / f"batch_{batch_id}"
|
||||
output_dir = PROCESSED_DIR / f"batch_{batch_id}"
|
||||
os.makedirs(batch_dir, exist_ok=True)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
# 保存上传的图像
|
||||
temp_files = []
|
||||
for image in images:
|
||||
temp_file = await save_upload_file_temp(image, batch_dir)
|
||||
temp_files.append(temp_file)
|
||||
|
||||
|
||||
# 批量处理图像
|
||||
output_paths = opencv_service.batch_process_images(
|
||||
temp_files,
|
||||
str(output_dir),
|
||||
operations_list
|
||||
)
|
||||
|
||||
|
||||
# 返回处理结果
|
||||
return {
|
||||
"success": True,
|
||||
@@ -141,7 +146,7 @@ async def batch_process_images(
|
||||
"batch_id": batch_id,
|
||||
"output_paths": [str(path) for path in output_paths]
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"批量处理图像失败: {str(e)}")
|
||||
|
||||
@@ -150,10 +155,10 @@ async def batch_process_images(
|
||||
async def get_processed_image(batch_id: str, filename: str):
|
||||
"""获取处理后的图像"""
|
||||
file_path = PROCESSED_DIR / f"batch_{batch_id}" / filename
|
||||
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="图像不存在")
|
||||
|
||||
|
||||
return FileResponse(path=file_path)
|
||||
|
||||
|
||||
@@ -166,15 +171,15 @@ async def analyze_image(
|
||||
try:
|
||||
# 保存上传的图像
|
||||
temp_file = await save_upload_file_temp(image, TEMP_DIR)
|
||||
|
||||
|
||||
# 读取图像
|
||||
img = opencv_service.read_image(temp_file)
|
||||
|
||||
|
||||
# 分析图像质量
|
||||
is_blurry, blur_score = opencv_service.detect_blur(img)
|
||||
is_overexposed, overexposed_ratio = opencv_service.detect_overexposure(img)
|
||||
is_underexposed, underexposed_ratio = opencv_service.detect_underexposure(img)
|
||||
|
||||
|
||||
# 返回分析结果
|
||||
return {
|
||||
"filename": image.filename,
|
||||
@@ -194,7 +199,7 @@ async def analyze_image(
|
||||
"图像曝光不足,建议增加曝光" if is_underexposed else None
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"分析图像失败: {str(e)}")
|
||||
finally:
|
||||
@@ -216,7 +221,7 @@ async def augment_dataset(
|
||||
):
|
||||
"""
|
||||
增强数据集
|
||||
|
||||
|
||||
增强选项格式:
|
||||
{
|
||||
"flip": true,
|
||||
@@ -229,21 +234,21 @@ async def augment_dataset(
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
|
||||
|
||||
try:
|
||||
# 解析增强选项
|
||||
augmentation_options_dict = json.loads(augmentation_options)
|
||||
|
||||
|
||||
# 验证数据集路径
|
||||
dataset_dir = Path(dataset_path)
|
||||
if not dataset_dir.exists() or not dataset_dir.is_dir():
|
||||
raise HTTPException(status_code=400, detail=f"数据集目录不存在: {dataset_path}")
|
||||
|
||||
|
||||
# 创建输出目录
|
||||
augmentation_id = str(uuid.uuid4())
|
||||
output_dir = AUGMENTED_DIR / augmentation_id
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
# 在后台任务中增强数据集
|
||||
def augment_dataset_task():
|
||||
try:
|
||||
@@ -253,19 +258,19 @@ async def augment_dataset(
|
||||
augmentation_options_dict,
|
||||
multiplier
|
||||
)
|
||||
|
||||
|
||||
# 保存统计信息
|
||||
with open(output_dir / "stats.json", "w") as f:
|
||||
json.dump(stats, f, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误
|
||||
with open(output_dir / "error.txt", "w") as f:
|
||||
f.write(f"增强数据集失败: {str(e)}")
|
||||
|
||||
|
||||
# 添加后台任务
|
||||
background_tasks.add_task(augment_dataset_task)
|
||||
|
||||
|
||||
# 返回任务信息
|
||||
return {
|
||||
"success": True,
|
||||
@@ -273,7 +278,7 @@ async def augment_dataset(
|
||||
"augmentation_id": augmentation_id,
|
||||
"output_dir": str(output_dir)
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"启动数据集增强任务失败: {str(e)}")
|
||||
|
||||
@@ -282,10 +287,10 @@ async def augment_dataset(
|
||||
async def get_augmentation_status(augmentation_id: str):
|
||||
"""获取数据集增强任务状态"""
|
||||
output_dir = AUGMENTED_DIR / augmentation_id
|
||||
|
||||
|
||||
if not output_dir.exists():
|
||||
raise HTTPException(status_code=404, detail="增强任务不存在")
|
||||
|
||||
|
||||
# 检查是否有错误
|
||||
error_file = output_dir / "error.txt"
|
||||
if error_file.exists():
|
||||
@@ -295,7 +300,7 @@ async def get_augmentation_status(augmentation_id: str):
|
||||
"status": "failed",
|
||||
"error": error_message
|
||||
}
|
||||
|
||||
|
||||
# 检查是否完成
|
||||
stats_file = output_dir / "stats.json"
|
||||
if stats_file.exists():
|
||||
@@ -306,7 +311,7 @@ async def get_augmentation_status(augmentation_id: str):
|
||||
"status": "completed",
|
||||
"stats": stats
|
||||
}
|
||||
|
||||
|
||||
# 任务仍在进行中
|
||||
return {
|
||||
"status": "in_progress"
|
||||
@@ -323,32 +328,32 @@ async def compare_images(
|
||||
try:
|
||||
# 解析标题
|
||||
titles_list = titles.split(",") if titles else None
|
||||
|
||||
|
||||
# 保存上传的图像
|
||||
temp_files = []
|
||||
for image in images:
|
||||
temp_file = await save_upload_file_temp(image, TEMP_DIR)
|
||||
temp_files.append(temp_file)
|
||||
|
||||
|
||||
# 读取图像
|
||||
imgs = [opencv_service.read_image(file) for file in temp_files]
|
||||
|
||||
|
||||
# 创建比较图像
|
||||
comparison = opencv_service.create_comparison_image(imgs, titles_list)
|
||||
|
||||
|
||||
# 保存比较图像
|
||||
import uuid
|
||||
output_filename = f"comparison_{uuid.uuid4()}.jpg"
|
||||
output_path = PROCESSED_DIR / output_filename
|
||||
opencv_service.save_image(comparison, output_path)
|
||||
|
||||
|
||||
# 返回比较图像
|
||||
return FileResponse(
|
||||
path=output_path,
|
||||
filename=output_filename,
|
||||
media_type="image/jpeg"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"创建图像比较视图失败: {str(e)}")
|
||||
finally:
|
||||
@@ -372,27 +377,27 @@ async def draw_bounding_boxes(
|
||||
):
|
||||
"""
|
||||
在图像上绘制边界框
|
||||
|
||||
|
||||
边界框格式:
|
||||
[[x1, y1, x2, y2], [x1, y1, x2, y2], ...]
|
||||
或
|
||||
[[x, y, w, h, 1], [x, y, w, h, 1], ...] (最后一个1表示是xywh格式)
|
||||
"""
|
||||
import json
|
||||
|
||||
|
||||
try:
|
||||
# 解析参数
|
||||
boxes_list = json.loads(boxes)
|
||||
labels_list = json.loads(labels) if labels else None
|
||||
confidences_list = json.loads(confidences) if confidences else None
|
||||
color_tuple = tuple(map(int, color.split(",")))
|
||||
|
||||
|
||||
# 保存上传的图像
|
||||
temp_file = await save_upload_file_temp(image, TEMP_DIR)
|
||||
|
||||
|
||||
# 读取图像
|
||||
img = opencv_service.read_image(temp_file)
|
||||
|
||||
|
||||
# 绘制边界框
|
||||
result = opencv_service.draw_bounding_boxes(
|
||||
img,
|
||||
@@ -402,20 +407,20 @@ async def draw_bounding_boxes(
|
||||
color_tuple,
|
||||
thickness
|
||||
)
|
||||
|
||||
|
||||
# 保存结果图像
|
||||
import uuid
|
||||
output_filename = f"boxes_{uuid.uuid4()}.jpg"
|
||||
output_path = PROCESSED_DIR / output_filename
|
||||
opencv_service.save_image(result, output_path)
|
||||
|
||||
|
||||
# 返回结果图像
|
||||
return FileResponse(
|
||||
path=output_path,
|
||||
filename=output_filename,
|
||||
media_type="image/jpeg"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"绘制边界框失败: {str(e)}")
|
||||
finally:
|
||||
@@ -425,3 +430,144 @@ async def draw_bounding_boxes(
|
||||
os.remove(temp_file)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/advanced-augmentation")
|
||||
async def advanced_augmentation(
|
||||
background_tasks: BackgroundTasks,
|
||||
image1: Optional[UploadFile] = File(None),
|
||||
image2: Optional[UploadFile] = File(None),
|
||||
image3: Optional[UploadFile] = File(None),
|
||||
image4: Optional[UploadFile] = File(None),
|
||||
augmentation_type: str = Form(...), # cutmix, mixup, mosaic, weather
|
||||
cutmix_alpha: Optional[float] = Form(None),
|
||||
mixup_alpha: Optional[float] = Form(None),
|
||||
weather_type: Optional[str] = Form(None), # rain, snow, fog, etc.
|
||||
weather_intensity: Optional[float] = Form(None),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
高级数据增强
|
||||
|
||||
支持的增强类型:
|
||||
- cutmix: 将两张图像按一定比例混合
|
||||
- mixup: 将两张图像按权重混合
|
||||
- mosaic: 将四张图像拼接成一张
|
||||
- weather: 模拟天气效果(雨、雪、雾等)
|
||||
"""
|
||||
try:
|
||||
# 创建临时目录
|
||||
temp_dir = TEMP_DIR / f"advanced_aug_{uuid.uuid4()}"
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# 保存上传的图像
|
||||
temp_files = []
|
||||
|
||||
# 根据增强类型处理图像
|
||||
if augmentation_type == "cutmix":
|
||||
# CutMix需要两张图像
|
||||
if not image1 or not image2:
|
||||
raise HTTPException(status_code=400, detail="CutMix需要两张图像")
|
||||
|
||||
# 保存图像
|
||||
temp_file1 = await save_upload_file_temp(image1, temp_dir)
|
||||
temp_file2 = await save_upload_file_temp(image2, temp_dir)
|
||||
temp_files.extend([temp_file1, temp_file2])
|
||||
|
||||
# 读取图像
|
||||
img1 = opencv_service.read_image(temp_file1)
|
||||
img2 = opencv_service.read_image(temp_file2)
|
||||
|
||||
# 应用CutMix
|
||||
alpha = cutmix_alpha or 0.5
|
||||
result = opencv_service.apply_cutmix(img1, img2, alpha)
|
||||
|
||||
elif augmentation_type == "mixup":
|
||||
# MixUp需要两张图像
|
||||
if not image1 or not image2:
|
||||
raise HTTPException(status_code=400, detail="MixUp需要两张图像")
|
||||
|
||||
# 保存图像
|
||||
temp_file1 = await save_upload_file_temp(image1, temp_dir)
|
||||
temp_file2 = await save_upload_file_temp(image2, temp_dir)
|
||||
temp_files.extend([temp_file1, temp_file2])
|
||||
|
||||
# 读取图像
|
||||
img1 = opencv_service.read_image(temp_file1)
|
||||
img2 = opencv_service.read_image(temp_file2)
|
||||
|
||||
# 应用MixUp
|
||||
alpha = mixup_alpha or 0.5
|
||||
result = opencv_service.apply_mixup(img1, img2, alpha)
|
||||
|
||||
elif augmentation_type == "mosaic":
|
||||
# Mosaic需要四张图像
|
||||
if not image1 or not image2 or not image3 or not image4:
|
||||
raise HTTPException(status_code=400, detail="Mosaic需要四张图像")
|
||||
|
||||
# 保存图像
|
||||
temp_file1 = await save_upload_file_temp(image1, temp_dir)
|
||||
temp_file2 = await save_upload_file_temp(image2, temp_dir)
|
||||
temp_file3 = await save_upload_file_temp(image3, temp_dir)
|
||||
temp_file4 = await save_upload_file_temp(image4, temp_dir)
|
||||
temp_files.extend([temp_file1, temp_file2, temp_file3, temp_file4])
|
||||
|
||||
# 读取图像
|
||||
img1 = opencv_service.read_image(temp_file1)
|
||||
img2 = opencv_service.read_image(temp_file2)
|
||||
img3 = opencv_service.read_image(temp_file3)
|
||||
img4 = opencv_service.read_image(temp_file4)
|
||||
|
||||
# 应用Mosaic
|
||||
result = opencv_service.apply_mosaic([img1, img2, img3, img4])
|
||||
|
||||
elif augmentation_type == "weather":
|
||||
# Weather只需要一张图像
|
||||
if not image1:
|
||||
raise HTTPException(status_code=400, detail="Weather需要一张图像")
|
||||
|
||||
# 保存图像
|
||||
temp_file1 = await save_upload_file_temp(image1, temp_dir)
|
||||
temp_files.append(temp_file1)
|
||||
|
||||
# 读取图像
|
||||
img = opencv_service.read_image(temp_file1)
|
||||
|
||||
# 检查天气类型
|
||||
if not weather_type:
|
||||
raise HTTPException(status_code=400, detail="请指定天气类型")
|
||||
|
||||
# 应用天气效果
|
||||
intensity = weather_intensity or 0.5
|
||||
result = opencv_service.apply_weather_effect(img, weather_type, intensity)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的增强类型: {augmentation_type}")
|
||||
|
||||
# 保存结果图像
|
||||
output_filename = f"{augmentation_type}_{uuid.uuid4()}.jpg"
|
||||
output_path = ADVANCED_AUG_DIR / output_filename
|
||||
opencv_service.save_image(result, output_path)
|
||||
|
||||
# 返回结果
|
||||
return {
|
||||
"success": True,
|
||||
"output_path": f"/static/advanced_augmentations/{output_filename}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"高级数据增强失败: {str(e)}")
|
||||
finally:
|
||||
# 清理临时文件
|
||||
for temp_file in locals().get('temp_files', []):
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 清理临时目录
|
||||
if 'temp_dir' in locals():
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
except:
|
||||
pass
|
||||
|
||||
134
app/api/endpoints/tracking.py
Normal file
134
app/api/endpoints/tracking.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
目标追踪 API 端点 - 提供基于自注意力机制的目标追踪功能
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
|
||||
|
||||
from app.services.tracking_service import tracking_service
|
||||
|
||||
# 设置日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 视频追踪功能已移除,仅保留摄像头追踪功能
|
||||
|
||||
|
||||
@router.post("/track-frame")
|
||||
async def track_frame(
|
||||
image: UploadFile = File(...),
|
||||
detections: str = Form("[]"),
|
||||
target_class_id: Optional[int] = Form(None),
|
||||
enable_tracking: bool = Form(False),
|
||||
cancel_tracking: bool = Form(False)
|
||||
):
|
||||
"""
|
||||
追踪单帧中的目标 - 不保存任何本地文件
|
||||
|
||||
参数:
|
||||
image: 输入帧
|
||||
detections: 检测结果列表,JSON格式
|
||||
target_class_id: 要追踪的目标类别ID
|
||||
enable_tracking: 是否启用追踪特定类别
|
||||
cancel_tracking: 是否取消追踪
|
||||
"""
|
||||
try:
|
||||
# 解析检测结果
|
||||
detections_list = json.loads(detections)
|
||||
|
||||
# 验证检测结果格式
|
||||
if not isinstance(detections_list, list):
|
||||
raise HTTPException(status_code=400, detail="检测结果必须是列表")
|
||||
|
||||
# 确保每个检测都有必要的字段
|
||||
for i, det in enumerate(detections_list):
|
||||
if not isinstance(det, dict):
|
||||
raise HTTPException(status_code=400, detail=f"检测结果 #{i} 必须是字典")
|
||||
|
||||
# 确保必要的字段存在
|
||||
if 'bbox' not in det:
|
||||
raise HTTPException(status_code=400, detail=f"检测结果 #{i} 缺少 'bbox' 字段")
|
||||
if 'class_id' not in det:
|
||||
raise HTTPException(status_code=400, detail=f"检测结果 #{i} 缺少 'class_id' 字段")
|
||||
if 'confidence' not in det:
|
||||
raise HTTPException(status_code=400, detail=f"检测结果 #{i} 缺少 'confidence' 字段")
|
||||
|
||||
# 直接读取上传的图像数据,不保存到本地
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# 读取图像数据(确保不保存到本地文件)
|
||||
contents = await image.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if frame is None:
|
||||
raise HTTPException(status_code=400, detail="无法读取图像")
|
||||
|
||||
# 追踪目标 - 传递追踪参数
|
||||
tracks = tracking_service.track_frame(
|
||||
frame,
|
||||
detections_list,
|
||||
target_class_id=target_class_id,
|
||||
enable_tracking=enable_tracking,
|
||||
cancel_tracking=cancel_tracking
|
||||
)
|
||||
|
||||
# 确保返回一致的格式
|
||||
return {
|
||||
"tracks": tracks,
|
||||
"tracking_status": {
|
||||
"is_tracking": enable_tracking and not cancel_tracking,
|
||||
"target_class_id": target_class_id if enable_tracking and not cancel_tracking else None
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"帧追踪失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"帧追踪失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/reset-tracker")
|
||||
async def reset_tracker():
|
||||
"""重置追踪器状态"""
|
||||
try:
|
||||
tracking_service.reset_tracker()
|
||||
return {"message": "追踪器已重置"}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"重置追踪器失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/set-single-target")
|
||||
async def set_single_target(
|
||||
enable: bool = Form(True),
|
||||
target_id: Optional[int] = Form(None),
|
||||
target_class_id: Optional[int] = Form(None)
|
||||
):
|
||||
"""
|
||||
设置单目标追踪模式
|
||||
|
||||
参数:
|
||||
enable: 是否启用单目标模式
|
||||
target_id: 目标ID
|
||||
target_class_id: 目标类别ID
|
||||
"""
|
||||
try:
|
||||
# 记录请求参数
|
||||
logger.info(f"设置单目标追踪模式: enable={enable}, target_id={target_id}, target_class_id={target_class_id}")
|
||||
|
||||
# 调用服务方法
|
||||
tracking_service.set_single_target_mode(enable, target_id, target_class_id)
|
||||
|
||||
# 返回结果
|
||||
return {
|
||||
"message": f"单目标模式已{'启用' if enable else '禁用'}",
|
||||
"target_id": target_id if enable else None,
|
||||
"target_class_id": target_class_id if enable else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"设置单目标模式失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"设置单目标模式失败: {str(e)}")
|
||||
@@ -8,6 +8,7 @@ from app.db.session import get_db
|
||||
from app.schemas.training_task import TrainingTask
|
||||
from app.services import training_service
|
||||
from app.services.training_service import DeviceManager
|
||||
from app.services.ascend_service import AscendDeviceManager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -56,8 +57,9 @@ def create_training_task(
|
||||
if not dataset_id and not local_dataset_path:
|
||||
raise HTTPException(status_code=400, detail="必须提供 dataset_id 或 local_dataset_path")
|
||||
|
||||
# 初始化GPU信息
|
||||
# 初始化设备信息
|
||||
gpu_info = None
|
||||
ascend_info = None
|
||||
recommended_memory = None
|
||||
|
||||
# 如果选择了GPU设备,自动获取GPU信息
|
||||
@@ -94,6 +96,40 @@ def create_training_task(
|
||||
hardware_config["gpu_memory"] = recommended_memory
|
||||
print(f"自动设置GPU显存限制为: {hardware_config['gpu_memory']}MB")
|
||||
|
||||
# 如果选择了昇腾NPU设备,自动获取昇腾NPU信息
|
||||
elif hardware_config and hardware_config.get("device_type") == "ascend":
|
||||
# 获取昇腾NPU信息
|
||||
ascends = AscendDeviceManager.get_available_ascends()
|
||||
|
||||
# 如果有可用的昇腾NPU
|
||||
if ascends:
|
||||
# 检查是否指定了昇腾NPU ID
|
||||
ascend_index = hardware_config.get("ascend_index", 0)
|
||||
|
||||
# 查找指定的昇腾NPU
|
||||
for ascend in ascends:
|
||||
if ascend.get("index", 0) == ascend_index:
|
||||
ascend_info = ascend
|
||||
break
|
||||
|
||||
# 如果没有找到指定的昇腾NPU,使用第一个昇腾NPU
|
||||
if not ascend_info and ascends:
|
||||
ascend_info = ascends[0]
|
||||
hardware_config["ascend_index"] = ascend_info.get("index", 0)
|
||||
print(f"指定的昇腾NPU ID {ascend_index} 不存在,使用昇腾NPU ID {hardware_config['ascend_index']}")
|
||||
|
||||
if ascend_info:
|
||||
free_memory = ascend_info.get("memory_free", 0)
|
||||
|
||||
if free_memory > 0:
|
||||
# 计算推荐的内存设置(可用内存的80%)
|
||||
recommended_memory = int(free_memory * 0.8)
|
||||
|
||||
# 如果没有设置内存限制,则使用推荐值
|
||||
if not hardware_config.get("ascend_memory"):
|
||||
hardware_config["ascend_memory"] = recommended_memory
|
||||
print(f"自动设置昇腾NPU内存限制为: {hardware_config['ascend_memory']}MB")
|
||||
|
||||
# 创建训练任务
|
||||
task = training_service.create_training_task(
|
||||
db=db,
|
||||
@@ -134,6 +170,18 @@ def create_training_task(
|
||||
"gpu_index": gpu_info.get("index", 0)
|
||||
}
|
||||
|
||||
# 添加昇腾NPU信息到响应中
|
||||
if ascend_info:
|
||||
response_data["ascend_info"] = {
|
||||
"has_ascend": True,
|
||||
"ascend_name": ascend_info.get("name", ""),
|
||||
"total_memory": ascend_info.get("memory", 0),
|
||||
"used_memory": ascend_info.get("memory_used", 0),
|
||||
"free_memory": ascend_info.get("memory_free", 0),
|
||||
"recommended_memory": recommended_memory,
|
||||
"ascend_index": ascend_info.get("index", 0)
|
||||
}
|
||||
|
||||
return response_data
|
||||
|
||||
@router.post("/{task_id}/start", response_model=TrainingTask)
|
||||
@@ -311,9 +359,11 @@ def delete_training_task(
|
||||
def get_device_info():
|
||||
"""获取设备信息,包括可用的GPU列表及其显存大小"""
|
||||
gpus = DeviceManager.get_available_gpus()
|
||||
ascends = AscendDeviceManager.get_available_ascends()
|
||||
recommended_memory = None
|
||||
current_device = None
|
||||
gpu_info = None
|
||||
ascend_info = None
|
||||
|
||||
# 如果有可用的GPU
|
||||
if gpus:
|
||||
@@ -335,12 +385,19 @@ def get_device_info():
|
||||
except Exception as e:
|
||||
print(f"获取当前活跃GPU设备失败: {str(e)}")
|
||||
|
||||
# 如果有可用的昇腾NPU
|
||||
if ascends:
|
||||
ascend_info = ascends[0] # 使用第一个昇腾NPU的信息
|
||||
|
||||
return {
|
||||
"gpus": gpus,
|
||||
"ascends": ascends,
|
||||
"has_cuda": torch.cuda.is_available(),
|
||||
"has_ascend": len(ascends) > 0,
|
||||
"recommended_memory": recommended_memory,
|
||||
"current_device": current_device,
|
||||
"gpu_info": gpu_info
|
||||
"gpu_info": gpu_info,
|
||||
"ascend_info": ascend_info
|
||||
}
|
||||
|
||||
class GPUMemoryValidationRequest(BaseModel):
|
||||
@@ -410,6 +467,93 @@ def validate_gpu_memory(request: GPUMemoryValidationRequest):
|
||||
"gpu_index": request.gpu_index
|
||||
}
|
||||
|
||||
class AscendMemoryValidationRequest(BaseModel):
|
||||
ascend_memory: int
|
||||
ascend_index: Optional[int] = 0
|
||||
|
||||
@router.post("/validate-ascend-memory")
|
||||
def validate_ascend_memory(request: AscendMemoryValidationRequest):
|
||||
"""验证昇腾NPU内存设置是否合理"""
|
||||
# 验证请求的内存是否合理
|
||||
is_valid, message, total_memory = AscendDeviceManager.validate_ascend_memory(
|
||||
requested_memory=request.ascend_memory,
|
||||
ascend_index=request.ascend_index
|
||||
)
|
||||
|
||||
# 获取昇腾NPU信息
|
||||
ascends = AscendDeviceManager.get_available_ascends()
|
||||
|
||||
# 初始化变量
|
||||
free_memory = 0
|
||||
used_memory = 0
|
||||
recommended_memory = None
|
||||
ascend_name = ""
|
||||
has_ascend = False
|
||||
|
||||
# 检查指定的昇腾NPU是否存在
|
||||
ascend_info = None
|
||||
for ascend in ascends:
|
||||
if ascend.get("index", 0) == request.ascend_index:
|
||||
ascend_info = ascend
|
||||
break
|
||||
|
||||
if ascend_info:
|
||||
has_ascend = True
|
||||
ascend_name = ascend_info.get("name", "")
|
||||
total_memory = ascend_info.get("memory", 0)
|
||||
free_memory = ascend_info.get("memory_free", 0)
|
||||
used_memory = ascend_info.get("memory_used", 0)
|
||||
|
||||
# 计算推荐的内存设置(可用内存的80%)
|
||||
if free_memory > 0:
|
||||
recommended_memory = int(free_memory * 0.8)
|
||||
|
||||
return {
|
||||
"valid": is_valid,
|
||||
"message": message,
|
||||
"total_memory": total_memory,
|
||||
"free_memory": free_memory,
|
||||
"used_memory": used_memory,
|
||||
"recommended_memory": recommended_memory,
|
||||
"ascend_name": ascend_name,
|
||||
"has_ascend": has_ascend,
|
||||
"ascend_index": request.ascend_index
|
||||
}
|
||||
|
||||
@router.get("/ascend-info", response_model=Dict)
|
||||
def get_ascend_info():
|
||||
"""获取所有可用的昇腾NPU信息,用于前端显示和选择"""
|
||||
ascends = AscendDeviceManager.get_available_ascends()
|
||||
|
||||
response = {
|
||||
"has_ascend": len(ascends) > 0,
|
||||
"ascends": [],
|
||||
"current_device": None
|
||||
}
|
||||
|
||||
# 如果有可用的昇腾NPU,添加到响应中
|
||||
if ascends:
|
||||
for i, ascend in enumerate(ascends):
|
||||
# 添加display_name字段用于显示
|
||||
ascend_info = {
|
||||
"index": ascend.get("index", i),
|
||||
"name": ascend.get("name", f"Ascend NPU {i}"),
|
||||
"display_name": f"NPU {i}: {ascend.get('name', 'Unknown')}",
|
||||
"total_memory": ascend.get("memory", 0),
|
||||
"used_memory": ascend.get("memory_used", 0),
|
||||
"free_memory": ascend.get("memory_free", 0),
|
||||
"recommended_memory": ascend.get("recommended_memory", 0)
|
||||
}
|
||||
response["ascends"].append(ascend_info)
|
||||
|
||||
# 设置当前设备为第一个昇腾NPU
|
||||
response["current_device"] = {
|
||||
"index": ascends[0].get("index", 0),
|
||||
"name": ascends[0].get("name", "Unknown")
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
504
app/nn/modules/attention.py
Normal file
504
app/nn/modules/attention.py
Normal file
@@ -0,0 +1,504 @@
|
||||
"""
|
||||
自注意力(Self-Attention)模块
|
||||
|
||||
该模块实现了Transformer架构中的自注意力机制,可用于各种深度学习任务。
|
||||
自注意力机制允许模型在处理序列数据时考虑序列中所有位置之间的关系,
|
||||
从而捕获长距离依赖和全局上下文信息。
|
||||
|
||||
主要组件:
|
||||
1. 标准自注意力 (SelfAttention): 实现经典的多头自注意力
|
||||
2. 空间自注意力 (SpatialSelfAttention): 专为图像特征设计的自注意力
|
||||
3. 高效自注意力 (EfficientSelfAttention): 使用空间降采样减少计算量
|
||||
4. 自注意力块 (SelfAttentionBlock): 包含自注意力模块和残差连接
|
||||
5. 位置编码 (PositionalEncoding): 为序列添加位置信息
|
||||
6. Transformer编码器层 (TransformerEncoderLayer): 完整的Transformer编码器层
|
||||
7. Transformer编码器 (TransformerEncoder): 堆叠多个编码器层
|
||||
|
||||
技术架构:
|
||||
- 自注意力机制基于"注意力即权重"的概念,通过计算查询(Q)和键(K)之间的相似度,
|
||||
然后用这些权重对值(V)进行加权求和,得到注意力输出。
|
||||
- 多头注意力将输入分割为多个头,每个头独立计算自注意力,然后合并结果,
|
||||
这允许模型同时关注不同表示子空间的信息。
|
||||
- Transformer架构通过堆叠自注意力层和前馈神经网络层,配合残差连接和层归一化,
|
||||
构建了强大的序列处理能力。
|
||||
|
||||
使用场景:
|
||||
- 序列建模: 文本处理、时间序列分析
|
||||
- 计算机视觉: 图像分类、目标检测、语义分割
|
||||
- 多模态任务: 图像描述、视觉问答
|
||||
|
||||
参考文献:
|
||||
- "Attention Is All You Need" (Vaswani et al., 2017)
|
||||
- "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., 2020)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from typing import Optional, Tuple, List, Union
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""
|
||||
自注意力模块
|
||||
|
||||
参数:
|
||||
dim (int): 输入特征的通道维度
|
||||
num_heads (int): 注意力头的数量
|
||||
qkv_bias (bool): 是否在QKV投影中使用偏置
|
||||
attn_drop (float): 注意力dropout率
|
||||
proj_drop (float): 输出投影dropout率
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5 # 缩放因子
|
||||
|
||||
# QKV投影
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入特征 [B, N, C],其中N是序列长度,C是通道数
|
||||
|
||||
返回:
|
||||
torch.Tensor: 注意力增强后的特征 [B, N, C]
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
|
||||
# 计算QKV
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, head_dim]
|
||||
|
||||
# 计算注意力分数
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, N, N]
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
# 加权聚合
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
"""
|
||||
空间自注意力模块,专为图像特征设计
|
||||
|
||||
参数:
|
||||
in_channels (int): 输入特征的通道数
|
||||
out_channels (int): 输出特征的通道数,默认与输入相同
|
||||
num_heads (int): 注意力头的数量
|
||||
reduction_ratio (int): 用于减少计算量的通道减少比例
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_heads: int = 8,
|
||||
reduction_ratio: int = 8
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.num_heads = num_heads
|
||||
|
||||
# 通道减少以降低计算复杂度
|
||||
self.reduced_channels = max(1, in_channels // reduction_ratio)
|
||||
|
||||
# QKV投影
|
||||
self.q_conv = nn.Conv2d(in_channels, self.reduced_channels, kernel_size=1, bias=False)
|
||||
self.k_conv = nn.Conv2d(in_channels, self.reduced_channels, kernel_size=1, bias=False)
|
||||
self.v_conv = nn.Conv2d(in_channels, self.reduced_channels, kernel_size=1, bias=False)
|
||||
|
||||
# 输出投影
|
||||
self.out_conv = nn.Conv2d(self.reduced_channels, out_channels or in_channels, kernel_size=1)
|
||||
|
||||
# 初始化
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""初始化权重"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入特征图 [B, C, H, W]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 注意力增强后的特征图 [B, C, H, W]
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# 计算QKV
|
||||
q = self.q_conv(x) # [B, reduced_C, H, W]
|
||||
k = self.k_conv(x) # [B, reduced_C, H, W]
|
||||
v = self.v_conv(x) # [B, reduced_C, H, W]
|
||||
|
||||
# 重塑为序列形式
|
||||
q = q.flatten(2).permute(0, 2, 1) # [B, H*W, reduced_C]
|
||||
k = k.flatten(2).permute(0, 2, 1) # [B, H*W, reduced_C]
|
||||
v = v.flatten(2).permute(0, 2, 1) # [B, H*W, reduced_C]
|
||||
|
||||
# 计算注意力分数
|
||||
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.reduced_channels) # [B, H*W, H*W]
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
|
||||
# 加权聚合
|
||||
out = (attn @ v).permute(0, 2, 1).reshape(B, self.reduced_channels, H, W) # [B, reduced_C, H, W]
|
||||
out = self.out_conv(out) # [B, C, H, W]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class EfficientSelfAttention(nn.Module):
|
||||
"""
|
||||
高效自注意力模块,使用空间降采样减少计算量
|
||||
|
||||
参数:
|
||||
in_channels (int): 输入特征的通道数
|
||||
key_channels (int): 键和查询的通道数
|
||||
value_channels (int): 值的通道数
|
||||
out_channels (int): 输出特征的通道数
|
||||
scale (int): 空间降采样比例
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
key_channels: int,
|
||||
value_channels: int,
|
||||
out_channels: int,
|
||||
scale: int = 1
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.scale = scale
|
||||
|
||||
# 查询、键、值投影
|
||||
self.query_conv = nn.Conv2d(in_channels, key_channels, kernel_size=1)
|
||||
self.key_conv = nn.Conv2d(in_channels, key_channels, kernel_size=1)
|
||||
self.value_conv = nn.Conv2d(in_channels, value_channels, kernel_size=1)
|
||||
|
||||
# 输出投影
|
||||
self.out_conv = nn.Conv2d(value_channels, out_channels, kernel_size=1)
|
||||
|
||||
# 初始化
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""初始化权重"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入特征图 [B, C, H, W]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 注意力增强后的特征图 [B, C, H, W]
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# 计算查询
|
||||
query = self.query_conv(x) # [B, key_channels, H, W]
|
||||
|
||||
# 如果启用降采样,对键和值进行降采样
|
||||
if self.scale > 1:
|
||||
x_sampled = F.avg_pool2d(x, kernel_size=self.scale, stride=self.scale)
|
||||
key = self.key_conv(x_sampled) # [B, key_channels, H/scale, W/scale]
|
||||
value = self.value_conv(x_sampled) # [B, value_channels, H/scale, W/scale]
|
||||
else:
|
||||
key = self.key_conv(x) # [B, key_channels, H, W]
|
||||
value = self.value_conv(x) # [B, value_channels, H, W]
|
||||
|
||||
# 重塑为序列形式
|
||||
query = query.flatten(2).permute(0, 2, 1) # [B, H*W, key_channels]
|
||||
key = key.flatten(2) # [B, key_channels, H*W/scale^2]
|
||||
value = value.flatten(2).permute(0, 2, 1) # [B, H*W/scale^2, value_channels]
|
||||
|
||||
# 计算注意力分数
|
||||
sim_map = torch.matmul(query, key) # [B, H*W, H*W/scale^2]
|
||||
sim_map = (sim_map / math.sqrt(self.key_conv.out_channels)).softmax(dim=-1)
|
||||
|
||||
# 加权聚合
|
||||
context = torch.matmul(sim_map, value) # [B, H*W, value_channels]
|
||||
context = context.permute(0, 2, 1).reshape(B, -1, H, W) # [B, value_channels, H, W]
|
||||
|
||||
# 输出投影
|
||||
output = self.out_conv(context) # [B, out_channels, H, W]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class SelfAttentionBlock(nn.Module):
|
||||
"""
|
||||
自注意力块,包含自注意力模块和残差连接
|
||||
|
||||
参数:
|
||||
in_channels (int): 输入特征的通道数
|
||||
attention_type (str): 注意力类型,可选 'standard', 'spatial', 'efficient'
|
||||
**kwargs: 传递给具体注意力模块的参数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
attention_type: str = 'spatial',
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.attention_type = attention_type
|
||||
|
||||
# 根据注意力类型选择相应的模块
|
||||
if attention_type == 'standard':
|
||||
# 标准自注意力需要先将特征图转换为序列
|
||||
self.norm = nn.LayerNorm(in_channels)
|
||||
self.attention = SelfAttention(in_channels, **kwargs)
|
||||
elif attention_type == 'spatial':
|
||||
self.norm = nn.BatchNorm2d(in_channels)
|
||||
self.attention = SpatialSelfAttention(in_channels, **kwargs)
|
||||
elif attention_type == 'efficient':
|
||||
self.norm = nn.BatchNorm2d(in_channels)
|
||||
self.attention = EfficientSelfAttention(
|
||||
in_channels=in_channels,
|
||||
key_channels=kwargs.get('key_channels', in_channels // 8),
|
||||
value_channels=kwargs.get('value_channels', in_channels // 2),
|
||||
out_channels=in_channels,
|
||||
scale=kwargs.get('scale', 1)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的注意力类型: {attention_type}")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入特征
|
||||
|
||||
返回:
|
||||
torch.Tensor: 注意力增强后的特征
|
||||
"""
|
||||
if self.attention_type == 'standard':
|
||||
# 标准自注意力需要特殊处理
|
||||
B, C, H, W = x.shape
|
||||
shortcut = x
|
||||
|
||||
# 将特征图转换为序列
|
||||
x = x.flatten(2).permute(0, 2, 1) # [B, H*W, C]
|
||||
x = self.norm(x)
|
||||
x = self.attention(x) # [B, H*W, C]
|
||||
|
||||
# 将序列转换回特征图
|
||||
x = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
else:
|
||||
# 空间和高效自注意力直接处理特征图
|
||||
shortcut = x
|
||||
x = self.norm(x)
|
||||
x = self.attention(x)
|
||||
|
||||
# 残差连接
|
||||
return x + shortcut
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""
|
||||
位置编码模块
|
||||
|
||||
为序列添加位置信息,使模型能够利用序列的顺序信息。
|
||||
使用正弦和余弦函数的组合来表示位置。
|
||||
|
||||
参数:
|
||||
d_model (int): 模型的维度
|
||||
max_len (int): 最大序列长度
|
||||
dropout (float): dropout率
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# 创建位置编码矩阵
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
|
||||
# 使用正弦和余弦函数
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0) # [1, max_len, d_model]
|
||||
|
||||
# 注册为缓冲区,不作为模型参数
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x: 输入张量 [batch_size, seq_len, d_model]
|
||||
|
||||
返回:
|
||||
添加位置编码后的张量 [batch_size, seq_len, d_model]
|
||||
"""
|
||||
x = x + self.pe[:, :x.size(1), :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
Transformer编码器层
|
||||
|
||||
包含自注意力和前馈神经网络,以及残差连接和层归一化。
|
||||
|
||||
参数:
|
||||
d_model (int): 模型的维度
|
||||
nhead (int): 多头注意力中的头数
|
||||
dim_feedforward (int): 前馈网络的隐藏层维度
|
||||
dropout (float): dropout率
|
||||
activation (str): 激活函数,'relu'或'gelu'
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
activation: str = "relu"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 多头自注意力
|
||||
self.self_attn = SelfAttention(d_model, num_heads=nhead, attn_drop=dropout, proj_drop=dropout)
|
||||
|
||||
# 前馈神经网络
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
# 层归一化
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
# dropout
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
# 激活函数
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, src: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
src: 输入序列 [batch_size, seq_len, d_model]
|
||||
|
||||
返回:
|
||||
输出序列 [batch_size, seq_len, d_model]
|
||||
"""
|
||||
# 自注意力子层
|
||||
src2 = self.self_attn(self.norm1(src))
|
||||
src = src + self.dropout1(src2) # 残差连接
|
||||
|
||||
# 前馈网络子层
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(src)))))
|
||||
src = src + self.dropout2(src2) # 残差连接
|
||||
|
||||
return src
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
"""
|
||||
Transformer编码器
|
||||
|
||||
由多个编码器层堆叠而成。
|
||||
|
||||
参数:
|
||||
d_model (int): 模型的维度
|
||||
nhead (int): 多头注意力中的头数
|
||||
num_layers (int): 编码器层的数量
|
||||
dim_feedforward (int): 前馈网络的隐藏层维度
|
||||
dropout (float): dropout率
|
||||
activation (str): 激活函数,'relu'或'gelu'
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
num_layers: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
activation: str = "relu"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 创建编码器层
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
activation=activation
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, src: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
src: 输入序列 [batch_size, seq_len, d_model]
|
||||
|
||||
返回:
|
||||
输出序列 [batch_size, seq_len, d_model]
|
||||
"""
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output)
|
||||
|
||||
return self.norm(output)
|
||||
333
app/nn/modules/attention_doc.md
Normal file
333
app/nn/modules/attention_doc.md
Normal file
@@ -0,0 +1,333 @@
|
||||
# 自注意力(Self-Attention)机制技术文档
|
||||
|
||||
## 目录
|
||||
1. [概述](#1-概述)
|
||||
2. [技术架构](#2-技术架构)
|
||||
- [2.1 核心组件](#21-核心组件)
|
||||
- [2.2 数据流程](#22-数据流程)
|
||||
- [2.3 系统集成](#23-系统集成)
|
||||
3. [技术原理](#3-技术原理)
|
||||
- [3.1 自注意力的基本原理](#31-自注意力的基本原理)
|
||||
- [3.2 多头注意力](#32-多头注意力)
|
||||
- [3.3 位置编码](#33-位置编码)
|
||||
- [3.4 Transformer编码器](#34-transformer编码器)
|
||||
4. [模块组件](#4-模块组件)
|
||||
- [4.1 SelfAttention](#41-selfattention)
|
||||
- [4.2 SpatialSelfAttention](#42-spatialselfattention)
|
||||
- [4.3 EfficientSelfAttention](#43-efficientselfattention)
|
||||
- [4.4 SelfAttentionBlock](#44-selfattentionblock)
|
||||
- [4.5 PositionalEncoding](#45-positionalencoding)
|
||||
- [4.6 TransformerEncoderLayer](#46-transformerencoderlayer)
|
||||
- [4.7 TransformerEncoder](#47-transformerencoder)
|
||||
5. [增强技术](#5-增强技术)
|
||||
- [5.1 Layer Scale](#51-layer-scale)
|
||||
- [5.2 CBAM注意力](#52-cbam注意力)
|
||||
- [5.3 交叉注意力](#53-交叉注意力)
|
||||
6. [目标追踪应用](#6-目标追踪应用)
|
||||
- [6.1 AttentionTracker](#61-attentiontracker)
|
||||
- [6.2 EnhancedAttentionTracker](#62-enhancedattentiontracker)
|
||||
7. [使用示例](#7-使用示例)
|
||||
8. [应用场景](#8-应用场景)
|
||||
9. [参考文献](#9-参考文献)
|
||||
|
||||
## 1. 概述
|
||||
|
||||
自注意力机制是Transformer架构的核心组件,它允许模型在处理序列数据时考虑序列中所有位置之间的关系,从而捕获长距离依赖和全局上下文信息。与传统的循环神经网络(RNN)和卷积神经网络(CNN)相比,自注意力机制具有更强的并行计算能力和更好的长距离依赖建模能力。
|
||||
|
||||
本文档详细介绍了`app/nn/modules/attention.py`中实现的自注意力机制及其相关组件,以及在目标追踪中的应用。
|
||||
|
||||
## 2. 技术架构
|
||||
|
||||
### 2.1 核心组件
|
||||
|
||||
系统的核心组件包括:
|
||||
|
||||
1. **基础注意力模块**:
|
||||
- `SelfAttention`: 标准多头自注意力
|
||||
- `SpatialSelfAttention`: 空间自注意力
|
||||
- `EfficientSelfAttention`: 高效自注意力
|
||||
|
||||
2. **Transformer组件**:
|
||||
- `PositionalEncoding`: 位置编码
|
||||
- `TransformerEncoderLayer`: Transformer编码器层
|
||||
- `TransformerEncoder`: 完整Transformer编码器
|
||||
|
||||
3. **增强模块**:
|
||||
- `LayerScale`: 增强深层网络训练稳定性
|
||||
- `CBAM`: 通道+空间注意力融合
|
||||
- `CrossAttention`: 交叉注意力机制
|
||||
|
||||
4. **追踪器**:
|
||||
- `AttentionTracker`: 基础自注意力追踪器
|
||||
- `EnhancedAttentionTracker`: 增强型自注意力追踪器
|
||||
|
||||
### 2.2 数据流程
|
||||
|
||||
自注意力系统的数据流程如下:
|
||||
|
||||
1. **输入处理**:
|
||||
- 序列数据输入 → 特征嵌入 → 位置编码
|
||||
- 图像数据输入 → 特征提取 → 特征映射
|
||||
|
||||
2. **注意力计算**:
|
||||
- 特征投影 → QKV计算 → 注意力分数 → 加权聚合
|
||||
|
||||
3. **多层处理**:
|
||||
- 自注意力层 → 残差连接 → 层归一化 → 前馈网络
|
||||
|
||||
4. **输出生成**:
|
||||
- 特征融合 → 任务特定头部 → 最终输出
|
||||
|
||||
### 2.3 系统集成
|
||||
|
||||
自注意力机制与其他系统的集成:
|
||||
|
||||
1. **与目标检测系统集成**:
|
||||
- 接收YOLO检测结果
|
||||
- 提取目标特征
|
||||
- 计算目标关联性
|
||||
|
||||
2. **与视频处理系统集成**:
|
||||
- 处理连续视频帧
|
||||
- 维护目标轨迹
|
||||
- 生成追踪结果
|
||||
|
||||
## 3. 技术原理
|
||||
|
||||
### 3.1 自注意力的基本原理
|
||||
|
||||
自注意力机制的核心思想是"注意力即权重"。对于输入序列中的每个元素,自注意力机制计算该元素与序列中所有元素(包括自身)的关系强度,然后根据这些关系强度对值向量进行加权求和,得到注意力输出。
|
||||
|
||||
基本公式:
|
||||
```
|
||||
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
|
||||
```
|
||||
|
||||
其中:
|
||||
- Q (Query): 查询矩阵,表示当前位置的特征
|
||||
- K (Key): 键矩阵,用于与查询计算相似度
|
||||
- V (Value): 值矩阵,根据注意力权重进行加权求和
|
||||
- d_k: 键向量的维度,用于缩放点积,防止梯度消失
|
||||
|
||||
### 3.2 多头注意力
|
||||
|
||||
多头注意力将输入分割为多个头,每个头独立计算自注意力,然后合并结果。这允许模型同时关注不同表示子空间的信息,增强了模型的表示能力。
|
||||
|
||||
公式:
|
||||
```
|
||||
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W^O
|
||||
where head_i = Attention(Q*W_i^Q, K*W_i^K, V*W_i^V)
|
||||
```
|
||||
|
||||
### 3.3 位置编码
|
||||
|
||||
由于自注意力机制本身不包含位置信息,需要额外添加位置编码来表示序列中元素的位置。常用的位置编码方法是使用正弦和余弦函数的组合:
|
||||
|
||||
```
|
||||
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
|
||||
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
|
||||
```
|
||||
|
||||
其中,pos是位置,i是维度索引,d_model是模型维度。
|
||||
|
||||
### 3.4 Transformer编码器
|
||||
|
||||
Transformer编码器由多个编码器层堆叠而成,每个编码器层包含自注意力子层和前馈神经网络子层,以及残差连接和层归一化。
|
||||
|
||||
## 4. 模块组件
|
||||
|
||||
### 4.1 SelfAttention
|
||||
|
||||
标准的自注意力模块,实现了多头自注意力机制。
|
||||
|
||||
**参数**:
|
||||
- `dim (int)`: 输入特征的通道维度
|
||||
- `num_heads (int)`: 注意力头的数量
|
||||
- `qkv_bias (bool)`: 是否在QKV投影中使用偏置
|
||||
- `attn_drop (float)`: 注意力dropout率
|
||||
- `proj_drop (float)`: 输出投影dropout率
|
||||
|
||||
**输入**:
|
||||
- 形状为 `[B, N, C]` 的张量,其中B是批次大小,N是序列长度,C是通道数
|
||||
|
||||
**输出**:
|
||||
- 形状为 `[B, N, C]` 的注意力增强后的特征
|
||||
|
||||
### 4.2 SpatialSelfAttention
|
||||
|
||||
专为图像特征设计的空间自注意力模块,直接处理2D特征图。
|
||||
|
||||
**参数**:
|
||||
- `in_channels (int)`: 输入特征的通道数
|
||||
- `out_channels (int)`: 输出特征的通道数,默认与输入相同
|
||||
- `num_heads (int)`: 注意力头的数量
|
||||
- `reduction_ratio (int)`: 用于减少计算量的通道减少比例
|
||||
|
||||
**输入**:
|
||||
- 形状为 `[B, C, H, W]` 的特征图,其中H和W是空间维度
|
||||
|
||||
**输出**:
|
||||
- 形状为 `[B, C, H, W]` 的注意力增强后的特征图
|
||||
|
||||
### 4.3 EfficientSelfAttention
|
||||
|
||||
高效自注意力模块,使用空间降采样减少计算量,适用于高分辨率特征图。
|
||||
|
||||
**参数**:
|
||||
- `in_channels (int)`: 输入特征的通道数
|
||||
- `key_channels (int)`: 键和查询的通道数
|
||||
- `value_channels (int)`: 值的通道数
|
||||
- `out_channels (int)`: 输出特征的通道数
|
||||
- `scale (int)`: 空间降采样比例
|
||||
|
||||
### 4.4 SelfAttentionBlock
|
||||
|
||||
自注意力块,包含自注意力模块和残差连接,支持多种注意力类型。
|
||||
|
||||
**参数**:
|
||||
- `in_channels (int)`: 输入特征的通道数
|
||||
- `attention_type (str)`: 注意力类型,可选 'standard', 'spatial', 'efficient'
|
||||
|
||||
### 4.5 PositionalEncoding
|
||||
|
||||
位置编码模块,为序列添加位置信息。
|
||||
|
||||
**参数**:
|
||||
- `d_model (int)`: 模型的维度
|
||||
- `max_len (int)`: 最大序列长度
|
||||
- `dropout (float)`: dropout率
|
||||
|
||||
### 4.6 TransformerEncoderLayer
|
||||
|
||||
Transformer编码器层,包含自注意力和前馈神经网络。
|
||||
|
||||
**参数**:
|
||||
- `d_model (int)`: 模型的维度
|
||||
- `nhead (int)`: 多头注意力中的头数
|
||||
- `dim_feedforward (int)`: 前馈网络的隐藏层维度
|
||||
- `dropout (float)`: dropout率
|
||||
- `activation (str)`: 激活函数,'relu'或'gelu'
|
||||
|
||||
### 4.7 TransformerEncoder
|
||||
|
||||
Transformer编码器,由多个编码器层堆叠而成。
|
||||
|
||||
**参数**:
|
||||
- `d_model (int)`: 模型的维度
|
||||
- `nhead (int)`: 多头注意力中的头数
|
||||
- `num_layers (int)`: 编码器层的数量
|
||||
- `dim_feedforward (int)`: 前馈网络的隐藏层维度
|
||||
- `dropout (float)`: dropout率
|
||||
- `activation (str)`: 激活函数,'relu'或'gelu'
|
||||
|
||||
## 5. 增强技术
|
||||
|
||||
### 5.1 Layer Scale
|
||||
|
||||
Layer Scale技术通过为每个残差分支添加可学习的缩放参数,增强深层网络的训练稳定性。
|
||||
|
||||
**特点**:
|
||||
- 改善深层网络的梯度流动
|
||||
- 防止浅层特征主导网络输出
|
||||
- 提高模型收敛速度和性能
|
||||
|
||||
### 5.2 CBAM注意力
|
||||
|
||||
CBAM (Convolutional Block Attention Module) 结合通道注意力和空间注意力,增强特征表示能力。
|
||||
|
||||
**组件**:
|
||||
- 通道注意力: 学习通道间的重要性
|
||||
- 空间注意力: 学习空间位置的重要性
|
||||
- 串联结构: 先应用通道注意力,再应用空间注意力
|
||||
|
||||
### 5.3 交叉注意力
|
||||
|
||||
交叉注意力机制用于建模两组不同序列之间的关系,特别适用于目标追踪中的目标匹配。
|
||||
|
||||
**特点**:
|
||||
- 查询和键来自不同序列
|
||||
- 能够捕获序列间的依赖关系
|
||||
- 支持不同长度的序列输入
|
||||
|
||||
## 6. 目标追踪应用
|
||||
|
||||
### 6.1 AttentionTracker
|
||||
|
||||
基础自注意力追踪器,使用自注意力机制计算目标之间的关联性。
|
||||
|
||||
**功能**:
|
||||
- 特征提取: 从目标区域提取特征向量
|
||||
- 自注意力匹配: 计算当前帧目标与历史帧目标的关联性
|
||||
- 轨迹管理: 更新匹配目标的轨迹,创建新轨迹,终止旧轨迹
|
||||
|
||||
### 6.2 EnhancedAttentionTracker
|
||||
|
||||
增强型自注意力追踪器,集成了Layer Scale、CBAM和交叉注意力等先进技术。
|
||||
|
||||
**增强功能**:
|
||||
- 更稳定的特征提取: 使用CBAM增强的特征提取器
|
||||
- 更精确的目标匹配: 使用交叉注意力计算目标关联性
|
||||
- 更稳定的训练: 使用Layer Scale增强网络稳定性
|
||||
|
||||
## 7. 使用示例
|
||||
|
||||
### 7.1 基本自注意力
|
||||
|
||||
```python
|
||||
import torch
|
||||
from app.nn.modules.attention import SelfAttention
|
||||
|
||||
# 创建自注意力模块
|
||||
attn = SelfAttention(dim=512, num_heads=8)
|
||||
|
||||
# 输入特征
|
||||
x = torch.randn(32, 100, 512) # [batch_size, seq_len, dim]
|
||||
|
||||
# 前向传播
|
||||
output = attn(x) # [32, 100, 512]
|
||||
```
|
||||
|
||||
### 7.2 目标追踪
|
||||
|
||||
```python
|
||||
from app.nn.modules.attention_tracker import AttentionTracker
|
||||
|
||||
# 创建追踪器
|
||||
tracker = AttentionTracker(
|
||||
max_age=30,
|
||||
min_hits=3,
|
||||
iou_threshold=0.3,
|
||||
feature_similarity_weight=0.7,
|
||||
motion_weight=0.3
|
||||
)
|
||||
|
||||
# 更新追踪器
|
||||
tracks = tracker.update(frame, detections)
|
||||
|
||||
# 获取追踪结果
|
||||
for track in tracks:
|
||||
track_id = track['id']
|
||||
bbox = track['boxes'][-1]
|
||||
class_id = track['class_id']
|
||||
# 绘制边界框和ID
|
||||
```
|
||||
|
||||
## 8. 应用场景
|
||||
|
||||
自注意力机制可以应用于多种深度学习任务:
|
||||
|
||||
1. **序列建模**:文本处理、时间序列分析
|
||||
2. **计算机视觉**:图像分类、目标检测、语义分割
|
||||
3. **多模态任务**:图像描述、视觉问答
|
||||
4. **目标追踪**:多目标追踪、单目标追踪
|
||||
5. **视频分析**:行为识别、异常检测
|
||||
|
||||
## 9. 参考文献
|
||||
|
||||
1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. In Advances in neural information processing systems.
|
||||
|
||||
2. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
|
||||
|
||||
3. Woo, S., Park, J., Lee, J. Y., & Kweon, I. S. (2018). CBAM: Convolutional block attention module. In Proceedings of the European conference on computer vision (ECCV).
|
||||
|
||||
4. Touvron, H., Bojanowski, P., Caron, M., Cord, M., El-Nouby, A., Grave, E., ... & Jégou, H. (2021). ResMLp: Feedforward networks for image classification with data-efficient training. arXiv preprint arXiv:2105.03404.
|
||||
219
app/nn/modules/attention_doc_detailed.md
Normal file
219
app/nn/modules/attention_doc_detailed.md
Normal file
@@ -0,0 +1,219 @@
|
||||
# 自注意力(Self-Attention)机制详细说明文档
|
||||
|
||||
## 1. 概述
|
||||
|
||||
自注意力机制是Transformer架构的核心组件,它允许模型在处理序列数据时考虑序列中所有位置之间的关系,从而捕获长距离依赖和全局上下文信息。本文档详细介绍了自注意力机制的工作原理、系统逻辑过程和算法执行流程。
|
||||
|
||||
## 2. 自注意力机制的基本原理
|
||||
|
||||
### 2.1 核心思想
|
||||
|
||||
自注意力机制的核心思想是"注意力即权重"。对于输入序列中的每个元素,自注意力机制计算该元素与序列中所有元素(包括自身)的关系强度,然后根据这些关系强度对值向量进行加权求和,得到注意力输出。
|
||||
|
||||
### 2.2 数学表达式
|
||||
|
||||
自注意力机制的数学表达式如下:
|
||||
|
||||
```
|
||||
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
|
||||
```
|
||||
|
||||
其中:
|
||||
- Q (Query): 查询矩阵,表示当前位置的特征
|
||||
- K (Key): 键矩阵,用于与查询计算相似度
|
||||
- V (Value): 值矩阵,根据注意力权重进行加权求和
|
||||
- d_k: 键向量的维度,用于缩放点积,防止梯度消失
|
||||
|
||||
## 3. 系统逻辑过程
|
||||
|
||||
自注意力机制的系统逻辑过程可以分为以下几个步骤:
|
||||
|
||||
### 3.1 输入处理
|
||||
|
||||
1. **输入嵌入**:将输入序列转换为嵌入向量
|
||||
2. **位置编码**:添加位置信息,使模型能够区分不同位置的元素
|
||||
|
||||
### 3.2 注意力计算
|
||||
|
||||
1. **线性投影**:将输入向量投影到查询(Q)、键(K)和值(V)空间
|
||||
2. **注意力分数计算**:计算查询和键之间的点积,得到注意力分数
|
||||
3. **缩放**:将注意力分数除以缩放因子(sqrt(d_k)),防止梯度消失
|
||||
4. **掩码应用**(可选):应用掩码,屏蔽某些位置的注意力
|
||||
5. **Softmax归一化**:对注意力分数应用Softmax函数,得到注意力权重
|
||||
6. **加权求和**:使用注意力权重对值向量进行加权求和,得到注意力输出
|
||||
|
||||
### 3.3 多头注意力
|
||||
|
||||
1. **头部分割**:将查询、键、值分割为多个头
|
||||
2. **并行计算**:每个头独立计算自注意力
|
||||
3. **结果合并**:将多个头的结果合并,并通过线性投影得到最终输出
|
||||
|
||||
### 3.4 残差连接和层归一化
|
||||
|
||||
1. **残差连接**:将注意力输出与输入相加,形成残差连接
|
||||
2. **层归一化**:对结果应用层归一化,提高训练稳定性
|
||||
|
||||
## 4. 算法执行流程
|
||||
|
||||
下面是自注意力机制的详细算法执行流程,以伪代码形式呈现:
|
||||
|
||||
```
|
||||
输入: 序列X,维度为[batch_size, seq_len, d_model]
|
||||
|
||||
# 步骤1: 线性投影
|
||||
Q = X * W_Q # W_Q是查询投影矩阵
|
||||
K = X * W_K # W_K是键投影矩阵
|
||||
V = X * W_V # W_V是值投影矩阵
|
||||
|
||||
# 步骤2: 分割为多头
|
||||
Q_heads = split(Q, num_heads) # 形状变为[batch_size, num_heads, seq_len, d_k]
|
||||
K_heads = split(K, num_heads)
|
||||
V_heads = split(V, num_heads)
|
||||
|
||||
# 步骤3: 对每个头计算注意力
|
||||
outputs = []
|
||||
for h in range(num_heads):
|
||||
# 计算注意力分数
|
||||
scores = matmul(Q_heads[h], transpose(K_heads[h])) # [batch_size, seq_len, seq_len]
|
||||
|
||||
# 缩放
|
||||
scaled_scores = scores / sqrt(d_k)
|
||||
|
||||
# 应用掩码(可选)
|
||||
if mask is not None:
|
||||
scaled_scores = apply_mask(scaled_scores, mask)
|
||||
|
||||
# Softmax归一化
|
||||
weights = softmax(scaled_scores, dim=-1) # [batch_size, seq_len, seq_len]
|
||||
|
||||
# 应用dropout(可选)
|
||||
weights = dropout(weights, p=dropout_rate)
|
||||
|
||||
# 加权求和
|
||||
head_output = matmul(weights, V_heads[h]) # [batch_size, seq_len, d_v]
|
||||
outputs.append(head_output)
|
||||
|
||||
# 步骤4: 合并多头结果
|
||||
multi_head_output = concat(outputs, dim=-1) # [batch_size, seq_len, d_model]
|
||||
|
||||
# 步骤5: 最终线性投影
|
||||
output = multi_head_output * W_O # W_O是输出投影矩阵
|
||||
|
||||
# 步骤6: 残差连接和层归一化
|
||||
output = layer_norm(output + X)
|
||||
|
||||
返回: output
|
||||
```
|
||||
|
||||
## 5. 自注意力在目标追踪中的应用
|
||||
|
||||
在目标追踪任务中,自注意力机制可以用于计算不同帧中目标之间的关联性,从而实现目标的跨帧匹配。具体流程如下:
|
||||
|
||||
### 5.1 特征提取
|
||||
|
||||
1. 使用卷积神经网络从每个检测到的目标区域提取特征向量
|
||||
2. 这些特征向量包含目标的外观信息,如颜色、纹理、形状等
|
||||
|
||||
### 5.2 自注意力匹配
|
||||
|
||||
1. 将当前帧的目标特征作为查询(Q)
|
||||
2. 将历史帧的目标特征作为键(K)和值(V)
|
||||
3. 使用自注意力机制计算当前帧目标与历史帧目标之间的关联性
|
||||
4. 根据关联性分数进行目标匹配
|
||||
|
||||
### 5.3 轨迹管理
|
||||
|
||||
1. 对于匹配成功的目标,更新其轨迹信息
|
||||
2. 对于未匹配的检测,创建新的轨迹
|
||||
3. 对于未匹配的轨迹,根据预定义的规则决定是否终止
|
||||
|
||||
### 5.4 执行流程图
|
||||
|
||||
```
|
||||
输入视频帧
|
||||
↓
|
||||
目标检测 (使用YOLOv8等检测器)
|
||||
↓
|
||||
特征提取 (从检测到的目标区域提取特征)
|
||||
↓
|
||||
自注意力匹配 (计算当前帧目标与历史帧目标的关联性)
|
||||
↓
|
||||
轨迹更新 (更新匹配目标的轨迹,创建新轨迹,终止旧轨迹)
|
||||
↓
|
||||
结果可视化 (绘制边界框、ID、轨迹等)
|
||||
```
|
||||
|
||||
## 6. 自注意力追踪器的实现细节
|
||||
|
||||
### 6.1 特征提取器
|
||||
|
||||
特征提取器使用卷积神经网络从目标区域提取特征向量。网络结构包括多个卷积层、批归一化层和池化层,最终输出固定维度的特征向量。
|
||||
|
||||
```python
|
||||
class FeatureExtractor(nn.Module):
|
||||
def __init__(self, input_dim=1024, feature_dim=256):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(128)
|
||||
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
||||
self.bn3 = nn.BatchNorm2d(256)
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Linear(256, feature_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
x = self.gap(x).squeeze(-1).squeeze(-1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
```
|
||||
|
||||
### 6.2 自注意力匹配
|
||||
|
||||
自注意力匹配模块使用自注意力机制计算目标之间的关联性。具体实现如下:
|
||||
|
||||
```python
|
||||
def compute_attention_similarity(track_features, detection_features):
|
||||
# 合并特征
|
||||
combined_features = torch.cat([track_features, detection_features], dim=0)
|
||||
|
||||
# 应用自注意力
|
||||
attended_features, attention_weights = self.attention(
|
||||
combined_features,
|
||||
combined_features,
|
||||
combined_features
|
||||
)
|
||||
|
||||
# 提取相似度矩阵
|
||||
num_tracks = track_features.size(0)
|
||||
similarity = attention_weights[0, :num_tracks, num_tracks:]
|
||||
|
||||
return similarity
|
||||
```
|
||||
|
||||
### 6.3 综合相似度计算
|
||||
|
||||
在实际应用中,我们通常结合外观特征相似度和运动预测相似度来计算综合相似度:
|
||||
|
||||
```python
|
||||
# 计算特征相似度
|
||||
feature_similarity = compute_attention_similarity(track_features, detection_features)
|
||||
|
||||
# 计算IOU相似度
|
||||
iou_matrix = compute_iou_matrix(predicted_boxes, detection_boxes)
|
||||
|
||||
# 计算综合相似度
|
||||
similarity_matrix = (
|
||||
feature_similarity_weight * feature_similarity +
|
||||
motion_weight * iou_matrix
|
||||
)
|
||||
```
|
||||
|
||||
## 7. 总结
|
||||
|
||||
自注意力机制通过计算序列中所有元素之间的关系,能够有效捕获长距离依赖和全局上下文信息。在目标追踪任务中,自注意力机制可以用于计算不同帧中目标之间的关联性,从而实现准确的目标匹配和轨迹管理。
|
||||
|
||||
通过结合外观特征和运动预测,自注意力追踪器能够处理复杂场景中的目标遮挡、出现和消失等情况,提供稳定可靠的追踪结果。
|
||||
576
app/nn/modules/attention_tracker.py
Normal file
576
app/nn/modules/attention_tracker.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""
|
||||
基于自注意力机制的目标追踪模块
|
||||
|
||||
该模块实现了一个基于自注意力机制的目标追踪器,可以在视频序列中追踪多个目标。
|
||||
追踪器结合了外观特征和运动特征,使用自注意力机制计算目标在不同帧之间的关联性。
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import cv2
|
||||
from typing import List, Dict, Tuple, Optional, Union, Any
|
||||
from pathlib import Path
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from collections import deque
|
||||
|
||||
from app.nn.modules.attention import SelfAttention, TransformerEncoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FeatureExtractor(nn.Module):
|
||||
"""
|
||||
特征提取器,用于从目标图像中提取特征
|
||||
|
||||
参数:
|
||||
input_dim (int): 输入特征的维度
|
||||
feature_dim (int): 输出特征的维度
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim: int = 1024, feature_dim: int = 256):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(128)
|
||||
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
||||
self.bn3 = nn.BatchNorm2d(256)
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Linear(256, feature_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入图像 [B, 3, H, W]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 提取的特征 [B, feature_dim]
|
||||
"""
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
x = self.gap(x).squeeze(-1).squeeze(-1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionTracker:
|
||||
"""
|
||||
基于自注意力机制的目标追踪器
|
||||
|
||||
参数:
|
||||
max_age (int): 目标消失后保持跟踪的最大帧数
|
||||
min_hits (int): 确认目标存在所需的最小检测次数
|
||||
iou_threshold (float): IOU匹配阈值
|
||||
feature_similarity_weight (float): 特征相似度权重
|
||||
motion_weight (float): 运动预测权重
|
||||
device (str): 使用的设备 ('cpu' 或 'cuda')
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_age: int = 30,
|
||||
min_hits: int = 3,
|
||||
iou_threshold: float = 0.3,
|
||||
feature_similarity_weight: float = 0.7,
|
||||
motion_weight: float = 0.3,
|
||||
device: str = 'cpu'
|
||||
):
|
||||
self.max_age = max_age
|
||||
self.min_hits = min_hits
|
||||
self.iou_threshold = iou_threshold
|
||||
self.feature_similarity_weight = feature_similarity_weight
|
||||
self.motion_weight = motion_weight
|
||||
self.device = device
|
||||
|
||||
# 初始化特征提取器
|
||||
self.feature_extractor = FeatureExtractor().to(device)
|
||||
self.feature_extractor.eval()
|
||||
|
||||
# 初始化自注意力模块
|
||||
self.attention = SelfAttention(dim=256, num_heads=8).to(device)
|
||||
self.attention.eval()
|
||||
|
||||
# 初始化轨迹列表
|
||||
self.tracks = []
|
||||
self.next_id = 1
|
||||
|
||||
# 初始化帧计数器
|
||||
self.frame_count = 0
|
||||
|
||||
# 是否处于单目标模式
|
||||
self.single_target_mode = False
|
||||
self.target_id = None
|
||||
|
||||
def reset(self):
|
||||
"""重置追踪器状态"""
|
||||
self.tracks = []
|
||||
self.next_id = 1
|
||||
self.frame_count = 0
|
||||
self.single_target_mode = False
|
||||
self.target_id = None
|
||||
|
||||
def set_single_target_mode(self, enable: bool, target_id: Optional[int] = None, target_class_id: Optional[int] = None):
|
||||
"""
|
||||
设置单目标追踪模式
|
||||
|
||||
参数:
|
||||
enable (bool): 是否启用单目标模式
|
||||
target_id (int, optional): 要追踪的目标ID
|
||||
target_class_id (int, optional): 要追踪的目标类别ID
|
||||
"""
|
||||
logger.info(f"设置单目标追踪模式: enable={enable}, target_id={target_id}, target_class_id={target_class_id}")
|
||||
|
||||
# 更新单目标模式状态
|
||||
self.single_target_mode = enable
|
||||
|
||||
# 如果禁用单目标模式,清除目标ID
|
||||
if not enable:
|
||||
self.target_id = None
|
||||
logger.info("已禁用单目标追踪模式")
|
||||
return
|
||||
|
||||
# 启用单目标模式
|
||||
if target_id is not None:
|
||||
# 如果提供了目标ID,直接使用
|
||||
self.target_id = target_id
|
||||
logger.info(f"已启用单目标追踪模式,追踪目标ID: {target_id}")
|
||||
elif target_class_id is not None:
|
||||
# 如果提供了类别ID,保存类别ID,在update方法中处理
|
||||
self.target_class_id = target_class_id
|
||||
logger.info(f"已启用单目标追踪模式,追踪类别ID: {target_class_id}")
|
||||
|
||||
# 如果已有轨迹,查找该类别的第一个目标
|
||||
for track in self.tracks:
|
||||
if track['class_id'] == target_class_id:
|
||||
self.target_id = track['id']
|
||||
logger.info(f"找到类别ID为 {target_class_id} 的目标,目标ID: {self.target_id}")
|
||||
break
|
||||
|
||||
# 如果没有找到该类别的目标,记录信息
|
||||
if self.target_id is None:
|
||||
logger.info(f"未找到类别ID为 {target_class_id} 的目标,将在下一帧中查找")
|
||||
else:
|
||||
logger.warning("启用单目标追踪模式但未提供目标ID或类别ID,单目标追踪模式可能无效")
|
||||
|
||||
def extract_features(self, image: np.ndarray, boxes: List[List[int]]) -> torch.Tensor:
|
||||
"""
|
||||
从图像中提取目标特征
|
||||
|
||||
参数:
|
||||
image (np.ndarray): 输入图像
|
||||
boxes (List[List[int]]): 目标边界框列表 [[x1, y1, x2, y2], ...]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 提取的特征 [num_boxes, feature_dim]
|
||||
"""
|
||||
if not boxes:
|
||||
return torch.zeros((0, 256), device=self.device)
|
||||
|
||||
# 裁剪目标区域
|
||||
crops = []
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box
|
||||
crop = image[y1:y2, x1:x2]
|
||||
if crop.size == 0:
|
||||
# 如果裁剪区域为空,使用整个图像
|
||||
crop = image
|
||||
# 调整大小为固定尺寸
|
||||
crop = cv2.resize(crop, (64, 64))
|
||||
crops.append(crop)
|
||||
|
||||
# 转换为张量
|
||||
crops_tensor = torch.stack([
|
||||
torch.from_numpy(crop.transpose(2, 0, 1)).float() / 255.0
|
||||
for crop in crops
|
||||
]).to(self.device)
|
||||
|
||||
# 提取特征
|
||||
with torch.no_grad():
|
||||
features = self.feature_extractor(crops_tensor)
|
||||
|
||||
return features
|
||||
|
||||
def compute_attention_similarity(self, track_features: torch.Tensor, detection_features: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
使用自注意力机制计算轨迹和检测之间的相似度
|
||||
|
||||
参数:
|
||||
track_features (torch.Tensor): 轨迹特征 [num_tracks, feature_dim]
|
||||
detection_features (torch.Tensor): 检测特征 [num_detections, feature_dim]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 相似度矩阵 [num_tracks, num_detections]
|
||||
"""
|
||||
if track_features.size(0) == 0 or detection_features.size(0) == 0:
|
||||
return torch.zeros((track_features.size(0), detection_features.size(0)), device=self.device)
|
||||
|
||||
# 合并特征
|
||||
combined_features = torch.cat([track_features, detection_features], dim=0)
|
||||
|
||||
# 添加位置编码
|
||||
batch_size, seq_len = 1, combined_features.size(0)
|
||||
combined_features = combined_features.unsqueeze(0) # [1, N, feature_dim]
|
||||
|
||||
# 应用自注意力
|
||||
with torch.no_grad():
|
||||
attended_features, attention_weights = self.attention(
|
||||
combined_features,
|
||||
combined_features,
|
||||
combined_features
|
||||
)
|
||||
|
||||
# 提取相似度矩阵
|
||||
num_tracks = track_features.size(0)
|
||||
similarity = attention_weights[0, :num_tracks, num_tracks:]
|
||||
|
||||
return similarity
|
||||
|
||||
def compute_iou(self, box1: List[int], box2: List[int]) -> float:
|
||||
"""
|
||||
计算两个边界框的IOU
|
||||
|
||||
参数:
|
||||
box1 (List[int]): 第一个边界框 [x1, y1, x2, y2]
|
||||
box2 (List[int]): 第二个边界框 [x1, y1, x2, y2]
|
||||
|
||||
返回:
|
||||
float: IOU值
|
||||
"""
|
||||
try:
|
||||
# 确保边界框至少有4个元素
|
||||
if len(box1) < 4 or len(box2) < 4:
|
||||
logger.warning(f"边界框维度不足4: box1={len(box1)}, box2={len(box2)}")
|
||||
# 填充边界框
|
||||
if len(box1) < 4:
|
||||
box1 = box1 + [0] * (4 - len(box1))
|
||||
if len(box2) < 4:
|
||||
box2 = box2 + [0] * (4 - len(box2))
|
||||
|
||||
# 确保边界框坐标是浮点数
|
||||
x1_1, y1_1, x2_1, y2_1 = map(float, box1[:4])
|
||||
x1_2, y1_2, x2_2, y2_2 = map(float, box2[:4])
|
||||
|
||||
# 计算交集区域
|
||||
x1_i = max(x1_1, x1_2)
|
||||
y1_i = max(y1_1, y1_2)
|
||||
x2_i = min(x2_1, x2_2)
|
||||
y2_i = min(y2_1, y2_2)
|
||||
|
||||
if x2_i < x1_i or y2_i < y1_i:
|
||||
return 0.0
|
||||
|
||||
intersection = (x2_i - x1_i) * (y2_i - y1_i)
|
||||
|
||||
# 计算各自面积
|
||||
area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
|
||||
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
|
||||
|
||||
# 计算IOU
|
||||
iou = intersection / (area1 + area2 - intersection + 1e-6)
|
||||
|
||||
return iou
|
||||
except Exception as e:
|
||||
logger.warning(f"计算IOU时出错: {str(e)}")
|
||||
return 0.0
|
||||
|
||||
def predict(self):
|
||||
"""预测所有轨迹的下一个位置"""
|
||||
for track in self.tracks:
|
||||
# 获取最后一个边界框
|
||||
box = track['boxes'][-1]
|
||||
|
||||
# 如果有足够的历史记录,使用卡尔曼滤波或简单线性预测
|
||||
if len(track['boxes']) >= 2:
|
||||
prev_box = track['boxes'][-2]
|
||||
|
||||
# 计算速度
|
||||
vx = box[0] - prev_box[0]
|
||||
vy = box[1] - prev_box[1]
|
||||
vw = (box[2] - box[0]) - (prev_box[2] - prev_box[0])
|
||||
vh = (box[3] - box[1]) - (prev_box[3] - prev_box[1])
|
||||
|
||||
# 预测下一个位置
|
||||
x1 = box[0] + vx
|
||||
y1 = box[1] + vy
|
||||
x2 = box[2] + vx + vw
|
||||
y2 = box[3] + vy + vh
|
||||
|
||||
track['predicted_box'] = [x1, y1, x2, y2]
|
||||
else:
|
||||
# 如果没有足够的历史记录,使用当前位置作为预测
|
||||
track['predicted_box'] = box
|
||||
|
||||
def update(self, image: np.ndarray, detections: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
更新追踪器状态
|
||||
|
||||
参数:
|
||||
image (np.ndarray): 当前帧图像
|
||||
detections (List[Dict]): 检测结果列表,每个检测包含 'bbox', 'class_id', 'confidence' 等字段
|
||||
|
||||
返回:
|
||||
List[Dict]: 更新后的轨迹列表
|
||||
"""
|
||||
self.frame_count += 1
|
||||
|
||||
# 如果没有轨迹,初始化轨迹
|
||||
if len(self.tracks) == 0:
|
||||
for i, det in enumerate(detections):
|
||||
if self.single_target_mode and i != self.target_id:
|
||||
continue
|
||||
|
||||
# 确保检测结果格式正确
|
||||
if not isinstance(det, dict):
|
||||
logger.warning(f"检测结果 #{i} 不是字典: {det}")
|
||||
continue
|
||||
|
||||
# 获取必要的字段,使用安全的方式
|
||||
try:
|
||||
bbox = det.get('bbox', None)
|
||||
if bbox is None or not isinstance(bbox, list) or len(bbox) != 4:
|
||||
logger.warning(f"检测结果 #{i} 的边界框格式不正确: {bbox}")
|
||||
continue
|
||||
|
||||
class_id = det.get('class_id', 0)
|
||||
confidence = det.get('confidence', 0.5)
|
||||
except Exception as e:
|
||||
logger.warning(f"处理检测结果 #{i} 时出错: {str(e)}")
|
||||
continue
|
||||
|
||||
# 提取特征
|
||||
features = self.extract_features(image, [bbox])
|
||||
|
||||
# 创建新轨迹
|
||||
self.tracks.append({
|
||||
'id': self.next_id,
|
||||
'boxes': [bbox],
|
||||
'class_id': class_id,
|
||||
'confidence': confidence,
|
||||
'features': [features[0]],
|
||||
'age': 1,
|
||||
'hits': 1,
|
||||
'time_since_update': 0,
|
||||
'predicted_box': bbox,
|
||||
'trajectory': [((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)]
|
||||
})
|
||||
|
||||
self.next_id += 1
|
||||
|
||||
return self.get_active_tracks()
|
||||
|
||||
# 预测轨迹的下一个位置
|
||||
self.predict()
|
||||
|
||||
# 如果没有检测,更新轨迹状态
|
||||
if len(detections) == 0:
|
||||
for track in self.tracks:
|
||||
track['time_since_update'] += 1
|
||||
|
||||
return self.get_active_tracks()
|
||||
|
||||
# 提取检测特征
|
||||
detection_boxes = []
|
||||
for det in detections:
|
||||
try:
|
||||
if isinstance(det, dict) and 'bbox' in det:
|
||||
bbox = det['bbox']
|
||||
if isinstance(bbox, list) and len(bbox) == 4:
|
||||
detection_boxes.append(bbox)
|
||||
except Exception as e:
|
||||
logger.warning(f"提取检测框时出错: {str(e)}")
|
||||
|
||||
# 如果没有有效的检测框,返回空结果
|
||||
if not detection_boxes:
|
||||
return self.get_active_tracks()
|
||||
|
||||
detection_features = self.extract_features(image, detection_boxes)
|
||||
|
||||
# 提取轨迹特征
|
||||
track_features = torch.stack([track['features'][-1] for track in self.tracks])
|
||||
|
||||
# 计算特征相似度
|
||||
feature_similarity = self.compute_attention_similarity(track_features, detection_features)
|
||||
|
||||
# 计算IOU矩阵
|
||||
iou_matrix = torch.zeros((len(self.tracks), len(detections)), device=self.device)
|
||||
for i, track in enumerate(self.tracks):
|
||||
for j, det in enumerate(detections):
|
||||
try:
|
||||
if isinstance(det, dict) and 'bbox' in det:
|
||||
bbox = det['bbox']
|
||||
if isinstance(bbox, list) and len(bbox) == 4:
|
||||
iou = self.compute_iou(track['predicted_box'], bbox)
|
||||
iou_matrix[i, j] = iou
|
||||
except Exception as e:
|
||||
logger.warning(f"计算IOU时出错: {str(e)}")
|
||||
|
||||
# 计算综合相似度
|
||||
similarity_matrix = (
|
||||
self.feature_similarity_weight * feature_similarity +
|
||||
self.motion_weight * iou_matrix
|
||||
)
|
||||
|
||||
# 匹配轨迹和检测
|
||||
matched_indices = []
|
||||
unmatched_tracks = list(range(len(self.tracks)))
|
||||
unmatched_detections = list(range(len(detections)))
|
||||
|
||||
# 贪心匹配
|
||||
similarity_matrix_np = similarity_matrix.cpu().numpy()
|
||||
while len(unmatched_tracks) > 0 and len(unmatched_detections) > 0:
|
||||
# 找到最大相似度
|
||||
track_idx, det_idx = np.unravel_index(
|
||||
np.argmax(similarity_matrix_np[unmatched_tracks][:, unmatched_detections]),
|
||||
(len(unmatched_tracks), len(unmatched_detections))
|
||||
)
|
||||
|
||||
track_idx = unmatched_tracks[track_idx]
|
||||
det_idx = unmatched_detections[det_idx]
|
||||
|
||||
# 如果相似度太低,停止匹配
|
||||
if similarity_matrix_np[track_idx, det_idx] < self.iou_threshold:
|
||||
break
|
||||
|
||||
# 添加匹配
|
||||
matched_indices.append((track_idx, det_idx))
|
||||
|
||||
# 从未匹配列表中移除
|
||||
unmatched_tracks.remove(track_idx)
|
||||
unmatched_detections.remove(det_idx)
|
||||
|
||||
# 更新匹配的轨迹
|
||||
for track_idx, det_idx in matched_indices:
|
||||
try:
|
||||
track = self.tracks[track_idx]
|
||||
det = detections[det_idx]
|
||||
|
||||
# 确保检测结果格式正确
|
||||
if not isinstance(det, dict):
|
||||
logger.warning(f"匹配的检测结果 #{det_idx} 不是字典: {det}")
|
||||
continue
|
||||
|
||||
# 获取必要的字段
|
||||
bbox = det.get('bbox', None)
|
||||
if bbox is None or not isinstance(bbox, list) or len(bbox) != 4:
|
||||
logger.warning(f"匹配的检测结果 #{det_idx} 的边界框格式不正确: {bbox}")
|
||||
continue
|
||||
|
||||
class_id = det.get('class_id', track['class_id'])
|
||||
confidence = det.get('confidence', track['confidence'])
|
||||
|
||||
# 更新轨迹
|
||||
track['boxes'].append(bbox)
|
||||
track['features'].append(detection_features[det_idx])
|
||||
track['class_id'] = class_id # 更新类别ID
|
||||
track['confidence'] = confidence # 更新置信度
|
||||
track['age'] += 1
|
||||
track['hits'] += 1
|
||||
track['time_since_update'] = 0
|
||||
|
||||
# 更新轨迹中心点
|
||||
center_x = (bbox[0] + bbox[2]) / 2
|
||||
center_y = (bbox[1] + bbox[3]) / 2
|
||||
track['trajectory'].append((center_x, center_y))
|
||||
except Exception as e:
|
||||
logger.warning(f"更新轨迹 #{track_idx} 时出错: {str(e)}")
|
||||
|
||||
# 更新未匹配的轨迹
|
||||
for track_idx in unmatched_tracks:
|
||||
track = self.tracks[track_idx]
|
||||
track['time_since_update'] += 1
|
||||
|
||||
# 创建新轨迹
|
||||
for det_idx in unmatched_detections:
|
||||
try:
|
||||
det = detections[det_idx]
|
||||
|
||||
# 如果是单目标模式且已有目标,不创建新轨迹
|
||||
if self.single_target_mode and len(self.tracks) > 0:
|
||||
continue
|
||||
|
||||
# 确保检测结果格式正确
|
||||
if not isinstance(det, dict):
|
||||
logger.warning(f"未匹配的检测结果 #{det_idx} 不是字典: {det}")
|
||||
continue
|
||||
|
||||
# 获取必要的字段
|
||||
bbox = det.get('bbox', None)
|
||||
if bbox is None or not isinstance(bbox, list) or len(bbox) != 4:
|
||||
logger.warning(f"未匹配的检测结果 #{det_idx} 的边界框格式不正确: {bbox}")
|
||||
continue
|
||||
|
||||
class_id = det.get('class_id', 0)
|
||||
confidence = det.get('confidence', 0.5)
|
||||
|
||||
# 提取特征
|
||||
if det_idx >= len(detection_features):
|
||||
logger.warning(f"检测特征索引越界: {det_idx} >= {len(detection_features)}")
|
||||
continue
|
||||
|
||||
features = detection_features[det_idx].unsqueeze(0)
|
||||
|
||||
# 创建新轨迹
|
||||
self.tracks.append({
|
||||
'id': self.next_id,
|
||||
'boxes': [bbox],
|
||||
'class_id': class_id,
|
||||
'confidence': confidence,
|
||||
'features': [features[0]],
|
||||
'age': 1,
|
||||
'hits': 1,
|
||||
'time_since_update': 0,
|
||||
'predicted_box': bbox,
|
||||
'trajectory': [((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)]
|
||||
})
|
||||
|
||||
self.next_id += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"创建新轨迹时出错: {str(e)}")
|
||||
|
||||
# 删除过期的轨迹
|
||||
self.tracks = [
|
||||
track for track in self.tracks
|
||||
if track['time_since_update'] <= self.max_age
|
||||
]
|
||||
|
||||
return self.get_active_tracks()
|
||||
|
||||
def get_active_tracks(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取活跃的轨迹
|
||||
|
||||
返回:
|
||||
List[Dict]: 活跃轨迹列表
|
||||
"""
|
||||
active_tracks = []
|
||||
|
||||
for track in self.tracks:
|
||||
# 只返回命中次数足够的轨迹
|
||||
if track['hits'] >= self.min_hits and track['time_since_update'] <= 1:
|
||||
# 如果是单目标模式,只返回目标轨迹
|
||||
if self.single_target_mode and track['id'] != self.target_id:
|
||||
continue
|
||||
|
||||
# 复制轨迹信息
|
||||
track_info = {
|
||||
'id': track['id'],
|
||||
'bbox': track['boxes'][-1],
|
||||
'class_id': track['class_id'],
|
||||
'confidence': track['confidence'],
|
||||
'age': track['age'],
|
||||
'time_since_update': track['time_since_update'],
|
||||
'trajectory': track['trajectory']
|
||||
}
|
||||
|
||||
active_tracks.append(track_info)
|
||||
|
||||
return active_tracks
|
||||
|
||||
# 追踪报告功能已移除
|
||||
304
app/nn/modules/cbam.py
Normal file
304
app/nn/modules/cbam.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
CBAM (Convolutional Block Attention Module) 模块
|
||||
|
||||
CBAM是一种有效的注意力机制,结合了通道注意力和空间注意力,
|
||||
可以显著提高模型性能。
|
||||
|
||||
主要组件:
|
||||
1. 通道注意力 (ChannelAttention): 捕获通道间的依赖关系,强调"什么"是重要的特征
|
||||
2. 空间注意力 (SpatialAttention): 捕获空间位置的依赖关系,强调"哪里"有重要特征
|
||||
3. CBAM模块 (CBAM): 组合通道和空间注意力
|
||||
|
||||
参考文献:
|
||||
- "CBAM: Convolutional Block Attention Module" (Woo et al., 2018)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple, List, Union
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
"""
|
||||
通道注意力模块
|
||||
|
||||
使用全局平均池化和最大池化提取通道统计信息,
|
||||
然后通过共享MLP生成通道注意力权重。
|
||||
|
||||
参数:
|
||||
in_channels (int): 输入特征的通道数
|
||||
reduction_ratio (int): 通道减少比例,用于降低计算复杂度
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, reduction_ratio: int = 16):
|
||||
super().__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1) # 全局最大池化
|
||||
|
||||
# 共享MLP
|
||||
self.fc = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入特征 [B, C, H, W]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 通道注意力权重 [B, C, 1, 1]
|
||||
"""
|
||||
# 全局平均池化分支
|
||||
avg_out = self.fc(self.avg_pool(x))
|
||||
|
||||
# 全局最大池化分支
|
||||
max_out = self.fc(self.max_pool(x))
|
||||
|
||||
# 融合两个分支
|
||||
out = avg_out + max_out
|
||||
|
||||
# 应用sigmoid激活函数
|
||||
return torch.sigmoid(out)
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
"""
|
||||
空间注意力模块
|
||||
|
||||
使用通道平均池化和最大池化提取空间统计信息,
|
||||
然后通过卷积层生成空间注意力权重。
|
||||
|
||||
参数:
|
||||
kernel_size (int): 卷积核大小
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size: int = 7):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入特征 [B, C, H, W]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 空间注意力权重 [B, 1, H, W]
|
||||
"""
|
||||
# 通道平均池化
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
|
||||
# 通道最大池化
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
|
||||
# 拼接平均池化和最大池化结果
|
||||
out = torch.cat([avg_out, max_out], dim=1)
|
||||
|
||||
# 应用卷积和sigmoid激活函数
|
||||
out = self.conv(out)
|
||||
|
||||
return torch.sigmoid(out)
|
||||
|
||||
|
||||
class CBAM(nn.Module):
|
||||
"""
|
||||
CBAM (Convolutional Block Attention Module)
|
||||
|
||||
结合通道注意力和空间注意力,先应用通道注意力,再应用空间注意力。
|
||||
|
||||
参数:
|
||||
in_channels (int): 输入特征的通道数
|
||||
reduction_ratio (int): 通道减少比例
|
||||
spatial_kernel_size (int): 空间注意力卷积核大小
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
reduction_ratio: int = 16,
|
||||
spatial_kernel_size: int = 7
|
||||
):
|
||||
super().__init__()
|
||||
self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
|
||||
self.spatial_attention = SpatialAttention(spatial_kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入特征 [B, C, H, W]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 注意力增强后的特征 [B, C, H, W]
|
||||
"""
|
||||
# 应用通道注意力
|
||||
x = x * self.channel_attention(x)
|
||||
|
||||
# 应用空间注意力
|
||||
x = x * self.spatial_attention(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CBAMResBlock(nn.Module):
|
||||
"""
|
||||
带有CBAM的残差块
|
||||
|
||||
在残差连接后应用CBAM注意力机制。
|
||||
|
||||
参数:
|
||||
in_channels (int): 输入特征的通道数
|
||||
out_channels (int): 输出特征的通道数
|
||||
stride (int): 卷积步长
|
||||
reduction_ratio (int): 通道减少比例
|
||||
spatial_kernel_size (int): 空间注意力卷积核大小
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int = 1,
|
||||
reduction_ratio: int = 16,
|
||||
spatial_kernel_size: int = 7
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 主分支
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
# 残差连接
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
|
||||
# CBAM注意力
|
||||
self.cbam = CBAM(out_channels, reduction_ratio, spatial_kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入特征 [B, C, H, W]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 注意力增强后的特征 [B, C, H, W]
|
||||
"""
|
||||
# 主分支
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
# 残差连接
|
||||
shortcut = self.shortcut(x)
|
||||
|
||||
# 应用CBAM
|
||||
out = self.cbam(out)
|
||||
|
||||
# 残差连接
|
||||
out += shortcut
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SelfAttentionWithCBAM(nn.Module):
|
||||
"""
|
||||
带有CBAM的自注意力模块
|
||||
|
||||
在自注意力计算前应用CBAM增强特征表示。
|
||||
|
||||
参数:
|
||||
dim (int): 输入特征的通道维度
|
||||
num_heads (int): 注意力头的数量
|
||||
qkv_bias (bool): 是否在QKV投影中使用偏置
|
||||
attn_drop (float): 注意力dropout率
|
||||
proj_drop (float): 输出投影dropout率
|
||||
reduction_ratio (int): CBAM通道减少比例
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
reduction_ratio: int = 16
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 从attention模块导入SelfAttention
|
||||
from app.nn.modules.attention import SelfAttention
|
||||
|
||||
# 自注意力模块
|
||||
self.self_attn = SelfAttention(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop
|
||||
)
|
||||
|
||||
# CBAM注意力
|
||||
self.cbam = CBAM(dim, reduction_ratio)
|
||||
|
||||
# 特征转换
|
||||
self.to_2d = lambda x, h, w: x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w)
|
||||
self.to_1d = lambda x: x.flatten(2).permute(0, 2, 1)
|
||||
|
||||
def forward(self, x: torch.Tensor, h: int = None, w: int = None) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入特征 [B, N, C] 或 [B, C, H, W]
|
||||
h (int): 特征图高度,当输入为[B, N, C]时需要提供
|
||||
w (int): 特征图宽度,当输入为[B, N, C]时需要提供
|
||||
|
||||
返回:
|
||||
torch.Tensor: 注意力增强后的特征,与输入形状相同
|
||||
"""
|
||||
# 检查输入维度
|
||||
if x.dim() == 4: # [B, C, H, W]
|
||||
h, w = x.shape[2], x.shape[3]
|
||||
x_2d = x
|
||||
x_1d = self.to_1d(x)
|
||||
elif x.dim() == 3: # [B, N, C]
|
||||
assert h is not None and w is not None, "需要提供特征图的高度和宽度"
|
||||
x_1d = x
|
||||
x_2d = self.to_2d(x, h, w)
|
||||
else:
|
||||
raise ValueError(f"不支持的输入维度: {x.dim()}")
|
||||
|
||||
# 应用CBAM
|
||||
enhanced_x_2d = self.cbam(x_2d)
|
||||
|
||||
# 转换回序列形式
|
||||
enhanced_x_1d = self.to_1d(enhanced_x_2d)
|
||||
|
||||
# 应用自注意力
|
||||
output = self.self_attn(enhanced_x_1d)
|
||||
|
||||
# 返回与输入相同形状的输出
|
||||
if x.dim() == 4:
|
||||
return self.to_2d(output, h, w)
|
||||
else:
|
||||
return output
|
||||
307
app/nn/modules/cross_attention.py
Normal file
307
app/nn/modules/cross_attention.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
交叉注意力 (Cross-Attention) 模块
|
||||
|
||||
交叉注意力机制允许模型在不同特征集之间建立关联,
|
||||
非常适合目标追踪任务中的目标匹配。
|
||||
|
||||
主要组件:
|
||||
1. 交叉注意力 (CrossAttention): 使用一组特征作为查询(Q),另一组特征作为键(K)和值(V)
|
||||
2. 交叉注意力追踪器 (CrossAttentionTracker): 使用交叉注意力进行目标匹配
|
||||
|
||||
技术原理:
|
||||
- 与自注意力不同,交叉注意力使用不同的特征集作为查询和键/值
|
||||
- 这允许模型学习两组特征之间的关系,而不仅仅是单个特征集内的关系
|
||||
- 在目标追踪中,可以使用当前帧的目标特征作为查询,历史帧的目标特征作为键和值
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple, List, Union
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
交叉注意力模块
|
||||
|
||||
使用一组特征作为查询(Q),另一组特征作为键(K)和值(V),
|
||||
计算它们之间的关联性。
|
||||
|
||||
参数:
|
||||
dim (int): 特征维度
|
||||
num_heads (int): 注意力头的数量
|
||||
qkv_bias (bool): 是否在QKV投影中使用偏置
|
||||
attn_drop (float): 注意力dropout率
|
||||
proj_drop (float): 输出投影dropout率
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5 # 缩放因子
|
||||
|
||||
# 查询、键、值的线性投影
|
||||
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key_value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
query (torch.Tensor): 查询特征 [B, Nq, C]
|
||||
key_value (torch.Tensor): 键值特征 [B, Nkv, C]
|
||||
attn_mask (torch.Tensor, optional): 注意力掩码 [B, Nq, Nkv]
|
||||
|
||||
返回:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- 输出特征 [B, Nq, C]
|
||||
- 注意力权重 [B, num_heads, Nq, Nkv]
|
||||
"""
|
||||
B, Nq, C = query.shape
|
||||
_, Nkv, _ = key_value.shape
|
||||
|
||||
# 投影查询、键、值
|
||||
q = self.q_proj(query).reshape(B, Nq, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
k = self.k_proj(key_value).reshape(B, Nkv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
v = self.v_proj(key_value).reshape(B, Nkv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
# 计算注意力分数
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, Nq, Nkv]
|
||||
|
||||
# 应用掩码(如果提供)
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dim() == 3: # [B, Nq, Nkv]
|
||||
attn_mask = attn_mask.unsqueeze(1) # [B, 1, Nq, Nkv]
|
||||
attn = attn.masked_fill(attn_mask == 0, float('-inf'))
|
||||
|
||||
# 应用softmax和dropout
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
# 加权聚合
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x, attn
|
||||
|
||||
|
||||
class CrossAttentionTracker(nn.Module):
|
||||
"""
|
||||
基于交叉注意力的目标追踪器
|
||||
|
||||
使用交叉注意力机制计算当前帧目标与历史帧目标之间的关联性,
|
||||
实现目标匹配和轨迹管理。
|
||||
|
||||
参数:
|
||||
feature_dim (int): 特征维度
|
||||
num_heads (int): 注意力头的数量
|
||||
dropout (float): dropout率
|
||||
appearance_weight (float): 外观特征权重
|
||||
motion_weight (float): 运动特征权重
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_dim: int = 256,
|
||||
num_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
appearance_weight: float = 0.7,
|
||||
motion_weight: float = 0.3
|
||||
):
|
||||
super().__init__()
|
||||
self.feature_dim = feature_dim
|
||||
self.appearance_weight = appearance_weight
|
||||
self.motion_weight = motion_weight
|
||||
|
||||
# 交叉注意力模块
|
||||
self.cross_attention = CrossAttention(
|
||||
dim=feature_dim,
|
||||
num_heads=num_heads,
|
||||
attn_drop=dropout,
|
||||
proj_drop=dropout
|
||||
)
|
||||
|
||||
# 特征转换
|
||||
self.appearance_transform = nn.Sequential(
|
||||
nn.Linear(feature_dim, feature_dim),
|
||||
nn.LayerNorm(feature_dim),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.motion_transform = nn.Sequential(
|
||||
nn.Linear(4, feature_dim // 4),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(feature_dim // 4, feature_dim),
|
||||
nn.LayerNorm(feature_dim)
|
||||
)
|
||||
|
||||
# 特征融合
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Linear(feature_dim * 2, feature_dim),
|
||||
nn.LayerNorm(feature_dim),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
def compute_motion_features(
|
||||
self,
|
||||
current_boxes: torch.Tensor,
|
||||
history_boxes: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
计算运动特征
|
||||
|
||||
参数:
|
||||
current_boxes (torch.Tensor): 当前帧边界框 [num_current, 4]
|
||||
history_boxes (torch.Tensor): 历史帧边界框 [num_history, 4]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 运动特征 [num_current, num_history, feature_dim]
|
||||
"""
|
||||
try:
|
||||
# 检查边界框维度
|
||||
if current_boxes.dim() < 2 or history_boxes.dim() < 2:
|
||||
raise ValueError(f"边界框维度不正确: current_boxes.dim()={current_boxes.dim()}, history_boxes.dim()={history_boxes.dim()}")
|
||||
|
||||
# 确保边界框是4维的
|
||||
if current_boxes.size(-1) != 4 or history_boxes.size(-1) != 4:
|
||||
logger.warning(f"边界框维度不是4: current_boxes.size(-1)={current_boxes.size(-1)}, history_boxes.size(-1)={history_boxes.size(-1)}")
|
||||
# 如果维度不是4,创建一个全零的特征张量
|
||||
num_current = current_boxes.size(0)
|
||||
num_history = history_boxes.size(0)
|
||||
return torch.zeros((num_current, num_history, self.feature_dim), device=current_boxes.device)
|
||||
|
||||
num_current = current_boxes.size(0)
|
||||
num_history = history_boxes.size(0)
|
||||
|
||||
# 计算边界框中心点和尺寸
|
||||
def box_to_center_size(boxes):
|
||||
# 确保boxes的最后一个维度是4
|
||||
if boxes.size(-1) != 4:
|
||||
logger.warning(f"边界框维度不是4: {boxes.size()}")
|
||||
# 如果不是4,填充或截断到4
|
||||
if boxes.size(-1) < 4:
|
||||
# 填充
|
||||
padding = torch.zeros(*boxes.shape[:-1], 4 - boxes.size(-1), device=boxes.device)
|
||||
boxes = torch.cat([boxes, padding], dim=-1)
|
||||
else:
|
||||
# 截断
|
||||
boxes = boxes[..., :4]
|
||||
|
||||
x1, y1, x2, y2 = boxes.unbind(-1)
|
||||
cx = (x1 + x2) / 2
|
||||
cy = (y1 + y2) / 2
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
return torch.stack([cx, cy, w, h], dim=-1)
|
||||
|
||||
current_centers = box_to_center_size(current_boxes) # [num_current, 4]
|
||||
history_centers = box_to_center_size(history_boxes) # [num_history, 4]
|
||||
|
||||
# 计算所有当前框与历史框之间的差异
|
||||
current_expanded = current_centers.unsqueeze(1).expand(-1, num_history, -1) # [num_current, num_history, 4]
|
||||
history_expanded = history_centers.unsqueeze(0).expand(num_current, -1, -1) # [num_current, num_history, 4]
|
||||
|
||||
# 计算中心点距离和尺寸比例
|
||||
motion_features = torch.cat([
|
||||
current_expanded - history_expanded, # 中心点位移
|
||||
current_expanded / (history_expanded + 1e-6) # 尺寸比例
|
||||
], dim=-1) # [num_current, num_history, 8]
|
||||
|
||||
# 转换为特征表示
|
||||
motion_features = self.motion_transform(motion_features[:, :, :4]) # 只使用位移信息
|
||||
|
||||
return motion_features
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"计算运动特征时出错: {str(e)}")
|
||||
# 创建一个全零的特征张量作为备选
|
||||
num_current = current_boxes.size(0) if current_boxes.dim() > 0 else 0
|
||||
num_history = history_boxes.size(0) if history_boxes.dim() > 0 else 0
|
||||
return torch.zeros((max(num_current, 1), max(num_history, 1), self.feature_dim),
|
||||
device=current_boxes.device)
|
||||
|
||||
def compute_similarity(
|
||||
self,
|
||||
current_features: torch.Tensor,
|
||||
history_features: torch.Tensor,
|
||||
current_boxes: Optional[torch.Tensor] = None,
|
||||
history_boxes: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
计算当前帧目标与历史帧目标之间的相似度
|
||||
|
||||
参数:
|
||||
current_features (torch.Tensor): 当前帧目标特征 [num_current, feature_dim]
|
||||
history_features (torch.Tensor): 历史帧目标特征 [num_history, feature_dim]
|
||||
current_boxes (torch.Tensor, optional): 当前帧边界框 [num_current, 4]
|
||||
history_boxes (torch.Tensor, optional): 历史帧边界框 [num_history, 4]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 相似度矩阵 [num_current, num_history]
|
||||
"""
|
||||
num_current = current_features.size(0)
|
||||
num_history = history_features.size(0)
|
||||
|
||||
if num_current == 0 or num_history == 0:
|
||||
return torch.zeros((num_current, num_history), device=current_features.device)
|
||||
|
||||
# 转换外观特征
|
||||
current_appearance = self.appearance_transform(current_features) # [num_current, feature_dim]
|
||||
history_appearance = self.appearance_transform(history_features) # [num_history, feature_dim]
|
||||
|
||||
# 添加批次维度
|
||||
current_appearance = current_appearance.unsqueeze(0) # [1, num_current, feature_dim]
|
||||
history_appearance = history_appearance.unsqueeze(0) # [1, num_history, feature_dim]
|
||||
|
||||
# 应用交叉注意力
|
||||
_, appearance_attn = self.cross_attention(current_appearance, history_appearance)
|
||||
|
||||
# 提取外观相似度(取第一个头的注意力权重)并分离梯度
|
||||
appearance_similarity = appearance_attn[0, 0].detach() # [num_current, num_history]
|
||||
|
||||
# 如果提供了边界框,计算运动相似度
|
||||
if current_boxes is not None and history_boxes is not None:
|
||||
try:
|
||||
# 计算运动特征
|
||||
motion_features = self.compute_motion_features(current_boxes, history_boxes) # [num_current, num_history, feature_dim]
|
||||
|
||||
# 计算运动相似度并分离梯度
|
||||
motion_similarity = torch.sum(motion_features, dim=-1).detach() # [num_current, num_history]
|
||||
motion_similarity = torch.sigmoid(motion_similarity) # 归一化到[0,1]
|
||||
|
||||
# 融合外观和运动相似度
|
||||
similarity = (
|
||||
self.appearance_weight * appearance_similarity +
|
||||
self.motion_weight * motion_similarity
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"计算运动相似度时出错: {str(e)}")
|
||||
# 如果出错,只使用外观相似度
|
||||
similarity = appearance_similarity
|
||||
else:
|
||||
similarity = appearance_similarity
|
||||
|
||||
return similarity
|
||||
761
app/nn/modules/enhanced_attention_tracker.py
Normal file
761
app/nn/modules/enhanced_attention_tracker.py
Normal file
@@ -0,0 +1,761 @@
|
||||
"""
|
||||
增强型自注意力追踪器
|
||||
|
||||
该模块实现了一个增强型自注意力追踪器,集成了多种先进技术:
|
||||
1. Layer Scale: 增强深层网络的训练稳定性
|
||||
2. CBAM: 通道+空间注意力融合,增强特征表示
|
||||
3. 交叉注意力: 更精确地建模目标间关系
|
||||
|
||||
这些技术的结合显著提高了追踪的准确性和稳定性,
|
||||
特别是在复杂场景和遮挡情况下。
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import cv2
|
||||
import math
|
||||
import logging
|
||||
from typing import List, Dict, Tuple, Optional, Union, Any
|
||||
from pathlib import Path
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
|
||||
# 导入自定义模块
|
||||
from app.nn.modules.attention import SelfAttention, TransformerEncoder
|
||||
from app.nn.modules.layer_scale import LayerScale, DropPath, TransformerEncoderWithLayerScale
|
||||
from app.nn.modules.cbam import CBAM, SelfAttentionWithCBAM
|
||||
from app.nn.modules.cross_attention import CrossAttention, CrossAttentionTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnhancedFeatureExtractor(nn.Module):
|
||||
"""
|
||||
增强型特征提取器,集成了CBAM注意力机制
|
||||
|
||||
参数:
|
||||
input_dim (int): 输入特征的维度
|
||||
feature_dim (int): 输出特征的维度
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim: int = 1024, feature_dim: int = 256):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.cbam1 = CBAM(64)
|
||||
|
||||
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(128)
|
||||
self.cbam2 = CBAM(128)
|
||||
|
||||
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
||||
self.bn3 = nn.BatchNorm2d(256)
|
||||
self.cbam3 = CBAM(256)
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Linear(256, feature_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入图像 [B, 3, H, W]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 提取的特征 [B, feature_dim]
|
||||
"""
|
||||
# 第一层卷积+CBAM
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = self.cbam1(x)
|
||||
|
||||
# 第二层卷积+CBAM
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = self.cbam2(x)
|
||||
|
||||
# 第三层卷积+CBAM
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
x = self.cbam3(x)
|
||||
|
||||
# 全局平均池化和全连接
|
||||
x = self.gap(x).squeeze(-1).squeeze(-1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EnhancedAttentionTracker:
|
||||
"""
|
||||
增强型自注意力追踪器
|
||||
|
||||
集成了Layer Scale、CBAM和交叉注意力等先进技术,
|
||||
提高追踪的准确性和稳定性。
|
||||
|
||||
参数:
|
||||
max_age (int): 目标消失后保持跟踪的最大帧数
|
||||
min_hits (int): 确认目标存在所需的最小检测次数
|
||||
iou_threshold (float): IOU匹配阈值
|
||||
feature_similarity_weight (float): 特征相似度权重
|
||||
motion_weight (float): 运动预测权重
|
||||
device (str): 使用的设备 ('cpu' 或 'cuda')
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_age: int = 30,
|
||||
min_hits: int = 3,
|
||||
iou_threshold: float = 0.3,
|
||||
feature_similarity_weight: float = 0.7,
|
||||
motion_weight: float = 0.3,
|
||||
device: str = 'cpu'
|
||||
):
|
||||
self.max_age = max_age
|
||||
self.min_hits = min_hits
|
||||
self.iou_threshold = iou_threshold
|
||||
self.feature_similarity_weight = feature_similarity_weight
|
||||
self.motion_weight = motion_weight
|
||||
self.device = device
|
||||
|
||||
# 初始化特征提取器
|
||||
self.feature_extractor = EnhancedFeatureExtractor().to(device)
|
||||
self.feature_extractor.eval()
|
||||
|
||||
# 初始化带有Layer Scale的Transformer编码器
|
||||
self.transformer = TransformerEncoderWithLayerScale(
|
||||
d_model=256,
|
||||
nhead=8,
|
||||
num_layers=3,
|
||||
dim_feedforward=1024,
|
||||
dropout=0.1,
|
||||
layer_scale_init_value=1e-6,
|
||||
drop_path_rate=0.1
|
||||
).to(device)
|
||||
self.transformer.eval()
|
||||
|
||||
# 初始化交叉注意力追踪器
|
||||
self.cross_attention_tracker = CrossAttentionTracker(
|
||||
feature_dim=256,
|
||||
num_heads=8,
|
||||
dropout=0.1,
|
||||
appearance_weight=feature_similarity_weight,
|
||||
motion_weight=motion_weight
|
||||
).to(device)
|
||||
self.cross_attention_tracker.eval()
|
||||
|
||||
# 初始化轨迹列表
|
||||
self.tracks = []
|
||||
self.next_id = 1
|
||||
|
||||
# 初始化帧计数器
|
||||
self.frame_count = 0
|
||||
|
||||
# 是否处于单目标模式
|
||||
self.single_target_mode = False
|
||||
self.target_id = None
|
||||
|
||||
def reset(self):
|
||||
"""重置追踪器状态"""
|
||||
self.tracks = []
|
||||
self.next_id = 1
|
||||
self.frame_count = 0
|
||||
self.single_target_mode = False
|
||||
self.target_id = None
|
||||
|
||||
def set_single_target_mode(self, enable: bool, target_id: Optional[int] = None, target_class_id: Optional[int] = None):
|
||||
"""
|
||||
设置单目标追踪模式
|
||||
|
||||
参数:
|
||||
enable (bool): 是否启用单目标模式
|
||||
target_id (int, optional): 要追踪的目标ID
|
||||
target_class_id (int, optional): 要追踪的目标类别ID
|
||||
"""
|
||||
logger.info(f"设置单目标追踪模式: enable={enable}, target_id={target_id}, target_class_id={target_class_id}")
|
||||
|
||||
# 更新单目标模式状态
|
||||
self.single_target_mode = enable
|
||||
|
||||
# 如果禁用单目标模式,清除目标ID和类别ID
|
||||
if not enable:
|
||||
self.target_id = None
|
||||
if hasattr(self, 'target_class_id'):
|
||||
delattr(self, 'target_class_id')
|
||||
logger.info("已禁用单目标追踪模式")
|
||||
return
|
||||
|
||||
# 启用单目标模式
|
||||
if target_id is not None:
|
||||
# 如果提供了目标ID,直接使用
|
||||
self.target_id = target_id
|
||||
# 清除之前的类别ID
|
||||
if hasattr(self, 'target_class_id'):
|
||||
delattr(self, 'target_class_id')
|
||||
logger.info(f"已启用单目标追踪模式,追踪目标ID: {target_id}")
|
||||
elif target_class_id is not None:
|
||||
# 确保target_class_id是整数
|
||||
try:
|
||||
self.target_class_id = int(target_class_id)
|
||||
except (ValueError, TypeError):
|
||||
self.target_class_id = target_class_id
|
||||
|
||||
logger.info(f"已启用单目标追踪模式,追踪类别ID: {self.target_class_id}")
|
||||
|
||||
# 重置目标ID,让系统在下一帧中重新查找该类别的目标
|
||||
self.target_id = None
|
||||
|
||||
# 如果已有轨迹,查找该类别的第一个目标
|
||||
for track in self.tracks:
|
||||
if track['class_id'] == self.target_class_id:
|
||||
self.target_id = track['id']
|
||||
logger.info(f"找到类别ID为 {self.target_class_id} 的目标,目标ID: {self.target_id}")
|
||||
break
|
||||
|
||||
# 如果没有找到该类别的目标,记录信息
|
||||
if self.target_id is None:
|
||||
logger.info(f"未找到类别ID为 {self.target_class_id} 的目标,将在下一帧中查找")
|
||||
else:
|
||||
logger.warning("启用单目标追踪模式但未提供目标ID或类别ID,单目标追踪模式可能无效")
|
||||
|
||||
def extract_features(self, image: np.ndarray, boxes: List[List[int]]) -> torch.Tensor:
|
||||
"""
|
||||
从图像中提取目标特征
|
||||
|
||||
参数:
|
||||
image (np.ndarray): 输入图像
|
||||
boxes (List[List[int]]): 目标边界框列表 [[x1, y1, x2, y2], ...]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 提取的特征 [num_boxes, feature_dim]
|
||||
"""
|
||||
if not boxes:
|
||||
return torch.zeros((0, 256), device=self.device)
|
||||
|
||||
# 裁剪目标区域
|
||||
crops = []
|
||||
for box in boxes:
|
||||
try:
|
||||
# 确保边界框坐标是整数
|
||||
x1, y1, x2, y2 = map(int, box)
|
||||
|
||||
# 确保坐标在图像范围内
|
||||
height, width = image.shape[:2]
|
||||
x1 = max(0, min(x1, width - 1))
|
||||
y1 = max(0, min(y1, height - 1))
|
||||
x2 = max(0, min(x2, width))
|
||||
y2 = max(0, min(y2, height))
|
||||
|
||||
# 确保裁剪区域有效
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
# 无效的裁剪区域,使用整个图像
|
||||
crop = image
|
||||
else:
|
||||
crop = image[y1:y2, x1:x2]
|
||||
if crop.size == 0:
|
||||
# 如果裁剪区域为空,使用整个图像
|
||||
crop = image
|
||||
|
||||
# 调整大小为固定尺寸 - 减小尺寸以降低内存占用
|
||||
crop = cv2.resize(crop, (32, 32)) # 从64x64减小到32x32
|
||||
crops.append(crop)
|
||||
except Exception as e:
|
||||
logger.warning(f"裁剪目标区域失败: {str(e)}")
|
||||
# 使用空白图像
|
||||
crop = np.zeros((32, 32, 3), dtype=np.uint8)
|
||||
crops.append(crop)
|
||||
|
||||
# 转换为张量
|
||||
crops_tensor = torch.stack([
|
||||
torch.from_numpy(crop.transpose(2, 0, 1)).float() / 255.0
|
||||
for crop in crops
|
||||
]).to(self.device)
|
||||
|
||||
# 提取特征
|
||||
with torch.no_grad():
|
||||
# 使用torch.cuda.empty_cache()清理GPU内存
|
||||
if self.device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
features = self.feature_extractor(crops_tensor)
|
||||
|
||||
# 应用Transformer编码器增强特征
|
||||
features = features.unsqueeze(0) # [1, num_boxes, feature_dim]
|
||||
features = self.transformer(features)
|
||||
features = features.squeeze(0) # [num_boxes, feature_dim]
|
||||
|
||||
# 再次清理内存
|
||||
if self.device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 确保从GPU移动到CPU,减少GPU内存占用
|
||||
features = features.cpu()
|
||||
|
||||
return features
|
||||
|
||||
def compute_similarity(
|
||||
self,
|
||||
track_features: torch.Tensor,
|
||||
detection_features: torch.Tensor,
|
||||
track_boxes: List[List[int]],
|
||||
detection_boxes: List[List[int]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
计算轨迹和检测之间的相似度
|
||||
|
||||
参数:
|
||||
track_features (torch.Tensor): 轨迹特征 [num_tracks, feature_dim]
|
||||
detection_features (torch.Tensor): 检测特征 [num_detections, feature_dim]
|
||||
track_boxes (List[List[int]]): 轨迹边界框 [[x1, y1, x2, y2], ...]
|
||||
detection_boxes (List[List[int]]): 检测边界框 [[x1, y1, x2, y2], ...]
|
||||
|
||||
返回:
|
||||
torch.Tensor: 相似度矩阵 [num_tracks, num_detections]
|
||||
"""
|
||||
try:
|
||||
if track_features.size(0) == 0 or detection_features.size(0) == 0:
|
||||
return torch.zeros((track_features.size(0), detection_features.size(0)), device=self.device)
|
||||
|
||||
# 确保所有边界框都是4维的
|
||||
normalized_track_boxes = []
|
||||
for box in track_boxes:
|
||||
if len(box) == 4:
|
||||
normalized_track_boxes.append(box)
|
||||
elif len(box) > 4:
|
||||
normalized_track_boxes.append(box[:4]) # 截断
|
||||
else:
|
||||
# 填充
|
||||
padded_box = box + [0] * (4 - len(box))
|
||||
normalized_track_boxes.append(padded_box)
|
||||
|
||||
normalized_detection_boxes = []
|
||||
for box in detection_boxes:
|
||||
if len(box) == 4:
|
||||
normalized_detection_boxes.append(box)
|
||||
elif len(box) > 4:
|
||||
normalized_detection_boxes.append(box[:4]) # 截断
|
||||
else:
|
||||
# 填充
|
||||
padded_box = box + [0] * (4 - len(box))
|
||||
normalized_detection_boxes.append(padded_box)
|
||||
|
||||
# 转换边界框为张量
|
||||
track_boxes_tensor = torch.tensor(normalized_track_boxes, dtype=torch.float32, device=self.device)
|
||||
detection_boxes_tensor = torch.tensor(normalized_detection_boxes, dtype=torch.float32, device=self.device)
|
||||
|
||||
# 使用交叉注意力计算相似度
|
||||
similarity = self.cross_attention_tracker.compute_similarity(
|
||||
detection_features,
|
||||
track_features,
|
||||
detection_boxes_tensor,
|
||||
track_boxes_tensor
|
||||
)
|
||||
|
||||
# 转置相似度矩阵,使其形状为 [num_tracks, num_detections]
|
||||
similarity = similarity.t()
|
||||
|
||||
return similarity
|
||||
except Exception as e:
|
||||
logger.error(f"计算相似度时出错: {str(e)}")
|
||||
# 创建一个全零的相似度矩阵作为备选
|
||||
return torch.zeros((len(track_boxes), len(detection_boxes)), device=self.device)
|
||||
|
||||
def compute_iou(self, box1: List[float], box2: List[float]) -> float:
|
||||
"""
|
||||
计算两个边界框的IOU
|
||||
|
||||
参数:
|
||||
box1 (List[float]): 第一个边界框 [x1, y1, x2, y2]
|
||||
box2 (List[float]): 第二个边界框 [x1, y1, x2, y2]
|
||||
|
||||
返回:
|
||||
float: IOU值
|
||||
"""
|
||||
try:
|
||||
# 确保边界框至少有4个元素
|
||||
if len(box1) < 4 or len(box2) < 4:
|
||||
logger.warning(f"边界框维度不足4: box1={len(box1)}, box2={len(box2)}")
|
||||
# 填充边界框
|
||||
if len(box1) < 4:
|
||||
box1 = box1 + [0] * (4 - len(box1))
|
||||
if len(box2) < 4:
|
||||
box2 = box2 + [0] * (4 - len(box2))
|
||||
|
||||
# 确保边界框坐标是浮点数
|
||||
x1_1, y1_1, x2_1, y2_1 = map(float, box1[:4])
|
||||
x1_2, y1_2, x2_2, y2_2 = map(float, box2[:4])
|
||||
|
||||
# 计算交集区域
|
||||
x1_i = max(x1_1, x1_2)
|
||||
y1_i = max(y1_1, y1_2)
|
||||
x2_i = min(x2_1, x2_2)
|
||||
y2_i = min(y2_1, y2_2)
|
||||
|
||||
if x2_i < x1_i or y2_i < y1_i:
|
||||
return 0.0
|
||||
|
||||
intersection = (x2_i - x1_i) * (y2_i - y1_i)
|
||||
|
||||
# 计算各自面积
|
||||
area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
|
||||
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
|
||||
|
||||
# 计算IOU
|
||||
iou = intersection / (area1 + area2 - intersection + 1e-6)
|
||||
|
||||
return iou
|
||||
except Exception as e:
|
||||
logger.warning(f"计算IOU时出错: {str(e)}")
|
||||
return 0.0
|
||||
|
||||
def predict(self):
|
||||
"""预测所有轨迹的下一个位置"""
|
||||
for track in self.tracks:
|
||||
# 获取最后一个边界框
|
||||
box = track['boxes'][-1]
|
||||
|
||||
# 如果有足够的历史记录,使用卡尔曼滤波或简单线性预测
|
||||
if len(track['boxes']) >= 2:
|
||||
prev_box = track['boxes'][-2]
|
||||
|
||||
# 计算速度
|
||||
vx = box[0] - prev_box[0]
|
||||
vy = box[1] - prev_box[1]
|
||||
vw = (box[2] - box[0]) - (prev_box[2] - prev_box[0])
|
||||
vh = (box[3] - box[1]) - (prev_box[3] - prev_box[1])
|
||||
|
||||
# 预测下一个位置
|
||||
x1 = box[0] + vx
|
||||
y1 = box[1] + vy
|
||||
x2 = box[2] + vx + vw
|
||||
y2 = box[3] + vy + vh
|
||||
|
||||
track['predicted_box'] = [x1, y1, x2, y2]
|
||||
else:
|
||||
# 如果没有足够的历史记录,使用当前位置作为预测
|
||||
track['predicted_box'] = box
|
||||
|
||||
def update(self, image: np.ndarray, detections: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
更新追踪器状态
|
||||
|
||||
参数:
|
||||
image (np.ndarray): 当前帧图像
|
||||
detections (List[Dict]): 检测结果列表,每个检测包含 'bbox', 'class_id', 'confidence' 等字段
|
||||
|
||||
返回:
|
||||
List[Dict]: 更新后的轨迹列表
|
||||
"""
|
||||
self.frame_count += 1
|
||||
|
||||
# 如果没有轨迹,初始化轨迹
|
||||
if len(self.tracks) == 0:
|
||||
for i, det in enumerate(detections):
|
||||
try:
|
||||
# 确保检测结果格式正确
|
||||
if not isinstance(det, dict):
|
||||
logger.warning(f"检测结果 #{i} 不是字典: {det}")
|
||||
continue
|
||||
|
||||
# 获取必要的字段
|
||||
bbox = det.get('bbox', None)
|
||||
if bbox is None or not isinstance(bbox, list) or len(bbox) != 4:
|
||||
logger.warning(f"检测结果 #{i} 的边界框格式不正确: {bbox}")
|
||||
continue
|
||||
|
||||
class_id = det.get('class_id', 0)
|
||||
confidence = det.get('confidence', 0.5)
|
||||
|
||||
# 如果是单目标模式,检查类别ID
|
||||
if self.single_target_mode:
|
||||
# 如果有目标类别ID,只处理该类别的目标
|
||||
if hasattr(self, 'target_class_id') and class_id != self.target_class_id:
|
||||
logger.info(f"跳过类别ID不匹配的检测: 检测类别ID={class_id}, 目标类别ID={self.target_class_id}")
|
||||
continue
|
||||
# 如果有目标ID,只处理该ID的目标
|
||||
elif self.target_id is not None and i != self.target_id:
|
||||
continue
|
||||
|
||||
# 提取特征
|
||||
features = self.extract_features(image, [bbox])
|
||||
|
||||
# 创建新轨迹
|
||||
self.tracks.append({
|
||||
'id': self.next_id,
|
||||
'boxes': [bbox],
|
||||
'class_id': class_id,
|
||||
'confidence': confidence,
|
||||
'features': [features[0]],
|
||||
'age': 1,
|
||||
'hits': 1,
|
||||
'time_since_update': 0,
|
||||
'predicted_box': bbox,
|
||||
'trajectory': [((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)]
|
||||
})
|
||||
|
||||
self.next_id += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"处理检测结果 #{i} 时出错: {str(e)}")
|
||||
|
||||
return self.get_active_tracks()
|
||||
|
||||
# 预测轨迹的下一个位置
|
||||
self.predict()
|
||||
|
||||
# 如果没有检测,更新轨迹状态
|
||||
if len(detections) == 0:
|
||||
for track in self.tracks:
|
||||
track['time_since_update'] += 1
|
||||
|
||||
return self.get_active_tracks()
|
||||
|
||||
# 提取检测特征
|
||||
detection_boxes = []
|
||||
valid_detections = []
|
||||
for det in detections:
|
||||
try:
|
||||
if isinstance(det, dict) and 'bbox' in det:
|
||||
bbox = det['bbox']
|
||||
if isinstance(bbox, list) and len(bbox) == 4:
|
||||
detection_boxes.append(bbox)
|
||||
valid_detections.append(det)
|
||||
except Exception as e:
|
||||
logger.warning(f"提取检测框时出错: {str(e)}")
|
||||
|
||||
# 如果没有有效的检测框,返回空结果
|
||||
if not detection_boxes:
|
||||
for track in self.tracks:
|
||||
track['time_since_update'] += 1
|
||||
return self.get_active_tracks()
|
||||
|
||||
detection_features = self.extract_features(image, detection_boxes)
|
||||
|
||||
# 提取轨迹特征和边界框
|
||||
track_features = torch.stack([track['features'][-1] for track in self.tracks])
|
||||
track_boxes = [track['predicted_box'] for track in self.tracks]
|
||||
|
||||
# 计算特征相似度
|
||||
feature_similarity = self.compute_similarity(
|
||||
track_features,
|
||||
detection_features,
|
||||
track_boxes,
|
||||
detection_boxes
|
||||
)
|
||||
|
||||
# 计算IOU矩阵
|
||||
iou_matrix = torch.zeros((len(self.tracks), len(detection_boxes)), device=self.device)
|
||||
for i, track in enumerate(self.tracks):
|
||||
for j, bbox in enumerate(detection_boxes):
|
||||
try:
|
||||
iou = self.compute_iou(track['predicted_box'], bbox)
|
||||
iou_matrix[i, j] = iou
|
||||
except Exception as e:
|
||||
logger.warning(f"计算IOU时出错: {str(e)}")
|
||||
|
||||
# 计算综合相似度
|
||||
similarity_matrix = (
|
||||
self.feature_similarity_weight * feature_similarity +
|
||||
self.motion_weight * iou_matrix
|
||||
)
|
||||
|
||||
# 匹配轨迹和检测
|
||||
matched_indices = []
|
||||
unmatched_tracks = list(range(len(self.tracks)))
|
||||
unmatched_detections = list(range(len(detection_boxes)))
|
||||
|
||||
# 贪心匹配 - 使用detach()防止梯度传播
|
||||
similarity_matrix_np = similarity_matrix.detach().cpu().numpy()
|
||||
while len(unmatched_tracks) > 0 and len(unmatched_detections) > 0:
|
||||
# 找到最大相似度
|
||||
track_idx, det_idx = np.unravel_index(
|
||||
np.argmax(similarity_matrix_np[unmatched_tracks][:, unmatched_detections]),
|
||||
(len(unmatched_tracks), len(unmatched_detections))
|
||||
)
|
||||
|
||||
track_idx = unmatched_tracks[track_idx]
|
||||
det_idx = unmatched_detections[det_idx]
|
||||
|
||||
# 如果相似度太低,停止匹配
|
||||
if similarity_matrix_np[track_idx, det_idx] < self.iou_threshold:
|
||||
break
|
||||
|
||||
# 添加匹配
|
||||
matched_indices.append((track_idx, det_idx))
|
||||
|
||||
# 从未匹配列表中移除
|
||||
unmatched_tracks.remove(track_idx)
|
||||
unmatched_detections.remove(det_idx)
|
||||
|
||||
# 更新匹配的轨迹
|
||||
for track_idx, det_idx in matched_indices:
|
||||
try:
|
||||
track = self.tracks[track_idx]
|
||||
det = valid_detections[det_idx]
|
||||
|
||||
# 确保检测结果格式正确
|
||||
if not isinstance(det, dict):
|
||||
logger.warning(f"匹配的检测结果 #{det_idx} 不是字典: {det}")
|
||||
continue
|
||||
|
||||
# 获取必要的字段
|
||||
bbox = det.get('bbox', None)
|
||||
if bbox is None or not isinstance(bbox, list) or len(bbox) != 4:
|
||||
logger.warning(f"匹配的检测结果 #{det_idx} 的边界框格式不正确: {bbox}")
|
||||
continue
|
||||
|
||||
class_id = det.get('class_id', track['class_id'])
|
||||
confidence = det.get('confidence', track['confidence'])
|
||||
|
||||
# 更新轨迹 - 限制保存的历史数据量
|
||||
MAX_HISTORY = 5 # 只保留最近5帧的数据
|
||||
|
||||
# 更新边界框历史,只保留最近的MAX_HISTORY个
|
||||
track['boxes'].append(bbox)
|
||||
if len(track['boxes']) > MAX_HISTORY:
|
||||
track['boxes'] = track['boxes'][-MAX_HISTORY:]
|
||||
|
||||
# 更新特征历史,只保留最近的MAX_HISTORY个
|
||||
track['features'].append(detection_features[det_idx])
|
||||
if len(track['features']) > MAX_HISTORY:
|
||||
track['features'] = track['features'][-MAX_HISTORY:]
|
||||
|
||||
track['class_id'] = class_id # 更新类别ID
|
||||
track['confidence'] = confidence # 更新置信度
|
||||
track['age'] += 1
|
||||
track['hits'] += 1
|
||||
track['time_since_update'] = 0
|
||||
|
||||
# 更新轨迹中心点 - 限制轨迹长度
|
||||
center_x = (bbox[0] + bbox[2]) / 2
|
||||
center_y = (bbox[1] + bbox[3]) / 2
|
||||
track['trajectory'].append((center_x, center_y))
|
||||
if len(track['trajectory']) > MAX_HISTORY * 2: # 轨迹可以稍长一些
|
||||
track['trajectory'] = track['trajectory'][-MAX_HISTORY * 2:]
|
||||
except Exception as e:
|
||||
logger.warning(f"更新轨迹 #{track_idx} 时出错: {str(e)}")
|
||||
|
||||
# 更新未匹配的轨迹
|
||||
for track_idx in unmatched_tracks:
|
||||
track = self.tracks[track_idx]
|
||||
track['time_since_update'] += 1
|
||||
|
||||
# 创建新轨迹
|
||||
for det_idx in unmatched_detections:
|
||||
try:
|
||||
det = valid_detections[det_idx]
|
||||
|
||||
# 如果是单目标模式,检查是否应该创建新轨迹
|
||||
if self.single_target_mode:
|
||||
# 获取检测的类别ID
|
||||
class_id = det.get('class_id', 0)
|
||||
|
||||
# 如果有目标类别ID,只为该类别创建轨迹
|
||||
if hasattr(self, 'target_class_id'):
|
||||
if class_id != self.target_class_id:
|
||||
logger.info(f"跳过创建类别ID不匹配的轨迹: 检测类别ID={class_id}, 目标类别ID={self.target_class_id}")
|
||||
continue
|
||||
# 允许创建多个相同类别的轨迹,以便追踪同一类别的多个物体
|
||||
# 注释掉以下代码,允许创建多个相同类别的轨迹
|
||||
# elif any(track['class_id'] == self.target_class_id for track in self.tracks):
|
||||
# logger.info(f"已有类别ID为 {self.target_class_id} 的轨迹,不创建新轨迹")
|
||||
# continue
|
||||
# 如果有目标ID且已有轨迹,不创建新轨迹
|
||||
elif self.target_id is not None and len(self.tracks) > 0:
|
||||
continue
|
||||
|
||||
# 确保检测结果格式正确
|
||||
if not isinstance(det, dict):
|
||||
logger.warning(f"未匹配的检测结果 #{det_idx} 不是字典: {det}")
|
||||
continue
|
||||
|
||||
# 获取必要的字段
|
||||
bbox = det.get('bbox', None)
|
||||
if bbox is None or not isinstance(bbox, list) or len(bbox) != 4:
|
||||
logger.warning(f"未匹配的检测结果 #{det_idx} 的边界框格式不正确: {bbox}")
|
||||
continue
|
||||
|
||||
class_id = det.get('class_id', 0)
|
||||
confidence = det.get('confidence', 0.5)
|
||||
|
||||
# 提取特征
|
||||
if det_idx >= len(detection_features):
|
||||
logger.warning(f"检测特征索引越界: {det_idx} >= {len(detection_features)}")
|
||||
continue
|
||||
|
||||
features = detection_features[det_idx].unsqueeze(0)
|
||||
|
||||
# 创建新轨迹
|
||||
self.tracks.append({
|
||||
'id': self.next_id,
|
||||
'boxes': [bbox],
|
||||
'class_id': class_id,
|
||||
'confidence': confidence,
|
||||
'features': [features[0]],
|
||||
'age': 1,
|
||||
'hits': 1,
|
||||
'time_since_update': 0,
|
||||
'predicted_box': bbox,
|
||||
'trajectory': [((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)]
|
||||
})
|
||||
|
||||
self.next_id += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"创建新轨迹时出错: {str(e)}")
|
||||
|
||||
# 删除过期的轨迹
|
||||
self.tracks = [
|
||||
track for track in self.tracks
|
||||
if track['time_since_update'] <= self.max_age
|
||||
]
|
||||
|
||||
return self.get_active_tracks()
|
||||
|
||||
def get_active_tracks(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取活跃的轨迹
|
||||
|
||||
返回:
|
||||
List[Dict]: 活跃轨迹列表
|
||||
"""
|
||||
active_tracks = []
|
||||
|
||||
# 记录日志,帮助调试
|
||||
logger.info(f"获取活跃轨迹: 单目标模式={self.single_target_mode}, 目标ID={self.target_id}, 目标类别ID={getattr(self, 'target_class_id', None)}")
|
||||
logger.info(f"当前轨迹数量: {len(self.tracks)}")
|
||||
|
||||
# 如果是单目标模式且有目标类别ID,记录所有轨迹的类别ID
|
||||
if self.single_target_mode and hasattr(self, 'target_class_id'):
|
||||
track_class_ids = [track.get('class_id') for track in self.tracks]
|
||||
logger.info(f"轨迹类别IDs: {track_class_ids}")
|
||||
|
||||
for track in self.tracks:
|
||||
# 只返回命中次数足够的轨迹
|
||||
if track['hits'] >= self.min_hits and track['time_since_update'] <= 1:
|
||||
# 如果是单目标模式,检查目标ID或目标类别ID
|
||||
if self.single_target_mode:
|
||||
# 如果有目标ID,只返回该ID的轨迹
|
||||
if self.target_id is not None and track['id'] != self.target_id:
|
||||
continue
|
||||
# 如果有目标类别ID,只返回该类别的轨迹
|
||||
elif hasattr(self, 'target_class_id') and track['class_id'] != self.target_class_id:
|
||||
logger.info(f"跳过类别ID不匹配的轨迹: 轨迹类别ID={track['class_id']}, 目标类别ID={self.target_class_id}")
|
||||
continue
|
||||
|
||||
# 复制轨迹信息
|
||||
track_info = {
|
||||
'id': track['id'],
|
||||
'bbox': track['boxes'][-1],
|
||||
'class_id': track['class_id'],
|
||||
'confidence': track['confidence'],
|
||||
'age': track['age'],
|
||||
'time_since_update': track['time_since_update'],
|
||||
'trajectory': track['trajectory']
|
||||
}
|
||||
|
||||
active_tracks.append(track_info)
|
||||
|
||||
logger.info(f"返回活跃轨迹数量: {len(active_tracks)}")
|
||||
return active_tracks
|
||||
|
||||
# 追踪报告功能已移除
|
||||
227
app/nn/modules/layer_scale.py
Normal file
227
app/nn/modules/layer_scale.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Layer Scale 和 Stochastic Depth 模块
|
||||
|
||||
这些模块用于增强深层网络的训练稳定性和泛化能力。
|
||||
- Layer Scale: 在每个层的输出添加可学习的深度缩放参数
|
||||
- Stochastic Depth: 在训练期间随机丢弃整个层
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Callable
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
"""
|
||||
Layer Scale 模块
|
||||
|
||||
在每个层的输出添加可学习的深度缩放参数,以稳定深层网络的训练。
|
||||
这些参数初始化为很小的值,随着训练逐渐增大。
|
||||
|
||||
参数:
|
||||
dim (int): 特征维度
|
||||
init_value (float): 初始化值,通常很小,如1e-6
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, init_value: float = 1e-6):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(dim) * init_value)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入特征
|
||||
|
||||
返回:
|
||||
torch.Tensor: 缩放后的特征
|
||||
"""
|
||||
# 根据输入维度调整gamma的形状
|
||||
if x.dim() == 3: # [B, N, C]
|
||||
return x * self.gamma.unsqueeze(0).unsqueeze(0)
|
||||
elif x.dim() == 4: # [B, C, H, W]
|
||||
return x * self.gamma.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError(f"不支持的输入维度: {x.dim()}")
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""
|
||||
随机深度 (Stochastic Depth) 模块
|
||||
|
||||
在训练期间随机丢弃整个层,以提高模型的泛化能力和稳定性。
|
||||
|
||||
参数:
|
||||
drop_prob (float): 丢弃概率,0.0表示不丢弃
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob: float = 0.0):
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 输入特征
|
||||
|
||||
返回:
|
||||
torch.Tensor: 随机丢弃后的特征
|
||||
"""
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
|
||||
keep_prob = 1.0 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # 保持批次维度,其他维度为1
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # 二值化
|
||||
output = x.div(keep_prob) * random_tensor # 缩放保留的激活
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerEncoderLayerWithLayerScale(nn.Module):
|
||||
"""
|
||||
带有Layer Scale的Transformer编码器层
|
||||
|
||||
在自注意力和前馈网络的残差连接后应用Layer Scale,以稳定深层网络的训练。
|
||||
|
||||
参数:
|
||||
d_model (int): 模型的维度
|
||||
nhead (int): 多头注意力中的头数
|
||||
dim_feedforward (int): 前馈网络的隐藏层维度
|
||||
dropout (float): dropout率
|
||||
activation (str): 激活函数,'relu'或'gelu'
|
||||
layer_scale_init_value (float): Layer Scale初始化值
|
||||
drop_path_rate (float): 随机深度丢弃概率
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
activation: str = "relu",
|
||||
layer_scale_init_value: float = 1e-6,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 多头自注意力
|
||||
from app.nn.modules.attention import SelfAttention
|
||||
self.self_attn = SelfAttention(d_model, num_heads=nhead, attn_drop=dropout, proj_drop=dropout)
|
||||
|
||||
# 前馈神经网络
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
# 层归一化
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
# dropout
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
# Layer Scale
|
||||
self.layer_scale_1 = LayerScale(d_model, init_value=layer_scale_init_value)
|
||||
self.layer_scale_2 = LayerScale(d_model, init_value=layer_scale_init_value)
|
||||
|
||||
# 随机深度
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
||||
|
||||
# 激活函数
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, src: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
src: 输入序列 [batch_size, seq_len, d_model]
|
||||
|
||||
返回:
|
||||
输出序列 [batch_size, seq_len, d_model]
|
||||
"""
|
||||
# 自注意力子层
|
||||
src2 = self.self_attn(self.norm1(src))
|
||||
src = src + self.drop_path(self.layer_scale_1(self.dropout1(src2)))
|
||||
|
||||
# 前馈网络子层
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(src)))))
|
||||
src = src + self.drop_path(self.layer_scale_2(self.dropout2(src2)))
|
||||
|
||||
return src
|
||||
|
||||
|
||||
class TransformerEncoderWithLayerScale(nn.Module):
|
||||
"""
|
||||
带有Layer Scale的Transformer编码器
|
||||
|
||||
由多个带有Layer Scale的编码器层堆叠而成。
|
||||
|
||||
参数:
|
||||
d_model (int): 模型的维度
|
||||
nhead (int): 多头注意力中的头数
|
||||
num_layers (int): 编码器层的数量
|
||||
dim_feedforward (int): 前馈网络的隐藏层维度
|
||||
dropout (float): dropout率
|
||||
activation (str): 激活函数,'relu'或'gelu'
|
||||
layer_scale_init_value (float): Layer Scale初始化值
|
||||
drop_path_rates (List[float]): 每一层的随机深度丢弃概率
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
num_layers: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
activation: str = "relu",
|
||||
layer_scale_init_value: float = 1e-6,
|
||||
drop_path_rate: float = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 为每一层计算不同的丢弃概率(深层更高)
|
||||
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
|
||||
|
||||
# 创建编码器层
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerEncoderLayerWithLayerScale(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
drop_path_rate=drop_path_rates[i]
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, src: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
参数:
|
||||
src: 输入序列 [batch_size, seq_len, d_model]
|
||||
|
||||
返回:
|
||||
输出序列 [batch_size, seq_len, d_model]
|
||||
"""
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output)
|
||||
|
||||
return self.norm(output)
|
||||
191
app/services/ascend_service.py
Normal file
191
app/services/ascend_service.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Ascend NPU service module"""
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
class AscendDeviceManager:
|
||||
"""华为昇腾NPU设备管理器"""
|
||||
|
||||
@staticmethod
|
||||
def get_available_ascends() -> list:
|
||||
"""
|
||||
获取所有可用的昇腾NPU信息
|
||||
:return: NPU信息列表 [{'index': 0, 'name': 'NPU名称', 'memory': 内存大小(MB), 'memory_used': 已用内存(MB), 'memory_free': 可用内存(MB)}]
|
||||
"""
|
||||
ascends = []
|
||||
|
||||
try:
|
||||
# 尝试导入昇腾相关库
|
||||
# 注意:实际使用时需要安装相关依赖
|
||||
# 尝试导入昇腾相关库
|
||||
import torch_npu
|
||||
import acl
|
||||
|
||||
# 获取NPU数量
|
||||
device_count = torch_npu.npu.device_count()
|
||||
|
||||
if device_count == 0:
|
||||
# 没有检测到昇腾NPU设备
|
||||
print("未检测到昇腾NPU设备")
|
||||
return []
|
||||
|
||||
# 获取所有昇腾NPU信息
|
||||
for i in range(device_count):
|
||||
# 获取NPU属性
|
||||
props = torch_npu.npu.get_device_properties(i)
|
||||
total_memory = int(props.total_memory / (1024 * 1024)) # 转换为MB
|
||||
|
||||
# 获取内存使用情况
|
||||
memory_used = int(torch_npu.npu.memory_used(i) / (1024 * 1024))
|
||||
memory_free = total_memory - memory_used
|
||||
|
||||
ascends.append({
|
||||
'index': i,
|
||||
'name': props.name,
|
||||
'memory': total_memory,
|
||||
'memory_used': memory_used,
|
||||
'memory_free': memory_free,
|
||||
'recommended_memory': int(memory_free * 0.8) # 推荐使用80%的可用内存
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"获取昇腾NPU信息失败: {str(e)}")
|
||||
|
||||
return ascends
|
||||
|
||||
@staticmethod
|
||||
def validate_ascend_memory(requested_memory: int, ascend_index: int = 0) -> tuple[bool, str, int]:
|
||||
"""
|
||||
验证请求的昇腾NPU内存是否合理
|
||||
:param requested_memory: 请求的内存大小(MB)
|
||||
:param ascend_index: 昇腾NPU索引,默认为0
|
||||
:return: (是否有效, 提示信息, 总内存大小)
|
||||
"""
|
||||
# 获取昇腾NPU信息
|
||||
ascends = AscendDeviceManager.get_available_ascends()
|
||||
if not ascends:
|
||||
return False, "昇腾NPU不可用,请使用其他模式训练", 0
|
||||
|
||||
# 查找指定的昇腾NPU
|
||||
ascend_info = None
|
||||
for ascend in ascends:
|
||||
if ascend.get("index", 0) == ascend_index:
|
||||
ascend_info = ascend
|
||||
break
|
||||
|
||||
# 如果没有找到指定的昇腾NPU,使用第一个
|
||||
if not ascend_info and ascends:
|
||||
ascend_info = ascends[0]
|
||||
|
||||
if not ascend_info:
|
||||
return False, "昇腾NPU信息获取失败,请使用其他模式训练", 0
|
||||
|
||||
total_memory = ascend_info.get("memory", 0)
|
||||
free_memory = ascend_info.get("memory_free", 0)
|
||||
used_memory = ascend_info.get("memory_used", 0)
|
||||
|
||||
if requested_memory <= 0:
|
||||
return False, f"请求的内存必须大于0MB", total_memory
|
||||
|
||||
if requested_memory > total_memory:
|
||||
return False, f"请求的内存({requested_memory}MB)超过了昇腾NPU最大内存({total_memory}MB)", total_memory
|
||||
|
||||
# 检查是否超过可用内存
|
||||
if requested_memory > free_memory:
|
||||
return False, f"请求的内存({requested_memory}MB)超过了当前可用内存({free_memory}MB)", total_memory
|
||||
|
||||
# 建议最多使用可用内存的90%
|
||||
recommended_memory = int(free_memory * 0.9)
|
||||
if requested_memory > recommended_memory:
|
||||
return False, f"建议使用不超过{recommended_memory}MB内存(当前可用内存{free_memory}MB)", total_memory
|
||||
|
||||
return True, "内存设置有效", total_memory
|
||||
|
||||
@staticmethod
|
||||
def get_device_info(ascend_memory: Optional[int] = None, ascend_index: int = 0) -> dict:
|
||||
"""
|
||||
获取昇腾NPU设备信息
|
||||
:param ascend_memory: 昇腾NPU内存限制(MB)
|
||||
:param ascend_index: 昇腾NPU索引,默认为0
|
||||
:return: 设备配置信息
|
||||
"""
|
||||
# 获取所有可用的昇腾NPU
|
||||
available_ascends = AscendDeviceManager.get_available_ascends()
|
||||
|
||||
device_info = {
|
||||
'device_type': 'ascend',
|
||||
'device': 'cpu', # 默认使用CPU,如果昇腾NPU可用则会更新
|
||||
'ascend_memory': None,
|
||||
'ascend_index': ascend_index,
|
||||
'available_ascends': available_ascends
|
||||
}
|
||||
|
||||
# 检查是否有可用的昇腾NPU
|
||||
has_ascend = len(available_ascends) > 0
|
||||
|
||||
if not has_ascend:
|
||||
print('\n=== 警告: 昇腾NPU不可用,将使用CPU训练 ===')
|
||||
device_info['device_type'] = 'cpu'
|
||||
return device_info
|
||||
|
||||
# 检查指定的昇腾NPU是否存在
|
||||
selected_ascend = None
|
||||
for ascend in available_ascends:
|
||||
if ascend.get("index", 0) == ascend_index:
|
||||
selected_ascend = ascend
|
||||
break
|
||||
|
||||
# 如果没有找到指定的昇腾NPU,使用第一个
|
||||
if not selected_ascend and available_ascends:
|
||||
selected_ascend = available_ascends[0]
|
||||
ascend_index = selected_ascend.get("index", 0)
|
||||
device_info['ascend_index'] = ascend_index
|
||||
print(f"\n=== 警告: 指定的昇腾NPU ID {ascend_index} 不存在,使用昇腾NPU ID {ascend_index} ===")
|
||||
|
||||
if selected_ascend:
|
||||
# 设置当前设备
|
||||
try:
|
||||
# 这里需要根据实际的昇腾NPU API进行实现
|
||||
# 以下是示例代码,实际使用时需要替换为真实的API调用
|
||||
# torch_npu.npu.set_device(ascend_index)
|
||||
device_info['device'] = f'npu:{ascend_index}'
|
||||
except Exception as e:
|
||||
print(f'\n=== 警告: 无法设置当前昇腾NPU设备: {str(e)} ===')
|
||||
device_info['device'] = 'npu' # 使用默认昇腾NPU
|
||||
|
||||
# 设置昇腾NPU内存限制
|
||||
if ascend_memory:
|
||||
# 获取选定的昇腾NPU信息
|
||||
total_memory = selected_ascend.get("memory", 0)
|
||||
free_memory = selected_ascend.get("memory_free", 0)
|
||||
|
||||
# 验证昇腾NPU内存设置
|
||||
if ascend_memory <= 0:
|
||||
print(f"\n=== 警告: 请求的内存必须大于0MB ===")
|
||||
# 使用推荐的内存大小(80%的可用内存)
|
||||
ascend_memory = int(free_memory * 0.8)
|
||||
elif ascend_memory > total_memory:
|
||||
print(f"\n=== 警告: 请求的内存({ascend_memory}MB)超过了昇腾NPU最大内存({total_memory}MB) ===")
|
||||
# 使用推荐的内存大小(80%的可用内存)
|
||||
ascend_memory = int(free_memory * 0.8)
|
||||
elif ascend_memory > free_memory:
|
||||
print(f"\n=== 警告: 请求的内存({ascend_memory}MB)超过了当前可用内存({free_memory}MB) ===")
|
||||
# 使用推荐的内存大小(80%的可用内存)
|
||||
ascend_memory = int(free_memory * 0.8)
|
||||
|
||||
print(f"\n=== 设置昇腾NPU {ascend_index} 内存限制为 {ascend_memory}MB ===")
|
||||
|
||||
# 设置昇腾NPU内存限制
|
||||
try:
|
||||
# 这里需要根据实际的昇腾NPU API进行实现
|
||||
# 以下是示例代码,实际使用时需要替换为真实的API调用
|
||||
# torch_npu.npu.set_memory_limit(ascend_memory * 1024 * 1024) # 转换为字节
|
||||
device_info['ascend_memory'] = ascend_memory
|
||||
except Exception as e:
|
||||
print(f'\n=== 警告: 无法设置昇腾NPU内存限制: {str(e)} ===')
|
||||
else:
|
||||
print('\n=== 警告: 没有可用的昇腾NPU,将使用CPU训练 ===')
|
||||
device_info['device_type'] = 'cpu'
|
||||
device_info['device'] = 'cpu'
|
||||
|
||||
return device_info
|
||||
@@ -50,43 +50,56 @@ except ImportError as e:
|
||||
|
||||
async def create_detection_task(
|
||||
db: Session,
|
||||
model_id: str,
|
||||
file: UploadFile,
|
||||
parameters: Dict[str, Any]
|
||||
model_id: str = None,
|
||||
file: UploadFile = None,
|
||||
parameters: Dict[str, Any] = None
|
||||
) -> DetectionTask:
|
||||
"""
|
||||
Create a new detection task
|
||||
"""
|
||||
# Check if model exists
|
||||
# 初始化参数
|
||||
if parameters is None:
|
||||
parameters = {}
|
||||
|
||||
# 默认模型ID和路径
|
||||
DEFAULT_MODEL_ID = "6b30123c-2c33-43c0-a045-41cda7e61d37" # 使用生成的UUID
|
||||
model_path = "yolov8n.pt" # 默认模型路径
|
||||
|
||||
# 如果没有提供模型ID,使用默认模型ID
|
||||
if not model_id:
|
||||
model_id = DEFAULT_MODEL_ID
|
||||
print(f"使用默认模型ID: {model_id}")
|
||||
|
||||
# 检查模型是否存在
|
||||
db_model = model.get(db, id=model_id)
|
||||
if not db_model:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model not found",
|
||||
)
|
||||
# 如果找不到模型,创建一个默认模型记录
|
||||
db_model = model.create_with_fields(db, obj_in={
|
||||
"id": DEFAULT_MODEL_ID,
|
||||
"name": "YOLOv8n",
|
||||
"description": "默认YOLOv8n模型",
|
||||
"path": model_path,
|
||||
"type": "detection",
|
||||
"task": "object_detection",
|
||||
"source": "default"
|
||||
})
|
||||
print(f"创建默认模型: {db_model.name}")
|
||||
|
||||
# 使用模型路径
|
||||
model_path = db_model.path
|
||||
|
||||
# Generate unique ID for the detection task
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Create input and output directories
|
||||
input_dir = settings.UPLOADS_DIR / task_id
|
||||
output_dir = settings.RESULTS_DIR / task_id
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# 直接读取上传的文件内容到内存中,不保存到本地
|
||||
content = await file.read()
|
||||
|
||||
# Save the uploaded file
|
||||
file_extension = Path(file.filename).suffix
|
||||
input_path = input_dir / f"input{file_extension}"
|
||||
with open(input_path, "wb") as buffer:
|
||||
content = await file.read()
|
||||
buffer.write(content)
|
||||
|
||||
# Create detection task with all required fields
|
||||
# 创建内存中的检测任务记录
|
||||
obj_in_data = {
|
||||
"model_id": model_id,
|
||||
"model_id": model_id, # 确保model_id总是有值
|
||||
"parameters": parameters,
|
||||
"input_path": str(input_path),
|
||||
"output_path": str(output_dir),
|
||||
"input_path": f"memory:{file.filename}", # 使用内存标识符
|
||||
"output_path": f"memory:{task_id}", # 使用内存标识符
|
||||
"status": "pending"
|
||||
}
|
||||
|
||||
@@ -94,7 +107,7 @@ async def create_detection_task(
|
||||
db_task = detection_task.create_with_fields(db, obj_in=obj_in_data)
|
||||
|
||||
# 在实际实现中,这将由后台任务处理
|
||||
# 现在,我们将执行实际的检测过程
|
||||
# 执行实际的检测过程
|
||||
try:
|
||||
# 导入YOLO模型
|
||||
from ultralytics import YOLO
|
||||
@@ -135,32 +148,36 @@ async def create_detection_task(
|
||||
except ImportError as e:
|
||||
print(f"Warning: Could not import required classes: {e}")
|
||||
|
||||
# 导入必要的库
|
||||
import cv2
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
|
||||
# 加载模型
|
||||
model_path = db_model.path
|
||||
yolo_model = YOLO(model_path)
|
||||
print(f"成功加载模型: {model_path}")
|
||||
|
||||
# 执行检测
|
||||
results = yolo_model(str(input_path), conf=parameters.get('conf_thres', 0.25), iou=parameters.get('iou_thres', 0.45))
|
||||
# 将内容转换为numpy数组
|
||||
nparr = np.frombuffer(content, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if img is None:
|
||||
raise ValueError("无法解码图像数据")
|
||||
|
||||
# 执行检测 - 直接传递图像数组而不是文件路径
|
||||
results = yolo_model(img, conf=parameters.get('conf_thres', 0.25), iou=parameters.get('iou_thres', 0.45))
|
||||
|
||||
# 在内存中处理结果
|
||||
import json
|
||||
memory_results = []
|
||||
|
||||
# 保存结果
|
||||
for i, result in enumerate(results):
|
||||
# 保存带有检测框的图像
|
||||
result_path = output_dir / f"result_{i}.jpg"
|
||||
|
||||
# 使用新版Ultralytics API保存图像
|
||||
# 方法1: 使用结果中的绘制后图像
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# 获取绘制后的图像
|
||||
# 如果结果中有绘制后的图像,直接使用
|
||||
if hasattr(result, 'plot') and callable(getattr(result, 'plot')):
|
||||
# 使用plot方法获取绘制后的图像
|
||||
plotted_img = result.plot()
|
||||
cv2.imwrite(str(result_path), plotted_img)
|
||||
else:
|
||||
# 如果没有plot方法,手动绘制检测框
|
||||
img = result.orig_img.copy()
|
||||
plotted_img = img.copy()
|
||||
boxes = result.boxes
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = map(int, box.xyxy.tolist()[0])
|
||||
@@ -170,34 +187,41 @@ async def create_detection_task(
|
||||
|
||||
# 绘制检测框
|
||||
color = (0, 255, 0) # BGR格式,绿色
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.rectangle(plotted_img, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# 绘制标签
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
cv2.putText(img, label, (x1, y1 - 10), font, 0.5, color, 2)
|
||||
cv2.putText(plotted_img, label, (x1, y1 - 10), font, 0.5, color, 2)
|
||||
|
||||
# 保存图像
|
||||
cv2.imwrite(str(result_path), img)
|
||||
# 将检测结果转换为JSON格式
|
||||
boxes = result.boxes
|
||||
json_results = []
|
||||
for box in boxes:
|
||||
class_id = int(box.cls.item())
|
||||
json_results.append({
|
||||
'class': class_id, # 保留'class'字段以兼容旧代码
|
||||
'class_id': class_id, # 添加'class_id'字段以兼容前端代码
|
||||
'class_name': result.names[class_id],
|
||||
'confidence': float(box.conf.item()),
|
||||
'bbox': box.xyxy.tolist()[0],
|
||||
})
|
||||
|
||||
# 保存JSON结果
|
||||
json_path = output_dir / f"result_{i}.json"
|
||||
with open(json_path, 'w') as f:
|
||||
import json
|
||||
# 将检测结果转换为JSON格式
|
||||
boxes = result.boxes
|
||||
json_results = []
|
||||
for box in boxes:
|
||||
json_results.append({
|
||||
'class': int(box.cls.item()),
|
||||
'class_name': result.names[int(box.cls.item())],
|
||||
'confidence': float(box.conf.item()),
|
||||
'bbox': box.xyxy.tolist()[0],
|
||||
})
|
||||
json.dump(json_results, f, indent=2)
|
||||
# 不保存图像数据,只保存检测结果
|
||||
# 这样可以大幅减少内存占用
|
||||
memory_results.append({
|
||||
'detections': json_results
|
||||
})
|
||||
|
||||
# 将结果存储在参数中
|
||||
updated_parameters = {
|
||||
**parameters,
|
||||
'memory_results': memory_results
|
||||
}
|
||||
|
||||
# 更新任务状态为已完成
|
||||
db_task = detection_task.update(db, db_obj=db_task, obj_in={
|
||||
"status": "completed"
|
||||
"status": "completed",
|
||||
"parameters": updated_parameters
|
||||
})
|
||||
except Exception as e:
|
||||
# 如果发生错误,更新任务状态为失败
|
||||
@@ -232,7 +256,7 @@ def get_detection_tasks(db: Session, skip: int = 0, limit: int = 100) -> List[De
|
||||
|
||||
def get_detection_result(db: Session, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the detection result for a task
|
||||
Get the detection result for a task - 支持内存中的结果
|
||||
"""
|
||||
# 获取任务信息
|
||||
db_task = detection_task.get(db, id=task_id)
|
||||
@@ -250,6 +274,42 @@ def get_detection_result(db: Session, task_id: str) -> Dict[str, Any]:
|
||||
"results": None
|
||||
}
|
||||
|
||||
# 检查是否是内存中的结果
|
||||
if db_task.output_path.startswith("memory:"):
|
||||
# 从参数中获取内存结果
|
||||
memory_results = db_task.parameters.get('memory_results', [])
|
||||
|
||||
if not memory_results:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No detection results found in memory",
|
||||
"results": None
|
||||
}
|
||||
|
||||
# 构建结果对象
|
||||
results = []
|
||||
for i, result in enumerate(memory_results):
|
||||
try:
|
||||
# 从内存结果中提取检测结果
|
||||
detections = result.get('detections', [])
|
||||
|
||||
# 不返回图像数据,只返回检测结果
|
||||
results.append({
|
||||
"memory_image_id": f"memory_image_{i}", # 使用内存标识符
|
||||
"detections": detections,
|
||||
"count": len(detections)
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Error processing memory result #{i}: {e}")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"message": "Detection completed successfully (memory mode)",
|
||||
"results": results,
|
||||
"input_image": db_task.input_path.replace("memory:", "")
|
||||
}
|
||||
|
||||
# 以下是原有的文件系统处理逻辑,保留以兼容旧数据
|
||||
# 获取输出目录
|
||||
output_dir = Path(db_task.output_path)
|
||||
if not output_dir.exists():
|
||||
|
||||
@@ -4,8 +4,6 @@ OpenCV 服务模块 - 提供图像处理和计算机视觉功能
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Tuple, Optional, Union
|
||||
import logging
|
||||
@@ -989,6 +987,200 @@ class OpenCVService:
|
||||
|
||||
return grid
|
||||
|
||||
# ==================== 高级数据增强功能 ====================
|
||||
|
||||
@staticmethod
|
||||
def apply_cutmix(img1: np.ndarray, img2: np.ndarray, alpha: float = 0.5) -> np.ndarray:
|
||||
"""
|
||||
应用CutMix数据增强
|
||||
|
||||
Args:
|
||||
img1: 第一张图像
|
||||
img2: 第二张图像
|
||||
alpha: 混合参数,控制裁剪区域大小
|
||||
|
||||
Returns:
|
||||
np.ndarray: 增强后的图像
|
||||
"""
|
||||
# 确保两张图像大小一致
|
||||
h, w = img1.shape[:2]
|
||||
img2 = cv2.resize(img2, (w, h))
|
||||
|
||||
# 生成随机裁剪区域
|
||||
lam = np.random.beta(alpha, alpha)
|
||||
|
||||
# 计算裁剪区域大小
|
||||
cut_w = int(w * np.sqrt(1 - lam))
|
||||
cut_h = int(h * np.sqrt(1 - lam))
|
||||
|
||||
# 随机选择裁剪区域中心点
|
||||
cx = np.random.randint(w)
|
||||
cy = np.random.randint(h)
|
||||
|
||||
# 计算裁剪区域边界
|
||||
x1 = max(0, cx - cut_w // 2)
|
||||
y1 = max(0, cy - cut_h // 2)
|
||||
x2 = min(w, cx + cut_w // 2)
|
||||
y2 = min(h, cy + cut_h // 2)
|
||||
|
||||
# 创建结果图像
|
||||
result = img1.copy()
|
||||
|
||||
# 将img2的区域粘贴到img1上
|
||||
result[y1:y2, x1:x2] = img2[y1:y2, x1:x2]
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def apply_mixup(img1: np.ndarray, img2: np.ndarray, alpha: float = 0.5) -> np.ndarray:
|
||||
"""
|
||||
应用MixUp数据增强
|
||||
|
||||
Args:
|
||||
img1: 第一张图像
|
||||
img2: 第二张图像
|
||||
alpha: 混合参数,控制混合权重
|
||||
|
||||
Returns:
|
||||
np.ndarray: 增强后的图像
|
||||
"""
|
||||
# 确保两张图像大小一致
|
||||
h, w = img1.shape[:2]
|
||||
img2 = cv2.resize(img2, (w, h))
|
||||
|
||||
# 生成混合权重
|
||||
lam = np.random.beta(alpha, alpha)
|
||||
|
||||
# 混合图像
|
||||
result = cv2.addWeighted(img1, lam, img2, 1 - lam, 0)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def apply_mosaic(images: List[np.ndarray]) -> np.ndarray:
|
||||
"""
|
||||
应用Mosaic数据增强
|
||||
|
||||
Args:
|
||||
images: 四张图像的列表
|
||||
|
||||
Returns:
|
||||
np.ndarray: 增强后的图像
|
||||
"""
|
||||
if len(images) != 4:
|
||||
raise ValueError("Mosaic需要四张图像")
|
||||
|
||||
# 确定输出图像大小
|
||||
output_size = 640
|
||||
|
||||
# 创建空白画布
|
||||
mosaic_img = np.zeros((output_size, output_size, 3), dtype=np.uint8)
|
||||
|
||||
# 计算中心点
|
||||
cx = output_size // 2
|
||||
cy = output_size // 2
|
||||
|
||||
# 调整图像大小并放置到画布上
|
||||
# 左上
|
||||
img1 = cv2.resize(images[0], (cx, cy))
|
||||
mosaic_img[:cy, :cx] = img1
|
||||
|
||||
# 右上
|
||||
img2 = cv2.resize(images[1], (cx, cy))
|
||||
mosaic_img[:cy, cx:] = img2
|
||||
|
||||
# 左下
|
||||
img3 = cv2.resize(images[2], (cx, cy))
|
||||
mosaic_img[cy:, :cx] = img3
|
||||
|
||||
# 右下
|
||||
img4 = cv2.resize(images[3], (cx, cy))
|
||||
mosaic_img[cy:, cx:] = img4
|
||||
|
||||
return mosaic_img
|
||||
|
||||
@staticmethod
|
||||
def apply_weather_effect(img: np.ndarray, weather_type: str, intensity: float = 0.5) -> np.ndarray:
|
||||
"""
|
||||
应用天气效果
|
||||
|
||||
Args:
|
||||
img: 输入图像
|
||||
weather_type: 天气类型 (rain, snow, fog)
|
||||
intensity: 效果强度 (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
np.ndarray: 增强后的图像
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
result = img.copy()
|
||||
|
||||
# 限制强度范围
|
||||
intensity = max(0.1, min(1.0, intensity))
|
||||
|
||||
if weather_type == "rain":
|
||||
# 模拟雨滴
|
||||
rain_drops = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
|
||||
# 雨滴数量与强度成正比
|
||||
num_drops = int(intensity * 1000)
|
||||
|
||||
# 随机生成雨滴
|
||||
for _ in range(num_drops):
|
||||
x = np.random.randint(0, w)
|
||||
y = np.random.randint(0, h)
|
||||
length = np.random.randint(5, 15)
|
||||
angle = np.random.uniform(0.7, 0.9) * np.pi # 雨滴角度
|
||||
|
||||
# 绘制雨滴
|
||||
x2 = int(x + length * np.cos(angle))
|
||||
y2 = int(y + length * np.sin(angle))
|
||||
|
||||
# 确保雨滴在图像范围内
|
||||
if 0 <= x2 < w and 0 <= y2 < h:
|
||||
cv2.line(rain_drops, (x, y), (x2, y2), (200, 200, 255), 1)
|
||||
|
||||
# 添加雨滴到原图
|
||||
result = cv2.addWeighted(result, 1.0, rain_drops, intensity, 0)
|
||||
|
||||
# 降低亮度和对比度
|
||||
result = cv2.addWeighted(result, 0.8, np.zeros_like(result), 0, 0)
|
||||
|
||||
elif weather_type == "snow":
|
||||
# 模拟雪花
|
||||
snow_layer = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
|
||||
# 雪花数量与强度成正比
|
||||
num_flakes = int(intensity * 1000)
|
||||
|
||||
# 随机生成雪花
|
||||
for _ in range(num_flakes):
|
||||
x = np.random.randint(0, w)
|
||||
y = np.random.randint(0, h)
|
||||
size = np.random.randint(1, 4)
|
||||
|
||||
# 绘制雪花
|
||||
cv2.circle(snow_layer, (x, y), size, (255, 255, 255), -1)
|
||||
|
||||
# 添加雪花到原图
|
||||
result = cv2.addWeighted(result, 1.0, snow_layer, intensity, 0)
|
||||
|
||||
# 增加亮度
|
||||
brightness = int(intensity * 30)
|
||||
result = cv2.addWeighted(result, 1.0, np.zeros_like(result), 0, brightness)
|
||||
|
||||
elif weather_type == "fog":
|
||||
# 创建雾效果
|
||||
fog = np.ones((h, w, 3), dtype=np.uint8) * 255
|
||||
|
||||
# 添加雾到原图
|
||||
result = cv2.addWeighted(result, 1 - intensity * 0.7, fog, intensity * 0.7, 0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的天气类型: {weather_type}")
|
||||
|
||||
return result
|
||||
|
||||
# ==================== 视频处理功能 ====================
|
||||
|
||||
@staticmethod
|
||||
|
||||
152
app/services/tracking_service.py
Normal file
152
app/services/tracking_service.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
目标追踪服务模块 - 提供基于自注意力机制的目标追踪功能
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
import logging
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.nn.modules.enhanced_attention_tracker import EnhancedAttentionTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrackingService:
|
||||
"""
|
||||
目标追踪服务类,提供基于自注意力机制的目标追踪功能
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化追踪服务 - 不保存任何本地文件"""
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"初始化追踪服务,使用设备: {self.device}")
|
||||
|
||||
# 初始化追踪器
|
||||
self.tracker = EnhancedAttentionTracker(
|
||||
max_age=30,
|
||||
min_hits=3,
|
||||
iou_threshold=0.3,
|
||||
feature_similarity_weight=0.7,
|
||||
motion_weight=0.3,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
def reset_tracker(self):
|
||||
"""重置追踪器状态"""
|
||||
self.tracker.reset()
|
||||
|
||||
def set_single_target_mode(self, enable: bool, target_id: Optional[int] = None, target_class_id: Optional[int] = None):
|
||||
"""
|
||||
设置单目标追踪模式
|
||||
|
||||
参数:
|
||||
enable (bool): 是否启用单目标模式
|
||||
target_id (int, optional): 要追踪的目标ID
|
||||
target_class_id (int, optional): 要追踪的目标类别ID
|
||||
"""
|
||||
logger.info(f"设置单目标追踪模式: enable={enable}, target_id={target_id}, target_class_id={target_class_id}")
|
||||
self.tracker.set_single_target_mode(enable, target_id, target_class_id)
|
||||
|
||||
# 视频追踪功能已移除,仅保留摄像头追踪功能
|
||||
|
||||
def track_frame(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
detections: List[Dict[str, Any]],
|
||||
target_class_id: Optional[int] = None,
|
||||
enable_tracking: bool = False,
|
||||
cancel_tracking: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
追踪单帧中的目标 - 不保存任何本地文件
|
||||
|
||||
参数:
|
||||
frame: 输入帧
|
||||
detections: 检测结果列表
|
||||
target_class_id: 要追踪的目标类别ID
|
||||
enable_tracking: 是否启用追踪特定类别
|
||||
cancel_tracking: 是否取消追踪
|
||||
|
||||
返回:
|
||||
List[Dict]: 追踪结果列表
|
||||
"""
|
||||
try:
|
||||
# 记录请求参数
|
||||
logger.info(f"追踪帧参数: target_class_id={target_class_id}, enable_tracking={enable_tracking}, cancel_tracking={cancel_tracking}")
|
||||
|
||||
# 处理追踪模式切换
|
||||
if cancel_tracking:
|
||||
# 取消追踪 - 禁用单目标模式
|
||||
logger.info("取消追踪,禁用单目标模式")
|
||||
self.tracker.set_single_target_mode(False)
|
||||
elif enable_tracking and target_class_id is not None:
|
||||
# 启用追踪特定类别
|
||||
logger.info(f"启用追踪类别ID: {target_class_id}")
|
||||
# 确保target_class_id是整数
|
||||
try:
|
||||
target_class_id_int = int(target_class_id)
|
||||
self.tracker.set_single_target_mode(True, target_class_id=target_class_id_int)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"无法将target_class_id转换为整数: {e}")
|
||||
# 尝试直接使用原始值
|
||||
self.tracker.set_single_target_mode(True, target_class_id=target_class_id)
|
||||
|
||||
# 更新追踪器 - 确保不保存任何本地文件
|
||||
tracks = self.tracker.update(frame, detections)
|
||||
|
||||
# 记录追踪结果
|
||||
logger.info(f"追踪结果: {len(tracks)} 个目标")
|
||||
|
||||
# 如果启用了追踪特定类别,但没有找到目标,尝试再次查找
|
||||
if enable_tracking and target_class_id is not None and len(tracks) == 0:
|
||||
logger.info(f"未找到类别ID为 {target_class_id} 的目标,尝试再次查找")
|
||||
|
||||
# 查找检测结果中是否有目标类别
|
||||
for det in detections:
|
||||
if det.get('class_id') == target_class_id:
|
||||
logger.info(f"在检测结果中找到类别ID为 {target_class_id} 的目标,重新设置追踪模式")
|
||||
self.tracker.set_single_target_mode(True, target_class_id=target_class_id)
|
||||
|
||||
# 重新更新追踪器 - 确保不保存任何本地文件
|
||||
tracks = self.tracker.update(frame, detections)
|
||||
logger.info(f"重新追踪结果: {len(tracks)} 个目标")
|
||||
break
|
||||
|
||||
return tracks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"帧追踪失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"帧追踪失败: {str(e)}")
|
||||
|
||||
# 追踪报告功能已移除
|
||||
|
||||
def _get_color_by_id(self, track_id: int) -> Tuple[int, int, int]:
|
||||
"""
|
||||
根据轨迹ID生成颜色
|
||||
|
||||
参数:
|
||||
track_id: 轨迹ID
|
||||
|
||||
返回:
|
||||
Tuple[int, int, int]: BGR颜色
|
||||
"""
|
||||
# 使用固定的颜色列表
|
||||
colors = [
|
||||
(0, 0, 255), # 红色
|
||||
(0, 255, 0), # 绿色
|
||||
(255, 0, 0), # 蓝色
|
||||
(0, 255, 255), # 黄色
|
||||
(255, 0, 255), # 紫色
|
||||
(255, 255, 0), # 青色
|
||||
(128, 0, 0), # 深蓝色
|
||||
(0, 128, 0), # 深绿色
|
||||
(0, 0, 128), # 深红色
|
||||
(128, 128, 0), # 深青色
|
||||
]
|
||||
|
||||
# 使用ID取模选择颜色
|
||||
return colors[track_id % len(colors)]
|
||||
|
||||
|
||||
# 创建服务实例
|
||||
tracking_service = TrackingService()
|
||||
@@ -16,6 +16,7 @@ from app.core.config import settings
|
||||
from app.crud import training_task, dataset, model
|
||||
from app.models.training_task import TrainingTask
|
||||
from app.schemas.training_task import TrainingTaskCreate, TrainingTaskUpdate
|
||||
from app.services.ascend_service import AscendDeviceManager
|
||||
|
||||
# 导入可能需要的PyTorch和Ultralytics模型类
|
||||
try:
|
||||
@@ -139,34 +140,77 @@ class DeviceManager:
|
||||
return True, "显存设置有效", total_memory
|
||||
|
||||
@staticmethod
|
||||
def get_device_info(device_type: str = 'auto', gpu_memory: Optional[int] = None, gpu_index: int = 0) -> dict:
|
||||
def get_device_info(device_type: str = 'auto', gpu_memory: Optional[int] = None, gpu_index: int = 0,
|
||||
ascend_memory: Optional[int] = None, ascend_index: int = 0) -> dict:
|
||||
"""
|
||||
获取设备信息并配置训练设备
|
||||
:param device_type: 'cpu', 'gpu' 或 'auto'
|
||||
:param device_type: 'cpu', 'gpu', 'ascend' 或 'auto'
|
||||
:param gpu_memory: GPU显存限制(MB)
|
||||
:param gpu_index: GPU索引,默认为0
|
||||
:param ascend_memory: 昇腾NPU内存限制(MB)
|
||||
:param ascend_index: 昇腾NPU索引,默认为0
|
||||
:return: 设备配置信息
|
||||
"""
|
||||
# 获取所有可用的GPU
|
||||
# 获取所有可用的GPU和昇腾NPU
|
||||
available_gpus = DeviceManager.get_available_gpus()
|
||||
available_ascends = AscendDeviceManager.get_available_ascends()
|
||||
|
||||
device_info = {
|
||||
'device_type': device_type,
|
||||
'device': 'cpu',
|
||||
'gpu_memory': None,
|
||||
'gpu_index': gpu_index,
|
||||
'ascend_memory': None,
|
||||
'ascend_index': ascend_index,
|
||||
'cpu_cores': None,
|
||||
'available_gpus': available_gpus
|
||||
'available_gpus': available_gpus,
|
||||
'available_ascends': available_ascends
|
||||
}
|
||||
|
||||
# 检测是否有可用的GPU
|
||||
# 检测是否有可用的GPU和昇腾NPU
|
||||
has_cuda = torch.cuda.is_available()
|
||||
has_ascend = len(available_ascends) > 0
|
||||
|
||||
# 自动模式下,优先使用GPU
|
||||
# 自动模式下,优先使用昇腾NPU,其次是GPU,最后是CPU
|
||||
if device_type == 'auto':
|
||||
device_type = 'gpu' if has_cuda else 'cpu'
|
||||
if has_ascend:
|
||||
device_type = 'ascend'
|
||||
elif has_cuda:
|
||||
device_type = 'gpu'
|
||||
else:
|
||||
device_type = 'cpu'
|
||||
device_info['device_type'] = device_type
|
||||
|
||||
# 处理昇腾NPU设备
|
||||
if device_type == 'ascend':
|
||||
if not has_ascend:
|
||||
print('\n=== 警告: 昇腾NPU不可用,将尝试使用GPU训练 ===')
|
||||
if has_cuda:
|
||||
device_type = 'gpu'
|
||||
device_info['device_type'] = 'gpu'
|
||||
else:
|
||||
device_type = 'cpu'
|
||||
device_info['device_type'] = 'cpu'
|
||||
else:
|
||||
# 使用昇腾NPU设备管理器获取设备信息
|
||||
ascend_device_info = AscendDeviceManager.get_device_info(
|
||||
ascend_memory=ascend_memory,
|
||||
ascend_index=ascend_index
|
||||
)
|
||||
|
||||
# 更新设备信息
|
||||
device_info.update({
|
||||
'device': ascend_device_info['device'],
|
||||
'ascend_memory': ascend_device_info['ascend_memory'],
|
||||
'ascend_index': ascend_device_info['ascend_index']
|
||||
})
|
||||
|
||||
# 如果昇腾NPU不可用,回退到其他设备
|
||||
if ascend_device_info['device_type'] != 'ascend':
|
||||
device_type = ascend_device_info['device_type']
|
||||
device_info['device_type'] = device_type
|
||||
|
||||
# 处理GPU设备
|
||||
if device_type == 'gpu':
|
||||
if not has_cuda:
|
||||
print('\n=== 警告: GPU不可用,将使用CPU训练 ===')
|
||||
@@ -230,6 +274,7 @@ class DeviceManager:
|
||||
device_info['device_type'] = 'cpu'
|
||||
device_info['device'] = 'cpu'
|
||||
|
||||
# 处理CPU设备
|
||||
if device_type == 'cpu':
|
||||
# 获取CPU核心数
|
||||
cpu_cores = os.cpu_count()
|
||||
@@ -260,19 +305,30 @@ def train_model(
|
||||
image_size: int,
|
||||
device_type: str = 'auto',
|
||||
gpu_memory: Optional[int] = None,
|
||||
gpu_index: int = 0,
|
||||
ascend_memory: Optional[int] = None,
|
||||
ascend_index: int = 0,
|
||||
**kwargs
|
||||
) -> Path:
|
||||
"""
|
||||
训练模型的主函数
|
||||
"""
|
||||
# 获取设备配置
|
||||
device_info = DeviceManager.get_device_info(device_type, gpu_memory)
|
||||
device_info = DeviceManager.get_device_info(
|
||||
device_type=device_type,
|
||||
gpu_memory=gpu_memory,
|
||||
gpu_index=gpu_index,
|
||||
ascend_memory=ascend_memory,
|
||||
ascend_index=ascend_index
|
||||
)
|
||||
|
||||
# 打印设备信息
|
||||
if device_info['device_type'] == 'cpu':
|
||||
print(f"\n=== 使用 CPU 训练,线程数: {device_info['cpu_cores']} ===")
|
||||
else:
|
||||
elif device_info['device_type'] == 'gpu':
|
||||
print(f"\n=== 使用 GPU 训练,显存限制: {device_info['gpu_memory']}MB ===")
|
||||
elif device_info['device_type'] == 'ascend':
|
||||
print(f"\n=== 使用 昇腾NPU 训练,内存限制: {device_info['ascend_memory']}MB ===")
|
||||
|
||||
# 配置训练参数
|
||||
train_args = {
|
||||
@@ -289,6 +345,23 @@ def train_model(
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(model_type)
|
||||
|
||||
# 如果是昇腾NPU设备,需要进行特殊处理
|
||||
if device_info['device_type'] == 'ascend':
|
||||
# 这里需要根据实际的昇腾NPU API进行实现
|
||||
# 以下是示例代码,实际使用时需要替换为真实的API调用
|
||||
try:
|
||||
import torch_npu
|
||||
# 设置环境变量
|
||||
os.environ['ASCEND_VISIBLE_DEVICES'] = str(device_info['ascend_index'])
|
||||
# 其他昇腾NPU特定的设置
|
||||
# ...
|
||||
except ImportError:
|
||||
print("\n=== 警告: 无法导入torch_npu,将尝试使用其他设备 ===")
|
||||
# 回退到CPU
|
||||
train_args['device'] = 'cpu'
|
||||
|
||||
# 执行训练
|
||||
results = model.train(**train_args)
|
||||
|
||||
# 返回训练后的模型路径
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>YOLO训练平台</title>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css">
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap-icons@1.11.1/font/bootstrap-icons.css">
|
||||
<link rel="stylesheet" href="css/style.css">
|
||||
<link rel="stylesheet" href="css/alert-styles.css">
|
||||
<!-- 使用本地Bootstrap资源 -->
|
||||
<link rel="stylesheet" href="static/libs/bootstrap.min.css">
|
||||
<link rel="stylesheet" href="static/libs/bootstrap-icons.css">
|
||||
<link rel="stylesheet" href="static/css/style.css">
|
||||
<link rel="stylesheet" href="static/css/alert-styles.css">
|
||||
</head>
|
||||
<body>
|
||||
<nav class="navbar navbar-expand-lg navbar-dark bg-dark">
|
||||
@@ -36,6 +37,9 @@
|
||||
<li class="nav-item">
|
||||
<a class="nav-link" href="#" data-page="video">视频处理</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="nav-link" href="#" data-page="tracking">目标追踪</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
@@ -413,6 +417,255 @@
|
||||
<button type="submit" class="btn btn-primary" id="augment-btn">开始增强</button>
|
||||
</form>
|
||||
|
||||
<!-- 高级数据增强部分 -->
|
||||
<div class="mt-4">
|
||||
<h5>高级数据增强</h5>
|
||||
<ul class="nav nav-tabs" id="advancedAugTabs" role="tablist">
|
||||
<li class="nav-item" role="presentation">
|
||||
<button class="nav-link active" id="cutmix-tab" data-bs-toggle="tab" data-bs-target="#cutmix" type="button" role="tab" aria-controls="cutmix" aria-selected="true">CutMix</button>
|
||||
</li>
|
||||
<li class="nav-item" role="presentation">
|
||||
<button class="nav-link" id="mixup-tab" data-bs-toggle="tab" data-bs-target="#mixup" type="button" role="tab" aria-controls="mixup" aria-selected="false">MixUp</button>
|
||||
</li>
|
||||
<li class="nav-item" role="presentation">
|
||||
<button class="nav-link" id="mosaic-tab" data-bs-toggle="tab" data-bs-target="#mosaic" type="button" role="tab" aria-controls="mosaic" aria-selected="false">Mosaic</button>
|
||||
</li>
|
||||
<li class="nav-item" role="presentation">
|
||||
<button class="nav-link" id="weather-tab" data-bs-toggle="tab" data-bs-target="#weather" type="button" role="tab" aria-controls="weather" aria-selected="false">天气模拟</button>
|
||||
</li>
|
||||
</ul>
|
||||
<div class="tab-content" id="advancedAugTabContent">
|
||||
<!-- CutMix -->
|
||||
<div class="tab-pane fade show active" id="cutmix" role="tabpanel" aria-labelledby="cutmix-tab">
|
||||
<div class="card mt-3">
|
||||
<div class="card-body">
|
||||
<p class="card-text">CutMix将两张图像的一部分混合在一起,创建新的训练样本。</p>
|
||||
<form id="cutmix-form" enctype="multipart/form-data">
|
||||
<input type="hidden" name="augmentation_type" value="cutmix">
|
||||
<div class="row mb-3">
|
||||
<div class="col-md-6">
|
||||
<label for="cutmix-image1" class="form-label">图像 1</label>
|
||||
<input class="form-control" type="file" id="cutmix-image1" name="image1" accept="image/*" required>
|
||||
</div>
|
||||
<div class="col-md-6">
|
||||
<label for="cutmix-image2" class="form-label">图像 2</label>
|
||||
<input class="form-control" type="file" id="cutmix-image2" name="image2" accept="image/*" required>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mb-3">
|
||||
<label for="cutmix-alpha" class="form-label">Alpha 参数: <span id="cutmix-alpha-value">0.5</span></label>
|
||||
<input type="range" class="form-range" id="cutmix-alpha" name="cutmix_alpha" min="0.1" max="0.9" step="0.1" value="0.5">
|
||||
<div class="form-text">控制混合区域的大小</div>
|
||||
</div>
|
||||
<button type="submit" class="btn btn-primary">应用 CutMix</button>
|
||||
</form>
|
||||
<div class="mt-3" id="cutmix-result" style="display: none;">
|
||||
<h6>结果</h6>
|
||||
<div class="row">
|
||||
<div class="col-md-4">
|
||||
<div class="card">
|
||||
<div class="card-header">图像 1</div>
|
||||
<div class="card-body p-0">
|
||||
<img id="cutmix-original1" class="img-fluid" alt="原始图像 1">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-4">
|
||||
<div class="card">
|
||||
<div class="card-header">图像 2</div>
|
||||
<div class="card-body p-0">
|
||||
<img id="cutmix-original2" class="img-fluid" alt="原始图像 2">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-4">
|
||||
<div class="card">
|
||||
<div class="card-header">CutMix 结果</div>
|
||||
<div class="card-body p-0">
|
||||
<img id="cutmix-output" class="img-fluid" alt="CutMix 结果">
|
||||
</div>
|
||||
<div class="card-footer">
|
||||
<a id="download-cutmix" class="btn btn-sm btn-success" download>下载结果</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- MixUp -->
|
||||
<div class="tab-pane fade" id="mixup" role="tabpanel" aria-labelledby="mixup-tab">
|
||||
<div class="card mt-3">
|
||||
<div class="card-body">
|
||||
<p class="card-text">MixUp将两张图像按权重混合,创建新的训练样本。</p>
|
||||
<form id="mixup-form" enctype="multipart/form-data">
|
||||
<input type="hidden" name="augmentation_type" value="mixup">
|
||||
<div class="row mb-3">
|
||||
<div class="col-md-6">
|
||||
<label for="mixup-image1" class="form-label">图像 1</label>
|
||||
<input class="form-control" type="file" id="mixup-image1" name="image1" accept="image/*" required>
|
||||
</div>
|
||||
<div class="col-md-6">
|
||||
<label for="mixup-image2" class="form-label">图像 2</label>
|
||||
<input class="form-control" type="file" id="mixup-image2" name="image2" accept="image/*" required>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mb-3">
|
||||
<label for="mixup-alpha" class="form-label">Alpha 参数: <span id="mixup-alpha-value">0.5</span></label>
|
||||
<input type="range" class="form-range" id="mixup-alpha" name="mixup_alpha" min="0.1" max="0.9" step="0.1" value="0.5">
|
||||
<div class="form-text">控制混合权重的分布</div>
|
||||
</div>
|
||||
<button type="submit" class="btn btn-primary">应用 MixUp</button>
|
||||
</form>
|
||||
<div class="mt-3" id="mixup-result" style="display: none;">
|
||||
<h6>结果</h6>
|
||||
<div class="row">
|
||||
<div class="col-md-4">
|
||||
<div class="card">
|
||||
<div class="card-header">图像 1</div>
|
||||
<div class="card-body p-0">
|
||||
<img id="mixup-original1" class="img-fluid" alt="原始图像 1">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-4">
|
||||
<div class="card">
|
||||
<div class="card-header">图像 2</div>
|
||||
<div class="card-body p-0">
|
||||
<img id="mixup-original2" class="img-fluid" alt="原始图像 2">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-4">
|
||||
<div class="card">
|
||||
<div class="card-header">MixUp 结果</div>
|
||||
<div class="card-body p-0">
|
||||
<img id="mixup-output" class="img-fluid" alt="MixUp 结果">
|
||||
</div>
|
||||
<div class="card-footer">
|
||||
<a id="download-mixup" class="btn btn-sm btn-success" download>下载结果</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Mosaic -->
|
||||
<div class="tab-pane fade" id="mosaic" role="tabpanel" aria-labelledby="mosaic-tab">
|
||||
<div class="card mt-3">
|
||||
<div class="card-body">
|
||||
<p class="card-text">Mosaic将四张图像拼接成一张,创建新的训练样本。</p>
|
||||
<form id="mosaic-form" enctype="multipart/form-data">
|
||||
<input type="hidden" name="augmentation_type" value="mosaic">
|
||||
<div class="row mb-3">
|
||||
<div class="col-md-6">
|
||||
<label for="mosaic-image1" class="form-label">图像 1 (左上)</label>
|
||||
<input class="form-control" type="file" id="mosaic-image1" name="image1" accept="image/*" required>
|
||||
</div>
|
||||
<div class="col-md-6">
|
||||
<label for="mosaic-image2" class="form-label">图像 2 (右上)</label>
|
||||
<input class="form-control" type="file" id="mosaic-image2" name="image2" accept="image/*" required>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row mb-3">
|
||||
<div class="col-md-6">
|
||||
<label for="mosaic-image3" class="form-label">图像 3 (左下)</label>
|
||||
<input class="form-control" type="file" id="mosaic-image3" name="image3" accept="image/*" required>
|
||||
</div>
|
||||
<div class="col-md-6">
|
||||
<label for="mosaic-image4" class="form-label">图像 4 (右下)</label>
|
||||
<input class="form-control" type="file" id="mosaic-image4" name="image4" accept="image/*" required>
|
||||
</div>
|
||||
</div>
|
||||
<button type="submit" class="btn btn-primary">应用 Mosaic</button>
|
||||
</form>
|
||||
<div class="mt-3" id="mosaic-result" style="display: none;">
|
||||
<h6>结果</h6>
|
||||
<div class="row">
|
||||
<div class="col-md-8 mx-auto">
|
||||
<div class="card">
|
||||
<div class="card-header">Mosaic 结果</div>
|
||||
<div class="card-body p-0">
|
||||
<img id="mosaic-output" class="img-fluid" alt="Mosaic 结果">
|
||||
</div>
|
||||
<div class="card-footer">
|
||||
<a id="download-mosaic" class="btn btn-sm btn-success" download>下载结果</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 天气模拟 -->
|
||||
<div class="tab-pane fade" id="weather" role="tabpanel" aria-labelledby="weather-tab">
|
||||
<div class="card mt-3">
|
||||
<div class="card-body">
|
||||
<p class="card-text">模拟不同天气条件,增强模型在各种环境下的鲁棒性。</p>
|
||||
<form id="weather-form" enctype="multipart/form-data">
|
||||
<input type="hidden" name="augmentation_type" value="weather">
|
||||
<div class="mb-3">
|
||||
<label for="weather-image" class="form-label">选择图像</label>
|
||||
<input class="form-control" type="file" id="weather-image" name="image1" accept="image/*" required>
|
||||
</div>
|
||||
<div class="mb-3">
|
||||
<label class="form-label">天气类型</label>
|
||||
<div class="form-check">
|
||||
<input class="form-check-input" type="radio" name="weather_type" id="weather-rain" value="rain" checked>
|
||||
<label class="form-check-label" for="weather-rain">雨</label>
|
||||
</div>
|
||||
<div class="form-check">
|
||||
<input class="form-check-input" type="radio" name="weather_type" id="weather-snow" value="snow">
|
||||
<label class="form-check-label" for="weather-snow">雪</label>
|
||||
</div>
|
||||
<div class="form-check">
|
||||
<input class="form-check-input" type="radio" name="weather_type" id="weather-fog" value="fog">
|
||||
<label class="form-check-label" for="weather-fog">雾</label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mb-3">
|
||||
<label for="weather-intensity" class="form-label">强度: <span id="weather-intensity-value">0.5</span></label>
|
||||
<input type="range" class="form-range" id="weather-intensity" name="weather_intensity" min="0.1" max="1.0" step="0.1" value="0.5">
|
||||
</div>
|
||||
<button type="submit" class="btn btn-primary">应用天气效果</button>
|
||||
</form>
|
||||
<div class="mt-3" id="weather-result" style="display: none;">
|
||||
<h6>结果</h6>
|
||||
<div class="row">
|
||||
<div class="col-md-6">
|
||||
<div class="card">
|
||||
<div class="card-header">原始图像</div>
|
||||
<div class="card-body p-0">
|
||||
<img id="weather-original" class="img-fluid" alt="原始图像">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-6">
|
||||
<div class="card">
|
||||
<div class="card-header">天气效果</div>
|
||||
<div class="card-body p-0">
|
||||
<img id="weather-output" class="img-fluid" alt="天气效果">
|
||||
</div>
|
||||
<div class="card-footer">
|
||||
<a id="download-weather" class="btn btn-sm btn-success" download>下载结果</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mt-4" id="augment-result" style="display: none;">
|
||||
<div class="alert alert-info">
|
||||
<h5>增强任务已启动</h5>
|
||||
@@ -1416,6 +1669,7 @@
|
||||
<select class="form-select" id="device-type">
|
||||
<option value="cpu">CPU</option>
|
||||
<option value="gpu">GPU</option>
|
||||
<option value="ascend">华为昇腾</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
@@ -1446,6 +1700,28 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mb-3 ascend-option" style="display: none;">
|
||||
<button type="button" class="btn btn-info mb-3" id="get-ascend-info">获取昇腾信息</button>
|
||||
|
||||
<div id="ascend-info" style="display: none; background-color: #f8f9fa; padding: 15px; border-radius: 5px; margin-bottom: 15px;">
|
||||
<div class="mb-3">
|
||||
<label for="ascend-select" class="form-label">选择昇腾NPU:</label>
|
||||
<select class="form-select" id="ascend-select"></select>
|
||||
</div>
|
||||
|
||||
<div class="mb-3">
|
||||
<label for="ascend-memory" class="form-label">NPU内存限制 (MB):</label>
|
||||
<input type="number" class="form-control" id="ascend-memory" value="8192" min="1024" step="1024">
|
||||
<div id="ascend-memory-info" class="mt-2 small"></div>
|
||||
</div>
|
||||
|
||||
<div class="mb-3">
|
||||
<button type="button" class="btn btn-sm btn-secondary" id="validate-ascend-memory">验证内存设置</button>
|
||||
<div id="ascend-validation-result" class="mt-2"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mb-3">
|
||||
<label for="memory" class="form-label">内存 (MB)</label>
|
||||
<input type="number" class="form-control" id="memory" value="8192" min="1024" step="1024">
|
||||
@@ -1456,20 +1732,135 @@
|
||||
</form>
|
||||
</template>
|
||||
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
|
||||
<!-- 目标追踪页面 -->
|
||||
<template id="tracking-template">
|
||||
<div class="row mb-4">
|
||||
<div class="col">
|
||||
<h2>目标追踪</h2>
|
||||
<p class="lead">基于Transformer的目标追踪功能</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card mb-4">
|
||||
<div class="card-header">
|
||||
<h5>摄像头目标追踪</h5>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div>
|
||||
<!-- 摄像头追踪 -->
|
||||
<div id="camera-tracking" role="tabpanel">
|
||||
<h5 class="card-title">摄像头追踪</h5>
|
||||
<p class="card-text">使用摄像头进行实时目标追踪。</p>
|
||||
|
||||
<div class="mb-3">
|
||||
<label for="camera-select" class="form-label">选择摄像头</label>
|
||||
<select class="form-select" id="camera-select">
|
||||
<option value="">加载中...</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="mb-3">
|
||||
<label for="camera-tracking-model-select" class="form-label">选择检测模型</label>
|
||||
<select class="form-select" id="camera-tracking-model-select">
|
||||
<option value="default">使用默认模型 (YOLOv8n)</option>
|
||||
<!-- 模型选项将通过JavaScript动态加载 -->
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="row mb-3">
|
||||
<div class="col-md-6">
|
||||
<label for="camera-conf-threshold" class="form-label">置信度阈值: <span id="camera-conf-value">0.25</span></label>
|
||||
<input type="range" class="form-range" id="camera-conf-threshold" min="0.1" max="1.0" step="0.05" value="0.25">
|
||||
</div>
|
||||
<div class="col-md-6">
|
||||
<label for="camera-iou-threshold" class="form-label">IoU阈值: <span id="camera-iou-value">0.45</span></label>
|
||||
<input type="range" class="form-range" id="camera-iou-threshold" min="0.1" max="1.0" step="0.05" value="0.45">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mb-3">
|
||||
<div class="alert alert-info">
|
||||
<i class="bi bi-info-circle"></i> 启动检测后,右侧将显示所有识别到的目标。点击目标旁的"追踪"按钮可以单独追踪该目标。
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mb-3">
|
||||
<button id="start-camera-tracking" class="btn btn-primary">开始追踪</button>
|
||||
<button id="stop-camera-tracking" class="btn btn-danger" style="display: none;">停止追踪</button>
|
||||
</div>
|
||||
|
||||
<div class="mt-4" id="camera-tracking-container" style="display: none;">
|
||||
<div class="row">
|
||||
<div class="col-md-8">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
<span id="camera-tracking-status">准备就绪</span>
|
||||
</div>
|
||||
<div class="card-body p-0">
|
||||
<div id="video-container" style="position: relative; width: 100%; height: auto; overflow: hidden; border: 2px solid green;">
|
||||
<!-- 视频元素 -->
|
||||
<video id="camera-video" autoplay muted style="width: 100%; z-index: 1; display: block;"></video>
|
||||
|
||||
<!-- 注意:画布元素现在由JavaScript动态创建和添加 -->
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-4">
|
||||
<div class="card mb-3">
|
||||
<div class="card-header">检测到的目标</div>
|
||||
<div class="card-body">
|
||||
<div class="alert alert-warning" id="no-objects-detected" style="display: block;">
|
||||
尚未检测到任何目标。
|
||||
</div>
|
||||
<div id="detected-objects-list" class="list-group">
|
||||
<!-- 检测到的目标将通过JavaScript动态加载 -->
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card">
|
||||
<div class="card-header">当前追踪的目标</div>
|
||||
<div class="card-body">
|
||||
<div class="alert alert-info" id="no-objects-tracked" style="display: block;">
|
||||
尚未选择追踪目标。
|
||||
</div>
|
||||
<div id="tracked-objects-list" class="list-group">
|
||||
<!-- 追踪目标将通过JavaScript动态加载 -->
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mt-3">
|
||||
<button id="reset-tracking" class="btn btn-warning">重置追踪器</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 追踪报告功能已移除 -->
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- 使用本地JavaScript资源 -->
|
||||
<script src="static/libs/bootstrap.bundle.min.js"></script>
|
||||
<!-- Chart.js 图表库 -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/chart.js@3.9.1/dist/chart.min.js"></script>
|
||||
<script src="static/libs/chart.min.js"></script>
|
||||
<!-- 轮询管理器,用于管理API轮询 -->
|
||||
<script src="js/polling_manager.js"></script>
|
||||
<script src="static/js/polling_manager.js"></script>
|
||||
<!-- 报警声音生成器 -->
|
||||
<script src="js/alert-sound.js"></script>
|
||||
<script src="static/js/alert-sound.js"></script>
|
||||
<!-- 报警区域设置 -->
|
||||
<script src="js/alert-zone.js"></script>
|
||||
<script src="static/js/alert-zone.js"></script>
|
||||
<!-- 报警统计分析 -->
|
||||
<script src="js/alert-stats.js"></script>
|
||||
<script src="static/js/alert-stats.js"></script>
|
||||
<!-- 轨迹记录和预测 -->
|
||||
<script src="js/trajectory.js"></script>
|
||||
<script src="js/main.js"></script>
|
||||
<script src="static/js/trajectory.js"></script>
|
||||
<!-- 自注意力追踪 -->
|
||||
<script src="static/js/attention_tracker.js"></script>
|
||||
<script src="static/js/main.js"></script>
|
||||
|
||||
<!-- 引入模态框 -->
|
||||
<div id="modals-container"></div>
|
||||
@@ -1485,4 +1876,4 @@
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
</html>
|
||||
1164
app/static/js/attention_tracker.js
Normal file
1164
app/static/js/attention_tracker.js
Normal file
File diff suppressed because it is too large
Load Diff
@@ -43,7 +43,26 @@ function testApiEndpoints() {
|
||||
// 页面加载完成后执行
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
// 初始化Bootstrap模态框
|
||||
modal = new bootstrap.Modal(document.getElementById('mainModal'));
|
||||
try {
|
||||
if (typeof bootstrap !== 'undefined' && bootstrap.Modal) {
|
||||
modal = new bootstrap.Modal(document.getElementById('mainModal'));
|
||||
console.log('Bootstrap模态框初始化成功');
|
||||
} else {
|
||||
console.warn('Bootstrap未加载,跳过模态框初始化');
|
||||
// 创建一个空对象,避免后续代码出错
|
||||
modal = {
|
||||
show: function() { console.warn('模态框功能不可用'); },
|
||||
hide: function() { console.warn('模态框功能不可用'); }
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('初始化Bootstrap模态框失败:', error);
|
||||
// 创建一个空对象,避免后续代码出错
|
||||
modal = {
|
||||
show: function() { console.warn('模态框功能不可用'); },
|
||||
hide: function() { console.warn('模态框功能不可用'); }
|
||||
};
|
||||
}
|
||||
|
||||
// 导航菜单点击事件
|
||||
document.querySelectorAll('.nav-link').forEach(link => {
|
||||
@@ -109,6 +128,9 @@ function loadPage(page) {
|
||||
loadVideoPage();
|
||||
bindVideoEvents();
|
||||
break;
|
||||
case 'tracking':
|
||||
initTrackingPage();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -676,6 +698,233 @@ function bindOpenCVEventHandlers() {
|
||||
});
|
||||
}
|
||||
|
||||
// 高级数据增强相关事件
|
||||
|
||||
// CutMix Alpha 参数滑块
|
||||
const cutmixAlpha = document.getElementById('cutmix-alpha');
|
||||
const cutmixAlphaValue = document.getElementById('cutmix-alpha-value');
|
||||
if (cutmixAlpha && cutmixAlphaValue) {
|
||||
cutmixAlpha.addEventListener('input', function() {
|
||||
cutmixAlphaValue.textContent = this.value;
|
||||
});
|
||||
}
|
||||
|
||||
// MixUp Alpha 参数滑块
|
||||
const mixupAlpha = document.getElementById('mixup-alpha');
|
||||
const mixupAlphaValue = document.getElementById('mixup-alpha-value');
|
||||
if (mixupAlpha && mixupAlphaValue) {
|
||||
mixupAlpha.addEventListener('input', function() {
|
||||
mixupAlphaValue.textContent = this.value;
|
||||
});
|
||||
}
|
||||
|
||||
// 天气强度滑块
|
||||
const weatherIntensity = document.getElementById('weather-intensity');
|
||||
const weatherIntensityValue = document.getElementById('weather-intensity-value');
|
||||
if (weatherIntensity && weatherIntensityValue) {
|
||||
weatherIntensity.addEventListener('input', function() {
|
||||
weatherIntensityValue.textContent = this.value;
|
||||
});
|
||||
}
|
||||
|
||||
// CutMix 表单提交
|
||||
const cutmixForm = document.getElementById('cutmix-form');
|
||||
if (cutmixForm) {
|
||||
cutmixForm.addEventListener('submit', function(event) {
|
||||
event.preventDefault();
|
||||
|
||||
// 获取表单数据
|
||||
const formData = new FormData(this);
|
||||
|
||||
// 显示原始图像
|
||||
const image1 = document.getElementById('cutmix-image1').files[0];
|
||||
const image2 = document.getElementById('cutmix-image2').files[0];
|
||||
|
||||
if (!image1 || !image2) {
|
||||
alert('请选择两张图像');
|
||||
return;
|
||||
}
|
||||
|
||||
// 显示原始图像预览
|
||||
const reader1 = new FileReader();
|
||||
reader1.onload = function(e) {
|
||||
document.getElementById('cutmix-original1').src = e.target.result;
|
||||
};
|
||||
reader1.readAsDataURL(image1);
|
||||
|
||||
const reader2 = new FileReader();
|
||||
reader2.onload = function(e) {
|
||||
document.getElementById('cutmix-original2').src = e.target.result;
|
||||
};
|
||||
reader2.readAsDataURL(image2);
|
||||
|
||||
// 发送请求
|
||||
fetch(`${API_URL}/opencv/advanced-augmentation`, {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
})
|
||||
.then(response => {
|
||||
if (!response.ok) {
|
||||
return response.json().then(err => { throw new Error(err.detail || 'CutMix 增强失败'); });
|
||||
}
|
||||
return response.json();
|
||||
})
|
||||
.then(data => {
|
||||
// 显示结果
|
||||
document.getElementById('cutmix-result').style.display = 'block';
|
||||
document.getElementById('cutmix-output').src = data.output_path;
|
||||
document.getElementById('download-cutmix').href = data.output_path;
|
||||
})
|
||||
.catch(error => {
|
||||
alert('错误: ' + error.message);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// MixUp 表单提交
|
||||
const mixupForm = document.getElementById('mixup-form');
|
||||
if (mixupForm) {
|
||||
mixupForm.addEventListener('submit', function(event) {
|
||||
event.preventDefault();
|
||||
|
||||
// 获取表单数据
|
||||
const formData = new FormData(this);
|
||||
|
||||
// 显示原始图像
|
||||
const image1 = document.getElementById('mixup-image1').files[0];
|
||||
const image2 = document.getElementById('mixup-image2').files[0];
|
||||
|
||||
if (!image1 || !image2) {
|
||||
alert('请选择两张图像');
|
||||
return;
|
||||
}
|
||||
|
||||
// 显示原始图像预览
|
||||
const reader1 = new FileReader();
|
||||
reader1.onload = function(e) {
|
||||
document.getElementById('mixup-original1').src = e.target.result;
|
||||
};
|
||||
reader1.readAsDataURL(image1);
|
||||
|
||||
const reader2 = new FileReader();
|
||||
reader2.onload = function(e) {
|
||||
document.getElementById('mixup-original2').src = e.target.result;
|
||||
};
|
||||
reader2.readAsDataURL(image2);
|
||||
|
||||
// 发送请求
|
||||
fetch(`${API_URL}/opencv/advanced-augmentation`, {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
})
|
||||
.then(response => {
|
||||
if (!response.ok) {
|
||||
return response.json().then(err => { throw new Error(err.detail || 'MixUp 增强失败'); });
|
||||
}
|
||||
return response.json();
|
||||
})
|
||||
.then(data => {
|
||||
// 显示结果
|
||||
document.getElementById('mixup-result').style.display = 'block';
|
||||
document.getElementById('mixup-output').src = data.output_path;
|
||||
document.getElementById('download-mixup').href = data.output_path;
|
||||
})
|
||||
.catch(error => {
|
||||
alert('错误: ' + error.message);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Mosaic 表单提交
|
||||
const mosaicForm = document.getElementById('mosaic-form');
|
||||
if (mosaicForm) {
|
||||
mosaicForm.addEventListener('submit', function(event) {
|
||||
event.preventDefault();
|
||||
|
||||
// 获取表单数据
|
||||
const formData = new FormData(this);
|
||||
|
||||
// 检查是否选择了四张图像
|
||||
const image1 = document.getElementById('mosaic-image1').files[0];
|
||||
const image2 = document.getElementById('mosaic-image2').files[0];
|
||||
const image3 = document.getElementById('mosaic-image3').files[0];
|
||||
const image4 = document.getElementById('mosaic-image4').files[0];
|
||||
|
||||
if (!image1 || !image2 || !image3 || !image4) {
|
||||
alert('请选择四张图像');
|
||||
return;
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
fetch(`${API_URL}/opencv/advanced-augmentation`, {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
})
|
||||
.then(response => {
|
||||
if (!response.ok) {
|
||||
return response.json().then(err => { throw new Error(err.detail || 'Mosaic 增强失败'); });
|
||||
}
|
||||
return response.json();
|
||||
})
|
||||
.then(data => {
|
||||
// 显示结果
|
||||
document.getElementById('mosaic-result').style.display = 'block';
|
||||
document.getElementById('mosaic-output').src = data.output_path;
|
||||
document.getElementById('download-mosaic').href = data.output_path;
|
||||
})
|
||||
.catch(error => {
|
||||
alert('错误: ' + error.message);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 天气模拟表单提交
|
||||
const weatherForm = document.getElementById('weather-form');
|
||||
if (weatherForm) {
|
||||
weatherForm.addEventListener('submit', function(event) {
|
||||
event.preventDefault();
|
||||
|
||||
// 获取表单数据
|
||||
const formData = new FormData(this);
|
||||
|
||||
// 显示原始图像
|
||||
const image = document.getElementById('weather-image').files[0];
|
||||
|
||||
if (!image) {
|
||||
alert('请选择图像');
|
||||
return;
|
||||
}
|
||||
|
||||
// 显示原始图像预览
|
||||
const reader = new FileReader();
|
||||
reader.onload = function(e) {
|
||||
document.getElementById('weather-original').src = e.target.result;
|
||||
};
|
||||
reader.readAsDataURL(image);
|
||||
|
||||
// 发送请求
|
||||
fetch(`${API_URL}/opencv/advanced-augmentation`, {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
})
|
||||
.then(response => {
|
||||
if (!response.ok) {
|
||||
return response.json().then(err => { throw new Error(err.detail || '天气模拟增强失败'); });
|
||||
}
|
||||
return response.json();
|
||||
})
|
||||
.then(data => {
|
||||
// 显示结果
|
||||
document.getElementById('weather-result').style.display = 'block';
|
||||
document.getElementById('weather-output').src = data.output_path;
|
||||
document.getElementById('download-weather').href = data.output_path;
|
||||
})
|
||||
.catch(error => {
|
||||
alert('错误: ' + error.message);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
console.log('OpenCV events bound successfully');
|
||||
|
||||
// 添加调试信息
|
||||
@@ -2261,13 +2510,20 @@ function showAddTrainingModal() {
|
||||
deviceTypeSelect.addEventListener('change', function() {
|
||||
const cpuOptions = document.querySelectorAll('.cpu-option');
|
||||
const gpuOptions = document.querySelectorAll('.gpu-option');
|
||||
const ascendOptions = document.querySelectorAll('.ascend-option');
|
||||
|
||||
if (this.value === 'cpu') {
|
||||
cpuOptions.forEach(option => option.style.display = 'block');
|
||||
gpuOptions.forEach(option => option.style.display = 'none');
|
||||
} else {
|
||||
ascendOptions.forEach(option => option.style.display = 'none');
|
||||
} else if (this.value === 'gpu') {
|
||||
cpuOptions.forEach(option => option.style.display = 'none');
|
||||
gpuOptions.forEach(option => option.style.display = 'block');
|
||||
ascendOptions.forEach(option => option.style.display = 'none');
|
||||
} else if (this.value === 'ascend') {
|
||||
cpuOptions.forEach(option => option.style.display = 'none');
|
||||
gpuOptions.forEach(option => option.style.display = 'none');
|
||||
ascendOptions.forEach(option => option.style.display = 'block');
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -2341,6 +2597,75 @@ function showAddTrainingModal() {
|
||||
});
|
||||
}
|
||||
|
||||
// 绑定获取昇腾NPU信息按钮事件
|
||||
const getAscendInfoBtn = document.getElementById('get-ascend-info');
|
||||
const ascendInfoDiv = document.getElementById('ascend-info');
|
||||
const ascendSelect = document.getElementById('ascend-select');
|
||||
const ascendMemoryInput = document.getElementById('ascend-memory');
|
||||
const ascendMemoryInfoDiv = document.getElementById('ascend-memory-info');
|
||||
const validateAscendMemoryBtn = document.getElementById('validate-ascend-memory');
|
||||
const ascendValidationResultDiv = document.getElementById('ascend-validation-result');
|
||||
|
||||
if (getAscendInfoBtn) {
|
||||
getAscendInfoBtn.addEventListener('click', async function() {
|
||||
try {
|
||||
// 显示加载中状态
|
||||
getAscendInfoBtn.disabled = true;
|
||||
getAscendInfoBtn.textContent = '获取中...';
|
||||
|
||||
const response = await fetch(`${API_URL}/training/ascend-info`);
|
||||
const data = await response.json();
|
||||
|
||||
// 恢复按钮状态
|
||||
getAscendInfoBtn.disabled = false;
|
||||
getAscendInfoBtn.textContent = '获取昇腾信息';
|
||||
|
||||
if (data.has_ascend) {
|
||||
// 清空昇腾选择框
|
||||
ascendSelect.innerHTML = '';
|
||||
|
||||
// 添加昇腾NPU选项
|
||||
data.ascends.forEach(ascend => {
|
||||
const option = document.createElement('option');
|
||||
option.value = ascend.index;
|
||||
option.textContent = ascend.display_name;
|
||||
option.dataset.memory = ascend.free_memory;
|
||||
option.dataset.recommended = ascend.recommended_memory;
|
||||
option.dataset.totalMemory = ascend.total_memory;
|
||||
option.dataset.usedMemory = ascend.used_memory;
|
||||
option.dataset.name = ascend.name;
|
||||
ascendSelect.appendChild(option);
|
||||
});
|
||||
|
||||
// 设置默认内存值
|
||||
const selectedAscend = data.ascends[0];
|
||||
ascendMemoryInput.value = selectedAscend.recommended_memory;
|
||||
ascendMemoryInfoDiv.innerHTML = `
|
||||
<div>昇腾NPU型号: ${selectedAscend.name}</div>
|
||||
<div>总内存: ${selectedAscend.total_memory} MB</div>
|
||||
<div>已用内存: ${selectedAscend.used_memory} MB</div>
|
||||
<div>可用内存: ${selectedAscend.free_memory} MB</div>
|
||||
<div class="text-success">推荐内存: ${selectedAscend.recommended_memory} MB</div>
|
||||
`;
|
||||
|
||||
// 显示昇腾信息
|
||||
ascendInfoDiv.style.display = 'block';
|
||||
} else {
|
||||
alert('没有可用的昇腾NPU,请使用其他模式训练');
|
||||
deviceTypeSelect.value = 'cpu';
|
||||
deviceTypeSelect.dispatchEvent(new Event('change'));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取昇腾NPU信息失败:', error);
|
||||
alert('获取昇腾NPU信息失败,请检查网络连接');
|
||||
|
||||
// 恢复按钮状态
|
||||
getAscendInfoBtn.disabled = false;
|
||||
getAscendInfoBtn.textContent = '获取昇腾信息';
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 绑定GPU选择改变事件
|
||||
if (gpuSelect) {
|
||||
gpuSelect.addEventListener('change', function() {
|
||||
@@ -2412,6 +2737,77 @@ function showAddTrainingModal() {
|
||||
});
|
||||
}
|
||||
|
||||
// 绑定昇腾NPU选择改变事件
|
||||
if (ascendSelect) {
|
||||
ascendSelect.addEventListener('change', function() {
|
||||
const selectedOption = this.options[this.selectedIndex];
|
||||
const freeMemory = parseInt(selectedOption.dataset.memory);
|
||||
const recommendedMemory = parseInt(selectedOption.dataset.recommended);
|
||||
const totalMemory = parseInt(selectedOption.dataset.totalMemory);
|
||||
const usedMemory = parseInt(selectedOption.dataset.usedMemory);
|
||||
const ascendName = selectedOption.dataset.name;
|
||||
|
||||
ascendMemoryInput.value = recommendedMemory;
|
||||
|
||||
ascendMemoryInfoDiv.innerHTML = `
|
||||
<div>昇腾NPU型号: ${ascendName}</div>
|
||||
<div>总内存: ${totalMemory} MB</div>
|
||||
<div>已用内存: ${usedMemory} MB</div>
|
||||
<div>可用内存: ${freeMemory} MB</div>
|
||||
<div class="text-success">推荐内存: ${recommendedMemory} MB</div>
|
||||
`;
|
||||
});
|
||||
}
|
||||
|
||||
// 绑定验证昇腾内存设置按钮事件
|
||||
if (validateAscendMemoryBtn) {
|
||||
validateAscendMemoryBtn.addEventListener('click', async function() {
|
||||
const ascendMemory = parseInt(ascendMemoryInput.value);
|
||||
const ascendIndex = parseInt(ascendSelect.value);
|
||||
|
||||
if (isNaN(ascendMemory) || ascendMemory <= 0) {
|
||||
ascendValidationResultDiv.innerHTML = '<div class="text-danger">请输入有效的内存值</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// 显示加载中状态
|
||||
validateAscendMemoryBtn.disabled = true;
|
||||
validateAscendMemoryBtn.textContent = '验证中...';
|
||||
|
||||
const response = await fetch(`${API_URL}/training/validate-ascend-memory`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
ascend_memory: ascendMemory,
|
||||
ascend_index: ascendIndex
|
||||
})
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
// 恢复按钮状态
|
||||
validateAscendMemoryBtn.disabled = false;
|
||||
validateAscendMemoryBtn.textContent = '验证内存设置';
|
||||
|
||||
if (data.valid) {
|
||||
ascendValidationResultDiv.innerHTML = `<div class="text-success">${data.message}</div>`;
|
||||
} else {
|
||||
ascendValidationResultDiv.innerHTML = `<div class="text-danger">${data.message}</div>`;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('验证昇腾内存设置失败:', error);
|
||||
ascendValidationResultDiv.innerHTML = '<div class="text-danger">验证昇腾内存设置失败,请检查网络连接</div>';
|
||||
|
||||
// 恢复按钮状态
|
||||
validateAscendMemoryBtn.disabled = false;
|
||||
validateAscendMemoryBtn.textContent = '验证内存设置';
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 绑定矩形训练模式切换事件
|
||||
const enableRectTraining = document.getElementById('enable-rect-training');
|
||||
if (enableRectTraining) {
|
||||
@@ -2560,7 +2956,7 @@ function submitAddTraining() {
|
||||
// 根据设备类型添加相应的参数
|
||||
if (deviceTypeSelect.value === 'cpu') {
|
||||
hardwareConfig.cpu_cores = parseInt(cpuCoresInput.value);
|
||||
} else {
|
||||
} else if (deviceTypeSelect.value === 'gpu') {
|
||||
// 获取GPU显存和GPU ID
|
||||
hardwareConfig.gpu_memory = parseInt(gpuMemoryInput.value);
|
||||
|
||||
@@ -2569,6 +2965,15 @@ function submitAddTraining() {
|
||||
if (gpuSelect && gpuSelect.value) {
|
||||
hardwareConfig.gpu_index = parseInt(gpuSelect.value);
|
||||
}
|
||||
} else if (deviceTypeSelect.value === 'ascend') {
|
||||
// 获取昇腾NPU内存和NPU ID
|
||||
hardwareConfig.ascend_memory = parseInt(ascendMemoryInput.value);
|
||||
|
||||
// 如果用户选择了特定昇腾NPU,添加NPU ID
|
||||
const ascendSelect = document.getElementById('ascend-select');
|
||||
if (ascendSelect && ascendSelect.value) {
|
||||
hardwareConfig.ascend_index = parseInt(ascendSelect.value);
|
||||
}
|
||||
}
|
||||
|
||||
// 添加内存参数
|
||||
|
||||
2078
app/static/libs/bootstrap-icons.css
vendored
Normal file
2078
app/static/libs/bootstrap-icons.css
vendored
Normal file
File diff suppressed because it is too large
Load Diff
7
app/static/libs/bootstrap.bundle.min.js
vendored
Normal file
7
app/static/libs/bootstrap.bundle.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
6
app/static/libs/bootstrap.min.css
vendored
Normal file
6
app/static/libs/bootstrap.min.css
vendored
Normal file
File diff suppressed because one or more lines are too long
13
app/static/libs/chart.min.js
vendored
Normal file
13
app/static/libs/chart.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
BIN
app/static/libs/fonts/bootstrap-icons.woff
Normal file
BIN
app/static/libs/fonts/bootstrap-icons.woff
Normal file
Binary file not shown.
BIN
app/static/libs/fonts/bootstrap-icons.woff2
Normal file
BIN
app/static/libs/fonts/bootstrap-icons.woff2
Normal file
Binary file not shown.
@@ -1,75 +0,0 @@
|
||||
# Docker 配置说明
|
||||
|
||||
为了解决 Docker Hub 连接超时问题,我们已经进行了以下优化:
|
||||
|
||||
## 已完成的优化
|
||||
|
||||
1. **修改了 docker-compose.yml**:
|
||||
- 移除了过时的 version 字段
|
||||
- 使用阿里云镜像源替代官方镜像源
|
||||
- 添加了容器名称
|
||||
- 添加了健康检查
|
||||
- 优化了依赖关系
|
||||
|
||||
2. **修改了 Dockerfile**:
|
||||
- 使用阿里云镜像源替代官方镜像源
|
||||
- 配置了 apt 使用中科大镜像源
|
||||
- 配置了 pip 使用清华大学镜像源
|
||||
|
||||
3. **创建了 .env 文件**:
|
||||
- 提供了代理设置选项
|
||||
- 设置了镜像仓库和数据库参数
|
||||
|
||||
4. **创建了 .dockerignore 文件**:
|
||||
- 排除不必要的文件,加快构建速度
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 配置 Docker 守护进程(一次性操作)
|
||||
|
||||
将 `daemon.json` 文件复制到 Docker 的配置目录:
|
||||
|
||||
```
|
||||
# Windows
|
||||
copy daemon.json %USERPROFILE%\.docker\daemon.json
|
||||
|
||||
# 或手动复制到 C:\Users\<用户名>\.docker\daemon.json
|
||||
```
|
||||
|
||||
然后重启 Docker Desktop。
|
||||
|
||||
### 2. 直接使用 docker-compose 启动
|
||||
|
||||
```
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
### 3. 如果仍然遇到问题
|
||||
|
||||
如果仍然遇到连接超时问题,可以尝试手动拉取镜像:
|
||||
|
||||
```
|
||||
docker pull registry.cn-hangzhou.aliyuncs.com/library/postgres:13
|
||||
docker pull registry.cn-hangzhou.aliyuncs.com/library/python:3.9-slim
|
||||
```
|
||||
|
||||
然后再次运行:
|
||||
|
||||
```
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
## 故障排除
|
||||
|
||||
如果遇到以下错误:
|
||||
|
||||
- **context deadline exceeded**:这通常是网络连接问题,请检查您的网络连接或使用代理。
|
||||
- **repository does not exist**:请确认镜像名称是否正确,或者尝试使用其他镜像源。
|
||||
- **authentication required**:某些镜像源可能需要认证,请检查镜像源设置。
|
||||
|
||||
## 其他有用的命令
|
||||
|
||||
- 查看容器状态:`docker-compose ps`
|
||||
- 查看容器日志:`docker-compose logs -f`
|
||||
- 停止所有容器:`docker-compose down`
|
||||
- 重建并启动容器:`docker-compose up -d --build`
|
||||
28
package-lock.json
generated
Normal file
28
package-lock.json
generated
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"name": "Myolotrain",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"dependencies": {
|
||||
"bootstrap-icons": "^1.12.1"
|
||||
}
|
||||
},
|
||||
"node_modules/bootstrap-icons": {
|
||||
"version": "1.12.1",
|
||||
"resolved": "https://registry.npmjs.org/bootstrap-icons/-/bootstrap-icons-1.12.1.tgz",
|
||||
"integrity": "sha512-ekwupjsteHQmgGV+haQ0nNMoSyKCbJj5ou+06vFzb9uR2/bwN9isNEgXBaQzcT+fLzhKS3OaBNpwz8XdZlIgYQ==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/twbs"
|
||||
},
|
||||
{
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/bootstrap"
|
||||
}
|
||||
],
|
||||
"license": "MIT"
|
||||
}
|
||||
}
|
||||
}
|
||||
5
package.json
Normal file
5
package.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"dependencies": {
|
||||
"bootstrap-icons": "^1.12.1"
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,15 @@
|
||||
# =============================================
|
||||
# 深度学习框架 - GPU支持
|
||||
# Deep Learning Framework - GPU support
|
||||
# =============================================
|
||||
# 需要单独安装torch
|
||||
|
||||
#华为昇腾需要单独安装:
|
||||
#torch_npu
|
||||
#acl
|
||||
#onnxruntime-ascend
|
||||
|
||||
# YOLOv8及其依赖
|
||||
|
||||
# YOLOv8 and dependencies
|
||||
ultralytics>=8.0.196
|
||||
numpy>=1.24.0
|
||||
opencv-python>=4.8.0
|
||||
@@ -16,27 +21,31 @@ psutil>=5.9.0
|
||||
pynvml>=11.5.0
|
||||
|
||||
# =============================================
|
||||
# GPU支持的额外依赖
|
||||
# GPU support additional dependencies
|
||||
# =============================================
|
||||
# GPU监控工具,类似nvidia-smi但功能更强大
|
||||
nvitop>=1.0.0
|
||||
|
||||
# API和Web框架
|
||||
# CUDA加速的NumPy库,根据实际CUDA版本安装
|
||||
# cupy-cuda12x>=12.0.0
|
||||
|
||||
# API and Web Framework
|
||||
fastapi>=0.115.0
|
||||
uvicorn>=0.34.0
|
||||
python-multipart>=0.0.5
|
||||
|
||||
# 数据库
|
||||
# Database
|
||||
sqlalchemy>=2.0.0
|
||||
psycopg2-binary>=2.9.0 # PostgreSQL数据库驱动
|
||||
psycopg2-binary>=2.9.0
|
||||
|
||||
# 配置
|
||||
|
||||
# Configuration
|
||||
pydantic>=2.0.0
|
||||
pydantic-settings>=2.0.0
|
||||
|
||||
# 文件处理
|
||||
# File handling
|
||||
aiofiles>=23.0
|
||||
|
||||
# 监控和日志
|
||||
# Monitoring and Logging
|
||||
tensorboard>=2.15.0
|
||||
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
import torch
|
||||
print(torch.cuda.is_available()) # 应该返回 True
|
||||
print(torch.cuda.device_count()) # 应该显示 GPU 数量
|
||||
print(torch.cuda.get_device_name(0)) # 应该显示 GPU 名称
|
||||
Reference in New Issue
Block a user