微信扫码
添加专属顾问
我要投稿
Vanna用RAG技术将AI生成SQL的准确率从3%提升到80%,彻底改变了数据库查询的游戏规则。 核心内容: 1. AI生成SQL面临的数据库结构理解难题 2. Vanna基于RAG架构的智能SQL生成框架解析 3. 模块化设计带来的技术突破与实战效果
❝"在数据驱动的时代,让AI理解你的数据库就像教会外星人说人话一样困难。但Vanna做到了,而且做得相当优雅。"
想象一下这样的场景:你兴冲冲地打开ChatGPT,输入"帮我查询一下德国有多少客户",期待着AI能够生成一条完美的SQL语句。结果呢?AI给你返回了一个看起来很专业的查询:
SELECT COUNT(*) FROM customers WHERE country = 'Germany';
看起来不错对吧?但当你兴奋地复制到数据库中执行时,系统无情地抛出了错误:Table 'customers' doesn't exist
。
这就是当前AI生成SQL面临的核心困境:LLM虽然掌握了SQL语法的精髓,却对你的具体数据库结构一无所知。就像一个语言天才试图在不了解当地文化的情况下进行深度交流一样,注定会闹出笑话。
但是,如果我告诉你有一个开源项目能够将SQL生成的准确率从令人绝望的3%提升到令人惊艳的80%,你会相信吗?这就是我们今天要深入探讨的主角——Vanna。
Vanna不是简单的"ChatGPT + SQL"的组合,而是一个基于RAG(Retrieval-Augmented Generation)架构的智能SQL生成框架。它的核心理念可以用一句话概括:让AI不仅懂SQL语法,更要懂你的数据。
从技术架构上看,Vanna采用了经典的RAG模式:
graph TD
A[用户问题] --> B[向量化检索]
B --> C[相关上下文]
C --> D[LLM生成SQL]
D --> E[执行验证]
E --> F[结果反馈]
F --> G[自动训练]
G --> B
这个架构的精妙之处在于,它不是简单地把问题扔给LLM,而是先从知识库中检索出最相关的上下文信息,然后再让LLM基于这些信息生成SQL。这就像给一个外国朋友不仅提供了字典,还提供了当地的文化背景和使用习惯。
让我们深入Vanna的技术内核。通过分析其源码结构,我们可以发现Vanna采用了高度模块化的设计:
# Vanna的核心抽象基类
class VannaBase(ABC):
def __init__(self, config=None):
self.config = config
self.run_sql_is_set = False
self.static_documentation = ""
self.dialect = self.config.get("dialect", "SQL")
self.language = self.config.get("language", None)
self.max_tokens = self.config.get("max_tokens", 14000)
@abstractmethod
def generate_embedding(self, data: str, **kwargs) -> List[float]:
"""生成文本嵌入向量"""
pass
@abstractmethod
def get_similar_question_sql(self, question: str, **kwargs) -> list:
"""检索相似的问题-SQL对"""
pass
@abstractmethod
def submit_prompt(self, prompt, **kwargs) -> str:
"""提交提示词到LLM"""
pass
这种设计的巧妙之处在于,它将复杂的SQL生成过程分解为三个可插拔的组件:
Vanna的另一个令人印象深刻的特点是其广泛的生态系统支持。从项目结构可以看出,它支持:
LLM提供商(9+):
向量数据库(10+):
关系数据库(10+):
这种"大一统"的设计哲学让Vanna能够适应几乎任何技术栈,这在企业级应用中尤为重要。
在深入Vanna的解决方案之前,我们先来理解传统方法的局限性。Vanna团队进行了一项令人印象深刻的实验,使用Cybersyn SEC数据集测试了不同方法的SQL生成准确率:
实验设置:
结果令人震惊:
这个结果揭示了一个重要的洞察:上下文比模型更重要。即使是最强大的GPT-4,在没有合适上下文的情况下,准确率也只有可怜的10%。
Vanna的成功秘诀在于其精心设计的三层上下文策略:
def add_ddl(self, ddl: str, **kwargs) -> str:
"""添加数据定义语言到训练数据"""
id = deterministic_uuid(ddl) + "-ddl"
self.ddl_collection.add(
documents=ddl,
embeddings=self.generate_embedding(ddl),
ids=id,
)
return id
这一层提供了数据库的结构信息,包括表名、字段名、数据类型等。但仅有这些还不够,因为它无法告诉AI如何正确地使用这些表。
def add_documentation(self, documentation: str, **kwargs) -> str:
"""添加业务文档到训练数据"""
id = deterministic_uuid(documentation) + "-doc"
self.documentation_collection.add(
documents=documentation,
embeddings=self.generate_embedding(documentation),
ids=id,
)
return id
这一层包含了业务规则、字段含义、计算逻辑等信息。比如"revenue"字段的具体定义,或者某个表中数据的业务含义。
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
"""添加问题-SQL对到训练数据"""
question_sql_json = json.dumps({
"question": question,
"sql": sql,
}, ensure_ascii=False)
id = deterministic_uuid(question_sql_json) + "-sql"
self.sql_collection.add(
documents=question_sql_json,
embeddings=self.generate_embedding(question_sql_json),
ids=id,
)
return id
这是最关键的一层,它提供了具体的问题-SQL对应关系,让AI能够学习到如何将自然语言问题转换为正确的SQL查询。
Vanna的核心创新在于其智能检索机制。当用户提出问题时,系统会:
def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> str:
# 检索相似的问题-SQL对
question_sql_list = self.get_similar_question_sql(question, **kwargs)
# 检索相关的DDL信息
ddl_list = self.get_related_ddl(question, **kwargs)
# 检索相关的文档
doc_list = self.get_related_documentation(question, **kwargs)
# 构建提示词
prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list,
**kwargs,
)
# 提交给LLM生成SQL
llm_response = self.submit_prompt(prompt, **kwargs)
return self.extract_sql(llm_response)
这种方法的精妙之处在于,它不是简单地把所有信息都塞给LLM,而是智能地选择最相关的信息。这样既保证了上下文的质量,又避免了超出LLM的上下文窗口限制。
让我们通过一个具体的例子来看看Vanna是如何工作的:
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
# 创建自定义的Vanna类
class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)
# 初始化
vn = MyVanna(config={
'api_key': 'your-openai-key',
'model': 'gpt-4'
})
# 连接数据库
vn.connect_to_postgres(
host="localhost",
dbname="ecommerce",
user="admin",
password="password"
)
训练Vanna就像教一个新员工熟悉公司的数据库:
# 1. 添加表结构信息
vn.train(ddl="""
CREATE TABLE customers (
id SERIAL PRIMARY KEY,
name VARCHAR(100),
email VARCHAR(100),
country VARCHAR(50),
created_at TIMESTAMP
);
""")
# 2. 添加业务文档
vn.train(documentation="""
customers表存储了所有注册用户的基本信息。
country字段使用ISO 3166-1标准的国家代码。
created_at表示用户注册时间。
""")
# 3. 添加示例查询
vn.train(
question="德国有多少客户?",
sql="SELECT COUNT(*) FROM customers WHERE country = 'DE';"
)
vn.train(
question="最近一个月新注册的用户数量?",
sql="""
SELECT COUNT(*)
FROM customers
WHERE created_at >= CURRENT_DATE - INTERVAL '1 month';
"""
)
训练完成后,你就可以开始享受AI助手的服务了:
# 提问
sql, df, fig = vn.ask("显示每个国家的客户数量,按数量降序排列")
# Vanna会自动:
# 1. 理解问题意图
# 2. 检索相关上下文
# 3. 生成SQL查询
# 4. 执行查询
# 5. 返回结果和可视化图表
生成的SQL可能是这样的:
SELECT
country,
COUNT(*) as customer_count
FROM customers
GROUP BY country
ORDER BY customer_count DESC;
def ask(self, question: str, auto_train: bool = True, **kwargs):
# ... 生成和执行SQL ...
# 如果查询成功且auto_train=True,自动添加到训练数据
if len(df) > 0 and auto_train:
self.add_question_sql(question=question, sql=sql)
这个特性让Vanna能够从每次成功的查询中学习,不断改进自己的性能。
# 当需要探索数据时,Vanna可以生成中间查询
if 'intermediate_sql' in llm_response:
intermediate_sql = self.extract_sql(llm_response)
df = self.run_sql(intermediate_sql)
# 基于中间结果生成最终SQL
prompt = self.get_sql_prompt(
# ... 包含中间结果的上下文 ...
doc_list=doc_list + [f"中间查询结果: \n{df.to_markdown()}"],
)
这个特性让AI能够像人类分析师一样,先探索数据再生成最终查询。
def _response_language(self) -> str:
if self.language is None:
return ""
return f"Respond in the {self.language} language."
Vanna支持多种语言的问答,这对国际化企业尤为重要。
通过Vanna团队的详细实验,我们可以清晰地看到不同策略对准确率的影响:
实验数据深度分析:
Schema-only方法的失败原因:
静态示例的局限性:
上下文相关方法的优势:
有趣的是,实验结果显示了不同LLM在不同上下文策略下的表现差异:
GPT-4.1:
Google Bison:
GPT-4.1-mini:
基于实验结果,我们可以总结出几个关键的性能优化策略:
def get_sql_prompt(self, question: str, **kwargs):
# 动态调整检索数量
n_results = min(10, max(3, len(question.split()) // 2))
question_sql_list = self.get_similar_question_sql(
question, n_results=n_results
)
# ...
# 高质量的核心示例
core_examples = [
{"question": "...", "sql": "...", "priority": "high"},
# ...
]
# 自动生成的示例
auto_examples = [
{"question": "...", "sql": "...", "priority": "medium"},
# ...
]
def continuous_learning(self):
# 定期分析查询日志
successful_queries = self.get_successful_queries()
# 自动提取新的训练样本
for query in successful_queries:
if self.is_novel_pattern(query):
self.add_question_sql(query.question, query.sql)
在企业环境中部署Vanna需要考虑更多的因素:
# 企业级配置示例
class EnterpriseVanna:
def __init__(self):
self.config = {
# 多模型支持
'primary_llm': 'gpt-4',
'fallback_llm': 'gpt-3.5-turbo',
# 向量数据库集群
'vector_store': {
'type': 'qdrant',
'cluster_urls': ['http://qdrant-1:6333', 'http://qdrant-2:6333'],
'collection_name': 'enterprise_sql_kb'
},
# 安全配置
'security': {
'enable_query_validation': True,
'allowed_operations': ['SELECT'],
'max_result_rows': 10000,
'query_timeout': 30
},
# 监控配置
'monitoring': {
'enable_logging': True,
'log_level': 'INFO',
'metrics_endpoint': '/metrics'
}
}
企业级部署必须考虑安全性:
class SecureVanna(VannaBase):
def validate_sql(self, sql: str) -> bool:
"""SQL安全验证"""
# 检查危险操作
dangerous_keywords = ['DROP', 'DELETE', 'UPDATE', 'INSERT', 'TRUNCATE']
sql_upper = sql.upper()
for keyword in dangerous_keywords:
if keyword in sql_upper:
raise SecurityError(f"Dangerous operation detected: {keyword}")
# 检查表访问权限
tables = self.extract_table_names(sql)
for table in tables:
ifnot self.user_has_access(table):
raise PermissionError(f"Access denied to table: {table}")
returnTrue
def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
# 验证SQL安全性
self.validate_sql(sql)
# 添加行数限制
if'LIMIT'notin sql.upper():
sql += f" LIMIT {self.config['max_result_rows']}"
return super().run_sql(sql, **kwargs)
import logging
from prometheus_client import Counter, Histogram, Gauge
class MonitoredVanna(VannaBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Prometheus指标
self.query_counter = Counter('vanna_queries_total', 'Total queries')
self.query_duration = Histogram('vanna_query_duration_seconds', 'Query duration')
self.accuracy_gauge = Gauge('vanna_accuracy_rate', 'Current accuracy rate')
# 日志配置
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
def ask(self, question: str, **kwargs):
start_time = time.time()
self.query_counter.inc()
try:
result = super().ask(question, **kwargs)
# 记录成功查询
self.logger.info(f"Successful query: {question}")
duration = time.time() - start_time
self.query_duration.observe(duration)
return result
except Exception as e:
# 记录失败查询
self.logger.error(f"Failed query: {question}, Error: {str(e)}")
raise
class MultiTenantVanna(VannaBase):
def __init__(self, tenant_id: str, **kwargs):
self.tenant_id = tenant_id
# 租户隔离的配置
config = kwargs.get('config', {})
config['collection_name'] = f"vanna_{tenant_id}"
super().__init__(config=config)
def add_question_sql(self, question: str, sql: str, **kwargs):
# 添加租户标识
metadata = kwargs.get('metadata', {})
metadata['tenant_id'] = self.tenant_id
return super().add_question_sql(
question, sql, metadata=metadata, **kwargs
)
背景:某大型电商公司有复杂的数据仓库,包含用户、订单、商品、物流等多个业务域的数据。业务分析师经常需要进行复杂的数据查询。
挑战:
Vanna解决方案:
# 训练数据示例
training_examples = [
{
"question": "最近30天每日GMV趋势",
"sql": """
SELECT
DATE(order_time) as date,
SUM(total_amount) as gmv
FROM orders
WHERE order_time >= CURRENT_DATE - INTERVAL '30 days'
AND order_status = 'completed'
GROUP BY DATE(order_time)
ORDER BY date;
"""
},
{
"question": "各品类的复购率",
"sql": """
WITH user_category_orders AS (
SELECT
u.user_id,
p.category_id,
COUNT(DISTINCT o.order_id) as order_count
FROM users u
JOIN orders o ON u.user_id = o.user_id
JOIN order_items oi ON o.order_id = oi.order_id
JOIN products p ON oi.product_id = p.product_id
WHERE o.order_status = 'completed'
GROUP BY u.user_id, p.category_id
)
SELECT
c.category_name,
COUNT(CASE WHEN uco.order_count > 1 THEN 1 END) * 100.0 / COUNT(*) as repurchase_rate
FROM user_category_orders uco
JOIN categories c ON uco.category_id = c.category_id
GROUP BY c.category_name;
"""
}
]
效果:
背景:金融公司需要实时监控各种风险指标,业务人员需要快速获取风控数据。
特殊要求:
Vanna定制方案:
class FinanceVanna(VannaBase):
def __init__(self, user_role: str, **kwargs):
super().__init__(**kwargs)
self.user_role = user_role
self.audit_logger = AuditLogger()
def ask(self, question: str, **kwargs):
# 审计日志
self.audit_logger.log_query_request(
user_role=self.user_role,
question=question,
timestamp=datetime.now()
)
# 基于角色的查询限制
if self.user_role == 'analyst':
# 分析师只能查询汇总数据
kwargs['aggregation_only'] = True
elif self.user_role == 'manager':
# 经理可以查询详细数据但有行数限制
kwargs['max_rows'] = 1000
result = super().ask(question, **kwargs)
# 记录查询结果
self.audit_logger.log_query_result(
user_role=self.user_role,
question=question,
sql=result[0] if result elseNone,
row_count=len(result[1]) if result and result[1] isnotNoneelse0
)
return result
背景:制造企业有大量IoT设备数据,需要进行设备状态监控和预测性维护分析。
技术挑战:
解决方案:
# 专门的时序数据训练
time_series_examples = [
{
"question": "设备A最近24小时的温度异常点",
"sql": """
WITH temp_stats AS (
SELECT
AVG(temperature) as avg_temp,
STDDEV(temperature) as std_temp
FROM sensor_data
WHERE device_id = 'A'
AND timestamp >= NOW() - INTERVAL '24 hours'
)
SELECT
timestamp,
temperature,
ABS(temperature - ts.avg_temp) / ts.std_temp as z_score
FROM sensor_data sd
CROSS JOIN temp_stats ts
WHERE sd.device_id = 'A'
AND sd.timestamp >= NOW() - INTERVAL '24 hours'
AND ABS(sd.temperature - ts.avg_temp) / ts.std_temp > 2
ORDER BY timestamp;
"""
}
]
未来的Vanna可能支持:
# 未来可能的功能
class SmartVanna(VannaBase):
def auto_discover_schema(self):
"""自动发现和理解数据库结构"""
pass
def suggest_data_quality_checks(self):
"""基于查询模式建议数据质量检查"""
pass
def auto_generate_documentation(self):
"""自动生成数据字典和业务文档"""
pass
# 跨数据源查询
vn.ask("比较我们在MySQL中的销售数据和Snowflake中的财务数据")
# 自动生成跨数据源的查询计划
Vanna这样的工具正在推动"数据民主化"的实现:
# 安装Vanna
pip install vanna
# 安装可选依赖
pip install vanna[openai,chromadb,postgres]
import os
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
class EcommerceVanna(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)
# 初始化
vn = EcommerceVanna(config={
'api_key': os.getenv('OPENAI_API_KEY'),
'model': 'gpt-4',
'path': './vanna_db'# ChromaDB存储路径
})
# 连接数据库
vn.connect_to_postgres(
host="localhost",
dbname="ecommerce",
user="postgres",
password="password"
)
# 训练数据
def train_ecommerce_model():
# 添加表结构
ddl_statements = [
"""
CREATE TABLE users (
user_id SERIAL PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
email VARCHAR(100) UNIQUE NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
country VARCHAR(2)
);
""",
"""
CREATE TABLE products (
product_id SERIAL PRIMARY KEY,
product_name VARCHAR(200) NOT NULL,
category_id INTEGER,
price DECIMAL(10,2),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""",
"""
CREATE TABLE orders (
order_id SERIAL PRIMARY KEY,
user_id INTEGER REFERENCES users(user_id),
order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
total_amount DECIMAL(10,2),
status VARCHAR(20) DEFAULT 'pending'
);
"""
]
for ddl in ddl_statements:
vn.train(ddl=ddl)
# 添加业务文档
documentation = [
"用户表(users)存储所有注册用户信息,country字段使用ISO 3166-1 alpha-2标准",
"订单表(orders)记录所有订单,status字段可能值:pending, completed, cancelled",
"产品表(products)存储商品信息,price字段为美元价格"
]
for doc in documentation:
vn.train(documentation=doc)
# 添加示例查询
examples = [
{
"question": "今天有多少新用户注册?",
"sql": "SELECT COUNT(*) FROM users WHERE DATE(created_at) = CURRENT_DATE;"
},
{
"question": "最近7天的日均订单金额是多少?",
"sql": """
SELECT AVG(daily_total) as avg_daily_amount
FROM (
SELECT DATE(order_date) as date, SUM(total_amount) as daily_total
FROM orders
WHERE order_date >= CURRENT_DATE - INTERVAL '7 days'
AND status = 'completed'
GROUP BY DATE(order_date)
) daily_totals;
"""
},
{
"question": "哪个国家的用户最多?",
"sql": """
SELECT country, COUNT(*) as user_count
FROM users
WHERE country IS NOT NULL
GROUP BY country
ORDER BY user_count DESC
LIMIT 1;
"""
}
]
for example in examples:
vn.train(question=example["question"], sql=example["sql"])
# 执行训练
train_ecommerce_model()
# 开始使用
if __name__ == "__main__":
whileTrue:
question = input("请输入你的问题(输入'quit'退出): ")
if question.lower() == 'quit':
break
try:
sql, df, fig = vn.ask(question)
print(f"\n生成的SQL:\n{sql}")
print(f"\n查询结果:\n{df}")
if fig:
fig.show() # 显示图表
except Exception as e:
print(f"查询失败: {str(e)}")
class AdvancedEcommerceVanna(EcommerceVanna):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.business_metrics = {
'GMV': 'Gross Merchandise Value - 总商品交易额',
'AOV': 'Average Order Value - 平均订单价值',
'LTV': 'Lifetime Value - 客户生命周期价值',
'CAC': 'Customer Acquisition Cost - 客户获取成本'
}
def preprocess_question(self, question: str) -> str:
"""预处理问题,替换业务术语"""
for abbr, full_name in self.business_metrics.items():
if abbr.lower() in question.lower():
question = question.replace(abbr, full_name)
return question
def ask(self, question: str, **kwargs):
# 预处理问题
processed_question = self.preprocess_question(question)
# 添加业务上下文
if any(metric in processed_question.lower() for metric in self.business_metrics.values()):
kwargs['include_business_context'] = True
return super().ask(processed_question, **kwargs)
def generate_business_report(self, period: str = "last_30_days"):
"""生成业务报告"""
questions = [
f"What was the GMV for the {period}?",
f"What was the AOV for the {period}?",
f"How many new customers did we acquire in the {period}?",
f"What was the top-selling product category in the {period}?"
]
report = {}
for question in questions:
try:
sql, df, _ = self.ask(question, print_results=False)
report[question] = {
'sql': sql,
'result': df.to_dict('records') if df isnotNoneelseNone
}
except Exception as e:
report[question] = {'error': str(e)}
return report
通过深入分析Vanna项目,我们看到了RAG技术在SQL生成领域的巨大潜力。从3%到80%的准确率提升不仅仅是一个数字的变化,更代表着一种全新的数据分析范式的诞生。
Vanna只是AI+SQL领域的一个开始。随着技术的不断发展,我们可以期待:
对于技术从业者:
对于企业决策者:
读到这里,相信你对Vanna和RAG技术有了深入的了解。但学习的旅程永远不会结束,我特别想听听你的想法:
我为大家准备了一个小挑战:
挑战题目:基于本文介绍的Vanna架构,设计一个针对你所在行业的AI SQL助手。请考虑:
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费POC验证,效果达标后再合作。零风险落地应用大模型,已交付160+中大型企业
2025-05-30
2025-06-05
2025-06-06
2025-06-05
2025-06-05
2025-06-20
2025-06-24
2025-07-15
2025-06-20
2025-06-24
2025-08-25
2025-08-20
2025-08-11
2025-08-05
2025-07-28
2025-07-09
2025-07-04
2025-07-01