import torch import torch.nn as nn import math from typing import Optional, Union from torch.nn.functional import scaled_dot_product_attention class MOBAAttention(nn.Module): def __init__( self, embed_dim: int, num_heads: int, chunk_size: int = 64, topk: int = 4, dropout: float = 0.1, causal: bool = True, use_flash: bool = True, ): """ Args: embed_dim (int): 总嵌入维度(必须能被 num_heads 整除) num_heads (int): 注意力头数 chunk_size (int): 每个块的 token 数量 topk (int): 每个 query 最多 attend 到几个 chunk dropout (float): attention dropout rate causal (bool): 是否使用因果掩码(防止看到未来信息) use_flash (bool): 是否尝试使用 Flash / SDPA 加速 """ super().__init__() assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.head_dim = embed_dim // num_heads self.num_heads = num_heads self.chunk_size = chunk_size self.topk = topk self.dropout = dropout self.causal = causal self.use_flash = use_flash self.scale = self.head_dim ** (-0.5) self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: """ Args: x: shape [B, T, E] attention_mask: shape [B, T], 1 表示有效 token,0 表示 padding Returns: Output tensor of shape [B, T, E] """ B, T, E = x.shape H = self.num_heads D = self.head_dim # Step 1: 生成 QKV qkv = self.qkv_proj(x).reshape(B, T, 3, H, D).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # each is [B, H, T, D] o = torch.zeros_like(q) # 构建 cu_seqlens(用于 batched varlen 支持) if attention_mask is not None: lengths = attention_mask.sum(dim=-1).cpu() else: lengths = torch.full((B,), T, dtype=torch.long) cu_seqlens = torch.cat([torch.tensor([0]), lengths.cumsum(dim=0)], dim=0).to(x.device) for b in range(B): start = cu_seqlens[b].item() end = cu_seqlens[b + 1].item() seq_len = end - start q_ = q[b, :, :seq_len] # [H, Q, D] k_ = k[b, :, :seq_len] # [H, K, D] v_ = v[b, :, :seq_len] # [H, V, D] # Step 3.1: 构建 chunk gate weight key_gate_weight = [] num_block = math.ceil(seq_len / self.chunk_size) for i in range(num_block): s = i * self.chunk_size e = min(seq_len, s + self.chunk_size) key_gate_weight.append(k_[:, s:e].mean(dim=2, keepdim=True)) # [H, 1, D] key_gate_weight = torch.cat(key_gate_weight, dim=1) # [H, N, D] # Step 3.2: 计算 gate score gate = torch.einsum("hqd,hkd->hqk", q_, key_gate_weight) # [H, Q, N] gate = gate.type(torch.float32) # Step 3.3: Causal masking if self.causal: for i in range(num_block): gate[:, : (i + 1) * self.chunk_size, i] = float("-inf") gate[:, i * self.chunk_size : (i + 1) * self.chunk_size, i] = float("inf") # Step 3.4: Top-K selection gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=min(self.topk, num_block), dim=-1) gate_top_k_val_min, _ = gate_top_k_val.min(dim=-1) need_attend = gate >= gate_top_k_val_min.unsqueeze(-1) gate_idx_mask = torch.zeros_like(gate, dtype=torch.bool).scatter_(-1, gate_top_k_idx, True) need_attend = torch.logical_and(need_attend, gate_idx_mask) gate[~need_attend] = -float("inf") # Step 3.5: Expand gate to token level gate = gate.repeat_interleave(self.chunk_size, dim=-1)[..., :seq_len] # [H, Q, K] # Step 3.6: Apply causal mask (optional) if self.causal: causal_mask = torch.ones((seq_len, seq_len), device=x.device).tril(diagonal=0) gate += (causal_mask == 0).unsqueeze(0) * -float("inf") # Step 3.7: Compute final attention and output qk = torch.einsum("hqd,hkd->hqk", q_, k_) # [H, Q, K] qk = (qk + gate) * self.scale p = qk.softmax(dim=-1) p = p.masked_fill(p.isnan(), 0) o[b, :, :seq_len] = torch.einsum("hqk,hkd->hqd", p, v_) attn_output = o # Step 4: reshape 输出并应用输出投影 attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1) return self.out_proj(attn_output) if __name__ == '__main__': device = 'cuda:0' if torch.cuda.is_available() else 'cpu' B, T, E, H = 2, 256, 768, 12 moba_attn = MOBAAttention(embed_dim=E, num_heads=H) moba_attn = moba_attn.to(device) x = torch.randn(B, T, E).to(device) attn_mask = torch.randint(0, 2, (B, T)).to(device) # 假设有 padding output = moba_attn(x, attention_mask=attn_mask) print(output.shape) # 应该输出: torch.Size([2, 256, 768])