AI知识库 AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


LLaMA 3/2/1模型结构总览
浏览次数: 1539

作者:孟繁续,北京大学博士生 ,研究方向 LLM(大型语言模型)和模型压缩
主页:fxmeng.github.io
声明:原文已经授权,版权归原作者!
原文:https://zhuanlan.zhihu.com/p/636784644

LLaMA-3又出来了,综合表现非常惊艳,我在实际测试中能力也比LLaMA-2-7B,Mistral-7B和Gemma-7B效果好。模型还是直接复用之前的代码,不过最小的8B模型也用上了GQA了,实测速度挺快。手头的llama-2可以丢了,可以拥抱llama-3了。

llama2 出来了,并且开源可商用,这下开源社区又要变天了。快速看一下官网以及paper,看看llamav2相比v1有什么更新吧:

  • • 预训练语料从1->2 Trillion tokens

  • • context window 长度从2048->4096

  • • 收集了100k人类标注数据进行SFT

  • • 收集了1M人类偏好数据进行RLHF

  • • 在reasoning, coding, proficiency, and knowledge tests上表现超越MPT和Falcon

  • • 和falcon一样,使用了Group query attention,节省cache

LLaMA现在已经是开源社区里炙手可热的模型了,但是原文中仅仅介绍了其和标准Transformer的差别,并没有一个全局的模型介绍。找了找其他博客也都是和原文一样,没有介绍模型的结构总览。因此打算写这篇文章,争取让读者不参考任何其他资料把LLaMA的模型搞懂。

结构

如图所示为LLaMA的示意图,由Attention和MLP层堆叠而成:

模型的主要特点为:

  • • 前置的RMSNorm,

  • • 在Q、K上使用RoPE旋转式位置编码,

  • • 使用causal mask保证每个位置只能看到前面的tokens,

  • • LLaMA可以将更早的K、V拼接到当前K、V前面,可以用Q查找更早的信息,为了清晰没在图中画出来。

  • • MLP表达式:$down(up)(x) x SiLU(gate(x))$ ,其中down, up, gate都是线性层。

  • • V2 context window 4096,使用了Group Query Attention。

LLaMA各个不同大小的结构设置如下表所示。其中最大的65B的LLaMA用了2048张80GB的A100,batch size为4百万,训练一次需要21天。

params dimension n heads n layers learning rate n tokens A100-hours
6.7B 4096 32 32 3.0e−4 1.0T 82432
13.0B 5120 40 40 3.0e−4 1.0T 135168
32.5B 6656 52 60 1.5e−4 1.4T 530432
65.2B 8192 64 80 1.5e−4 1.4T 530432

Group Query Attention(V2 only)

自回归模型生成回答时,需要前面生成的KV缓存起来,来加速计算。多头注意力机制(MHA)需要的缓存量很大,Multi-Query Attention指出多个头之间可以共享KV对。Group Query Attention没有像MQA一样极端,将query分组,组内共享KV,效果接近MHA,速度上与MQA可比较。p.s. 这个技术falcon已经用上了,当时falcon说自己用的是multi query attention,因为当group=1时,GQA和MQA是等价的。falcon支持设置不同的G。

RMSNorm

这是在BERT、GPT等模型中广泛使用的LayerNorm:

RMSNorm(root mean square)发现LayerNorm的中心偏移没什么用(减去均值等操作)。将其去掉之后,效果几乎不变,但是速度提升了40%。最终公式为:

注意除了没有减均值,加偏置以外,分母上求的RMS而不是方差。

LLaMA在 Attention Layer和MLP的输入上使用了RMSNorm,相比在输出上使用,训练会更加稳定。

SwiGLU

LLaMA没有使用ReLU,而是使用了SwiGLU,有时也被称为SiLU。公式为:  ,效果类似平滑版的ReLU:

RoPE

LLaMA使用了Rotary Position Embedding。对于Q的第m个位置向量q,通过以下方法注入位置编码:

其中  是值介于[1,0)之间的固定向量。通过以下代码得到了上式中的第二项  和第四项  。

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000):
        super().__init__()
        theta = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
        t = torch.arange(max_position_mbeddings)
        freqs = torch.einsum("i,j->ij", t, theta)

        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos())
        self.register_buffer("sin_cached", emb.sin())

    def forward(self, seq_len=None):
        return self.cos_cached[:, :, :seq_len, ...], self.sin_cached[:, :, :seq_len, ...]

