支持私有化部署
AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


RAG召回优化完全指南:从理论到实践的三大核心策略!

发布日期:2025-07-30 08:19:52 浏览次数: 1542
作者:多模态智能体

微信搜一搜,关注“多模态智能体”

推荐语

RAG系统召回效果不佳?三大核心策略帮你精准命中知识库内容,附完整代码实现!

核心内容:
1. 多查询扩展:通过LLM生成语义相同的查询变体提升召回率
2. 语义相似度优化:改进向量化方法增强查询-文档匹配度
3. 混合检索策略:结合关键词与语义搜索的优势实现互补

杨芳贤
53AI创始人/腾讯云(TVP)最具价值专家

在构建RAG(检索增强生成)系统时,我们经常遇到这样的问题:明明知识库里有相关内容,但检索出来的结果却不够准确。用户问"如何提高深度学习训练效率?",系统却返回了一堆关于深度学习基础理论的文档,真正有用的优化技巧反而被埋没了。

这个问题的根源在于召回环节的不足。召回是RAG系统的第一道关卡,它决定了后续生成的质量上限。今天我们就来深入探讨三个核心的召回优化策略,并提供完整的代码实现。

策略一:多查询扩展 - 用不同方式问同一个问题

核心思想

用户的查询往往表达不够充分,或者用词与知识库中的描述不匹配。多查询扩展通过LLM生成多个语义相同但表述不同的查询变体,增加命中相关文档的概率。

实现原理

  1. 使用LLM将原始查询改写成多个变体

  2. 每个变体独立进行检索

  3. 合并所有检索结果并去重


完整代码实现

