深度学习-----------------------注意力分数

news/2024/10/8 12:32:38 标签: 深度学习, 人工智能

目录

  • 注意力分数
    • 注意力打分函数代码
  • 掩蔽softmax操作
  • 拓展到高纬度
    • Additive Attention(加性注意力)
      • 加性注意力代码
      • 演示一下AdditiveAttention类
      • 该部分总代码
      • 注意力权重
    • Scaled Dot-Product Attention(缩放点积注意力)
      • 缩放点积注意力代码
      • 演示一下DotProductAttention类
      • 该部分总代码
      • 注意力权重
  • 总结

注意力分数

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述




注意力打分函数代码

import math
import torch
from torch import nn
from d2l import torch as d2l



掩蔽softmax操作

import torch
from torch import nn
from d2l import torch as d2l


def masked_softmax(X, valid_lens):
    """通过在最后一个轴上遮蔽元素来执行softmax操作"""
    if valid_lens is None:
        # 如果valid_lens为空,则对X执行softmax操作
        return nn.functional.softmax(X, dim=-1)
    else:
        # shape的形状为(2,2,4)
        shape = X.shape
        # 判断有效长度是否是一维的
        if valid_lens.dim() == 1:
            # valid_lens重复两次[2,3]→[2,2,3,3],和x的列数一样
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 将valid_lens重塑为一维向量
            valid_lens = valid_lens.reshape(-1)
        # 在X的最后一个维度(即:列)上进行遮蔽操作
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        # 对遮蔽后的X执行softmax操作,并将形状还原为原始形状
        return nn.functional.softmax(X.reshape(shape), dim=-1)


print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])))

在这里插入图片描述

print(masked_softmax(torch.rand(2,2,4), torch.tensor([[1,3],[2,4]])))

在这里插入图片描述





拓展到高纬度

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述




Additive Attention(加性注意力)

(拓展到多维)

在这里插入图片描述


可学参数

在这里插入图片描述
等价于将key和query合并起来后放入到一个隐藏大小为h输出大小为1的单隐藏层MLP。

它的好处是:key、value、query可以是任意的长度。




加性注意力代码

需要学习三个参数:key_size, query_size, num_hiddens

class AdditiveAttention(nn.Module):
    """加性注意力"""

    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        # 用于生成注意力分数的线性变换
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        # queries的形状:(batch_size,查询的个数,num_hiddens),key是(batch_size,键的数目,num_hiddens)
        # 两者不能直接相加
        queries, keys = self.W_q(queries), self.W_k(keys)
        # 执行加性操作,将查询和键相加
        # queries加一维进去,变成了(batch_size,查询的个数,1,num_hiddens),key在第一维加一个维度,变成了(batch_size,1,键的数目,num_hiddens)
        # 最后features变成了(batch_size,number_querys,number_keys,num_hiddens)
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # features的形状:(batch_size,number_querys,number_keys,1)
        # 使用线性变换生成注意力分数,并将最后一维的维度压缩掉
        scores = self.w_v(features).squeeze(-1)
        # 使用遮蔽softmax计算注意力权重
        self.attention_weights = masked_softmax(scores, valid_lens)
        # 根据注意力权重对values进行加权求和
        return torch.bmm(self.dropout(self.attention_weights), values)



演示一下AdditiveAttention类

# queries是一个批量大小为2,1个query,query长度为20
# keys是一个批量大小为2,10个key,key的长度为2
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# repeat(2, 1, 1)沿着第一个维度重复两次(共两个)
# values是一个批量大小为2,10个value,value的长度为4
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
# 第一个样本看前两个,第二个样本看前6个
valid_lens = torch.tensor([2, 6])
# 创建加性注意力对象
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
# 调用加性注意力对象的forward方法
print(attention(queries, keys, values, valid_lens))



该部分总代码

import math
import torch
from torch import nn
from d2l import torch as d2l


