当前位置:网站首页 > R语言数据分析 > 正文

pointnet++复现(transunet复现)



前文词嵌入和位置嵌入已经做了模型推理前对token序列的预处理工作,本文来重点讨论语言模型的核心组件——Attention。

语言模型最早面临的问题是,如何从一种语言翻译成另一种语言,例如:德文翻译成英文,由于两种语言的语法结构不同,如果逐字翻译会造成语法错误。
在这里插入图片描述

为了解决无法逐字翻译文本的问题,出现了基于编码器-解码器架构的递归神经网络()。

  • 编码器按顺序对输入文本进行处理,在每个步骤中更新其隐藏状态,并试图在最终的隐藏状态中捕获输入句子的全部含义。
  • 解码器则以这个最终的隐藏状态作为输入,开始一个单词一个单词的生成翻译后的句子,并在每一步更新其隐藏状态,以确保下一个单词预测时携带了已经生成的词作为上下文。
    在这里插入图片描述

但RNN只适用于翻译短句,而不适用于较长的文本,它最大的在于:解码阶段完全依赖于编码器输出的最终隐藏状态,而无法访问编码器中更早期的隐藏状态,这可能会导致上下文丢失,尤其是在依赖关系可能跨越很长距离的复杂句子中。

为此,研究人员在 2014 年开发了 RNN 的 Bahdanau 注意力机制,该机制所作的重要修改是允许在解码步骤中选择性地访问输入序列的不同部分。
在这里插入图片描述

仅仅三年后,谷歌发表了论文《Attention is all you need》,指出RNN对于构建自然语言的神经网络并不是必需的,并提出了完全基于self-attention的transformer架构。

自注意力是一种更有效的进行输入表示的机制。它允许序列中的每个位置在计算它在序列中的表示时,都能关注序列中的所有位置。通俗来讲,就是。

下面我们将从头开始编写这种自注意力机制。

2.1 单个token的注意力权重

自注意力机制中的(self)是指注意力权重所要关注的是单一输入序列内部不同位置的联系。作为对比,传统注意力关注的是两个不同序列之间的元素关系。

自注意力的目标是为每个输入token计算一个上下文向量,该向量结合了序列中所有其他输入token的信息。

以一个输入文本 “Your journey starts with one step.” 为例,假设一个单词对应一个token,如果我们要计算第2个token的上下文向量z(2),那最终每个token x(1)……x(T)对z(2)的重要性将由注意力权重 21到 2T决定。
在这里插入图片描述

接下来,我们将用代码来演示这个计算的过程。

为了可显示的需要,我们降低了嵌入的维度,我们对每个token采用三维嵌入。

 
  
计算注意力得分

实现自注意力机制的第一步是计算中间变量 ω,这些变量被称为注意力得分。方法是计算 的嵌入向量x(2)与其它token的嵌入向量之间的点积。

注:点积是一个衡量两个向量相似度的数学工具,具体的计算可以认为是两个向量对应位置的元素相乘后求和。以x(1)和x(2)的点积为例,就是0.43 * 0.55 + 0.15 * 0.87 + 0.89 * 0.66 = 0.9544。

 
  
 
  

注:torch.empty()返回的是指定形状的零张量,例如size为6时返回。

那通过点积计算的这些数值有什么含义呢?

在自注意力机制中,,点积值越高,两个token之间的相似性和注意力得分就越高。

分数归一化

归一化就是将注意力得分中的每个数值进行变换,目标是各项数值的总和为1。归一化的最简单做法是每个元素除以所有元素之和,如下所示:

 
  
 
  

查验这些注意力分数相加之和是否为1:

 
  
 
  

但实际场景中,更推荐使用softmax进行归一化,与上面我们直接采用元素本身进行运算不同,softmax先对所有元素进行指数运算。

 
  
 
  

使用softmax进行归一化的好处在于:指数运算确保了所有注意力得分为正,这意味着输出可以被解释为概率,高数值代表更大的重要性。

现在我们已经得到了注意力权重, 下一步是将每个嵌入的向量 x(i) 与相应的注意力权重相乘,然后将结果向量求和,计算出上下文向量 z(2)。

 
  
 
  
2.2 所有token的注意力权重

上面对第2个token计算了上下文向量,下面会对代码作一些修改,以将计算过程扩展至所有输入token的注意力权重和上下文向量。

 
  
 
  

上面计算过程本质上是inputs矩阵和inputs矩阵的转置相乘,在pytorch中矩阵相乘有更简单的写法:

 
  
 
  

pytorch中的这种矩阵乘法运算不仅简洁,而且执行效率也比python中的for循环更高效。

下面对每行的注意力得分进行归一化,使其总和为1。同样,对于归一化也使用pytorch中封装的softmax来代替我们手工编写的softmax。

 
  
 
  

可以看到,使用pytorch的softmax计算的第二行权重向量和我们之前手工编写的softmax函数的计算结果是相同的,softmax函数的实现与我们预期相同。

简单确认下,各行相加之和是否为1。

 
  
 
  

最后通过矩阵乘法来生成所有的上下文向量,在得到的输出结果中,每一行都包含一个三维的上下文向量。

 
  
 
  

其中,第二行的输出张量与我们上面计算的完全相同,也说明代码计算无误。