import osimport openaifrom typing import List, Setimport numpy as npfrom sklearn.feature_extraction.text import TfidfVectorizerfrom sklearn.metrics.pairwise import cosine_similarityimport jiebaimport reclass MultiQueryRetriever:    def __init__(self, api_key: str, knowledge_base: List[str]):        """        初始化多查询检索器                Args:            api_key: OpenAI API密钥            knowledge_base: 知识库文档列表        """        openai.api_key = api_key        self.knowledge_base = knowledge_base        self.vectorizer = TfidfVectorizer(            tokenizer=lambda x: list(jieba.cut(x)),            lowercase=False,            max_features=5000        )        # 预先构建文档向量        self.doc_vectors = self.vectorizer.fit_transform(knowledge_base)        def generate_query_variants(self, original_query: str, num_variants: int = 3) -> List[str]:        """        生成查询变体                Args:            original_query: 原始查询            num_variants: 生成变体数量                    Returns:            查询变体列表        """        prompt = f"""        请对以下查询生成{num_variants}个语义相同但表述不同的变体,要求:        1. 保持原意不变        2. 使用不同的词汇和句式        3. 每个变体占一行                原始查询:{original_query}                变体查询:        """                try:            response = openai.ChatCompletion.create(                model="gpt-3.5-turbo",                messages=[{"role": "user", "content": prompt}],                temperature=0.7,                max_tokens=200            )                        variants = response.choices[0].message.content.strip().split('\n')            variants = [v.strip() for v in variants if v.strip()]                        # 过滤掉明显不合理的结果            valid_variants = []            for variant in variants:                # 移除序号                variant = re.sub(r'^\d+[\.\)]\s*', '', variant)                if len(variant) > 5 and len(variant) < 200:                    valid_variants.append(variant)                        return valid_variants[:num_variants]                    except Exception as e:            print(f"生成查询变体失败: {e}")            return [original_query]        def retrieve_single_query(self, query: str, top_k: int = 5) -> List[tuple]:        """        单个查询的检索                Args:            query: 查询文本            top_k: 返回文档数量                    Returns:            (文档索引, 相似度分数, 文档内容) 的列表        """        # 向量化查询        query_vector = self.vectorizer.transform([query])                # 计算相似度        similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()                # 获取top_k结果        top_indices = similarities.argsort()[-top_k:][::-1]                results = []        for idx in top_indices:            if similarities[idx] > 0.1:  # 过滤相似度过低的结果                results.append((                    idx,                     float(similarities[idx]),                     self.knowledge_base[idx]                ))                return results        def retrieve(self, query: str, top_k: int = 5) -> List[dict]:        """        多查询检索主函数                Args:            query: 原始查询            top_k: 最终返回的文档数量                    Returns:            检索结果列表        """        # 生成查询变体        print(f"原始查询: {query}")        variants = self.generate_query_variants(query)        print(f"生成的查询变体: {variants}")                all_queries = [query] + variants                # 收集所有检索结果        all_results = {}  # doc_idx -> (max_score, doc_content, matched_queries)                for i, q in enumerate(all_queries):            print(f"\n检索查询 {i+1}: {q}")            results = self.retrieve_single_query(q, top_k * 2)  # 每个查询多检索一些                        for doc_idx, score, content in results:                if doc_idx not in all_results:                    all_results[doc_idx] = [score, content, [q]]                else:                    # 保留最高分数,记录匹配的查询                    if score > all_results[doc_idx][0]:                        all_results[doc_idx][0] = score                    all_results[doc_idx][2].append(q)                # 按分数排序并返回top_k        sorted_results = sorted(            all_results.items(),             key=lambda x: x[1][0],             reverse=True        )[:top_k]                final_results = []        for doc_idx, (score, content, matched_queries) in sorted_results:            final_results.append({                'document_id': doc_idx,                'content': content,                'score': score,                'matched_queries': matched_queries,                'query_count': len(matched_queries)            })                return final_results# 使用示例def demo_multi_query():    """多查询检索演示"""        # 模拟知识库    knowledge_base = [        "深度学习模型训练效率优化的关键技术包括混合精度训练、梯度累积、模型并行等方法,可以显著提升训练速度。",        "深度学习的基础理论涉及神经网络的反向传播算法、激活函数的选择以及损失函数的设计原理。",        "AdamW优化器相比传统SGD在深度学习训练中表现更好,特别是在处理稀疏梯度时有明显优势。",        "分布式训练技术如数据并行和模型并行可以充分利用多GPU资源,大幅缩短大模型的训练时间。",        "混合精度训练通过FP16和FP32的结合使用,在保持模型精度的同时减少显存占用并加速训练过程。",        "深度学习框架PyTorch和TensorFlow各有优势,PyTorch更适合研究,TensorFlow更适合生产环境部署。",        "Transformer架构的自注意力机制revolutionized自然语言处理领域,成为现代大语言模型的基础。",        "卷积神经网络CNN在计算机视觉任务中表现出色,特别是在图像分类和目标检测方面。"    ]        # 初始化检索器(注意:需要设置你的OpenAI API Key)    api_key = "your-openai-api-key"  # 请替换为你的实际API Key    retriever = MultiQueryRetriever(api_key, knowledge_base)        # 执行检索    query = "如何提高深度学习模型的训练效率?"    results = retriever.retrieve(query, top_k=3)        # 展示结果    print(f"\n=== 检索结果 ===")    for i, result in enumerate(results, 1):        print(f"\n第{i}个结果 (相似度: {result['score']:.3f}):")        print(f"内容: {result['content']}")        print(f"匹配的查询数量: {result['query_count']}")        print(f"匹配的查询: {result['matched_queries']}")if __name__ == "__main__":    demo_multi_query()

效果分析

通过多查询扩展,原本可能遗漏的相关文档得以被发现。比如用户问"训练效率",变体查询可能包括"训练速度优化"、"模型训练加速"等,这样就能匹配到更多相关内容。

策略二:重排序 - 精确识别最相关内容

核心思想

初步检索往往基于简单的相似度计算,可能返回语义相近但主题不够精确的文档。重排序使用专门的相关性模型对初步结果进行精排。

实现原理

  1. 使用快速方法(如TF-IDF)进行初步召回

  2. 使用深度模型计算查询与每个候选文档的精确相关性

  3. 根据新的相关性分数重新排序


