微信扫码
添加专属顾问
我要投稿
探索知识图谱与多模态学习在药物预测领域的革命性应用。核心内容:1. 知识图谱在多模态学习中的核心作用及其在药物预测中的潜力2. KG4MM如何整合分子图像和文本描述,提升药物相互作用预测的准确性3. 利用图神经网络连接知识图谱与多模态数据,实现更可解释的预测结果
点击上方↗️「活水智能」,关注 + 星标?
作者:Eleventh Hour Enthusiast
编译:活水智能
知识图谱已成为表示不同实体间如何相互关联的重要工具。它们将信息编码为节点和边,从而直观地展示实体之间的关联。知识图谱用于多模态学习 (KG4MM) 在此基础上进一步发展,它利用知识图谱来指导从图像和文本中学习的过程。在 KG4MM 中,图谱扮演着一张"地图"的角色,明确标示出在训练过程中需要重点关注的每种数据类型的部分。这种引导有助于系统将注意力集中在图像中最相关的特征和文本中最具信息量的词语上。
在药物相互作用预测领域,KG4MM 提供了明显的优势。图结构在一个统一框架下整合了药物的分子图像和文本描述。这种统一视角有助于提高预测准确性,因为它同时捕捉了化学结构和药理背景信息。此外,知识图谱创建了一条从输入到输出的透明路径,使得理解模型为何得出特定预测变得更容易。
本文将解释 KG4MM 在实践中如何用于预测成对药物的相互作用。它将逐步讲解构建知识图谱以及整合分子和文本信息的步骤。通过具体示例,说明受知识图谱引导的多模态学习如何解决医学和医疗研究中的实际挑战。目标是展示 KG4MM 如何在实际药物相互作用任务中提高预测准确性和可解释性。
KG4MM 方法将知识图谱置于整个流程的核心。图谱指导每种数据类型的处理和理解方式。在药物相互作用示例中,图谱中的每个药物节点都关联两种信息。第一种是源自其 SMILES 分子式的分子图像,第二种是包含其类别、官能团及其他关键细节的文本描述。
KG4MM 的独特之处在于它利用图神经网络 (GNN) 来连接图结构和多模态数据。GNN 根据药物在图中的位置,确定其图像的哪些部分以及描述中的哪些词语最值得关注。图中的边——显示药物如何与蛋白质、疾病及其他药物相关联——帮助网络确定哪些视觉和文本特征最具重要性。通过这种方式,知识图谱不仅仅是提供额外背景信息;它更主动地引导模型关注最具信息价值的数据元素。
KG4MM 的优势在于能够结合模式识别神经网络与明确的关系图谱。GNN 在处理关联数据方面具有显著优势,因此模型能够基于现有的药物相互作用和生化特性知识进行构建。这种受引导的学习不仅提高了预测准确性,还通过突出显示影响每次预测的具体图谱连接,产生了清晰、可解释的结果。
该系统围绕一个核心知识图谱构建,该图谱整合了所有组成部分。这个图谱捕获了药物、蛋白质和疾病之间的有向关系,例如药物"结合到" (binds_to) 蛋白质、"抑制" (inhibits) 靶点或"治疗" (treats) 疾病。通过将图谱置于设计的核心,处理的每一步都依赖于其结构化的医学知识图谱。
为了准备数据,系统将每个药物节点与两种表示形式关联起来:一种是分子图像,另一种是文本描述。第一种是使用 RDKit 从其 SMILES 分子式生成的分子图像。第二种是概括药物类别、官能团及其他相关细节的文本描述。图像和文本均直接连接至图谱中对应的药物节点,从而确保视觉和语言特征与底层知识结构保持一致。
对图谱本身进行建模依赖于图卷积网络 (GCN)。这些网络从每个节点的位置及其在图中的连接中学习,创建编码药物、蛋白质和疾病如何相互关联的嵌入。同时,多模态编码器将图像和文本转换为特征向量:一个 ResNet 处理分子图像,而一个 BERT 模型转换文本描述。
最终,一个图注意力网络 (GAT) 融合图嵌入与视觉和文本特征。注意力机制利用图结构对来自各模态的最重要特征进行加权。组合后的表示然后馈入预测模块,该模块确定两种药物是否会相互作用。同时,注意力权重揭示了哪些图谱连接、图像区域或文本元素对模型的决策贡献最大,从而为每次预测提供清晰的解释。
这一步骤确保了所有必要的深度学习、图处理和化学信息学软件包在环境中可用。实现过程首先安装并导入所需的库。它安装了用于神经网络的 PyTorch 和 torchvision,用于文本编码的 HuggingFace Transformers,用于图操作的 NetworkX 和 torch-geometric,用于处理分子结构的 RDKit 和 OpenBabel,以及 pandas、NumPy 和 Matplotlib 等支持库。安装完成后,导入所需的库和模块,以便在后续单元中使用。
# install necessary packages
!pip install torch torchvision transformers networkx spacy rdflib rdkit pillow scikit-learn matplotlib seaborn torch-geometric
# pip did not work
!apt-get install openbabel
!pip install openbabel-wheel# import libraries
import torch
import torch.nnas nn
from torch.utils.dataimportDataset, DataLoader
import torchvision.modelsas models
import torchvision.transformsas transforms
from transformers importBertModel, BertTokenizer
import networkx as nx
import numpy as np
import matplotlib.pyplotas plt
import pandas as pd
import json
import os
from rdkit importChem
from rdkit.ChemimportDraw
fromPILimportImage
import io
import base64
from openbabel import openbabel
from torch_geometric.dataimportData
import torch_geometric.nnas geom_nn
首先创建一个目录用于存放药物图像。然后从公共仓库下载一个简化的 DrugBank 样本,并保存为 TSV 文件。该文件被加载到 pandas DataFrame 中,生成一张表格,包含每种药物的唯一标识符、名称、用于分子结构的 InChI 字符串,以及类别和组等描述性元数据。这个结构化数据集为后续步骤中生成视觉和文本表示奠定了基础。
# create directory for data storage
!mkdir -p data/drug_images# download DrugBank sample data (simplified version for demonstration)
!wget -q -O data/drugbank_sample.tsv https://raw.githubusercontent.com/dhimmel/drugbank/gh-pages/data/drugbank-slim.tsv# load DrugBank data
drug_df = pd.read_csv('data/drugbank_sample.tsv', sep='\t')
分子结构以 InChI 格式提供。这些表示需要通过 OpenBabel 转换为 SMILES 格式。SMILES 代表简化分子输入线表示法 (Simplified Molecular Input Line Entry System),提供了一种简洁、基于文本的方式来描述化学结构。SMILES 字符串与 RDKit 等工具兼容,RDKit 可以从 SMILES 字符串生成分子图像。下面的代码展示了如何进行此转换。
# create a SMILES column by converting InChI to SMILES
def inchi_to_smiles_openbabel(inchi_str):
try: # create Open Babel OBMol object from InChI
obConversion = openbabel.OBConversion()
obConversion.SetInAndOutFormats("inchi", "smiles")
mol = openbabel.OBMol() # convert InChI to molecule # also remove extra newlines or spaces
if obConversion.ReadString(mol, inchi_str):
return obConversion.WriteString(mol).strip()
else:
return None
except Exception as e:
print(f"Error converting InChI to SMILES: {inchi_str}. Error: {e}")
return None# apply the conversion to each InChI in the dataframe
drug_df['smiles'] = drug_df['inchi'].apply(inchi_to_smiles_openbabel)
系统构建了一个有向医学知识图谱,用于捕获药物、蛋白质和疾病之间的关系。每个节点代表一种药物、蛋白质或疾病,每条边编码了一种相互作用,如 binds_to、inhibits 或 treats。这些连接存储了关于药物如何影响生物靶点和病症的专家知识。
该图谱充当了一个结构化的关系信息来源,模型在处理图像和文本特征的同时会利用这些信息。通过明确地表示领域知识,图谱增强了预测准确性以及解释两种药物为何可能相互作用的能力。
# initialize a medical knowledge graph
medical_kg = nx.DiGraph()# extract drug entities from DrugBank
# limit to 50 drugs for demo
drug_entities = drug_df['name'].dropna().unique().tolist()[:50]# create drug nodes
for drug in drug_entities:
medical_kg.add_node(drug, type='drug')# add biomedical entities (proteins, targets, diseases)
protein_entities = ["Cytochrome P450", "Albumin", "P-glycoprotein", "GABA Receptor",
"Serotonin Receptor", "Beta-Adrenergic Receptor", "ACE", "HMGCR"]
disease_entities = ["Hypertension", "Diabetes", "Depression", "Epilepsy",
"Asthma", "Rheumatoid Arthritis", "Parkinson's Disease"]for protein in protein_entities:
medical_kg.add_node(protein, type='protein')for disease in disease_entities:
medical_kg.add_node(disease, type='disease')# add relationships (based on common drug mechanisms and interactions)
# drug-protein relationships
drug_protein_relations = [
("Warfarin", "binds_to", "Albumin"),
("Atorvastatin", "inhibits", "HMGCR"),
("Diazepam", "modulates", "GABA Receptor"),
("Fluoxetine", "inhibits", "Serotonin Receptor"),
("Phenytoin", "induces", "Cytochrome P450"),
("Metoprolol", "blocks", "Beta-Adrenergic Receptor"),
("Lisinopril", "inhibits", "ACE"),
("Rifampin", "induces", "P-glycoprotein"),
("Carbamazepine", "induces", "Cytochrome P450"),
("Verapamil", "inhibits", "P-glycoprotein")
]# drug-disease relationships
drug_disease_relations = [
("Lisinopril", "treats", "Hypertension"),
("Metformin", "treats", "Diabetes"),
("Fluoxetine", "treats", "Depression"),
("Phenytoin", "treats", "Epilepsy"),
("Albuterol", "treats", "Asthma"),
("Methotrexate", "treats", "Rheumatoid Arthritis"),
("Levodopa", "treats", "Parkinson's Disease")
]# known drug-drug interactions (based on actual medical knowledge)
drug_drug_interactions = [
("Goserelin", "interacts_with", "Desmopressin", "increases_anticoagulant_effect"),
("Goserelin", "interacts_with", "Cetrorelix", "increases_bleeding_risk"),
("Cyclosporine", "interacts_with", "Felypressin", "decreases_efficacy"),
("Octreotide", "interacts_with", "Cyanocobalamin", "increases_hypoglycemia_risk"),
("Tetrahydrofolic acid", "interacts_with", "L-Histidine", "increases_statin_concentration"),
("S-Adenosylmethionine", "interacts_with", "Pyruvic acid", "decreases_efficacy"),
("L-Phenylalanine", "interacts_with", "Biotin", "increases_sedation"),
("Choline", "interacts_with", "L-Lysine", "decreases_efficacy")
]# add all relationships to the knowledge graph
for s, r, o in drug_protein_relations:
if s in medical_kg and o in medical_kg:
medical_kg.add_edge(s, o, relation=r)for s, r, o in drug_disease_relations:
if s in medical_kg and o in medical_kg:
medical_kg.add_edge(s, o, relation=r)for s, r, o, mechanism in drug_drug_interactions:
if s in medical_kg and o in medical_kg:
medical_kg.add_edge(s, o, relation=r, mechanism=mechanism)
每种药物都由三种互补的数据类型表示。首先,其 SMILES 表示法被转换为分子对象,并使用 RDKit 渲染成图像。
# function to generate molecular structure images using RDKit
def generate_molecule_image(smiles_string, size=(224, 224)):
try:
mol = Chem.MolFromSmiles(smiles_string)
if mol:
img = Draw.MolToImage(mol, size=size)
return img
else:
return None
except:
return None
其次,通过结合药物名称、类别、组信息和任何可用的元数据来构建描述性文本。
# function to create text description for drugs combining various information
def create_drug_description(row):
description = f"Drug name: {row['name']}. " if pd.notna(row.get('category')):
description += f"Category: {row['category']}. " if pd.notna(row.get('groups')):
description += f"Groups: {row['groups']}. " if pd.notna(row.get('description')):
description += f"Description: {row['description']}"
第三,图谱被嵌入其中,具体而言,每个节点和关系最初是随机向量,然后进行迭代调整,使得对于每个真实的连接,一个实体的向量加上其关系的向量会使其接近与之相连实体的向量。经过多次迭代,这形成了一个嵌入空间,其中相连元素自然聚类,并且关系方向由关系向量编码。结果是一对查找表,将每个节点和关系映射到紧凑、可训练的坐标,反映了知识图谱的完整结构。
# convert NetworkX graph to PyG graph for modern graph neural network processing
def convert_nx_to_pyg(nx_graph): # create node mappings
node_to_idx = {node: i for i, node in enumerate(nx_graph.nodes())} # create edge lists
src_nodes = []
dst_nodes = []
edge_types = []
edge_type_to_idx = {} for u, v, data in nx_graph.edges(data=True):
relation = data.get('relation', 'unknown')
if relation not in edge_type_to_idx:
edge_type_to_idx[relation] = len(edge_type_to_idx)
src_nodes.append(node_to_idx[u])
dst_nodes.append(node_to_idx[v])
edge_types.append(edge_type_to_idx[relation]) # create PyG graph
edge_index = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)
edge_type = torch.tensor(edge_types, dtype=torch.long) # create node features
node_types = []
for node in nx_graph.nodes():
node_type = nx_graph.nodes[node].get('type', 'unknown')
node_types.append(node_type) # one-hot encode node types
unique_node_types = sorted(set(node_types))
node_type_to_idx = {nt: i for i, nt in enumerate(unique_node_types)}
node_type_features = torch.zeros(len(node_types), len(unique_node_types))
for i, nt in enumerate(node_types):
node_type_features[i, node_type_to_idx[nt]] = 1.0 # create PyG Data object with the proper attributes
g = Data(
edge_index=edge_index,
edge_type=edge_type,
x=node_type_features # node features in PyG are stored in 'x'
) # create reverse mappings for later use
idx_to_node = {idx: node for node, idx in node_to_idx.items()}
idx_to_edge_type = {idx: edge_type for edge_type, idx in edge_type_to_idx.items()} return g, node_to_idx, idx_to_node, edge_type_to_idx, idx_to_edge_type# convert medical_kg to DGL graph
pyg_graph, node_to_idx, idx_to_node, edge_type_to_idx, idx_to_edge_type = convert_nx_to_pyg(medical_kg)
这些视觉、文本和结构化表示被保存,以便模型可以将其融合用于相互作用预测。
# process drug data to create multi-modal representations
drug_data = []for idx, row in drug_df.iterrows():
if row['name'] in drug_entities and pd.notna(row.get('smiles')): # generate molecule image
img = generate_molecule_image(row['smiles']) if img:
img_path = f"data/drug_images/{row['drugbank_id']}.png"
img.save(img_path) # Create text description
description = create_drug_description(row) # Store drug information
drug_data.append({
'id': row['drugbank_id'],
'name': row['name'],
'smiles': row['smiles'],
'description': description,
'image_path': img_path
})drug_data_df = pd.DataFrame(drug_data)
MultimodalNodeEncoder 创建了一个单一编码器,将每个节点的分子图像及其文本摘要转换为兼容的特征向量。首先,它对原始化学图应用深度卷积网络,将其提炼成紧凑的视觉指纹。同时,它通过预训练语言模型处理药物描述,提取语义摘要。然后将两者的输出映射到同一个向量空间中,以便视觉和文本信号可以在知识图谱结构的引导下有意义地结合。
# processes visual and textual features for nodes
classMultimodalNodeEncoder(nn.Module): def __init__(self, output_dim=128):
super(MultimodalNodeEncoder, self).__init__()
# image encoder (ResNet)
resnet = models.resnet18(pretrained=True)
# remove the final fully connected layer to get 512 features
self.image_encoder = nn.Sequential(*list(resnet.children())[:-1])
self.image_projection = nn.Linear(512, output_dim) # text encoder (BERT)
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
# BERT base outputs 768 features
self.text_projection = nn.Linear(768, output_dim) def forward(self, image, text):
# image encoding
img_features = self.image_encoder(image).squeeze(-1).squeeze(-1)
img_features = self.image_projection(img_features) # text encoding
encoded_input = self.tokenizer(text, padding=True, truncation=True,
return_tensors="pt", max_length=128)
# move encoded input to the same device as the image
input_ids = encoded_input['input_ids'].to(image.device)
attention_mask = encoded_input['attention_mask'].to(image.device) text_outputs = self.text_encoder(input_ids=input_ids,
attention_mask=attention_mask)
# use the [CLS] token embedding (first token)
text_features = text_outputs.last_hidden_state[:, 0, :]
text_features = self.text_projection(text_features) return img_features, text_features
KG引导的多模态模型融合每个节点的视觉、文本和类型嵌入,并在知识图谱的指导下预测药物-药物相互作用。它首先将每个节点的图像和文本输出投影到共享空间,并为其节点类型分配单独的嵌入。这些嵌入随后在图上传播,使得每个节点将其自身特征与其邻居收集到的信号相融合。注意力步骤根据连接的强度和类型重新加权这些融合后的特征。评估一对药物时,模型获取它们精炼后的节点嵌入,通过连接、元素级乘积和差值进行组合,然后将结果馈入预测头,产生相互作用的概率。通过让图谱的拓扑决定多模态信号如何融合,模型生成的预测既准确又可直接追溯到潜在的网络结构。
# define KG-guided MultimodalModel
classKGGuidedMultimodalModel(nn.Module): def __init__(self, pyg_graph, num_node_types, num_edge_types, node_to_idx, idx_to_node, hidden_dim=128):
super(KGGuidedMultimodalModel, self).__init__()
self.pyg_graph = pyg_graph
self.node_to_idx = node_to_idx
self.idx_to_node = idx_to_node
self.hidden_dim = hidden_dim # multimodal encoder for processing node-associated data
self.multimodal_encoder = MultimodalNodeEncoder(output_dim=hidden_dim) # node type embeddings
self.node_type_embedding = nn.Embedding(num_node_types, hidden_dim) # GraphNeuralNetwork layers for knowledge graph processing (PyGGCNConv instead of dglnn.GraphConv)
self.gnn_layers = nn.ModuleList([
geom_nn.GCNConv(hidden_dim, hidden_dim),
geom_nn.GCNConv(hidden_dim, hidden_dim),
]) # GraphAttentionNetworkfor integrating multimodal features with graph structure (PyGGATConv)
# explicitly set output dimension so total output is hidden_dim (not hidden_dim * num_heads)
self.gat_layer = geom_nn.GATConv(hidden_dim, hidden_dim // 4, heads=4) # relation prediction layer - updated to match the actual input dimensions we'll have
self.relation_prediction = nn.Sequential(
nn.Linear(hidden_dim * 4, hidden_dim * 2),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, 1)
) def get_node_representation(self, node_name, image=None, text=None):
if node_name not in self.node_to_idx:
# handle unknown nodes
return torch.zeros(self.hidden_dim, device=self.pyg_graph.edge_index.device) node_idx = self.node_to_idx[node_name] # get node type features - use x instead of ndata['type']
node_type_feat = self.pyg_graph.x[node_idx]
node_type_embedding = self.node_type_embedding(torch.argmax(node_type_feat)) # if multimodal data is provided, process it
if image is not None and text is not None:
img_feat, text_feat = self.multimodal_encoder(image, text) # squeeze out the batch dimension to match shapes
img_feat = img_feat.squeeze(0)
text_feat = text_feat.squeeze(0) # knowledge graph structure guides how multimodal features are integrated
# use node_type_embedding as a query to attend to multimodal features
attention_weights = torch.softmax(
torch.matmul(
torch.stack([img_feat, text_feat, node_type_embedding]),
node_type_embedding
),
dim=0
) # weighted combination of features
combined_feat = (
attention_weights[0] * img_feat +
attention_weights[1] * text_feat +
attention_weights[2] * node_type_embedding
) return combined_feat
else:
# for nodes without multimodal data, just use type embedding
return node_type_embedding def forward(self, drug1_image, drug1_text, drug1_name, drug2_image, drug2_text, drug2_name):
# process the entire graph first
device = self.pyg_graph.edge_index.device
x = torch.zeros((self.pyg_graph.x.size(0), self.hidden_dim), device=device) # initialize known node features
for i, node_name inenumerate([drug1_name, drug2_name]):
if node_name in self.node_to_idx:
node_idx = self.node_to_idx[node_name]
if i == 0:
x[node_idx] = self.get_node_representation(node_name, drug1_image, drug1_text)
else:
x[node_idx] = self.get_node_representation(node_name, drug2_image, drug2_text) # apply graph convolutions to propagate information - PyG style
edge_index = self.pyg_graph.edge_index
for layer in self.gnn_layers:
x = layer(x, edge_index)
x = torch.relu(x) # apply graph attention to integrate features - PyG style
x = self.gat_layer(x, edge_index) # get final representations for the two drugs
drug1_idx = self.node_to_idx.get(drug1_name, 0)
drug2_idx = self.node_to_idx.get(drug2_name, 0) drug1_repr = x[drug1_idx]
drug2_repr = x[drug2_idx] # predict interaction
# concatenate representations in multiple ways to capture relationship
concat_repr = torch.cat([
drug1_repr,
drug2_repr,
drug1_repr * drug2_repr,
torch.abs(drug1_repr - drug2_repr)
], dim=0) interaction_prob = torch.sigmoid(self.relation_prediction(concat_repr.unsqueeze(0)).squeeze())
return interaction_prob
更大图谱中如何关联,会构建一个焦点子图。它首先会在图谱中查找两种药物之间是否存在任何直接边,如果找到则记录其属性。接下来,它识别与两种药物都关联的蛋白质或疾病,揭示共享机制。最后,它追踪所有不超过给定长度的简单路径,以揭示通过中间节点的间接连接。结果是由关键节点和边组成的紧凑网络,它捕获了预测相互作用背后的领域知识,并指导下游层强调最相关的多模态特征。
# function to retrieve knowledge subgraph relevant to a drug pair
def retrieve_knowledge_subgraph(graph, drug1, drug2, max_path_length=3):
relevant_knowledge = {
'direct_interaction': None,
'common_targets': [],
'paths': []
} # check for direct interaction
if graph.has_edge(drug1, drug2):
edge_data = graph.get_edge_data(drug1, drug2)
relevant_knowledge['direct_interaction'] = edge_data # find common targets (proteins, diseases)
drug1_neighbors = set(graph.neighbors(drug1)) if drug1 in graph elseset()
drug2_neighbors = set(graph.neighbors(drug2)) if drug2 in graph elseset() common_neighbors = drug1_neighbors.intersection(drug2_neighbors)
for common_node incommon_neighbors:
node_type = graph.nodes[common_node].get('type', '')
if node_type == 'protein' or node_type == 'disease':
relevant_knowledge['common_targets'].append(common_node) # find paths between drugs (up to max_path_length)
try:
paths = list(nx.all_simple_paths(graph, drug1, drug2, cutoff=max_path_length))
relevant_knowledge['paths'] = paths
except (nx.NetworkXError, nx.NodeNotFound):
# Handle cases where paths do not exist or nodes are not in graph
pass return relevant_knowledge
该函数负责准备每个训练批次,首先剔除所有加载失败或不完整的样本。然后,它将所有有效的分子图像堆叠成这对药物的批量张量,同时将其相应的文本摘要和标识符收集到并行列表中。相互作用标签也类似地组合成一个张量。通过返回一个包含这些批量化组件的统一字典——如果剩余样本无效则返回空占位符——它确保模型总是接收到结构良好、同质的输入,尽管底层数据存在异构性和偶尔缺失。
# custom collate function to handle None values
def custom_collate_fn(batch):
# filter out None values
batch = [item for item in batch if item is not None] # return empty batch if all items were None
iflen(batch) == 0:
return {
'drug1_img': torch.tensor([]),
'drug1_text': [],
'drug1_name': [],
'drug2_img': torch.tensor([]),
'drug2_text': [],
'drug2_name': [],
'label': torch.tensor([])
} # process non-None items
drug1_imgs = torch.stack([item['drug1_img'] for item in batch])
drug1_texts = [item['drug1_text'] for item in batch]
drug1_names = [item['drug1_name'] for item in batch] drug2_imgs = torch.stack([item['drug2_img'] for item in batch])
drug2_texts = [item['drug2_text'] for item in batch]
drug2_names = [item['drug2_name'] for item in batch] labels = torch.stack([item['label'] for item in batch]) return {
'drug1_img': drug1_imgs,
'drug1_text': drug1_texts,
'drug1_name': drug1_names,
'drug2_img': drug2_imgs,
'drug2_text': drug2_texts,
'drug2_name': drug2_names,
'label': labels
}
训练示例结合所有真实相互作用和一组匹配的随机非相互作用对。该过程首先提取所有已知的药物相互作用对,然后采样等量的负样本对以平衡数据集。获取一个示例时,加载并预处理每种药物的分子图像和文本摘要,跳过任何缺少数据的对,生成包含两种药物的图像、描述、名称和一个二元标签的记录。通过将正样本与负样本配对、应用一致的图像变换以及稳健处理缺失数据,数据集提供了可靠、即用型的批次,用于训练相互作用预测模型。
# define dataset for DDI prediction
classDDIDataset(Dataset): def __init__(self, drug_data_df, drug_drug_interactions, medical_kg, node_to_idx, transform=None):
self.drug_data = drug_data_df
self.drug_name_to_idx = {row['name']: i for i, row in drug_data_df.iterrows()}
self.node_to_idx = node_to_idx
self.transform = transform or transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]) # create pairs of drugs with interaction labels
self.pairs = []
drug_names = list(self.drug_name_to_idx.keys()) # positive samples (known interactions)
for interaction indrug_drug_interactions:
drug1, _, drug2, _ = interaction
if drug1 in drug_names and drug2 indrug_names:
# 1for positive interaction
self.pairs.append((drug1, drug2, 1))
positive_pairs = set((d1, d2) for d1, d2, _ in self.pairs) # generate some negative samples
np.random.seed(42)
neg_count = 0
max_neg = len(self.pairs)
while neg_count < max_neg:
i, j = np.random.choice(len(drug_names), 2, replace=False)
drug1, drug2 = drug_names[i], drug_names[j]
if (drug1, drug2) not in positive_pairs and (drug2, drug1) not inpositive_pairs:
# 0for negative interaction
self.pairs.append((drug1, drug2, 0))
neg_count += 1 def __len__(self):
returnlen(self.pairs) def __getitem__(self, idx):
try:
drug1_name, drug2_name, label = self.pairs[idx] # get drug1 data
drug1_idx = self.drug_name_to_idx[drug1_name]
drug1_data = self.drug_data.iloc[drug1_idx] # load drug1 image with error handling
try:
drug1_img = Image.open(drug1_data['image_path']).convert('RGB')
drug1_img = self.transform(drug1_img)
except Exceptionase:
print(f"Error loading drug1 image for {drug1_name}: {str(e)}")
returnNone drug1_text = drug1_data['description'] # get drug2 data
drug2_idx = self.drug_name_to_idx[drug2_name]
drug2_data = self.drug_data.iloc[drug2_idx] # load drug2 image with error handling
try:
drug2_img = Image.open(drug2_data['image_path']).convert('RGB')
drug2_img = self.transform(drug2_img)
except Exceptionase:
print(f"Error loading drug2 image for {drug2_name}: {str(e)}")
returnNone drug2_text = drug2_data['description'] return {
'drug1_img': drug1_img,
'drug1_text': drug1_text,
'drug1_name': drug1_name,
'drug2_img': drug2_img,
'drug2_text': drug2_text,
'drug2_name': drug2_name,
'label': torch.tensor(label, dtype=torch.float32)
}
except Exceptionase:
print(f"Error in __getitem__ for index {idx}: {str(e)}")
return None
模型训练开始时,将模型及其图谱移至选定的设备(GPU 或 CPU),并进行设定的轮数(epochs),每轮分为训练阶段和验证阶段。训练期间,成对药物的批次通过网络传递以产生相互作用分数,计算二元交叉熵损失,Adam 优化器通过反向传播更新所有参数。损失和正确预测计数被汇总,以便在每轮结束时报告平均训练损失和准确率,为保持稳定性跳过空或格式错误的批次。然后过程切换到评估模式——运行相同的批次但不进行梯度更新——以衡量验证损失和准确率。
# training function
def train_kg4mm_model(model, train_loader, val_loader, epochs=5):
device = torch.device('cuda'if torch.cuda.is_available() else'cpu')
model = model.to(device)
model.pyg_graph = model.pyg_graph.to(device) criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) for epoch inrange(epochs):
# training phase
model.train()
train_loss = 0
train_correct = 0
batch_count = 0 for batch intrain_loader:
# skip empty batches
iflen(batch['drug1_img']) == 0:
print("Skipping empty batch")
continue batch_count += 1 try:
drug1_img = batch['drug1_img'].to(device)
drug1_text = batch['drug1_text']
drug1_name = batch['drug1_name']
drug2_img = batch['drug2_img'].to(device)
drug2_text = batch['drug2_text']
drug2_name = batch['drug2_name']
labels = batch['label'].to(device) # forward pass - processing one pair at a time for clarity
batch_size = len(drug1_name)
outputs = torch.zeros(batch_size, 1, device=device) for i inrange(batch_size):
# this loop is for illustration - in practice, handle batch processing more efficiently
output = model(
drug1_img[i].unsqueeze(0),
[drug1_text[i]],
drug1_name[i],
drug2_img[i].unsqueeze(0),
[drug2_text[i]],
drug2_name[i]
)
outputs[i] = output # calculate loss
loss = criterion(outputs, labels.unsqueeze(1)) # backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step() train_loss += loss.item() # calculate accuracy
predictions = (outputs >= 0.5).float()
train_correct += (predictions == labels.unsqueeze(1)).sum().item() print(f"Batch {batch_count}: Loss: {loss.item():.4f}") except Exceptionase:
print(f"Error processing batch {batch_count}: {str(e)}")
import traceback
traceback.print_exc()
continue avg_train_loss = train_loss / max(1, batch_count)
train_acc = train_correct / max(1, batch_count * batch['drug1_img'].size(0)) print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}') # validation phase
model.eval()
val_loss = 0
val_correct = 0
val_batch_count = 0 with torch.no_grad():
for batch inval_loader:
# skip empty batches
iflen(batch['drug1_img']) == 0:
continue val_batch_count += 1 try:
drug1_img = batch['drug1_img'].to(device)
drug1_text = batch['drug1_text']
drug1_name = batch['drug1_name']
drug2_img = batch['drug2_img'].to(device)
drug2_text = batch['drug2_text']
drug2_name = batch['drug2_name']
labels = batch['label'].to(device) # forward pass - processing one pair at a time for clarity
batch_size = len(drug1_name)
outputs = torch.zeros(batch_size, 1, device=device) for i inrange(batch_size):
output = model(
drug1_img[i].unsqueeze(0),
[drug1_text[i]],
drug1_name[i],
drug2_img[i].unsqueeze(0),
[drug2_text[i]],
drug2_name[i]
)
outputs[i] = output # calculate loss
loss = criterion(outputs, labels.unsqueeze(1))
val_loss += loss.item() # calculate accuracy
predictions = (outputs >= 0.5).float()
val_correct += (predictions == labels.unsqueeze(1)).sum().item() except Exceptionase:
print(f"Error processing validation batch {val_batch_count}: {str(e)}")
continue avg_val_loss = val_loss / max(1, val_batch_count)
val_acc = val_correct / max(1, val_batch_count * 4) # Assuming batch_size=4 print(f'Epoch {epoch+1}/{epochs}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}') return model
准备好的数据集然后生成成对的药物示例——每个示例都包含其分子图像和文本摘要——并将它们分割为训练集和验证集,用于性能跟踪。数据加载器将这些多模态示例(图像、描述和标签)打包成批次,以便它们顺畅地输入模型。KG引导的预测网络被实例化,其维度源自图的节点和边类型,确保其层与知识图谱的结构对齐。最后,训练循环运行固定轮数,交替在训练数据上更新模型并在验证集上衡量其准确率。这一序列完成了从数据准备到主动、图驱动学习的转变。
# initialize dataset and model
ddi_dataset = DDIDataset(drug_data_df, drug_drug_interactions, medical_kg, node_to_idx)# split dataset into train and validation sets
train_size = int(0.8 \* len(ddi_dataset))
val_size = len(ddi_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(ddi_dataset, [train_size, val_size])# create data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=custom_collate_fn)# initialize the model with the DGL graph
num_node_types = pyg_graph.x.shape[1]
num_edge_types = len(edge_type_to_idx)# initialize the KG-guided multimodal model
model = KGGuidedMultimodalModel(pyg_graph, num_node_types, num_edge_types, node_to_idx, idx_to_node)# train the model
trained_model = train_kg4mm_model(model, train_loader, val_loader, epochs=5)
进行预测时,模型首先加载每种药物处理后的图像和文本摘要,并确定每种药物在知识图谱中的位置。然后,它产生一个概率分数,显示视觉、文本和图谱信息如何协同作用。同时,系统检查图谱是否存在两种药物之间的任何直接链接、它们都连接的任何蛋白质或疾病,以及连接它们的任何不超过给定长度的简单路径。该概率被转换为低、中或高风险等级。然后构建解释,突出显示已知的相互作用机制、共享靶点以及指导决策的关键图谱路径。最后,系统根据风险等级提供示例性的临床建议,清楚展示知识图谱如何塑造预测及其解释。
def predict_interaction(model, drug1_name, drug2_name, drug_data_df, medical_kg):
device = torch.device('cuda'if torch.cuda.is_available() else'cpu')
model = model.to(device)
model.eval() # get drug indices
drug1_idx = drug_data_df[drug_data_df['name'] == drug1_name].index[0]
drug2_idx = drug_data_df[drug_data_df['name'] == drug2_name].index[0] # get drug data
drug1_data = drug_data_df.iloc[drug1_idx]
drug2_data = drug_data_df.iloc[drug2_idx] # prepare images
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]) drug1_img = Image.open(drug1_data['image_path']).convert('RGB')
drug1_img = transform(drug1_img).unsqueeze(0).to(device)
drug1_text = [drug1_data['description']] drug2_img = Image.open(drug2_data['image_path']).convert('RGB')
drug2_img = transform(drug2_img).unsqueeze(0).to(device)
drug2_text = [drug2_data['description']] # get knowledge subgraph for the drug pair
knowledge = retrieve_knowledge_subgraph(medical_kg, drug1_name, drug2_name) # make prediction
with torch.no_grad():
interaction_prob = model(
drug1_img,
drug1_text,
drug1_name,
drug2_img,
drug2_text,
drug2_name
) return interaction_prob.item(), knowledgedef explain_interaction_prediction(drug1_name, drug2_name, probability, knowledge):
explanation = f"KG-guided multimodal analysis for interaction between {drug1_name} and {drug2_name}:\n\n" # interpret the probability
if probability > 0.8:
risk_level = "High"
elif probability > 0.5:
risk_level = "Moderate"
else:
risk_level = "Low" explanation += f"Interaction Risk Level: {risk_level} (Probability: {probability:.2f})\n\n" # explain based on knowledge graph structure
explanation += "Knowledge Graph Analysis:\n" if knowledge['direct_interaction']:
mechanism = knowledge['direct_interaction'].get('mechanism', 'unknown mechanism')
explanation += f"✓ Direct Connection: The knowledge graph contains a documented interaction between these drugs with {mechanism}.\n\n" if knowledge['common_targets']:
explanation += "✓ Common Target Nodes: These drugs connect to shared entities in the knowledge graph:\n"
for target in knowledge['common_targets']:
explanation += f" - {target}\n"
explanation += " This graph structure suggests potential interaction through common binding sites or pathways.\n\n" if knowledge['paths'] andlen(knowledge['paths']) > 0:
explanation += "✓ Knowledge Graph Pathways: The model identified these connecting paths in the graph:\n"
for i, path inenumerate(knowledge['paths'][:3]):
path_str = " → ".join(path)
explanation += f" - Path {i+1}: {path_str}\n"
explanation += " These graph structures guided the multimodal feature integration for prediction.\n\n" # focus on how KG structure guided the interpretation
explanation += "Multimodal Integration Process:\n"
explanation += " - Knowledge graph structure determined which drug properties were most relevant\n"
explanation += " - Graph neural networks analyzed the local neighborhood of both drug nodes\n"
explanation += " - Node position in the graph guided the weighting of visual and textual features\n\n" # clinical implications (example - in a real system, this would be more comprehensive)
if probability > 0.5:
explanation += "Clinical Recommendations (based on graph analysis):\n"
explanation += " - Consider alternative medications not connected in similar graph patterns\n"
explanation += " - If co-administration is necessary, monitor for interaction effects\n"
explanation += " - Review other drugs connected to the same nodes for potential complications\n"
else:
explanation += "Clinical Recommendations (based on graph analysis):\n"
explanation += " - Standard monitoring advised\n"
explanation += " - The knowledge graph structure suggests minimal interaction concerns\n" return explanation
为了说明完整的工作流程,选择两种药物,并加载并像训练时一样预处理它们预先生成的图像和文本摘要。这些多模态输入然后通过训练好的模型传递——此时处于评估模式——产生一个量化其相互作用风险的概率分数。同时,为了可视化和解释,该过程通过收集所有直接连接、共享生物靶点以及连接它们的任何不超过给定长度的简单路径,提取知识图谱的相关部分,然后通过增加一层直接邻居来扩充此子图以获取更广泛的上下文。
提取出的子图采用清晰的配色方案进行布局,有效区分了两种药物、蛋白质、疾病及其他实体,使网络结构一目了然,增强了可读性和分析效率。紧随其后的是清晰的自然语言解释,通过突出显示任何已记录的相互作用机制、共享靶点和关键连接路径,将概率分数与这些图谱特征关联起来。风险估计、颜色编码可视化和叙述性解释共同说明了知识图谱的拓扑如何指导了视觉和文本信号的融合,并为模型的预测提供了透明的理由。
# example usage
drug_pair = ("Goserelin", "Desmopressin")
prob, knowledge = predict_interaction(trained_model, drug_pair[0], drug_pair[1], drug_data_df, medical_kg)print(f"Predicted interaction probability between {drug_pair[0]} and {drug_pair[1]}: {prob:.4f}")print("\nKnowledge Graph Structure Analysis:")
print(f"Direct connection: {knowledge['direct_interaction']}")
print(f"Common target nodes: {knowledge['common_targets']}")
print(f"Graph paths connecting drugs:")
for path in knowledge['paths']:
print(f" {' -> '.join(path)}")# visualize the subgraph for these drugs to show the KG-guided approach
plt.figure(figsize=(12, 8))
subgraph_nodes = set([drug_pair[0], drug_pair[1]])
# add intermediate nodes in paths to highlight the KG structure
for path in knowledge['paths']:
subgraph_nodes.update(path) # add a level of neighbors to show context in KG
neighbors_to_add = set()
for node in subgraph_nodes:
if node in medical_kg:
neighbors_to_add.update(list(medical_kg.neighbors(node))[:3])
subgraph_nodes.update(neighbors_to_add)subgraph = medical_kg.subgraph(subgraph_nodes)# use different colors for node types to emphasize KG structure
node_colors = []
for node in subgraph.nodes():
if node == drug_pair[0] or node == drug_pair[1]:
node_colors.append('lightcoral')
elif subgraph.nodes[node].get('type') == 'protein':
node_colors.append('lightblue')
elif subgraph.nodes[node].get('type') == 'disease':
node_colors.append('lightgreen')
else:
node_colors.append('lightgray')pos = nx.spring_layout(subgraph, seed=42)
nx.draw(subgraph, pos, with_labels=True, node_color=node_colors,
node_size=2000, arrows=True, arrowsize=20)edge_labels = {(s, o): subgraph[s][o]['relation'] for s, o in subgraph.edges()}
nx.draw_networkx_edge_labels(subgraph, pos, edge_labels=edge_labels)plt.title(f"Knowledge Graph Structure Guiding {drug_pair[0]} and {drug_pair[1]} Interaction Analysis")
plt.savefig('kg_guided_interaction_analysis.png')
plt.show()# show explanation
explanation = explain_interaction_prediction(drug_pair[0], drug_pair[1], prob, knowledge)
print(explanation)
在戈舍瑞林 (Goserelin) 和去氨加压素 (Desmopressin) 上测试时,模型返回了 0.54 的概率,将其归类为中等风险对。知识图谱揭示了两种药物之间存在一个直接的"相互作用"(interacts_with)关系,该关系的具体描述/标签为"增加抗凝作用"(increases_anticoagulant_effect),没有共享的蛋白质或疾病连接,因此模型主要关注了该机制。在子图绘制中,两种药物以红色突出显示,单条有向边突出显示,清晰显示是哪种关系驱动了预测。
KG4MM 的研究表明,将知识图谱作为工作流程的核心,可以更好地融合分子图像和文本,效果优于单一来源的方法。每个预测都由清晰的图谱证据支持——直接边、共享靶点和连接路径——这使得结果与真实的生物关系关联起来。通过这样做,KG4MM 在生物化学、材料科学和医学诊断等领域都提供了更强的预测能力和内置的可解释性。
若要了解更多知识图谱或图数据库相关教学,你可以查看公众号的其他文章:
活水智能,成立于北京,专注通过AI教育、AI软件及高质量社群,持续提升知识工作者的生产力。
10+ 人气AI课程:线下工作坊与实操训练,聚焦最新AI应用。
2600+深度成员社群:知识星球汇聚大厂程序员、企业高管、律师、创业者等各领域精英。
城市分舵:北/上/广/深/杭/成/渝等城市均有线下组织,连接志同道合的伙伴。
? 近期开课:
AI学术分析首期 「即将开班,提前预约 ??重磅!如何在AI时代脱颖而出?「AI 分析三课」正式上线」
AI通识二期「即将开班,提前预约 ??AI通识课来啦!知识工作者的AI时代生存指南」
? 福利群开放加入
每周独家AI新知、专属优惠券、干货方法论、同学交流心得,更有不定期赠书活动,等你来参与!
??????
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-07-17
2025-01-02
2024-08-13
2024-08-27
2024-07-11
2025-01-03
2024-06-24
2024-07-13
2024-07-12
2024-06-10
2025-05-20
2025-04-20
2025-04-15
2025-04-09
2025-03-29
2025-02-13
2025-01-14
2025-01-10