Code

文件名称 实现功能 文件地址
MHA 多头注意力模块,支持 flash_attn,输入数据格式为:x:(B,T,C),atten_mask:(B,T) 🔗
GQA 分组注意力模块,支持 flash_attn,输入数据格式为:x:(B,T,C),atten_mask:(B,T) 🔗
MQA 多查询注意力模块,支持 flash_attn,输入数据格式为:x:(B,T,C),atten_mask:(B,T) 🔗
SWA 滑动窗口注意力模块,支持 flash_attn,输入数据格式为:x:(B,T,C),atten_mask:(B,T) 🔗
PosEncoding 位置编码,RotaryPositionalEncoding,AbsolutePositionEmbedding,LearnedPositionEmbedding。输入:x:(B,T,C) 🔗
Norm 归一化操作,LayerNorm,BatchNorm,RMSNorm,InstanceNorm,GlobalResponseNorm。输入:(B,T,C) 或者 (B,C,H,W) 🔗
ResNet 视觉编码器,ResNet50, ResNet101, ResNet152系列 🔗
ConvNeXt 视觉编码器,ConvNeXt v1系列 🔗
Vit 视觉编码器,Vit 🔗
SwinTransformer 视觉编码器,SwinTransformer 🔗
-->