当前位置:网站首页 > C++编程 > 正文

pointnet++分割算法网络(unet分割网络)



在医疗图像分割任务中,transformer模型获得了巨大的成功,UNETR提出了efficient paired attention (EPA) 模块,利用了空间和通道注意力来有效地学习通道和空间的特征,该模型在Synapse,BTCV,ACDC,BRaTs数据集上都获得了很好地效果。

论文:UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation

代码:https://github.com/Amshaker/unetr_plus_plus

一、论文笔记

首先看一下模型架构,整体还是UNet结构,在其中引入了提出的EPA模块。

该论文的核心就是EPA模块,EPA的提出主要是解决2个问题

1、计算更有效率:传统的self-attention计算成本很高,对于3D的医疗图像来说更高,EPA将self-attention的K和V投影到低纬度再计算,降低了计算复杂度;

2、增强了空间和通道特征表示能力:transformer本身就是一种空间注意力机制,但是它忽略了通道特征,EPA将空间和通道特征融合在了一起。

再仔细看一下EPA的结构图,上方蓝底部分式空间注意力,下方绿底部分式通道注意力。再空间注意部分,为了降低self-attention计算量,将HWDXC的K和V降维到pXC维度。

代码如下(类中的self.EF用于降低K和V的维度,空间注意力和通道注意力的K和Q是共享的):

class EPA(nn.Module):

"""

Efficient Paired Attention Block, based on: "Shaker et al.,

UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"

"""

def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False,

channel_attn_drop=0.1, spatial_attn_drop=0.1):

super().__init__()

self.num_heads = num_heads

self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))

# qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel)

self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias)

# E and F are projection matrices with shared weights used in spatial attention module to project

# keys and values from HWD-dimension to P-dimension

self.EF = nn.Parameter(init_(torch.zeros(input_size, proj_size)))

self.attn_drop = nn.Dropout(channel_attn_drop)

self.attn_drop_2 = nn.Dropout(spatial_attn_drop)

def forward(self, x):

B, N, C = x.shape

qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads)

qkvv = qkvv.permute(2, 0, 3, 1, 4)

q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3]

q_shared = q_shared.transpose(-2, -1)

k_shared = k_shared.transpose(-2, -1)

v_CA = v_CA.transpose(-2, -1)

v_SA = v_SA.transpose(-2, -1)

proj_e_f = lambda args: torch.einsum('bhdn,nk->bhdk', *args)

k_shared_projected, v_SA_projected = map(proj_e_f, zip((k_shared, v_SA), (self.EF, self.EF)))

q_shared = torch.nn.functional.normalize(q_shared, dim=-1)

k_shared = torch.nn.functional.normalize(k_shared, dim=-1)

attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature

attn_CA = attn_CA.softmax(dim=-1)

attn_CA = self.attn_drop(attn_CA)

x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C)

attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2

attn_SA = attn_SA.softmax(dim=-1)

attn_SA = self.attn_drop_2(attn_SA)

x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C)

return x_CA + x_SA

@torch.jit.ignore

def no_weight_decay(self):

return {'temperature', 'temperature2'}

二、代码实践

官方给出了Synapse,BTCV,ACDC,BRaTs数据集的跑通实例,我这里只跑一个BRaTs数据集,其他的是一样的步骤。

1、安装环境

使用conda安装环境:

conda create --name unetr_pp python=3.10

conda activate unetr_pp

安装torch:

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

安装依赖:

pip install -r requirements.txt

2、准备数据集

官方给出了处理好的数据集地址,直接下载即可:

数据集链接SynapseOneDriveACDCOneDriveDecathon-LungOneDriveBRaTsOneDrive

本文下载好了BraTs数据集作为实例,将其放入以下目录:

3、训练

因为只是跑通一下,把unetr_plus_plus/unetr_pp/training/network_training/unetr_pp_trainer_tumor.py中的epoch改成10:

训练就非常简单了,进入训练集脚本目录并运行脚本:

cd training_scripts

bash run_training_tumor.sh

训练起来了:

4、评估

首先将自己训练的权重放到指定位置(原来output_tumor的unetr_pp文件夹放到unetr_plus_plus舋_ppevaluation舋_pp_tumor_checkpoint里面去):

修改代码unetr_plus_plus/unetr_pp/inference/predict.py,共有两处:

进入评估脚本目录并运行脚本:

cd evaluation_scripts

修改run_evaluation_tumor.sh脚本,相关路径替换为自己的路径(自带的脚本我没成功,大家可以自行尝试):

#!/bin/sh

DATASET_PATH=https://www.51969.com/DATASET_Tumor

export PYTHONPATH=https://www.51969.com/https://www.51969.com/post/

export RESULTS_FOLDER=https://www.51969.com/unetr_pp/evaluation/unetr_pp_tumor_checkpoint

export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor

export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw

# Only for Tumor, it is recommended to train unetr_plus_plus first, and then use the provided checkpoint to evaluate. It might raise issues regarding the pickle files if you evaluated without training

python /deeplearning/medicalseg/unetr_plus_plus/unetr_pp/inference/predict_simple.py -i https://www.51969.com/DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task003_tumor/imagesTs -o https://www.51969.com/unetr_pp/evaluation/unetr_pp_tumor_checkpoint/inferTs -m 3d_fullres -t 3 -f 0 -chk model_final_checkpoint -tr unetr_pp_trainer_tumor

python /deeplearning/medicalseg/unetr_plus_plus/unetr_pp/inference_tumor.py 0

修改unetr_plus_plus/unetr_pp/inference_tumor.py的数据集路径,可以根据自己的情况改:

运行脚本:

bash run_evaluation_tumor.sh

在推理结果的目录unetr_plus_plus/unetr_pp/evaluation/unetr_pp_tumor_checkpoint/下多了一个

dice_five.txt文件,里面有相关精度,如下(因为就训练了10个epoch,效果不行):

本文到此结束。

好文阅读

  

/br>

到此这篇pointnet++分割算法网络(unet分割网络)的文章就介绍到这了,更多相关内容请继续浏览下面的相关推荐文章,希望大家都能在编程的领域有一番成就!

版权声明


相关文章:

  • ceph存储池容量阈值(ceph存储池有哪些类型)2025-10-26 13:45:08
  • iec104协议详解(iec101和iec104区别)2025-10-26 13:45:08
  • 泰拉瑞亚时间指令怎么用(泰拉瑞亚指令怎么用pc)2025-10-26 13:45:08
  • cnn什么意思骂人不带脏字(cnn什么意思的缩写)2025-10-26 13:45:08
  • libc.so是什么(libc++abi.so.1)2025-10-26 13:45:08
  • 查看k8s版本(k8s查看configmap)2025-10-26 13:45:08
  • pc和apc的好坏(apc和pc能对接吗)2025-10-26 13:45:08
  • kubectl 命令(kubectl 命令补全)2025-10-26 13:45:08
  • dbf文件怎么转换成excel(dbf文件怎么转换成excel网站)2025-10-26 13:45:08
  • cnns认证的检测机构(cnas检测机构认证)2025-10-26 13:45:08
  • 全屏图片