微信扫码
添加专属顾问
我要投稿
掌握Embedding模型微调,提升NLP任务性能。 核心内容: 1. Embedding模型微调的核心概念与数据依赖 2. 构建高质量微调训练集与评估集的实用步骤 3. 正负样本构建、数据增强等关键技术的应用
本文主要面向希望在特定领域或任务中提升Embedding模型表现的初学者。希望读完之后,能帮助大家:
? 目录
Embedding模型通过将文本映射为向量,在语义理解相关的NLP任务中扮演着基础性角色。然而,通用的预训练模型在特定领域往往因缺乏专业知识而表现不佳。
为了解决这一问题,模型微调(Fine-tuning),它通过在领域数据上进一步训练,使模型适应特定语义特性。高质量的数据集是微调成功的核心。本文旨在探讨如何基于现有数据,高效构建用于Embedding模型微调的训练与评估数据集,以提升模型在特定场景下的表现。
在咱们具体聊怎么弄数据集之前,有几个核心的概念得先弄明白。
原理阐述:Embedding模型微调是指在一个已经经过大规模通用语料预训练的Embedding模型基础上,利用特定任务或特定领域的数据集进行进一步的训练,从而调整模型参数,使其能更好地捕捉和表达目标数据中的语义信息。
核心目标:
微调的核心目标可以从度量学习 (Metric Learning) 的视角来理解。它旨在优化文本在向量空间中的表示,使得在语义上相似或相关的文本对(例如,一个问题和它的高质量答案)在向量空间中的距离更近,而语义上不相似或不相关的文本对在向量空间中的距离更远。这可以通过一个损失函数来形式化,例如三元组损失 (Triplet Loss):
其中 是锚点(如查询)的向量, 是正例的向量, 是负例的向量, 是向量 和 之间的距离度量(如欧氏距离或余弦距离的相反数), 是一个超参数,用于确保正例对和负例对之间的间隔。目标是最小化这个损失。
这种优化使得模型在执行如相似度计算、信息检索等任务时更为精准。
专门为微调构建数据集主要基于以下两点考虑:
第一,弥合领域差异。通用预训练模型学习到的是广泛的语言知识,而特定应用场景往往有其独特的语言模式、术语体系和知识结构。例如,金融领域的文本与日常对话的语言风格和核心词汇差异巨大。微调数据集承载了这种特定领域的信息,帮助模型弥合通用知识与特定需求之间的差距。
第二,数据驱动的学习范式。模型的学习过程是数据驱动的。通过向模型展示精心构造的、能够反映目标任务需求的样本(例如,哪些文本应被视为相关,哪些不相关),模型能够逐步学习到在该特定场景下区分文本语义的有效模式。
一个典型的用于对比学习或度量学习的Embedding模型微调数据集,核心通常包含以下几种要素:
查询 (Query):
这是信息需求的表示,可以是一个用户提出的问题、一个检索关键词,或任何需要模型为其寻找相关信息的文本。
正例 (Positive Sample, pos
):
这是与给定查询(Query)高度相关或语义一致的文本。在训练过程中,模型会学习拉近查询向量与正例向量之间的距离。
负例 (Negative Sample, neg
):
这是与给定查询(Query)不相关、相关性低或语义不一致的文本。在训练过程中,模型会学习推远查询向量与负例向量之间的距离。负样本的质量和选择策略对模型学习区分细微语义差异至关重要。理解这些基本构成及其在模型训练中的作用,是后续高效构建数据集的基础。
接下来,我们将结合一个具体的案例场景(financial-qa-10K
数据集处理示例),分步骤阐述如何构建高质量的微调训练数据集。
微调的首要前提是拥有数据。需要评估并选择已有的、能够反映目标应用场景的数据资源。这可能包括用户行为日志、已有的问答对、文档库、知识库内容等。以金融问答为例,financial-qa-10K
这样的数据集包含了金融领域的问题、对应的答案以及答案所在的上下文,是非常适合用于微调的数据源。
选定数据源后,需要对数据进行初步理解。
这包括分析其原始数据结构,例如包含哪些字段,每个字段的含义是什么。
同时,进行必要的初步数据清洗,如去除无效字符、处理缺失值、统一文本编码等,确保数据质量。
在financial-qa-10K
的例子中,原始数据包含'question', 'answer', 'context', 'ticker', 'filing'等列,我们需要理解这些列如何服务于我们的微调目标。
其中的数据数据格式说明:
原始数据格式中各字段含义如下:
不同的微调框架或模型可能对输入数据格式有特定要求。一个常见且有效的结构是JSON Lines格式,其中每一行是一个JSON对象,代表一个训练样本。该对象通常包含查询、正例、负例等字段,例如:
{
"query": str, // 查询文本
"pos": List[str], // 正样本列表
"neg": List[str], // 负样本列表
"pos_scores": List[int], // 正样本得分列表
"neg_scores": List[int], // 负样本得分列表
"prompt": str, // 提示信息
"type": str // 数据类型
}
这里面,query
就是查询的句子,pos
是一个或者好几个正向的文本列表,neg
呢,也是一个或者好几个负向的文本列表。
要是用知识蒸馏的话,pos_scores
和 neg_scores
就可能用得上,它们代表了对应样本的得分。如果不用知识蒸馏的话,就不需要用上这里的两个参数。
prompt
就是给模型的一个提示,告诉它怎么处理这个查询。那个 type
字段。
下面是一个具体的单条训练数据样本示例:
{
"query": "什么是市盈率它如何帮助投资者评估股票价值",
"pos": [
"市盈率(Price-to-Earnings Ratio, P/E Ratio)是衡量股票价格相对于每股收益的指标。计算公式为:市盈率 = 当前股价 / 每股收益(EPS)。它反映了投资者愿意为每一元盈利支付多少价格。",
"投资者通常使用市盈率来判断股票估值是否合理。较低的市盈率可能意味着股票被低估,而较高的市盈率则可能表示股票被高估或市场预期其未来盈利高速增长。然而,比较市盈率时应考虑行业特性和公司成长阶段。"
],
"neg": [
"市净率(Price-to-Book Ratio, P/B Ratio)是股价与每股净资产的比率,常用于评估银行、保险等资产密集型公司的价值。",
"股息收益率(Dividend Yield Ratio)是指公司年度总派息额与当前市价的比率,是衡量股票投资回报的指标之一。",
"技术分析主要关注股票价格和交易量的历史数据,通过图表模式预测未来价格走势,与基本面分析方法不同。"
],
"prompt": "Represent this sentence for searching relevant documents: ",
"type": "normal"
}
在这个例子中:
"query"
是用户提出的关于市盈率的问题。"pos"
列表包含了两个与查询高度相关的正面回答/解释。"neg"
列表包含了一些金融领域相关但与“市盈率”查询不直接相关或错误的文本,例如关于市净率、股息收益率的定义,或完全无关的技术分析概念。"prompt"
是一个可选的指令,用于指导模型如何处理查询。"type"
是一个可选字段, 用于 bge-en-icl,包括 normal、symmetric_class、symmetric_clustering 等类型。根据定义好的目标格式,需要从原始数据中选取核心信息,并将其转换为目标结构中的对应字段。
在financial-qa-10K
的示例中,我们将原始的'question'列选作query
,将'context'(或'answer',取决于具体任务目标)选作pos
。这一步通常伴随着字段的重命名和数据类型的转换。
pos
) 的构建与优化所谓正样本(Positive Sample, pos
),指的是与我们关心的查询(Query)在语义上高度相关或匹配的文本。这些样本旨在教会模型理解“什么是相似的”或“哪些内容是针对查询的正确答案/相关文档”。 正样本的质量直接影响模型对“相关性”的理解。
构建正样本的核心原则是确保与对应的query
之间具有强相关性或语义一致性。这意味着需要从原始数据中准确地抽取或匹配那些真正能够回答查询、或与查询语义高度一致的内容作为正样本。
正样本的文本粒度(是选择一个完整的句子、一个段落还是整个文档)需要根据具体的应用场景和模型能力来确定。如果上下文对理解至关重要,那么选择包含更完整上下文的段落可能比单个句子更优。
在抽取正样本时,应尽量减少其中包含的无关信息或噪声。一个“干净”的正样本能让模型更高效地学习到核心的语义关联,避免受到无关文本片段的干扰。
通常将正样本pos
处理成列表形式(List[str]
)。这样做一方面可以支持一个查询对应多个高质量正例的场景,另一方面也使得数据处理流程更为统一,即便当前只有一个正例,也用列表包装。
neg
) 构建策略:提升模型辨识力相应地,负样本(Negative Sample, neg
)则是指与查询(Query)不相关、相关性低或语义不一致的文本。它们的作用是帮助模型学会区分“什么是不相似的”或“哪些内容不是查询想要的”。 负样本在Embedding模型微调中扮演着至关重要的角色,它们帮助模型学习区分看上去相似但实际不相关的文本,从而塑造有效的决策边界,防止模型将所有内容都视为相似(即模型坍塌)。所谓模型坍塌(Model Collapse),通常指的是在训练过程中,模型学习到的所有或大部分文本的向量表示都变得非常相似,失去了区分度,导致模型无法有效识别不同文本之间的语义差异。
高质量的负样本能够显著提升模型的细粒度语义辨识能力。它们迫使模型不仅仅是学习"什么是相关的",更要学习"什么是不相关的,以及为什么不相关"。
构建负样本的技术路径多种多样,每种方法各有特点和适用场景。以下我们将详细介绍几种常见且有效的负样本构建策略。
原理: 在训练过程中,对于当前处理批次(batch)内的每一个查询:
让我们通过一个具体的例子来理解批内负样本的构建过程。假设我们有一个批次,包含4个训练样本:
batch = [
{
"query": "什么是市盈率",
"pos": ["市盈率是衡量股票价格相对于每股收益的指标,计算公式为股票价格除以每股收益。"]
},
{
"query": "比特币是如何工作的",
"pos": ["比特币是一种基于区块链技术的加密货币,通过挖矿的方式生成,交易被记录在公共分布式账本上。"]
},
{
"query": "如何计算复利",
"pos": ["复利计算公式为A=P(1+r)^n,其中A为最终金额,P为本金,r为利率,n为时间周期。"]
},
{
"query": "什么是通货膨胀",
"pos": ["通货膨胀是指一般物价水平持续上涨,导致货币购买力下降的经济现象。"]
}
]
非对称任务场景(如问答)下的批内负样本构建:
对于每个查询,我们将批次中其他样本的正例作为该查询的负例:
# 为第一个查询"什么是市盈率"构建批内负样本
query_1 = batch[0]["query"]
pos_1 = batch[0]["pos"][0]
neg_1 = [batch[1]["pos"][0], batch[2]["pos"][0], batch[3]["pos"][0]]
# 最终第一个查询的训练数据结构
sample_1 = {
"query": "什么是市盈率",
"pos": ["市盈率是衡量股票价格相对于每股收益的指标,计算公式为股票价格除以每股收益。"],
"neg": [
"比特币是一种基于区块链技术的加密货币,通过挖矿的方式生成,交易被记录在公共分布式账本上。",
"复利计算公式为A=P(1+r)^n,其中A为最终金额,P为本金,r为利率,n为时间周期。",
"通货膨胀是指一般物价水平持续上涨,导致货币购买力下降的经济现象。"
]
}
对称任务场景(如语义相似度匹配)下的批内负样本构建:
在对称任务中,不仅其他样本的正例可以作为负例,其他样本的查询本身也可以作为负例:
# 为第一个查询"什么是市盈率"构建批内负样本
query_1 = batch[0]["query"]
pos_1 = batch[0]["pos"][0]
neg_1 = [
# 其他样本的查询
batch[1]["query"],
batch[2]["query"],
batch[3]["query"],
# 其他样本的正例
batch[1]["pos"][0],
batch[2]["pos"][0],
batch[3]["pos"][0]
]
# 最终第一个查询的训练数据结构
sample_1 = {
"query": "什么是市盈率",
"pos": ["市盈率是衡量股票价格相对于每股收益的指标,计算公式为股票价格除以每股收益。"],
"neg": [
# 其他查询作为负例
"比特币是如何工作的",
"如何计算复利",
"什么是通货膨胀",
# 其他正例作为负例
"比特币是一种基于区块链技术的加密货币,通过挖矿的方式生成,交易被记录在公共分布式账本上。",
"复利计算公式为A=P(1+r)^n,其中A为最终金额,P为本金,r为利率,n为时间周期。",
"通货膨胀是指一般物价水平持续上涨,导致货币购买力下降的经济现象。"
]
}
批内负样本的实现流程:
在实际训练中,批内负样本通常在数据加载器或训练循环中动态构建,而不是预先准备好。下面是一个简化的实现流程:
def construct_in_batch_negatives(batch, symmetric=False):
"""
为批次中的每个样本构建批内负样本
Args:
batch: 包含多个训练样本的批次
symmetric: 是否为对称任务
Returns:
包含批内负样本的增强批次
"""
enhanced_batch = []
for i, sample in enumerate(batch):
query = sample["query"]
pos = sample["pos"]
neg = []
# 从批次中其他样本收集负例
for j, other_sample in enumerate(batch):
if i != j: # 排除自身
# 对于对称任务,其他样本的查询也可作为负例
if symmetric:
neg.append(other_sample["query"])
# 其他样本的正例作为负例
neg.extend(other_sample["pos"])
# 构建增强样本
enhanced_sample = {
"query": query,
"pos": pos,
"neg": neg
}
enhanced_batch.append(enhanced_sample)
return enhanced_batch
批内负样本的优势在于它们是顺手添加过去的,不需要额外的数据准备,同时能提供比随机负样本更有挑战性的反例,因为它们来自同一批次,具有一定的相关性。不过,显而易见的是它们的质量和难度可能不如专门设计的难负样本。
难负样本对模型提出了更精细的辨识挑战,迫使模型学习更细微的语义差别,从而显著提升模型在真实应用场景中的准确性和鲁棒性。
定义: 指那些在语义上或文本表征上与查询具有较高相似度,容易被模型误判为正例,但实际上与查询不相关或相关性较低的样本。
挖掘技术:
让我们通过一个具体示例来说明如何使用BM25和预训练Embedding模型结合的方式挖掘难负样本:
from rank_bm25 import BM25Okapi
import numpy as np
from sentence_transformers import SentenceTransformer
import jieba # 添加jieba导入用于中文分词
# 1. 准备语料库
corpus = [
"市盈率是衡量股票价格相对于每股收益的指标,计算公式为股票价格除以每股收益。",
"市净率是股价与每股净资产的比率,常用于评估银行等资产密集型公司价值。",
"股息收益率是公司年度总派息额与股票现价之比,衡量投资回报的指标。",
"市销率是股票价格与每股销售收入的比值,适用于评估尚未盈利的成长型公司。",
"企业价值倍数是企业价值与EBITDA的比率,考虑了公司债务水平的估值指标。",
"现金流折现模型通过预测未来现金流并折现至今来评估公司内在价值。",
"技术分析主要关注股票价格和交易量的历史数据,预测未来趋势。",
"基本面分析关注公司财务状况、管理层质量和市场地位等因素。",
"投资组合理论主张通过资产多样化来分散风险,优化风险回报比。",
"被动投资策略通过购买指数基金或ETF来追踪特定市场指数表现。"
]
# 查询和已知的正例
query = "什么是市盈率如何使用它评估股票价值"
true_positive = corpus[0] # 第一条关于市盈率的文本是真正的正例
# 2. 使用jieba分词进行BM25检索(稀疏检索阶段)
tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus] # 对corpus进行分词
tokenized_query = list(jieba.cut(query)) # 对query进行分词
bm25 = BM25Okapi(tokenized_corpus)
bm25_scores = bm25.get_scores(tokenized_query)
# 获取BM25排序后的文档索引(按相关性从高到低排序)
sorted_indices = np.argsort(bm25_scores)[::-1] # 降序排序
print("BM25检索结果排序:")
for idx in sorted_indices[:5]: # 取前5名
print(f"文档{idx} (得分: {bm25_scores[idx]:.4f}): {corpus[idx][:50]}...")
# 3. 使用Embedding模型进行重排序(稠密检索阶段)
# 加载预训练的Embedding模型
model = SentenceTransformer(r'C:\Users\k\Desktop\BaiduSyncdisk\baidu_sync_documents\hf_models\bge-m3', trust_remote_code=True) # 示例模型
# 计算查询和所有文档的嵌入向量
query_embedding = model.encode([query])[0]
corpus_embeddings = model.encode(corpus)
# 计算余弦相似度
from sklearn.metrics.pairwise import cosine_similarity
similarities = cosine_similarity([query_embedding], corpus_embeddings)[0]
# 获取嵌入模型排序后的文档索引
sorted_indices_emb = np.argsort(similarities)[::-1] # 降序排序
print("\n嵌入模型重排序结果:")
for idx in sorted_indices_emb[:5]: # 取前5名
print(f"文档{idx} (相似度: {similarities[idx]:.4f}): {corpus[idx][:50]}...")
# 4. 识别难负样本(高相似度但实际不相关的文档)
# 去除真正的正例
hard_negatives_candidates = [idx for idx in sorted_indices_emb if corpus[idx] != true_positive]
# 从候选中选择前N个作为难负样本
hard_negatives = [corpus[idx] for idx in hard_negatives_candidates[:2]] # 取前2个作为难负样本
# 5. 最终的训练样本结构
training_sample = {
"query": query,
"pos": [true_positive],
"neg": hard_negatives
}
print("\n最终构建的包含难负样本的训练数据:")
print(f"查询: {training_sample['query']}")
print(f"正例: {training_sample['pos'][0]}")
print("难负样本:")
for i, neg in enumerate(training_sample['neg']):
print(f" {i+1}. {neg}")
实际执行结果:
BM25检索结果排序:
文档0 (得分: 1.7982): 市盈率是衡量股票价格相对于每股收益的指标,计算公式为股票价格除...
文档5 (得分: 0.7829): 现金流折现模型通过预测未来现金流并折现至今来评估公司内在价值...
文档3 (得分: 0.7425): 市销率是股票价格与每股销售收入的比值,适用于评估尚未盈...
文档1 (得分: 0.7238): 市净率是股价与每股净资产的比率,常用于评估银行等资产...
文档9 (得分: 0.0000): 被动投资策略通过购买指数基金或ETF来追踪特定市场指数表现...
嵌入模型重排序结果:
文档0 (相似度: 0.8059): 市盈率是衡量股票价格相对于每股收益的指标,计算公式为股票价格除...
文档1 (相似度: 0.7493): 市净率是股价与每股净资产的比率,常用于评估银行等资产...
文档3 (相似度: 0.7044): 市销率是股票价格与每股销售收入的比值,适用于评估尚未盈...
文档2 (相似度: 0.5833): 股息收益率是公司年度总派息额与股票现价之比,衡量投资回报...
文档4 (相似度: 0.5568): 企业价值倍数是企业价值与EBITDA的比率,考虑了公司...
最终构建的包含难负样本的训练数据:
查询: 什么是市盈率如何使用它评估股票价值
正例: 市盈率是衡量股票价格相对于每股收益的指标,计算公式为股票价格除以每股收益。
难负样本:
1. 市净率是股价与每股净资产的比率,常用于评估银行等资产密集型公司价值。
2. 市销率是股票价格与每股销售收入的比值,适用于评估尚未盈利的成长型公司。
从这个实际运行结果中可以看出,难负样本往往是那些在字面上与查询有某种相似性(如都是讨论金融估值指标),甚至可能共享一些关键词("市X率"、"股票价格"),但实际上并不是查询真正想要的信息。典型的如"市净率"、"市销率"等指标,它们与"市盈率"形式类似,都是股票估值指标,但概念和使用场景不同。这种高相似度但实际不相关的文本正是最容易使模型产生混淆的地方,因此在训练中使用这类难负样本可以有效提升模型的细粒度语义区分能力。
负样本的数量通常会多于正样本(例如,每个正样本配备多个负样本)。同时,保证负样本的多样性也非常重要,即负样本应覆盖不同类型的不相关情况,而不仅仅是单一类型的易区分样本。需要在数量与多样性之间进行权衡,以达到最佳的训练效果和效率。
对比学习是一种表示学习方法,它通过构建正负样本对,让模型学习将语义相似的样本在表示空间中拉近,将不相似的样本推远。在Embedding微调中,这类方法特别有效,因为它们直接优化了向量空间中文本表示的分布,提高了语义相似性的准确性。
对比学习的核心是InfoNCE损失函数,它通过最大化正样本对的相似度,同时最小化负样本对的相似度来优化模型:
其中 是查询向量, 是正样本向量, 是负样本向量, 是相似度函数, 是温度参数。
关键参数说明:
通过最小化InfoNCE损失,模型可以:
SimCSE技术
DiffCSE技术
import torch
from transformers import AutoModel, AutoTokenizer
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
"""
SimCSE (Simple Contrastive Sentence Embedding) 是一种通过对比学习改进句子嵌入的方法。
核心思想:
1. 无监督学习:利用同一句子通过不同dropout mask生成的两个表示作为正样本对
2. 同一批次中的其他句子作为负样本
3. 训练目标是使正样本对的表示相似,而与负样本的表示不相似
"""
# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model = AutoModel.from_pretrained("bert-base-chinese")
# 准备输入句子 - 有意设计了语义相似和不相似的句子对
sentences = [
"市盈率是衡量股票价格相对于每股收益的指标。", # 金融指标相关
"P/E比率用于评估股票估值的合理性。", # 与第一句语义相似
"通货膨胀是物价持续上涨的经济现象。", # 经济现象,与前两句相关但不同概念
"每股收益是公司净利润除以流通股数。" # 与第一句相关,都涉及每股收益
]
def get_sentence_embeddings(model, tokenizer, sentences, use_simcse=False):
"""获取句子嵌入,可选是否使用SimCSE方法"""
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
if use_simcse:
# SimCSE方法:对每个句子使用不同dropout生成两个表示,然后取平均
model.train() # 激活dropout
# 运行两次获取不同的表示
outputs1 = model(**inputs, output_hidden_states=True)
outputs2 = model(**inputs, output_hidden_states=True)
# 取CLS token
embeddings1 = outputs1.last_hidden_state[:, 0]
embeddings2 = outputs2.last_hidden_state[:, 0]
# 取平均作为最终表示
embeddings = (embeddings1 + embeddings2) else:
# 传统方法:直接获取句子表示
model.eval() # 关闭dropout
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
embeddings = outputs.last_hidden_state[:, 0]
return embeddings
# 演示SimCSE训练过程
def demonstrate_simcse_training():
print("=== SimCSE训练过程演示 ===")
# 将句子转换为模型输入,同一批次输入两次(以使用不同的dropout mask)
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
inputs_repeated = {k: torch.cat([v, v]) for k, v in inputs.items()}
# 前向传播,获取CLS表示
model.train() # 确保dropout被激活
outputs = model(**inputs_repeated, output_hidden_states=True)
last_hidden = outputs.last_hidden_state
cls_embeds = last_hidden[:, 0] # 取CLS token表示
# 分开原始样本和重复样本的表示
batch_size = len(sentences)
z1, z2 = torch.split(cls_embeds, batch_size)
# 计算余弦相似度
cosine_sim = torch.nn.functional.cosine_similarity(z1.unsqueeze(1), z2.unsqueeze(0), dim=2)
# 计算对比学习损失(InfoNCE/NT-Xent)
# 对角线上的元素代表相同句子的两个不同表示之间的相似度(正样本)
# 非对角线元素代表不同句子之间的相似度(负样本)
# 训练目标是最大化对角线元素的值
labels = torch.arange(batch_size).to(cosine_sim.device)
temperature = 0.05 # 温度参数,控制分布的平滑程度
loss = torch.nn.CrossEntropyLoss()(cosine_sim temperature, labels)
print(f"SimCSE 对比损失:{loss.item():.4f}")
print("余弦相似度矩阵(训练中):")
print(cosine_sim.detach().numpy())
print("对角线元素(正样本对)平均相似度:", torch.mean(torch.diag(cosine_sim)).item())
print("非对角线元素(负样本对)平均相似度:",
(torch.sum(cosine_sim) - torch.sum(torch.diag(cosine_sim))) (batch_size * batch_size - batch_size))
# 比较使用SimCSE和不使用SimCSE的句子嵌入效果
def compare_embeddings():
print("\n=== 比较传统嵌入与SimCSE嵌入效果 ===")
# 获取传统句子嵌入
traditional_embeddings = get_sentence_embeddings(model, tokenizer, sentences, use_simcse=False)
traditional_embeddings = traditional_embeddings.detach().numpy()
# 获取SimCSE增强的句子嵌入
simcse_embeddings = get_sentence_embeddings(model, tokenizer, sentences, use_simcse=True)
simcse_embeddings = simcse_embeddings.detach().numpy()
# 计算相似度矩阵
traditional_sim = cosine_similarity(traditional_embeddings)
simcse_sim = cosine_similarity(simcse_embeddings)
# 显示结果
print("传统方法的相似度矩阵:")
print(np.round(traditional_sim, 3))
print("\nSimCSE方法的相似度矩阵:")
print(np.round(simcse_sim, 3))
print("\n句子对的语义关系:")
for i in range(len(sentences)):
for j in range(i+1, len(sentences)):
print(f"句子{i+1}与句子{j+1}:")
print(f" - 传统相似度: {traditional_sim[i,j]:.3f}")
print(f" - SimCSE相似度: {simcse_sim[i,j]:.3f}")
print(f" - 句子{i+1}: {sentences[i]}")
print(f" - 句子{j+1}: {sentences[j]}")
print()
# 运行演示
if __name__ == "__main__":
demonstrate_simcse_training()
compare_embeddings()
输出:
=== SimCSE训练过程演示 ===
SimCSE 对比损失:0.1850
余弦相似度矩阵(训练中):
[[0.8296242 0.79532236 0.7019348 0.736534 ]
[0.7617904 0.9060811 0.683814 0.7251259 ]
[0.5913595 0.67814845 0.8314158 0.6420494 ]
[0.6655865 0.657691 0.528461 0.886128 ]]
对角线元素(正样本对)平均相似度: 0.8633122444152832
非对角线元素(负样本对)平均相似度: tensor(0.6807, grad_fn=<DivBackward0>)
=== 比较传统嵌入与SimCSE嵌入效果 ===
传统方法的相似度矩阵:
[[1. 0.882 0.774 0.826]
[0.882 1. 0.736 0.778]
[0.774 0.736 1. 0.661]
[0.826 0.778 0.661 1. ]]
SimCSE方法的相似度矩阵:
[[1. 0.835 0.747 0.82 ]
[0.835 1. 0.737 0.804]
[0.747 0.737 1. 0.676]
[0.82 0.804 0.676 1. ]]
句子对的语义关系:
句子1与句子2:
- 传统相似度: 0.882
- SimCSE相似度: 0.835
- 句子1: 市盈率是衡量股票价格相对于每股收益的指标。
- 句子2: P/E比率用于评估股票估值的合理性。
句子1与句子3:
- 传统相似度: 0.774
- SimCSE相似度: 0.747
- 句子1: 市盈率是衡量股票价格相对于每股收益的指标。
- 句子3: 通货膨胀是物价持续上涨的经济现象。
句子1与句子4:
- 传统相似度: 0.826
- SimCSE相似度: 0.820
- 句子1: 市盈率是衡量股票价格相对于每股收益的指标。
- 句子4: 每股收益是公司净利润除以流通股数。
句子2与句子3:
- 传统相似度: 0.736
- SimCSE相似度: 0.737
- 句子2: P/E比率用于评估股票估值的合理性。
- 句子3: 通货膨胀是物价持续上涨的经济现象。
句子2与句子4:
- 传统相似度: 0.778
- SimCSE相似度: 0.804
- 句子2: P/E比率用于评估股票估值的合理性。
- 句子4: 每股收益是公司净利润除以流通股数。
句子3与句子4:
- 传统相似度: 0.661
- SimCSE相似度: 0.676
- 句子3: 通货膨胀是物价持续上涨的经济现象。
- 句子4: 每股收益是公司净利润除以流通股数。
当已有的标注数据量有限时,数据增强是一种有效的技术手段,可以在不显著增加人工标注成本的前提下,扩充训练集规模,提升模型的泛化能力。
数据增强通过对现有训练样本进行一系列变换来生成新的、合理的训练样本。这有助于模型学习到对输入文本中更多样变化的鲁棒性,减少过拟合风险,尤其是在小样本场景下。
相比大语言模型生成式的数据增强方法,基于词汇和语法的简单增强技术实现起来更加轻量和高效,适合快速扩充训练数据。以下是几种常用的简单增强方法:
"市盈率是衡量股票价格的重要指标"
"市盈率是评估股票价格的关键指标"
"如何计算股票的市盈率"
"How to calculate the P/E ratio of stocks?"
"怎样计算股票的市盈率指标"
"市盈率反映股票估值水平"
"市盈率准确反映当前股票估值水平"
"投资者经常使用市盈率进行股票估值"
"进行股票估值时,投资者经常使用市盈率"
句子改写/复述(Paraphrasing):利用Qwen系列、GPT系列等预训练语言模型来生成与原始文本语义相同但表达不同的新句子。大语言模型的强大语言能力使生成的文本更自然、多样。
多样化查询生成(Query Generation):针对已有的正例文档(pos
),可以利用大语言模型生成多种不同问法的查询。例如,对于一段关于"市盈率定义和计算"的文本,LLM可以生成"什么是P/E Ratio?"、"如何计算股票的市盈率"、"市盈率指标有什么用"等多种表述。
多样化正例生成 (Positive Sample Generation):对于给定查询,大语言模型可以基于其理解,生成多个语义相关但表述不同的正例文本,帮助模型学习对同一概念的不同表达方式。
难负样本候选生成:通过精心设计的提示(Prompt Engineering),引导大语言模型生成与查询主题相关但在细节上有误,或属于同一领域但讨论不同子话题的文本,作为高质量难负样本的候选。
指令驱动的文本改写:大语言模型可以遵循具体指令(如"将专业内容简化为普通人能理解的语言"、"保留核心信息但使表达更简洁")进行文本改写,创造出不同风格和复杂度的训练样本。
prompt
(提示语) 的功能在某些微调框架中,可以为查询添加一个prompt
或指令前缀。如financial-qa-10K
示例中添加的instruction = "Represent this sentence for searching relevant passages: "
。这个prompt
可以指导模型如何理解和处理查询,例如指明这是一个用于段落检索的句子,或者这是一个需要总结的问题。它有助于模型在不同任务或意图下产生更合适的向量表示。在推理时,在flagembedding这个prompt
通常会作为query_instruction_for_retrieval
使用。
更进一步地,我们可以为不同的任务类型定义不同的提示语,并在训练和推理时根据任务类型选择相应的提示。例如,一个配置文件可能包含如下的提示语定义,注意,不是所有文件都是这个示例的样子的,比如在jina-clip-v2里面就只有一个,而不是三个:
{
"prompts": {
"retrieval.query": "Represent the query for retrieving evidence documents: ",
"retrieval.document": "Represent the document for retrieval: ",
"classification": "Classify the text: "
},
"default_prompt_name": "retrieval.document" // 举例:默认使用的prompt名称
}
在这个例子中:
"retrieval.query"
: 当处理一个用于检索相关文档的查询时,可以在查询文本前添加这个提示,引导模型生成适合检索的查询向量。"retrieval.document"
: 当处理文档用于构建检索库时,可以在文档文本前添加这个提示,引导模型生成适合被检索的文档向量。"classification"
: 当任务是文本分类时,可以使用这个提示,引导模型生成有助于区分文本类别的向量表示。"default_prompt_name"
: 可以指定一个默认的提示语,当没有显式指定任务类型或找不到对应提示时使用。通过这种方式,同一个基础Embedding模型可以通过不同的提示语来适应多种下游任务,增强了模型的通用性和灵活性。在训练时,可以根据样本的type
字段(如果存在)或任务本身的性质来选择合适的prompt
。最后的部分,我提示下,不是每个模型都能支持这些提示词的,有的模型默认不设置提示词,有的模型设置默认提示词是retrieval.document等等,具体可以看模型的配置文件。
最后,将构建好的完整数据集按照一定的比例(例如8:1:1或9:1)划分为训练集(Training Set)、验证集(Validation Set,可选)和测试集(Test Set)。训练集用于模型参数的更新和学习。验证集用于在训练过程中监控模型性能,进行超参数调优,防止过拟合。测试集在模型训练完成后,用于最终评估模型在未见过数据上的泛化能力。划分时应注意数据的随机性和分层抽样(如果类别不平衡),确保各个集合的数据分布尽可能一致。
本文粗略地探讨了如何基于已有数据,为Embedding模型微调构建训练与评估数据集。核心步骤包括:
?补充信息与参考文献
本文内容主要基于(financial-qa-10K
数据集处理示例,因为是bge的教程里面使用的数据集),并结合了在Embedding模型微调、数据集构建、数据增强等领域的通用学术认知与公开技术文献的梳理。
参考文献:
这是官方文档,可以看看怎么加载、处理和操作各种数据集。
Reimers, N., & Gurevych, I. (2019). Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks.Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP). (arXiv:1908.10084) 这是一篇经典的讲怎么有监督地学习句子表征的论文,对理解怎么用成对的样本来训练模型挺有帮助的。
Karpukhin, V., Oguz, B., Min, S., Lewis, P., Wu, L., Edunov, S., ... & Yih, W. T. (2020). Dense Passage Retrieval for Open-Domain Question Answering.Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP). (arXiv:2004.04906) 这篇论文讨论了在稠密检索里,选好正负样本有多重要,特别是在开放领域问答这种场景下。
Settles, B. (2009). Active Learning Literature Survey.University of Wisconsin-Madison Department of Computer Sciences Technical Report 1648.这是主动学习领域一篇挺经典的综述文章(虽然年份稍微早了点,但里头的核心想法还是很有参考价值的)。
Gao, L., Ma, X., Lin, J., & Callan, J. (2021). Complementing Lexical Retrieval with Semantic Residual Embedding.Proceedings of the 44th International ACM SIGIR Conference on Research and Development in Information Retrieval. (arXiv:2109.04770) 这篇论文关注的是怎么把稀疏检索和稠密检索结合起来,里面也提到了怎么弄高质量的训练数据。
Gao, T., Yao, X., & Chen, D. (2021). SimCSE: Simple Contrastive Learning of Sentence Embeddings.Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP). (arXiv:2104.08821) SimCSE 是一个影响力很大的工作,它提出了一个非常简洁但效果显著的句子表示对比学习方法。如果你对如何通过简单的Dropout机制来构造正样本对,并结合批内负样本进行训练。知乎上也有一些不错的解读,比如 Maple小七的这篇文章(https://zhuanlan.zhihu.com/p/368353121),可以作为辅助理解。
Yoon, S., Kim, G., & Park, K. (2021). SSMix: Saliency-based Span Mixup for Text Classification.Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021. (arXiv:2106.08062) 这篇论文提出了基于显著性的片段混合方法(SSMix),是一种针对文本数据的创新性数据增强技术,通过智能地替换文本中的片段来提高模型的鲁棒性和泛化能力。
Chuang, Y. S., Li, R., Torralba, A., & Jegelka, S. (2022). DiffCSE: Difference-based Contrastive Learning for Sentence Embeddings.arXiv preprint arXiv:2204.10298.DiffCSE 是在 SimCSE 基础上的一个有趣改进,它关注于如何通过区分正负样本对之间的差异来学习更好的句子嵌入。
Wang, Z., Wu, W., Wang, H., Wu, H., & Wang, W. (2020). CLEAR: Contrastive Learning for Sentence Representation.arXiv preprint arXiv:2012.15466.CLEAR 也是对比学习在句子表示领域的一个重要工作,它探讨了如何结合词级别的扰动和批内负采样来增强表示学习。
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2025-05-21
可以将任何符合OpenAPI规范的接口转 MCP Server吗?
2025-05-21
深度剖析 MCP SDK 最新版: Streamable HTTP 模式正式发布,为你实测揭秘
2025-05-21
AI驱动的软件:为何强大的CI/CD基础至关重要
2025-05-20
RAG与微调,大语言模型的“大脑升级”,该选哪条路?(小白科普)
2025-05-20
两万字记录微软Build2025主题演讲谈了什么:萨提亚展现好人缘,奥特曼、马斯克、黄仁勋轮番出镜,软件工程的本质在于好工具
2025-05-20
百度自研Agent专用版模型上线千帆!已针对企业级大模型应用进行指令调优
2025-05-19
大模型微调
2025-05-19
一文读懂微调技术Lora和SFT的异同
2025-02-04
2025-02-04
2024-09-18
2024-07-11
2024-07-09
2024-07-11
2024-07-26
2025-02-05
2025-01-27
2025-02-01