完整代码实现

from sentence_transformers import CrossEncoderimport torchfrom typing import List, Tupleimport jiebafrom sklearn.feature_extraction.text import TfidfVectorizerfrom sklearn.metrics.pairwise import cosine_similarityclass RerankingRetriever:    def __init__(self, knowledge_base: List[str], model_name: str = 'BAAI/bge-reranker-base'):        """        初始化重排序检索器                Args:            knowledge_base: 知识库文档列表            model_name: 重排序模型名称        """        self.knowledge_base = knowledge_base                # 初始化重排序模型        try:            self.reranker = CrossEncoder(model_name)            print(f"成功加载重排序模型: {model_name}")        except:            print("未能加载BGE重排序模型,将使用简化的重排序逻辑")            self.reranker = None                # 初始化TF-IDF向量器用于初步检索        self.vectorizer = TfidfVectorizer(            tokenizer=lambda x: list(jieba.cut(x)),            lowercase=False,            max_features=5000,            ngram_range=(1, 2)        )        self.doc_vectors = self.vectorizer.fit_transform(knowledge_base)        def initial_retrieval(self, query: str, top_k: int = 20) -> List[Tuple[int, float, str]]:        """        初步检索阶段                Args:            query: 查询文本            top_k: 初步检索的文档数量                    Returns:            (文档索引, 初步分数, 文档内容) 的列表        """        # TF-IDF检索        query_vector = self.vectorizer.transform([query])        similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()                # 获取top_k候选        top_indices = similarities.argsort()[-top_k:][::-1]                candidates = []        for idx in top_indices:            if similarities[idx] > 0.05:  # 过滤过低相似度                candidates.append((                    idx,                     float(similarities[idx]),                     self.knowledge_base[idx]                ))                return candidates        def calculate_rerank_score(self, query: str, document: str) -> float:        """        计算重排序分数                Args:            query: 查询文本            document: 文档文本                    Returns:            重排序分数        """        if self.reranker:            # 使用BGE重排序模型            score = self.reranker.predict([(query, document)])            return float(score[0])        else:            # 简化的重排序逻辑:基于关键词匹配和长度惩罚            query_words = set(jieba.cut(query.lower()))            doc_words = set(jieba.cut(document.lower()))                        # 计算词汇重叠            overlap = len(query_words & doc_words)            union = len(query_words | doc_words)                        if union == 0:                return 0.0                        # Jaccard相似度            jaccard = overlap / union                        # 长度惩罚:过长或过短的文档得分降低            doc_len = len(document)            length_penalty = 1.0            if doc_len < 20:                length_penalty = 0.7            elif doc_len > 500:                length_penalty = 0.8                        # 查询词在文档中的位置权重            position_weight = 1.0            doc_lower = document.lower()            for word in query_words:                if word in doc_lower:                    pos = doc_lower.find(word) / len(doc_lower)                    # 越靠前权重越高                    position_weight += (1 - pos) * 0.1                        return jaccard * length_penalty * position_weight        def rerank_candidates(self, query: str, candidates: List[Tuple[int, float, str]]) -> List[dict]:        """        对候选文档进行重排序                Args:            query: 查询文本            candidates: 候选文档列表                    Returns:            重排序后的结果        """        reranked_results = []                print(f"正在对{len(candidates)}个候选文档进行重排序...")                for doc_idx, initial_score, content in candidates:            # 计算重排序分数            rerank_score = self.calculate_rerank_score(query, content)                        reranked_results.append({                'document_id': doc_idx,                'content': content,                'initial_score': initial_score,                'rerank_score': rerank_score,                'final_score': rerank_score  # 可以结合initial_score            })                # 按重排序分数排序        reranked_results.sort(key=lambda x: x['rerank_score'], reverse=True)                return reranked_results        def retrieve(self, query: str, top_k: int = 5) -> List[dict]:        """        完整的检索流程:初步检索 + 重排序                Args:            query: 查询文本            top_k: 最终返回的文档数量                    Returns:            最终检索结果        """        print(f"查询: {query}")                # 步骤1: 初步检索        print("步骤1: 初步检索...")        candidates = self.initial_retrieval(query, top_k * 4)  # 多检索一些用于重排        print(f"初步检索到{len(candidates)}个候选文档")                if not candidates:            return []                # 步骤2: 重排序        print("步骤2: 重排序...")        reranked_results = self.rerank_candidates(query, candidates)                # 返回top_k结果        final_results = reranked_results[:top_k]                print(f"最终返回{len(final_results)}个结果")        return final_results# 使用示例和对比演示def demo_reranking():    """重排序检索演示"""        # 扩展的知识库    knowledge_base = [        "深度学习模型训练效率优化的关键技术包括混合精度训练、梯度累积、模型并行等方法,可以显著提升训练速度。AdamW优化器和学习率调度也很重要。",        "深度学习的基础理论涉及神经网络的反向传播算法、激活函数的选择以及损失函数的设计原理。这些是理解深度学习的基础。",        "混合精度训练是一种重要的优化技术,通过FP16和FP32的结合使用,在保持模型精度的同时减少显存占用并加速训练过程。",        "分布式训练技术如数据并行和模型并行可以充分利用多GPU资源,大幅缩短大模型的训练时间。这对大规模模型训练至关重要。",        "PyTorch和TensorFlow是两个主流的深度学习框架,各有优势。PyTorch更适合研究和原型开发,TensorFlow更适合生产环境。",        "Transformer架构revolutionized了自然语言处理领域,其自注意力机制成为现代大语言模型如GPT和BERT的基础架构。",        "卷积神经网络CNN在计算机视觉任务中表现出色,特别适用于图像分类、目标检测和图像分割等任务。",        "优化器的选择对深度学习训练效果有重大impact。AdamW相比传统SGD在处理稀疏梯度时有明显优势,特别适合Transformer模型。",        "梯度累积技术允许在内存受限的情况下模拟大批量训练,这对训练大型模型很有帮助。通过累积多个小批量的梯度再更新参数。",        "学习率调度策略如余弦退火、warmup等对训练稳定性和最终效果有重要影响。合适的学习率调度可以显著提升模型性能。"    ]        # 初始化检索器    retriever = RerankingRetriever(knowledge_base)        # 执行检索    query = "如何提高深度学习训练效率?"    results = retriever.retrieve(query, top_k=5)        # 展示结果    print(f"\n=== 重排序检索结果 ===")    for i, result in enumerate(results, 1):        print(f"\n第{i}个结果:")        print(f"重排序分数: {result['rerank_score']:.4f}")        print(f"初步检索分数: {result['initial_score']:.4f}")        print(f"内容: {result['content']}")        print("-" * 50)if __name__ == "__main__":    demo_reranking()

