微信扫码
添加专属顾问
 
                        我要投稿
本计划先说说Google新论文动态分配计算的transformer,但是最近meta 发布了Llama3,这是个大新闻,所以我们先说说关于的Llama的架构。
同作为decoder-only的transformer ,Llama的架构和我们说过的GPT2相差并不算大,主要体现在以下3点:
1 (GQA, Grouped-Query Attention)
2 旋转位置编码RoPE
3 使用RMSNorm
Llama3对比Llama2其改动主要体现在:
1 使用了更多数据训练(15T)
2 采用了新的 Tokenizer,将词汇表大小扩展至 128,256(前版本为 32,000 Token)
Llama的架构如下:
首先我们说说Grouped-Query Attention,Multi-Head-Attention将q k v 分为N组,每一组分别做 Attention,然后再concat。为了在保证效果的前提下节省计算量,Grouped-Query Attention采用了,一组Q共享一个K V的机制去做Attention
使用hidden_size = 768 ,num_heads = 8, num_key_value_heads =2,也就是分两组,一个head的dim为96,打印出来就是:
代码中 k,v 重复 8/2 = 4 次:
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
q k 矩阵乘法画出来就是
旋转位置编码是位置编码的一种方式,其核心思想在于,利用序列中token的距离产生相对的位置关系
具体做法是按照一个固定的角度,与位置旋转向量q和k,比如角度为θ, q的位置为m,k的位置为n,则q转m倍θ度,k转n倍θ度。这样做的好处在于,现在的编码不是绝对位置,而是相对q与k距离而产生的编码!
下图的R表示旋转矩阵:
以二维向量来看,旋转过程如下
现在的问题在于,q k是高维向量没法直接旋转,所以只能分成一组组二维向量,然后一组给一个θ按组转!于是旋转矩阵就变成了
现在整个qk的旋转如下图:
总结下来,RoPE其实就是将 q,k 分为dim/2 组二维向量,每一组给一个固定的角度θ,按照q k位置关系 (n-m)*θ去旋转,以此获得相对位置编码。
在代码中,我们并不会真的去乘那个大旋转矩阵R,而是采用等价的方式实现,甚至不会用到sin cos直接用复数就行:
在代码中Llama预先计算好所有位置的旋转角度,
self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)def apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,) -> Tuple[torch.Tensor, torch.Tensor]:"""Apply rotary embeddings to input tensors using the given frequency tensor.This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the providedfrequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensoris reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and arereturned as real tensors.Args:xq (torch.Tensor): Query tensor to apply rotary embeddings.xk (torch.Tensor): Key tensor to apply rotary embeddings.freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.Returns:Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings."xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))freqs_cis = reshape_for_broadcast(freqs_cis, xq_)xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)
注:以hidden_size = 768,num_heads = 8,num_key_value_heads =2算出来如下图(因为是按照二维分组,所以是96还要除2)
最后计算旋转位置编码的代码为:
def apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,-> Tuple[torch.Tensor, torch.Tensor]:"""Apply rotary embeddings to input tensors using the given frequency tensor.This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the providedfrequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensoris reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and arereturned as real tensors.Args:xq (torch.Tensor): Query tensor to apply rotary embeddings.xk (torch.Tensor): Key tensor to apply rotary embeddings.freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.Returns:: Tuple of modified query tensor and key tensor with rotary embeddings."""xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))freqs_cis = reshape_for_broadcast(freqs_cis, xq_)xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费POC验证,效果达标后再合作。零风险落地应用大模型,已交付160+中大型企业
 
            2025-08-21
2025-08-20
2025-09-07
2025-08-21
2025-08-19
2025-08-05
2025-09-16
2025-08-20
2025-10-02
2025-09-08
2025-10-31
2025-10-29
2025-10-29
2025-10-29
2025-10-28
2025-10-28
2025-10-28
2025-10-27