微信扫码
添加专属顾问
我要投稿
探索Qwen3-8B模型的高效量化方案,教你如何实现w8a8在线旋转量化并适配transformers推理框架。 核心内容: 1. quarot旋转量化的配置与实现方法 2. 针对Qwen3-8B模型的在线旋转量化适配技巧 3. 量化权重保存策略与transformers推理框架的整合方案
代码已上传:https://github.com/taishan1994/LLM-Quantization#
quarot旋转量化如果加入在线旋转,则需要修改模型的forward。这里我们使用在线旋转并且适配transformers推理。
首先按照llmc给的环境安装基础环境,然后在configs/quantization下新建一个mine文件夹,里面定义好在线旋转量化的配置:quarot_w_a.yml
llmc目前只支持opt和llama模型的在线旋转,因此需要将model的type设置为Llama(qwen3和llama的结构基本差不多)
base:
seed:&seed42
model:
type:Llama
path:/data/gongoubo/checkpoints/Qwen/Qwen3-8B
torch_dtype:auto
quant:
method:Quarot
weight:
bit:8
symmetric:True
granularity:per_channel
group_size:-1
calib_algo:minmax
act:
bit:8
symmetric:True
granularity:per_token
special:
rotate_mode:hadamard
fp32_had:True
online_rotate:True
save:
save_trans:True
save_fake:True
save_vllm:True
save_path:/data/gongoubo/checkpoints/Qwen/llmc/Qwen3-8B-w8a8-online
我们采用w8a8量化,权重采用per channel,激活采用per-token。需要注意:
由于transformers不支持w8a8的推理,因此我们使用save_fake保存的权重。
进行在线旋转时,我们需要修改两个地方:attention的v,o以及mlp里面的up,down。
class Qwen3MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
# self.K = 12
self.had_K, self.K = get_hadK(self.intermediate_size)
had_K_tensor, K_tensor = get_hadK(self.intermediate_size)
self.rotater = Rotater(
online_full_had=True, # for mlps, we use online full hadamard transform
online_partial_had=False,
fp32_had=True,
K=K_tensor,
had_K=had_K_tensor,
had_dim=None, # for mlps, the had_dim is not used
)
print(f'enable online rotate for Qwen2MLP')
# Explicitly move tensors to the correct device and dtype
#target_device = self.gate_proj.weight.device
#self.rotater.had_K = self.rotater.had_K.to(device=target_device, dtype=torch.float32, non_blocking=True)
def forward(self, x):
act = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
# act = (act.float(), self.had_K, self.K).to(x.dtype)
act = self.rotater.rotate(act)
down_proj = self.down_proj(act)
# down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Qwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_head = config.num_attention_heads
self.num_kv_head =config.num_key_value_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention"elseNone
had_K_tensor, K_tensor = get_hadK(
self.num_head
) # for attention, we use partial hadamard transform
print(had_K_tensor, K_tensor, self.num_head)
self.rotater = Rotater(
online_full_had=False, # for attention, we use online partial hadamard transform
online_partial_had=True,
fp32_had=True,
K=K_tensor,
had_K=had_K_tensor,
had_dim=self.head_dim,
)
print("enable Qwen3Attention")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# 在这里进行在线旋转
init_q_shape = query_states.shape
init_k_shape = key_states.shape
# # print(query_states.shape, key_states.shape)
# query_states = (
# fast_hadamard_transform.hadamard_transform(
# query_states.to(torch.float32),
# scale=1 / math.sqrt(self.head_dim),
# )
# .reshape(init_q_shape)
# .contiguous()
# ).to(value_states.dtype)
#
# key_states = (
# fast_hadamard_transform.hadamard_transform(
# key_states.to(torch.float32),
# scale=1 / math.sqrt(self.head_dim),
# )
# .reshape(init_k_shape)
# .contiguous()
# ).to(value_states.dtype)
if past_key_value isnotNone:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0ifnot self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(-1, self.num_head * self.head_dim)
attn_output = self.rotater.rotate(attn_output)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
推理时正常推理即可:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import torch
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
model_name = "/data/gongoubo/checkpoints/Qwen/llmc/Qwen3-8B-w8a8-online/fake_quant_model/"
# model_name = "/data/gongoubo/Qwen-1.5-Factory/model_hub/Qwen/Qwen2___5-1___5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
CONFIG = AutoConfig.from_pretrained(model_name)
# transformers==4.53.0
# from modeling_qwen3 import Qwen3ForCausalLM
process_word_embeddings= False
if CONFIG.tie_word_embeddings:
CONFIG.tie_word_embeddings = False
process_word_embeddings = True
from modeling_qwen3_online_llmc import Qwen3ForCausalLM
# from modeling_qwen3_online_r3_r4 import Qwen3ForCausalLM
# from modeling_qwen3 import Qwen3ForCausalLM
model = Qwen3ForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, config=CONFIG).to("cuda:0")
# model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
# 确保模型在评估模式
model.eval()
# message = [{"role":"user", "content":"你是谁?"}]
# message = tokenizer.apply_chat_template(message, tokenize=False, add_special_tokens=True)
message = "<|im_start|>user\n你是谁?<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer.encode(message, return_tensors="pt")
input_ids = input_ids.to(model.device)
direct_output = model.generate(input_ids, max_new_tokens=256, do_sample=False, temperature=1)
direct_text = tokenizer.decode(direct_output[0])
print(direct_text)
注意加载模型修改后的qwen3的模型结构。
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费POC验证,效果达标后再合作。零风险落地应用大模型,已交付160+中大型企业
2025-05-07
2025-05-21
2025-05-26
2025-05-15
2025-06-17
2025-05-10
2025-06-21
2025-05-10
2025-05-13
2025-05-26