AI知识库 AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


AGI|三种高级RAG检索方法帮企业告别冗长文档!
浏览次数: 1559



基于LangChain实现的RAG检索方法

RAG(Retrieval-Augmented Generation)是当前企业问答应用中广泛使用的技术,它避免了重新训练大型私域知识模型的额外成本,同时保持了与后者相当的问答效果。然而,RAG面临的主要挑战是大型语言模型(LLM)的上下文长度限制,这限制了模型能够处理的企业私域知识文本的数量。


本篇文章将基于LangChain实现三种高级检索方法,帮您提高AI应用的准确度。



 作者 

孙泽文 | 神州数码

新手上路,多多指教~


Part1

前言


RAG(Retrieval-Augmented Generation)检索增强生成,是现如今基于企业私域知识的问答应用所使用的主流技术之一。相较于重新训练基于私域知识的大模型来说,RAG没有额外的预训练成本,且回答效果与之相当。


但在实际应用场景中,RAG所面临最大的问题是LLM的上下文长度限制。企业私域知识文本的数量十分庞大,不可能将其全部放在模型的prompt中,即使现在各类模型已经将上下文token从年初的2k、4k扩充到了128k、192k,但是这可能也就是一份合同、一份标书的长度。因此,如何减少传递给模型的内容数量且同时提高内容质量,是提升基于RAG的AI应用回答准确度的一个重要方法。


本篇文章将基于LangChain实现三种高级检索方法,句子窗口检索和自动合并检索旨在改善RAG流程的召回过程中存在的信息残缺的问题,而多路召回检索则保证了在多个文档中检索召回的准确性。



Part2

先验知识


● RAG简要流程



加载文档——切分划片——嵌入为向量表示——存入数据库



向量化问题——向量召回文档——合并放入Prompt——LLM生成答案



Part3

句子窗口检索


一、概念


在文档进行切片工作后,文档被分为若干个Langchain自定义的Document对象,该对象有两个属性,一是page_content即该切片的文本内容,二是meta_data即有关该切片的一些信息和可自定义封装进入的信息。


句子窗口检索方法,将每个切片的相邻切片的内容封装在切片的meta_data中。在检索和召回过程时,根据命中文档的meta_data可获得此段落的上下文信息,并将其封装进入命中文档的page_content中。组合完成的文档列表即可作为prompt交付给大模型生成。


在实际问答任务中,我们建议使用切片器将文档切分为较短的分片,或使用依据标点符号进行切分的切片器。保证整片文档拥有较细的颗粒度。同时在封装和召回阶段,适当扩大窗口大小,保证召回段落的完整性。


二、BERT


(1)元数据封装


def metadata_format(self, ordered_text, **kwargs):
        count = kwargs.get("split_count", 1)
        for i, document in enumerate(ordered_text):
            if i > 0:
                document.metadata['previous_page'] = ordered_text[i-count].page_content
            else:
                document.metadata['previous_page'] = ''

            if i < len(ordered_text) - 1:
                document.metadata['next_page'] = ordered_text[i+count].page_content
            else:
                document.metadata['next_page'] = ''
        return ordered_text


(2)数据重构


def search_and_format(self, databases, query, **kwargs):
        top_documents = []
        for db in databases:
            top_documents.append(db.similarity_search_with_score(query))
        docs = []
        for doc, _ in top_documents:
            doc.page_content = doc.metadata.get("previous_page") + doc.page_content + doc.metadata.get("next_page")
            docs.append(doc)
        return docs


(3)调用示例伪代码


# load document
......

#
 split
......
# use smartvision sdk to format
sentence_window_retrival = SentenceWindow()
formatted_documents = sentence_window_retrival.metadata_format(documents, split_count=2)

#
 embedding
......

#
 load in local vector db
......

#
 use smartvision sdk to do search and multiple recall
databases = [db]
query = "烟草专卖品的运输"
top_documents = sentence_window_retrival.search_and_format(databases, query)
print(top_documents)




Part4

自动合并检索


一、概念


自动合并检索方法,实现方法源自Llamaindex所封装的自动合并检索,但RAG全流程需要制定一套准确的规范,因此在用户文档完成读取和切片工作后,所得到的Langchain格式的Document对象需转化为Llamaindex定义的Document对象,便可通过Llamaindex的自定义算法自动划分整个切片列表的子节点和父节点,最后鉴于规范再重新转化为Langchain格式的Document对象,并将父节点信息、深度信息等封装进每个节点。


在检索阶段,召回最相关的若干个节点,遍历这些节点和附加信息,如若超过K个节点同时属于同一个节点(这里的K为用户自定义阈值,通常为一个节点所有子节点的半数)则执行合并该父节点下属所有子节点,即返回整个父节点内容。这使我们能够将可能不同的较小上下文合并到一个可能有助于综合的更大上下文中。


二、代码实现和调用


(1)元数据封装