到这里,我们就完成了一个简单自注意力的代码演示。接下来,我们将添加可训练的权重,使自己编写的注意力具备可学习性。

3.1 初始化权重矩阵

首先,我们定义输入和输出的维度dim_in和dim_out。

 
  

在前面简单注意力的基础上,引入三个可训练的权重矩阵 Wq、Wk 和 Wv,这三个矩阵会用于将嵌入的输入向量 x(i) 投影为查询向量Q、键向量K和值向量V。

 
  

注:这里将requires_grad设为False只是为了显示矩阵结果时更清晰,如果要将这些权重用于模型训练,则需要将require_grads设为True。

3.2 计算Q、K、V向量

同样,我们先以第二个token的输入向量x(2)作查询向量,来计算与其对应的查询(q)、键(k)和值(v)向量。

 
  
 
  

注:前面简单自注意力的实现代码中,我们是直接将嵌入向量自身x_i作为q、k、v,而这里的q、k、v三个向量是通过输入向量x_i与三个权重矩阵计算而来,这样的好处是能通过模型训练让这个自注意力不断进化。

尽管只是计算一个token的上下文向量z(2),但我们仍然需要所有token的键向量k和值向量v,因为查询向量q_2需要与序列中所有token的k向量和v向量运算,才能得到x_2的注意力权重和上下文向量。

 
  
 
  
3.3 计算注意力得分

如同前面的简单自注意力实现一样,注意力得分是一个点积运算,有所不同的是,我们不再直接计算输入元素,而是经过权重矩阵变换过的查询向量q、键向量k和值向量v。

在这里插入图片描述

 
  
 
  

注:点积操作是让查询向量q_2与keys中的每个行向量运算,由于keys是一个6行3列的矩阵,对keys进行转置变换后形状变为3行6列,这个点积操作就可以变换成q_2与keys.T的矩阵乘法运算,即q_2与keys.T中的每个列向量相乘。

3.4 计算注意力权重

从上面可以知道,注意力得分到注意力权重的变换是通过softmax进行归一化运算。与之前不同的是,这里我们会引入来计算注意力,具体操作就是在softmax之前,先将注意力得分除以键的嵌入维度的平方根来缩放注意力得分。

注:平方根可以表示成以1/2为底的冥运算,即嵌入维度的0.5次方。

 
  
 
  

缩放点积的原因在于:在类似GPT一样的大语言模型中,嵌入维度大到接近上千,大的嵌入维度在进行点积操作时也会产生大的点积和,而较大的点积应用softmax函数后,会在反向传播过程中产生非常小的梯度,这些小的梯度会减缓学习速度,甚至训练停滞。

通过嵌入维度的平方根进行缩放的这种注意力机制也被称为。

区分权重参数注意力权重:我们上面提到的Wq、Wk、Wv三个权重矩阵都属于权重参数,它是整个神经网络中的通用概念,是可以被训练的。而注意力权重只是Attention机制中的一个动态变量,它决定了输入序列中的不同token分别在多大程度上影响上下文向量。

3.5 计算上下文向量

得到了注意力权重后,我们通过对所有值向量加权求和就能计算出上下文向量,而注意力权重就是衡量每个值向量重要性的权重因子。

 
  
 
  

上面的整个计算过程可以用下面这个图来表示:

在这里插入图片描述

上面逐步计算自注意力的过程是为了理解和展示,实际应用中,会把上面的代码整合到一个python类中以便高效使用。

 
  

直接使用封装的自注意力类来计算上下文向量。

 
  
 
  

可以看到,第二行的上下文向量与我们上面手动计算的(tensor([0.6864, 1.0577, 1.1389]))是完全相同。

小结:本节从一个简单文本的三维嵌入开始,一步一步复现了自注意力的详细计算过程,而后从简单自注意力逐步过滤到可训练权重的自注意力,从基础的矩阵计算逐步过渡到pytorch的高级API,最后将这个计算过程封装为一个组件以备复用。下一节,我们将结合因果关系和多头元素对自注意机制进行改进。

  • 什么是词嵌入和位置嵌入?
  • 带你从零训练分词器
到此这篇pointnet++复现(transunet复现)的文章就介绍到这了,更多相关内容请继续浏览下面的相关推荐文章,希望大家都能在编程的领域有一番成就!

版权声明


相关文章:

  • 数组的some(数组的sort排序)2025-10-03 11:45:06
  • nowel是什么意思(nowre什么意思)2025-10-03 11:45:06
  • arse是什么意思(c0arse是什么意思)2025-10-03 11:45:06
  • naa/cr是啥(nac是什么意思中文)2025-10-03 11:45:06
  • ettercap扫描不到主机(twincat3扫描不到设备)2025-10-03 11:45:06
  • ucharit怎么读(ucan怎么读)2025-10-03 11:45:06
  • yum命令和rpm命令区别(linux yum和rpm)2025-10-03 11:45:06
  • 字符串转integer类型(字符串转成int类型)2025-10-03 11:45:06
  • 文件比较工具 beyond compare使用(bcompare比较文件夹)2025-10-03 11:45:06
  • vconn什么意思(vcorn什么意思)2025-10-03 11:45:06
  • 全屏图片