# 在LlamaAttention通过以下命令调用:
cos, sin = self.rotary_emb(seq_len=kv_seq_len)

以下代码将q沿着最后一个维度劈成两半,将后一半乘-1,然后连接在第一半之前,就得到了上式第三项。

# 在接下来的apply_rotary_pos_emb函数里调用

def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

后通过以下代码得到结合了位置编码的Q,K(K和Q使用同样的方式进行位置编码)。

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    q_embed = (q * cos[position_ids]) + (rotate_half(q) * sin[position_ids])
    k_embed = (k * cos[position_ids]) + (rotate_half(k) * sin[position_ids])
    return q_embed, k_embed

# 在LlamaAttention中通过以下命令调用:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

使用了这么复杂的位置编码,有什么好处呢?从上面的公式可以看出,RoPE形式上是绝对位置编码,即依赖其绝对位置m。

绝对位置编码的优点是计算速度快等,缺点是拓展长度比较麻烦,且绝对位置并没有什么实际意义。而相对位置编码对学习token之间的关系很有意义,比如距离的很远的两个token之间的关联大概率很小,使用相对位置编码往往能够获得更好的效果。此外拓展长度也更容易,因为不论context size多长,只需关注最长距离以内的输入即可。相对位置编码的缺点是没有绝对位置编码计算速度快。

当我们计算Attention时,RoPE可以变成相对位置编码

从上面这个公式可以看出,q和k的attention依赖相对距离m-n。因此RoPE为q、k注入的绝对位置编码,计算得到的attention,却变成了相对位置编码。妙的很,我这里为了不参考其他文章就很容易搞懂LLaMA的结构,简化了很多东西,推荐大家看一看RoPE原作者苏剑林[1]的博客了解更多信息。

文中参考的代码是huggingface的transformers库实现的版本,并不是Meta官方的代码。受笔者水平限制,如果哪里讲的不对,或者不够清晰易懂,欢迎在评论区交流。

引用链接

[1] 苏剑林: https://kexue.fm/archives/8265

青稞Talk预告

5月10日(周五)晚7点,【青稞Talk】第五期,3D-VLA第一作者甄昊宇,直播分享《3D-VLA:构建生成式三维具身世界模型》。



都看到这了,点个关注再走吧?~



》加入青稞社区·与青科同行《

备注:姓名+学校/公司+方向

推荐新闻
RAG系列04:使用ReRank进行重排序
本文介绍了重排序的原理和两种主流的重排序方法:基于重排模型和基于 LLM。文章指出,重排序是对检索到的上下文进行再次筛选的过程,类似于排序过程中的粗排和精排。在检索增强生成中,精排的术语就叫重排序。文章还介绍了使用 Cohere 提供的在线模型、bge-reranker-base 和 bge-reranker-large 等开源模型以及 LLM 实现重排序的方法。最后,文章得出结论:使用重排模型的方法轻量级、开销较小;而使用 LLM 的方法在多个基准测试上表现良好,但成本较高,且只有在使用 ChatGPT 和 GPT-4 时表现良好,如使用其他开源模型,如 FLAN-T5 和 Vicuna-13B 时,其性能就不那么理想。因此,在实际项目中,需要做出特定的权衡。
LangGPT论文:面向大语言模型的自然语言编程框架(中文版)
大语言模型 (Large Language Models, LLMs) 在不同领域都表现出了优异的性能。然而,对于非AI专家来说,制定高质量的提示来引导 LLMs 是目前AI应用领域的一项重要挑战。
第三篇:要真正入门AI,OpenAI的官方Prompt工程指南肯定还不够,您必须了解的强大方法论和框架!!!
自从ChatGPT(全名:Chat Generative Pre-trained Transformer)于2022年11月30日发布以来,一个新兴的行业突然兴起,那就是提示工程(Prompt engineering),可谓如日冲天。从简单的文章扩写,到RAG,ChatGPT展现了前所未有的惊人能力。
(三)12个RAG痛点及其解决方案
痛点9:结构化数据QA 痛点10:从复杂 PDF 中提取数据 痛点11:后备模型 痛点12:LLM安全
(二)12个RAG痛点及其解决方案
痛点5:格式错误 痛点6:不正确的特异性 痛点7:不完整 痛点8:数据摄取可扩展性

联系我们

售前咨询
186 6662 7370
产品演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询