重排序的威力

重排序能够识别出真正与查询主题匹配的文档。比如对于"如何提高训练效率"的查询:

  • 初步检索可能因为"深度学习"这个词而返回基础理论文档

  • 重排序会发现"训练效率优化"、"混合精度训练"等文档更贴近查询意图,将它们排在前面


策略三:查询-文档双向扩展

核心思想

通过LLM的能力,我们可以:

  1. Query2Doc: 根据查询生成假设的文档内容,丰富查询表示

  2. Doc2Query: 为文档生成可能的问题,增加匹配机会


完整代码实现

import openaiimport pickleimport osfrom typing import List, Dict, Tupleimport jiebafrom sklearn.feature_extraction.text import TfidfVectorizerfrom sklearn.metrics.pairwise import cosine_similarityimport numpy as npclass QueryDocExpansionRetriever:    def __init__(self, api_key: str, knowledge_base: List[str], cache_file: str = "doc2query_cache.pkl"):        """        初始化查询-文档双向扩展检索器                Args:            api_key: OpenAI API密钥            knowledge_base: 知识库文档列表            cache_file: Doc2Query结果缓存文件        """        openai.api_key = api_key        self.knowledge_base = knowledge_base        self.cache_file = cache_file                # 加载或生成Doc2Query扩展        self.expanded_docs = self._load_or_generate_doc2query()                # 构建扩展文档的向量索引        self.vectorizer = TfidfVectorizer(            tokenizer=lambda x: list(jieba.cut(x)),            lowercase=False,            max_features=8000,            ngram_range=(1, 2)        )        self.doc_vectors = self.vectorizer.fit_transform(self.expanded_docs)                print(f"知识库初始化完成,共{len(self.knowledge_base)}个文档")        def _load_or_generate_doc2query(self) -> List[str]:        """        加载或生成Doc2Query扩展文档        """        if os.path.exists(self.cache_file):            print("从缓存加载Doc2Query扩展...")            with open(self.cache_file, 'rb') as f:                return pickle.load(f)        else:            print("生成Doc2Query扩展...")            expanded_docs = self._generate_doc2query_expansion()            # 保存到缓存            with open(self.cache_file, 'wb') as f:                pickle.dump(expanded_docs, f)            return expanded_docs        def _generate_doc2query_expansion(self) -> List[str]:        """        为每个文档生成可能的查询问题(Doc2Query)        """        expanded_docs = []                for i, doc in enumerate(self.knowledge_base):            print(f"处理文档 {i+1}/{len(self.knowledge_base)}")                        # 生成该文档可能回答的问题            generated_queries = self._generate_queries_for_doc(doc)                        # 将原文档与生成的问题拼接            expanded_content = doc            if generated_queries:                expanded_content += "\n\n可能相关的问题:\n" + "\n".join(generated_queries)                        expanded_docs.append(expanded_content)                return expanded_docs        def _generate_queries_for_doc(self, document: str, num_queries: int = 3) -> List[str]:        """        为单个文档生成可能的查询问题                Args:            document: 文档内容            num_queries: 生成问题数量                    Returns:            生成的问题列表        """        prompt = f"""        基于以下文档内容,生成{num_queries}个用户可能会问的、该文档能够回答的问题。        要求:        1. 问题要自然、具体        2. 问题应该能够通过文档内容得到答案        3. 问题表述要多样化        4. 每个问题占一行                文档内容:        {document}                生成的问题:        """                try:            response = openai.ChatCompletion.create(                model="gpt-3.5-turbo",                messages=[{"role": "user", "content": prompt}],                temperature=0.7,                max_tokens=200            )                        questions = response.choices[0].message.content.strip().split('\n')            questions = [q.strip() for q in questions if q.strip()]                        # 清理问题格式            clean_questions = []            for q in questions:                # 移除序号                q = q.strip()                q = q.lstrip('0123456789.-) ')                if len(q) > 5 and q.endswith('?'):                    clean_questions.append(q)                        return clean_questions[:num_queries]                    except Exception as e:            print(f"生成问题失败: {e}")            return []        def query2doc_expansion(self, query: str) -> str:        """        Query2Doc扩展:根据查询生成假设的文档内容                Args:            query: 原始查询                    Returns:            扩展后的查询内容        """        prompt = f"""        基于以下查询,生成一个假设的、相关的文档片段,这个文档片段应该能够回答这个查询。        要求:        1. 内容要专业、准确        2. 包含相关的关键概念和术语        3. 长度控制在100-200字        4. 不要生成过于具体的数字或引用                查询:{query}                假设的相关文档内容:        """                try:            response = openai.ChatCompletion.create(                model="gpt-3.5-turbo",                messages=[{"role": "user", "content": prompt}],                temperature=0.6,                max_tokens=250            )                        pseudo_doc = response.choices[0].message.content.strip()            return pseudo_doc                    except Exception as e:            print(f"Query2Doc扩展失败: {e}")            return ""        def retrieve_with_expansion(self, query: str, top_k: int = 5, use_query2doc: bool = True) -> List[dict]:        """        使用查询和文档扩展进行检索                Args:            query: 原始查询            top_k: 返回结果数量            use_query2doc: 是否使用Query2Doc扩展                    Returns:            检索结果列表        """        print(f"原始查询: {query}")                # Query2Doc扩展        expanded_query = query        if use_query2doc:            pseudo_doc = self.query2doc_expansion(query)            if pseudo_doc:                expanded_query = f"{query}\n\n相关内容:{pseudo_doc}"                print(f"Query2Doc扩展: {pseudo_doc}")                # 检索扩展后的文档集合        query_vector = self.vectorizer.transform([expanded_query])        similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()                # 获取top_k结果        top_indices = similarities.argsort()[-top_k * 2:][::-1]  # 多取一些候选                results = []        for idx in top_indices:            if similarities[idx] > 0.1:  # 过滤低相似度结果                # 返回原始文档内容,而不是扩展后的                results.append({                    'document_id': idx,                    'content': self.knowledge_base[idx],  # 原始文档                    'expanded_content': self.expanded_docs[idx],  # 扩展文档                    'score': float(similarities[idx])                })                return results[:top_k]        def compare_methods(self, query: str, top_k: int = 3):        """        对比不同检索方法的效果                Args:            query: 查询文本            top_k: 返回结果数量        """        print(f"=== 检索方法对比 ===")        print(f"查询: {query}\n")                # 方法1:基础检索(无扩展)        print("【方法1:基础TF-IDF检索】")        basic_vectorizer = TfidfVectorizer(            tokenizer=lambda x: list(jieba.cut(x)),            lowercase=False        )        basic_vectors = basic_vectorizer.fit_transform(self.knowledge_base)        query_vec = basic_vectorizer.transform([query])        basic_similarities = cosine_similarity(query_vec, basic_vectors).flatten()        basic_top = basic_similarities.argsort()[-top_k:][::-1]                for i, idx in enumerate(basic_top, 1):            print(f"  {i}. (分数: {basic_similarities[idx]:.3f}) {self.knowledge_base[idx][:100]}...")                # 方法2:仅Doc2Query扩展        print(f"\n【方法2:Doc2Query扩展检索】")        doc2query_vec = self.vectorizer.transform([query])        doc2query_similarities = cosine_similarity(doc2query_vec, self.doc_vectors).flatten()        doc2query_top = doc2query_similarities.argsort()[-top_k:][::-1]                for i, idx in enumerate(doc2query_top, 1):            print(f"  {i}. (分数: {doc2query_similarities[idx]:.3f}) {self.knowledge_base[idx][:100]}...")                # 方法3:Query2Doc + Doc2Query双向扩展        print(f"\n【方法3:Query2Doc + Doc2Query双向扩展】")        results = self.retrieve_with_expansion(query, top_k, use_query2doc=True)                for i, result in enumerate(results, 1):            print(f"  {i}. (分数: {result['score']:.3f}) {result['content'][:100]}...")# 使用示例和完整演示def demo_expansion_retrieval():    """查询-文档双向扩展检索演示"""        # 构建更丰富的知识库    knowledge_base = [        "深度学习模型训练效率优化包括混合精度训练、梯度累积、模型并行等技术。混合精度使用FP16计算可以减少显存占用并加速训练。",        "AdamW优化器结合了Adam的自适应学习率和权重衰减正则化,在Transformer模型训练中表现优异,特别适合大模型训练。",        "分布式训练通过数据并行将批次分配到多个GPU,模型并行将大模型拆分到多个设备,可以显著缩短训练时间。",        "学习率调度策略对训练效果至关重要。常用的有余弦退火、线性衰减、warmup等,需要根据具体任务调整。",        "梯度累积技术允许在显存有限的情况下模拟大批量训练,通过累积多个小批次的梯度再进行参数更新。",        "Transformer架构使用自注意力机制处理序列数据,已成为NLP和多模态任务的主流架构,GPT和BERT都基于此架构。",        "卷积神经网络通过卷积层提取图像特征,在计算机视觉任务中应用广泛,包括图像分类、目标检测、语义分割等。",        "深度学习框架PyTorch提供动态计算图和直观的API,适合研究和原型开发,而TensorFlow更适合生产部署。",        "正则化技术如Dropout、BatchNorm、LayerNorm可以防止过拟合并提升模型泛化能力,是深度学习中的重要技术。",        "预训练模型如BERT、GPT通过大规模无监督学习获得通用语言表示,然后在下游任务上微调,大幅提升了NLP任务效果。"    ]        # 初始化检索器(需要OpenAI API Key)    api_key = "your-openai-api-key"  # 请替换为实际的API Key    retriever = QueryDocExpansionRetriever(api_key, knowledge_base)        # 测试查询    test_queries = [        "如何加速深度学习模型训练?",        "什么是混合精度训练?",        "Transformer架构有什么特点?"    ]        for query in test_queries:        print(f"\n{'='*60}")        retriever.compare_methods(query, top_k=3)        print(f"{'='*60}")if __name__ == "__main__":    demo_expansion_retrieval()

