微信扫码
添加专属顾问
我要投稿
本文重点:
重点代码实现,从 0 构建 Google Snake AI agent,中间 python 库不是官方自带库,需要各位 pip install,或者用 ide 集成。
涉及 AI 知识科普。
模仿学习(Imitation Learning),也称为学习从演(Learning from Demonstration,LfD)或行为克隆(Behavioral Cloning,BC),是一种机器学习方法,它允许机器通过观察和模仿专家的行为来学习任务。模仿学习的一个关键优势是它不需要显式的奖励函数,这在许多复杂任务中是难以定义的。然而,它也有一些局限性,比如可能会学习到专家的非最优行为,或者在未见过的情况下表现不佳。
数据收集有很多种方法,这里使用的方法依赖于 selenium,这是一种自动化浏览器导航和控制的 Python 工具,selenium 能将屏幕截图保存为具有适当标签的图像的代码, 主要是这个方便。
import base64
import io
import cv2
from PIL import Image
import numpy as np
import keyboard
import os
from datetime import datetime
from selenium import webdriver
from selenium.webdriver.common.by import By
# 初始化环境
isExist = os.path.exists("captures")
if isExist:
dir = "captures"
for f in os.listdir(dir):
os.remove(os.path.join(dir, f))
else:
os.mkdir("captures")
current_key = "1"
buffer = []
# 收集用户键盘反馈
def keyboardCallBack(key: keyboard.KeyboardEvent):
global current_key
if key.event_type == "down" and key.name not in buffer:
buffer.append(key.name)
if key.event_type == "up":
buffer.remove(key.name)
buffer.sort()
current_key = " ".join(buffer)
keyboard.hook(callback=keyboardCallBack)
# 获取浏览器上下文
driver = webdriver.Firefox()
# 导航到 Google Snake game
driver.get("<https://www.google.com/fbx?fbx=snake_arcade>")
frame_stack = deque(maxlen=4)
while True:
# 获取画布元素
canvas = driver.find_element(By.CSS_SELECTOR, "canvas")
# 获取画布数据,这里就是一张图,一帧
canvas_base64 = driver.execute_script(
"return arguments[0].toDataURL('image/png').substring(21);", canvas)
# Decode the base64 data to get the PNG image
canvas_png = base64.b64decode(canvas_base64)
image = cv2.cvtColor(
np.array(Image.open(io.BytesIO(canvas_png))), cv2.COLOR_BGR2RGB)
# 保存有用户键盘策略的图片和没有策略图片
if len(buffer) != 0:
cv2.imwrite(
"captures/" + str(datetime.now()).replace("-", "_").replace(":", "_").replace(" ", "_") + " "
+ current_key + ".png", image, )
else:
cv2.imwrite(
"captures/" + str(datetime.now()).replace("-", "_").replace(":", "_").replace(" ", "_") + " n"
+ ".png", image, )
# 计算每个标签预测值,用加权平均来计算
frame_stack.append(transformer(image))
input = torch.stack([*frame_stack], dim=1).to(device).squeeze().unsqueeze(0)
if len(frame_stack) == 4:
with torch.inference_mode():
outputs = model(input).to(device)
preds = torch.softmax(outputs, dim=1).argmax(dim=1)
if preds.item() != 0:
keyboard.press_and_release(label_keys[preds.item()])
运行此脚本后,会打开一个窗口来运行,然后即可开始玩游戏。在后台,脚本会不断保存游戏屏幕截图,并使用唯一时间戳和当前按下的键来命名图像。当没有按下任何键时,它会被标记为 n。
准备数据
将这些图像转换为带有文件名和相应操作的 csv 文件。
import pandas as pd
import matplotlib.pyplot as plt
import os
import csv
import os
# 创建目录,保存包含标签和图像文件名的 CSV 文件。
labels = []
dir = 'captures'
file_path = "data/labels_snake.csv"
if not os.path.exists(file_path):
os.mkdir('data')
# 读取文件名,从每张图片的文件名中提取按下的键。按下的键可以是左、右、上、下,或者根本没有键。
for f in os.listdir(dir):
key = f.rsplit('.',1)[0].rsplit(" ",1)[1]
# 根据所按的键,将每幅图像分为四类:0 表示未按任何键,1 表示左,2 表示上,3 表示右,4 表示下。
if key=="n":
labels.append({'file_name': f, 'class': 0})
elif key=="left":
labels.append({'file_name': f, 'class': 1})
elif key=="up":
labels.append({'file_name': f, 'class': 2})
elif key=="right":
labels.append({'file_name': f, 'class': 3})
elif key=="down":
labels.append({'file_name': f, 'class': 4})
# 创建标签文件,包含数据集来训练机器学习模型
field_names= ['file_name', 'class']
with open('data/labels_snake.csv', 'w') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=field_names)
writer.writeheader()
writer.writerows(labels)
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import os
from PIL import Image
import torch
from sklearn.model_selection import train_test_split
import pandas as pd
from torchvision.transforms import transforms, Compose, ToTensor, Resize, Normalize, CenterCrop, Grayscale
from torch import nn
from tqdm import tqdm
from torchinfo import summary
import numpy as np
import math
from torchvision.models.video import r3d_18, R3D_18_Weights, mc3_18, MC3_18_Weights
# 数据集由按时间顺序排列的四张图像堆栈组成。
# 从数据集中提取的每个项目代表四帧序列,其中最后一帧与按键相关联。
# 本质上,此数据集通过最后四帧捕获运动并将其与按键相关联。
# 其中 stack_size 会影响后续权重,这个解释下,是堆叠在一起作为单个数据点的图像数量。
class SnakeDataSet(Dataset):
# 包含有关图像和标签信息的数据集
def __init__(self, dataframe, root_dir, stack_size, transform=None):
self.stack_size = stack_size
self.key_frame = dataframe
self.root_dir = root_dir
self.transform = transform
# 返回数据集的长度,即数据点的总数。
# 长度计算为 的长度key_frame减去 的三倍 stack_size。
# 这表明数据集预计包含图像序列,并且每个数据点由一堆图像组成。
def __len__(self):
return len(self.key_frame) - self.stack_size * 3
# 获取索引idx并返回相应的数据点
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.to_list()
try:
img_names = [os.path.join(self.root_dir, self.key_frame.iloc[idx + i, 0]) for i in range(self.stack_size)]
images = [Image.open(img_name) for img_name in img_names]
# 使用 tensor 提取人工数据标签
label = torch.tensor(self.key_frame.iloc[idx + self.stack_size, 1])
# 图片转化,将变换应用于序列中的每个图像。
if self.transform:
images = [self.transform(image) for image in images]
except:
img_names = [os.path.join(self.root_dir, self.key_frame.iloc[0 + i, 0]) for i in range(self.stack_size)]
images = [Image.open(img_name) for img_name in img_names]
# 如果遇到错误,使用数据第一个标签,处理兼容形式,也就是 left 形式
label = torch.tensor(self.key_frame.iloc[0 + self.stack_size, 1])
# 图片转化,将变换应用于序列中的每个图像。
if self.transform:
images = [self.transform(image) for image in images]
# 将图像沿着新维度堆叠 torch.stack(images, dim=1),然后使用 删除单例维度squeeze()。这会产生一个表示堆叠图像的张量。
# 堆叠图像张量和标签张量作为元组返回。
return torch.stack(images, dim=1).squeeze(), label
STACK_SIZE = 4
BATCH_SIZE = 32
# 区分数据集和结果验证集,这里设置结果验证,占全样本 20%
train, test = train_test_split(pd.read_csv("data/labels_snake.csv"), test_size=0.2, shuffle=False)
classes = ["n", "left", "up", "right", "down"]
# 设置实验数据
# 通过将所有计数的总和除以该类别出现的次数来计算每个类别的权重。
labels_unique, counts = np.unique(train["class"], return_counts=True)
class_weights = [sum(counts)/c for c in counts]
# 下一步是通过为示例分配类权重来获取每个示例的权重。
# 这是通过遍历数据集并根据其类标签为每个示例分配权重来完成的。
example_weights = np.array([class_weights[l] for l in train['class']])
# 根据堆栈大小滚动示例权重,因为与特定图像相关联的标签实际上是该图像索引的标签 + STACK_SIZE。
# 这可确保根据其类标签为每个样本赋予正确的权重。
# 0 建设数据 14w 数据集,所以这里需要减去 14w,要么权重其实比其它低的
example_weights = np.roll(example_weights, -STACK_SIZE)
sampler = WeightedRandomSampler(example_weights, len(train))
# 重新设置一遍测试验证结果集
labels_unique, counts = np.unique(test["class"], return_counts=True)
class_weights = [sum(counts)/c for c in counts]
test_example_weights = np.array([class_weights[l] for l in test['class']])
test_example_weights = np.roll(test_example_weights, -STACK_SIZE)
test_sampler = WeightedRandomSampler(test_example_weights, len(test))
# 初始化模型 loader
dataset = SnakeDataSet(root_dir="captures", dataframe = train, stack_size=STACK_SIZE, transform=transformer)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, drop_last= True)
test_dataset = SnakeDataSet(root_dir="captures", dataframe = test, stack_size=STACK_SIZE, transform=transformer)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, sampler = test_sampler, drop_last=True)
通过+-加权法,计算数据集中每个样本的权重是机器学习任务中的关键步骤。它有助于平衡数据集并确保每个类别对学习过程的贡献相同。该过程包括将数据集拆分为训练集和测试集,获取唯一标签和每个标签的计数,计算每个类别的权重,并根据其类别标签为每个样本分配权重。
from torchvision.transforms import transforms, Compose, Normalize, CenterCrop
from torchvision.models.video import r3d_18, R3D_18_Weights, mc3_18, MC3_18_Weights
# 计算数据集图像的平均值和标准差
def compute_mean_std(dataloader):
# source: <https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/Basics/pytorch_std_mean.py>
# var[X] = E[X**2] - E[X]**2
channels_sqrd_sum, num_batches = 0, 0, 0
for batch_images, labels in tqdm(dataloader): # (B,H,W,C)
batch_images = batch_images.permute(0,3,4,2,1)
channels_sum += torch.mean(batch_images, dim=[0, 1, 2, 3])
channels_sqrd_sum += torch.mean(batch_images ** 2, dim=[0, 1, 2,3])
num_batches += 1
mean = channels_sum / num_batches
std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5
return mean, std
compute_mean_std(dataloader)
# 将图像大小调整为 84x84,转换为张量并对其进行规范化。
transformer = Compose([
antialias=True),
CenterCrop(84),
ToTensor(),
[ -0.7138, -2.9883, 1.5832], std =[0.2253, 0.2192, 0.2149]) =
])
# 使用 PyTorch 提供的 r3d 模型(ResNet架构)
r3d_18(weights = R3D_18_Weights.DEFAULT) =
nn.Linear(in_features=512, out_features=5, bias=True) =
(32,3,4,84,84))
训练模型
# 设置 epochs 次数,epochs 理解是一个批量处理,10w 次
num_epochs = 2
# 设置环境,损失函数交叉熵
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.AdamW(model.parameters(), 10e-5, weight_decay=0.1)
model.to(device)
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
total_loss = 0.0
correct_predictions = 0
total_samples = 0
val_loss = 0.0
val_correct_predictions = 0
val_total_samples = 0
# 开始训练
model.train()
# 显示进度条
pbar = tqdm(dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=True)
# 从 pbar 批处理种提取 inputs 和 labels
for inputs, labels in pbar:
labels = inputs.to(device), labels.to(device)
outputs = model(inputs.to(device))
loss = criterion(outputs, labels)
# 向后传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 更新梯度和参数
total_loss += loss.item()
predicted = torch.max(torch.softmax(outputs,1), 1)
correct_predictions += (predicted == labels).sum().item()
total_samples += labels.size(0)
# 更新损失率和精度
total_loss / total_samples, 'Accuracy': correct_predictions / total_samples}) :
steps = steps + 1
# 结果评估
model.eval()
with torch.inference_mode():
for inputs, labels in test_dataloader:
labels = inputs.to(device), labels.to(device)
outputs = model(inputs.to(device))
loss = criterion(outputs, labels)
# 更新损失率和精度
val_loss += loss.item()
predicted = torch.max(torch.softmax(outputs,1), 1)
val_correct_predictions += (predicted == labels).sum().item()
val_total_samples += labels.size(0)
# 最终结果评估和输出
epoch_loss = val_loss / val_total_samples
epoch_accuracy = val_correct_predictions / val_total_samples
{epoch + 1}/{num_epochs}, Val Loss: {epoch_loss:.4f}, Val Accuracy: {epoch_accuracy:.4f}')
"model_r3d.pth")
结论
https://github.com/akshayballal95/autodrive-snake/tree/blog
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2025-04-30
Mockaroo - 模拟生成测试数据
2025-04-30
MCP实战:将公众号接口做成mcp后,我终于实现了,一句话让AI自己搜索、撰文、配图、排版并发布公众号
2025-04-29
AI时代软件测试的认知革命与架构重塑
2025-04-29
Prompt 练习|教育中的等待现象
2025-04-29
AI 友好架构:AI 编程最佳范式,构建 10x 效率提升的代码库(万字长文)
2025-04-29
Fetch MCP网页内容抓取实操:抓取“刘强东送外卖”新闻案例详细教程!
2025-04-29
技术为何无法帮助我们思考?从笔记软件的局限性谈起
2025-04-29
豆包是懂PDF论文阅读的
2025-03-06
2024-09-04
2025-01-25
2024-09-26
2024-10-30
2024-09-03
2024-12-11
2024-12-25
2024-10-30
2025-02-18