def masked_softmax(X, valid_lens):
    """通过在最后一个轴上遮蔽元素来执行softmax操作"""
    if valid_lens is None:
        # 如果valid_lens为空,则对X执行softmax操作
        return nn.functional.softmax(X, dim=-1)
    else:
        # shape的形状为(2,2,4)
        shape = X.shape
        # 判断有效长度是否是一维的
        if valid_lens.dim() == 1:
            # valid_lens重复两次[2,3]→[2,2,3,3],和x的列数一样
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 将valid_lens重塑为一维向量
            valid_lens = valid_lens.reshape(-1)
        # 在X的最后一个维度(即:列)上进行遮蔽操作
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        # 对遮蔽后的X执行softmax操作,并将形状还原为原始形状
        return nn.functional.softmax(X.reshape(shape), dim=-1)


# 加性注意力
class AdditiveAttention(nn.Module):
    """加性注意力"""

    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        # 用于生成注意力分数的线性变换
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)


queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
print(attention(queries, keys, values, valid_lens))

在这里插入图片描述


注意力权重

# 调用d2l.show_heatmaps函数,显示注意力权重的热图
d2l.show_heatmaps(attention.attention_weights.reshape((1,1,2,10)),
                 xlabel='Keys', ylabel='Queries')

在这里插入图片描述




Scaled Dot-Product Attention(缩放点积注意力)

如果query和key都是同样的长度,q、k∈ R d R^d Rd,那么可以:

在这里插入图片描述

向量化版本(拓展到多维)

在这里插入图片描述




缩放点积注意力代码

好处是不需要学习参数

class DotProductAttention(nn.Module):
    """缩放点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        # Dropout层,用于随机丢弃一部分注意力权重
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # 获取查询向量的维度d
        d = queries.shape[-1]
        # 计算点积注意力得分,并进行缩放
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        # 使用遮蔽softmax计算注意力权重
        self.attention_weights = masked_softmax(scores, valid_lens)
        # 根据注意力权重对values进行加权求和
        return torch.bmm(self.dropout(self.attention_weights), values)



演示一下DotProductAttention类

queries = torch.normal(0,1,(2,1,2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
# 调用缩放点积注意力对象的forward方法
attention(queries, keys, values, valid_lens)



该部分总代码

import math
import torch
from torch import nn
from d2l import torch as d2l


def masked_softmax(X, valid_lens):
    """通过在最后一个轴上遮蔽元素来执行softmax操作"""
    if valid_lens is None:
        # 如果valid_lens为空,则对X执行softmax操作
        return nn.functional.softmax(X, dim=-1)
    else:
        # shape的形状为(2,2,4)
        shape = X.shape
        # 判断有效长度是否是一维的
        if valid_lens.dim() == 1:
            # valid_lens重复两次[2,3]→[2,2,3,3],和x的列数一样
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 将valid_lens重塑为一维向量
            valid_lens = valid_lens.reshape(-1)
        # 在X的最后一个维度(即:列)上进行遮蔽操作
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        # 对遮蔽后的X执行softmax操作,并将形状还原为原始形状
        return nn.functional.softmax(X.reshape(shape), dim=-1)


class DotProductAttention(nn.Module):
    """缩放点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        # Dropout层,用于随机丢弃一部分注意力权重
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # 获取查询向量的维度d
        d = queries.shape[-1]
        # 计算点积注意力得分,并进行缩放
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        # 使用遮蔽softmax计算注意力权重
        self.attention_weights = masked_softmax(scores, valid_lens)
        # 根据注意力权重对values进行加权求和
        return torch.bmm(self.dropout(self.attention_weights), values)


# keys是一个批量大小为2,10个key,key的长度为2
queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.ones((2, 10, 2))
# repeat(2, 1, 1)沿着第一个维度重复两次(共两个)
# values是一个批量大小为2,10个value,value的长度为4
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])
# 创建缩放点积注意力对象
attention = DotProductAttention(dropout=0.5)
# 设置为评估模式,不使用dropout
attention.eval()
# 调用缩放点积注意力对象的forward方法
print(attention(queries, keys, values, valid_lens))