双向扩展的优势

Doc2Query的威力:

  • 原始文档:"混合精度训练使用FP16计算减少显存占用"

  • 生成问题:"什么是混合精度训练?"、"如何减少训练显存占用?"

  • 效果:当用户问相关问题时,更容易匹配到这个文档

Query2Doc的作用:

  • 原始查询:"训练加速"

  • 生成伪文档:"训练加速可以通过优化器选择、批量大小调整、并行计算等方式实现..."

  • 效果:丰富了查询的语义表示,提高匹配精度


策略组合:构建高效RAG系统

在实际应用中,我们往往需要组合多种策略。下面是一个完整的组合示例:

class AdvancedRAGRetriever:    def __init__(self, api_key: str, knowledge_base: List[str]):        """        高级RAG检索器,组合多种优化策略        """        self.multi_query_retriever = MultiQueryRetriever(api_key, knowledge_base)        self.reranking_retriever = RerankingRetriever(knowledge_base)        self.expansion_retriever = QueryDocExpansionRetriever(api_key, knowledge_base)        def advanced_retrieve(self, query: str, top_k: int = 5) -> List[dict]:        """        高级检索:组合多种策略                策略组合:        1. 多查询扩展增加召回多样性        2. Query2Doc丰富查询语义        3. Doc2Query增强文档匹配能力        4. 重排序精确排序        """        print(f"=== 高级RAG检索 ===")        print(f"查询: {query}")                # 步骤1: 多查询检索        print("\n步骤1: 多查询检索...")        multi_results = self.multi_query_retriever.retrieve(query, top_k * 2)                # 步骤2: 扩展检索        print("\n步骤2: 双向扩展检索...")        expansion_results = self.expansion_retriever.retrieve_with_expansion(query, top_k * 2)                # 步骤3: 合并候选结果        print("\n步骤3: 合并候选结果...")        all_candidates = {}                # 添加多查询结果        for result in multi_results:            doc_id = result['document_id']            all_candidates[doc_id] = {                'content': result['content'],                'multi_query_score': result['score'],                'expansion_score': 0.0            }                # 添加扩展检索结果        for result in expansion_results:            doc_id = result['document_id']            if doc_id in all_candidates:                all_candidates[doc_id]['expansion_score'] = result['score']            else:                all_candidates[doc_id] = {                    'content': result['content'],                    'multi_query_score': 0.0,                    'expansion_score': result['score']                }                # 步骤4: 重排序        print("\n步骤4: 重排序最终结果...")        candidates_list = [            (doc_id, max(scores['multi_query_score'], scores['expansion_score']), scores['content'])            for doc_id, scores in all_candidates.items()        ]                # 使用重排序器        reranked_results = self.reranking_retriever.rerank_candidates(query, candidates_list)                # 添加组合分数        for result in reranked_results:            doc_id = result['document_id']            if doc_id in all_candidates:                result['multi_query_score'] = all_candidates[doc_id]['multi_query_score']                result['expansion_score'] = all_candidates[doc_id]['expansion_score']                # 组合最终分数                result['combined_score'] = (                    result['rerank_score'] * 0.5 +                    result['multi_query_score'] * 0.3 +                    result['expansion_score'] * 0.2                )                # 按组合分数重新排序        reranked_results.sort(key=lambda x: x['combined_score'], reverse=True)                return reranked_results[:top_k]# 完整演示def demo_advanced_rag():    """高级RAG系统演示"""    knowledge_base = [        # ... (使用前面定义的知识库)    ]        api_key = "your-openai-api-key"    advanced_retriever = AdvancedRAGRetriever(api_key, knowledge_base)        query = "如何优化大模型训练效率?"    results = advanced_retriever.advanced_retrieve(query, top_k=3)        print(f"\n=== 最终检索结果 ===")    for i, result in enumerate(results, 1):        print(f"\n第{i}个结果:")        print(f"  组合分数: {result['combined_score']:.4f}")        print(f"  重排序分数: {result['rerank_score']:.4f}")        print(f"  多查询分数: {result['multi_query_score']:.4f}")        print(f"  扩展检索分数: {result['expansion_score']:.4f}")        print(f"  内容: {result['content']}")