def auto_merge_format(documents, **kwargs):
    if documents is None:
        raise ValueError('documents is required')
    formatted_documents = []
    doc_text = "\n\n".join([d.page_content for d in documents])
    docs = [Document(text=doc_text)]
    node_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=kwargs.get("pc_chunk_size", [2048, 512, 128]),chunk_overlap=kwargs.get("pc_chunk_overlap", 10))
    nodes = node_parser.get_nodes_from_documents(docs)
    leaf_nodes = get_leaf_nodes(nodes)
    root_nodes = get_root_nodes(nodes)
    middle_nodes = get_middle_node(nodes, leaf_nodes, root_nodes)
    root_context_dict = {}
    for root_node in nodes:
        root_context_dict[root_node.node_id] = root_node.get_content()

    for node in nodes:
        if node.parent_node:
            node_id = node.node_id
            root_node_id = node.parent_node.node_id
            root_node_content = root_context_dict.get(node.parent_node.node_id)
            root_node_child_count = 0
            for parent_node in root_nodes + middle_nodes:
                if parent_node.node_id == node.parent_node.node_id:
                    root_node_child_count = len(parent_node.child_nodes)
                    break
            depth = 2 if node in middle_nodes else 3
            child_count = len(node.child_nodes) if node.child_nodes is not None else 0
            document = langchain.schema.Document(page_content=node.get_content(),metadata={"node_id": node_id, "root_node_id": root_node_id, "root_node_content": root_node_content, "root_node_child_count": root_node_child_count, "depth": depth, "child_count": child_count})
            formatted_documents.append(document)
    return formatted_documents


(2)数据重构


def search_and_format(self, databases, query, **kwargs):
        top_documents = []
        for db in databases:
            top_document = db.similarity_search_with_score(query)
            top_documents.append(top_document)
        leaf_nodes = [doc for doc, _ in top_documents]
        return do_merge(leaf_nodes, **kwargs)


def group_nodes_by_depth(nodes, depth):
    return [node for node in nodes if node.metadata.get("depth") == depth]

def process_group(nodes, threshold):
    grouped_by_root_id = {}
    for node in nodes:
        root_id = node.metadata.get("root_node_id")
        grouped_by_root_id.setdefault(root_id, []).append(node)

    merge_context = []
    for group in grouped_by_root_id.values():
        node_count = len(group)
        child_count = group[0].metadata.get("root_node_child_count")
        if node_count / child_count >= threshold:
            merge_context.append(langchain.schema.Document(
                page_content=group[0].metadata.get("root_node_content")
            ))
        else:
            for document in group:
                merge_context.append(document)
    return merge_context

def do_merge(nodes, **kwargs) -> List[langchain.schema.Document]:
    threshold = kwargs.get("threshold", 0.5)
    leaf_nodes = group_nodes_by_depth(nodes, 3)
    middle_nodes = group_nodes_by_depth(nodes, 2)
    leaf_merge_context = process_group(leaf_nodes, threshold)
    middle_merge_context = process_group(middle_nodes, threshold)
    merge_content = leaf_merge_context + middle_merge_context
    return merge_content

def get_middle_node(nodes, leaf_nodes, root_nodes):
    middle_node = []
    for node in nodes:
        if node not in leaf_nodes and node not in root_nodes:
            middle_node.append(node)
    return middle_node


(3)调用示例伪代码


# load document
......

#
split
......

#
 use smartvision sdk to format
auto_merge_retrival = AutoMergeRetrieval()
formatted_documents = auto_merge_retrival.metadata_format(documents,
                                                pc_chunk_size=[1024, 128, 32],
                                              pc_chunk_overlap=4)
#embedding
......

#
load in local vector db
......

#
 use smartvision sdk to do search and multiple recall
top_documents = auto_merge_retrival.search_and_format(databases, query, threshold=0.5)
print(top_documents)




Part5

多路召回检索


一、概念


多路召回检索方法,在元数据封装环节并未做任何操作,而在检索阶段他允许用户上传多个数据集或不同类型的向量数据库作为检索对象,以适应用户私域知识库文档类型不同,文档数量庞大的问题。从多个数据源检索得到文档列表,而后通过rerank模型对文档与问题的相关性进行评分,筛选出大于一定分值的文档,组合成为prompt。


由此可见,多路召回检索在数据源广而杂的情况下,富有更好的效果。此外,rerank模型虽能进行再次的重排以提高准确性,但是在牺牲速度和效率的前提下进行的,因此需充分考虑这个问题。


二、代码实现


(1)元数据封装


def metadata_format(self, ordered_text, **kwargs):
        """
            默认rag,不做任何处理
            """

        return ordered_text


(2)数据重构


def search_and_format(self, databases, query, **kwargs):
        top_documents = []
        result_data = []
        for db in databases:
            top_document = db.similarity_search_with_score(query)
            top_documents.append(top_document)
        pairs = [[query, item.page_content] for item in top_documents]
        with torch.no_grad():
            rerank_tokenizer = AutoTokenizer.from_pretrained(RERANK_FILE_PATH)
            inputs = rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
            rerank_model = AutoModelForSequenceClassification.from_pretrained(RERANK_FILE_PATH)
            scores = rerank_model(**inputs, return_dict=True).logits.view(-1, ).float()
            for i, score in enumerate(scores):
                data = {
                    "text": top_documents[i].page_content,
                    "score": float(score)
                }
                result_data.append(data)
        return result_data



Part6

结语


本文提供的三种高级RAG检索方法,但仅改善了流程中检索召回环节的信息残缺问题,实质上RAG全流程均存在各种优化方法,但最有效的方法仍是改进或提供新的召回方式。


总结以上三种方法,均需要重点注意切片器的选用并控制切片大小,过大导致上下文长度过长,且有研究表明过长的prompt易使大模型忽略的中间部分的信息。过短则导致关键信息残缺,无法为大模型提供有效的上下文。因此开发者需根据文档类型和结构,谨慎选择并适当调节优化。

联系我们

售前咨询
186 6662 7370
产品演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询