在这里插入图片描述



注意力权重

# 调用d2l.show_heatmaps函数,显示注意力权重的热图
d2l.show_heatmaps(attention.attention_weights.reshape((1,1,2,10)),
                 xlabel='Keys', ylabel='Queries')

在这里插入图片描述




总结

注意力分数是query和key的相似度,注意力权重分数的softmax结果。

两种常见的分数计算
    将query和key合并起来进入一个单输出单隐藏的MLP。(加性注意力)
    直接将query和key做内积。(缩放点积注意力)


http://www.niftyadmin.cn/n/5694129.html

相关文章

《Windows PE》4.1.4 手工重构导入表

接下来我们做一个稍微复杂一些的实验,实验需要四个程序: HelloWorld.exe:弹出MessageBox窗口(实验1已实现)。 Regedit.exe:添加注册表启动项。 LockTray.exe:锁定任务栏窗口。 UnLockTray.exe&…

用java编写飞机大战

游戏界面使用JFrame和JPanel构建。背景图通过BG类绘制。英雄机和敌机在界面上显示并移动。子弹从英雄机发射并在屏幕上移动。游戏有四种状态:READY、RUNNING、PAUSE、GAMEOVER。状态通过鼠标点击进行切换:点击开始游戏(从READY变为RUNNING&am…

Studying-多线程学习Part1-线程库的基本使用、线程函数中的数据未定义错误、互斥量解决多线程数据共享问题

来源:多线程编程 线程库的基本使用 两个概念: 进程是运行中的程序线程是进程中的进程 串行运行:一次只能取得一个任务并执行这一个任务 并行运行:可以同时通过多进程/多线程的方式取得多个任务,并以多进程或多线程…

黑马JavaWeb开发跟学(十二)SpringBootWeb案例

黑马JavaWeb开发跟学十二.SpringBootWeb案例 案例-登录认证1. 登录功能1.1 需求1.2 接口文档1.3 思路分析1.4 功能开发1.5 测试 2. 登录校验2.1 问题分析2.2 会话技术2.2.1 会话技术介绍2.2.2 会话跟踪方案2.2.2.1 方案一 - Cookie2.2.2.2 方案二 - Session2.2.2.3 方案三 - 令…

【分布式微服务云原生】Redis持久化策略:RDB vs AOF

Redis持久化策略:RDB vs AOF 摘要 本文深入探讨了Redis的两种主要持久化策略:RDB和AOF。我们将分析它们的工作原理、优缺点,并探讨如何在不同的应用场景中选择最合适的持久化策略。此外,文章还将提供Java代码示例和流程图&#…

828华为云征文 | 华为云Flexus X实例在混合云环境中的应用与实践

目录 前言 1. 混合云环境的优势与挑战 1.1 混合云的优势 1.2 混合云的挑战 2. Flexus X实例的配置与集成 2.1 Flexus X实例简介 2.2 Flexus X实例的混合云部署 2.3 配置步骤与措施 3. 数据迁移与同步策略 3.1 数据迁移方案 3.2 数据同步措施 4. 安全性与合规性管理…

性能剖析利器-Conan|得物技术

作者 / 得物技术 - 仁慈的狮子 目录 一、背景 1. 局限性 2. 向前一步 二、原理剖析 1. 系统架构 2. 工作模式 3. reporter 三、稳定性验证 四、案例分析 五、写在最后 一、背景 线上问题的定位与优化是程序员进阶的必经之路,常见的问题定位手段有日志排查、分布式链…

操作系统 | 学习笔记 | 王道 | 3.2 虚拟内存管理

3.2 虚拟内存管理 3.2.1 虚拟内存的基本概念 传统存储管理方式的特征 传统存储管理方式 连续分配 单一连续分配固定分区分配动态分区分配 非连续分配 基本分页存储管理基本分段存储管理基本段页式存储管理 特征: 一次性: 作业必须一次性全部装入内存后…