支持私有化部署
AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


使用quarot量化qwen3并实现在线推理

发布日期:2025-07-30 21:58:45 浏览次数: 1535
作者:西西嘛呦

微信搜一搜,关注“西西嘛呦”

推荐语

探索Qwen3-8B模型的高效量化方案,教你如何实现w8a8在线旋转量化并适配transformers推理框架。

核心内容:
1. quarot旋转量化的配置与实现方法
2. 针对Qwen3-8B模型的在线旋转量化适配技巧
3. 量化权重保存策略与transformers推理框架的整合方案

杨芳贤
53AI创始人/腾讯云(TVP)最具价值专家


代码已上传:https://github.com/taishan1994/LLM-Quantization#

quarot旋转量化如果加入在线旋转,则需要修改模型的forward。这里我们使用在线旋转并且适配transformers推理。

  • 使用的量化框架:llmc
  • 使用的推理框架:transformers(当然也可以替换成vllm和sglang进行适配)

首先按照llmc给的环境安装基础环境,然后在configs/quantization下新建一个mine文件夹,里面定义好在线旋转量化的配置:quarot_w_a.ymlllmc目前只支持opt和llama模型的在线旋转,因此需要将model的type设置为Llama(qwen3和llama的结构基本差不多)

base:
    seed:&seed42
model:
    type:Llama
    path:/data/gongoubo/checkpoints/Qwen/Qwen3-8B
    torch_dtype:auto
quant:
    method:Quarot
    weight:
        bit:8
        symmetric:True
        granularity:per_channel
        group_size:-1
        calib_algo:minmax
    act:
        bit:8
        symmetric:True
        granularity:per_token
    special:
        rotate_mode:hadamard
        fp32_had:True
        online_rotate:True
save:
    save_trans:True
    save_fake:True
    save_vllm:True
    save_path:/data/gongoubo/checkpoints/Qwen/llmc/Qwen3-8B-w8a8-online

我们采用w8a8量化,权重采用per channel,激活采用per-token。需要注意:

  • save_trans:保存quarot旋转但不量化的权重。
  • save_fake:保存quarot旋转并量化后反量化的权重。
  • save_vllm:保存quarot旋转后量化后的权重。

由于transformers不支持w8a8的推理,因此我们使用save_fake保存的权重。

进行在线旋转时,我们需要修改两个地方:attention的v,o以及mlp里面的up,down。

class Qwen3MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

        # self.K = 12
        self.had_K, self.K = get_hadK(self.intermediate_size)

        had_K_tensor, K_tensor = get_hadK(self.intermediate_size)
        self.rotater = Rotater(
            online_full_had=True,  # for mlps, we use online full hadamard transform
            online_partial_had=False,
            fp32_had=True,
            K=K_tensor,
            had_K=had_K_tensor,
            had_dim=None,  # for mlps, the had_dim is not used
        )
        print(f'enable online rotate for Qwen2MLP')
        # Explicitly move tensors to the correct device and dtype
        #target_device = self.gate_proj.weight.device
        #self.rotater.had_K = self.rotater.had_K.to(device=target_device, dtype=torch.float32, non_blocking=True)

    def forward(self, x):
        act = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        # act = (act.float(), self.had_K, self.K).to(x.dtype)
        act = self.rotater.rotate(act)
        down_proj = self.down_proj(act)
        # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj
class Qwen3Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: Qwen3Config, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_head = config.num_attention_heads
        self.num_kv_head =config.num_key_value_heads
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )
        self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)  # unlike olmo, only on the head dim!
        self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)  # thus post q_norm does not need reshape
        self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention"elseNone

        had_K_tensor, K_tensor = get_hadK(
            self.num_head
        )  # for attention, we use partial hadamard transform
        print(had_K_tensor, K_tensor, self.num_head)
        self.rotater = Rotater(
            online_full_had=False,  # for attention, we use online partial hadamard transform
            online_partial_had=True,
            fp32_had=True,
            K=K_tensor,
            had_K=had_K_tensor,
            had_dim=self.head_dim,
        )

        print("enable Qwen3Attention")

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    )
 -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:

        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)


        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(12)
        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(12)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(12)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # 在这里进行在线旋转
        init_q_shape = query_states.shape
        init_k_shape = key_states.shape

        # # print(query_states.shape, key_states.shape)
        # query_states = (
        #     fast_hadamard_transform.hadamard_transform(
        #         query_states.to(torch.float32),
        #         scale=1 / math.sqrt(self.head_dim),
        #     )
        #     .reshape(init_q_shape)
        #     .contiguous()
        # ).to(value_states.dtype)
        #
        # key_states = (
        #     fast_hadamard_transform.hadamard_transform(
        #         key_states.to(torch.float32),
        #         scale=1 / math.sqrt(self.head_dim),
        #     )
        #     .reshape(init_k_shape)
        #     .contiguous()
        # ).to(value_states.dtype)


        if past_key_value isnotNone:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0ifnot self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=self.sliding_window,  # diff with Llama
            **kwargs,
        )

        attn_output = attn_output.reshape(-1, self.num_head * self.head_dim)
        attn_output = self.rotater.rotate(attn_output)
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()

        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

推理时正常推理即可:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import torch
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

model_name = "/data/gongoubo/checkpoints/Qwen/llmc/Qwen3-8B-w8a8-online/fake_quant_model/"
# model_name = "/data/gongoubo/Qwen-1.5-Factory/model_hub/Qwen/Qwen2___5-1___5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
CONFIG = AutoConfig.from_pretrained(model_name)
# transformers==4.53.0
# from modeling_qwen3 import Qwen3ForCausalLM

process_word_embeddings= False
if CONFIG.tie_word_embeddings:
    CONFIG.tie_word_embeddings = False
    process_word_embeddings = True
from modeling_qwen3_online_llmc import Qwen3ForCausalLM
# from modeling_qwen3_online_r3_r4 import Qwen3ForCausalLM
# from modeling_qwen3 import Qwen3ForCausalLM
model = Qwen3ForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, config=CONFIG).to("cuda:0")

# model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
# 确保模型在评估模式
model.eval()

# message = [{"role":"user", "content":"你是谁?"}]
# message = tokenizer.apply_chat_template(message, tokenize=False, add_special_tokens=True)

message = "<|im_start|>user\n你是谁?<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer.encode(message, return_tensors="pt")
input_ids = input_ids.to(model.device)
direct_output = model.generate(input_ids, max_new_tokens=256, do_sample=False, temperature=1)
direct_text = tokenizer.decode(direct_output[0])

print(direct_text)

注意加载模型修改后的qwen3的模型结构。


53AI,企业落地大模型首选服务商

产品:场景落地咨询+大模型应用平台+行业解决方案

承诺:免费POC验证,效果达标后再合作。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

添加专属顾问

回到顶部

加载中...

扫码咨询