实践建议与注意事项

1. 性能优化

  • 缓存策略: Doc2Query扩展比较耗时,一定要缓存结果

  • 批处理: 重排序时可以批量处理多个查询-文档对

  • 索引优化: 使用Faiss等专业向量数据库提升检索速度

2. 参数调优

  • 检索数量: 初步检索可以多召回一些(如top_k的3-4倍)

  • 重排序模型: BGE-Reranker-Base是个好选择,中英文效果都不错

  • 组合权重: 不同策略的权重需要根据具体场景调整

3. 成本控制

  • API调用: Query2Doc和多查询扩展都需要调用LLM,注意成本

  • 模型选择: 可以用较小的模型(如gpt-3.5-turbo)进行扩展生成

4. 效果评估

def evaluate_retrieval_quality(retriever, test_queries, ground_truth):    """    评估检索质量        Args:        retriever: 检索器实例        test_queries: 测试查询列表        ground_truth: 每个查询的相关文档ID列表            Returns:        评估指标字典    """    total_precision = 0    total_recall = 0    total_ndcg = 0        for i, query in enumerate(test_queries):        results = retriever.retrieve(query, top_k=5)        retrieved_ids = [r['document_id'] for r in results]        relevant_ids = set(ground_truth[i])                # 计算Precision@5        precision = len(set(retrieved_ids) & relevant_ids) / len(retrieved_ids)        total_precision += precision                # 计算Recall@5        recall = len(set(retrieved_ids) & relevant_ids) / len(relevant_ids)        total_recall += recall        return {        'precision': total_precision / len(test_queries),        'recall': total_recall / len(test_queries)    }

总结

RAG召回优化的核心在于多策略组合:

  1. 多查询扩展解决表达多样性问题

  2. 重排序解决相关性精确排序问题

  3. 双向扩展解决语义匹配不充分问题

这三个策略相互补充,形成了一个完整的召回优化体系。在实际应用中,你可以根据业务场景选择合适的组合:

  • 追求速度:多查询 + 简单重排序

  • 追求精度:全策略组合 + 深度重排序模型

  • 平衡性能:扩展检索 + 轻量重排序

记住,最好的RAG系统不是技术最复杂的,而是最适合你的业务场景的。先用简单方法建立baseline,再逐步优化,这样能更好地理解每个策略的实际效果。

53AI,企业落地大模型首选服务商

产品:场景落地咨询+大模型应用平台+行业解决方案

承诺:免费POC验证,效果达标后再合作。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

添加专属顾问

回到顶部

加载中...

扫码咨询