【全球新视野】在树莓派上使用numpy实现简单的神经网络推理,pytorch在服务器或PC上训练好模型保存成numpy格式的数据,推理在树莓派上加载模型
2023-05-30 17:13:38 来源:博客园
(资料图片仅供参考)
这几天又在玩树莓派,先是搞了个物联网,又在尝试在树莓派上搞一些简单的神经网络,这次搞得是mlp识别mnist手写数字识别
训练代码在电脑上,cpu就能训练,很快的:
1 import torch 2 import torch.nn as nn 3 import torch.optim as optim 4 from torchvision import datasets, transforms 5 6 # 设置随机种子 7 torch.manual_seed(42) 8 9 # 定义MLP模型10 class MLP(nn.Module):11 def __init__(self):12 super(MLP, self).__init__()13 self.fc1 = nn.Linear(784, 256)14 self.fc2 = nn.Linear(256, 128)15 self.fc3 = nn.Linear(128, 10)16 17 def forward(self, x):18 x = x.view(-1, 784)19 x = torch.relu(self.fc1(x))20 x = torch.relu(self.fc2(x))21 x = self.fc3(x)22 return x23 24 # 加载MNIST数据集25 transform = transforms.Compose([26 transforms.ToTensor(),27 # transforms.Normalize((0.1307,), (0.3081,))28 ])29 30 train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)31 test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)32 33 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)34 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)35 36 # 创建模型实例37 model = MLP()38 39 # 定义损失函数和优化器40 criterion = nn.CrossEntropyLoss()41 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)42 43 # 训练模型44 def train(model, train_loader, optimizer, criterion, epochs):45 model.train()46 for epoch in range(1, epochs + 1):47 for batch_idx, (data, target) in enumerate(train_loader):48 optimizer.zero_grad()49 output = model(data)50 loss = criterion(output, target)51 loss.backward()52 optimizer.step()53 54 if batch_idx % 100 == 0:55 print("Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(56 epoch, batch_idx * len(data), len(train_loader.dataset),57 100. * batch_idx / len(train_loader), loss.item()))58 59 # 训练模型60 train(model, train_loader, optimizer, criterion, epochs=5)61 62 # 保存模型为NumPy格式63 numpy_model = {}64 numpy_model["fc1.weight"] = model.fc1.weight.detach().numpy()65 numpy_model["fc1.bias"] = model.fc1.bias.detach().numpy()66 numpy_model["fc2.weight"] = model.fc2.weight.detach().numpy()67 numpy_model["fc2.bias"] = model.fc2.bias.detach().numpy()68 numpy_model["fc3.weight"] = model.fc3.weight.detach().numpy()69 numpy_model["fc3.bias"] = model.fc3.bias.detach().numpy()70 71 # 保存为NumPy格式的数据72 import numpy as np73 np.savez("mnist_model.npz", **numpy_model)然后需要自己倒出一些图片在dataset里:我保存在了mnist_pi文件夹下,“_”后面的是标签,主要是在pc端导出保存到树莓派下
树莓派推理端的代码,需要numpy手动重新搭建网络,然后加载那些保存的矩阵参数,做矩阵乘法和加法
1 import numpy as np 2 import os 3 from PIL import Image 4 5 # 加载模型 6 model_data = np.load("mnist_model.npz") 7 weights1 = model_data["fc1.weight"] 8 biases1 = model_data["fc1.bias"] 9 weights2 = model_data["fc2.weight"]10 biases2 = model_data["fc2.bias"]11 weights3 = model_data["fc3.weight"]12 biases3 = model_data["fc3.bias"]13 14 # 进行推理15 def predict(image, weights1, biases1,weights2, biases2,weights3, biases3):16 image = image.flatten()/255 # 将输入图像展平并进行归一化17 output = np.dot(weights1, image) + biases118 output = np.dot(weights2, output) + biases219 output = np.dot(weights3, output) + biases320 predicted_class = np.argmax(output)21 return predicted_class22 23 24 25 26 folder_path = "./mnist_pi" # 替换为图片所在的文件夹路径27 def infer_images_in_folder(folder_path):28 for file_name in os.listdir(folder_path):29 file_path = os.path.join(folder_path, file_name)30 if os.path.isfile(file_path) and file_name.endswith((".jpg", ".jpeg", ".png")):31 image = Image.open(file_path)32 label = file_name.split(".")[0].split("_")[1]33 image = np.array(image)34 print("file_path:",file_path,"img size:",image.shape,"label:",label)35 predicted_class = predict(image, weights1, biases1,weights2, biases2,weights3, biases3)36 print("Predicted class:", predicted_class)37 38 infer_images_in_folder(folder_path)结果:
效果还不错:
这次内容就到这里了,下次争取做一个卷积的神经网络在树莓派上推理,然后争取做一个目标检测的模型在树莓派上
关键词:
相关新闻
- 【全球新视野】在树莓派上使用numpy实现简单的神经网络推理,pytorch在服务器或PC上训练好模型保存成numpy格式的数据,推理在树莓派上加载模型
- 百万医疗险榜单前十保险公司有哪些?前十保险公司哪个好?
- 多个OTA平台收到航司调整国内航班燃油附加费征收标准通知-每日速递
- 深“V”:A股三大指数午后跌超1%后反抽收涨,中特估走强
- 今日pd990钯金回收价格查询(2023年05月30日) 每日视点
- 世俱杯决赛历史比分查询(世俱杯决赛历史比分)
- 2023年张江镇科技节开幕,送上更开放、更多元、更前沿科学知识
- 世界即时看!星辰变剧情简述
- 【世界新要闻】2023广州海尔以旧换新活动官网入口
- 长城为何捅破这层“窗户纸”?
- 姨妈期间能吃香蕉吗月经期间能吃香蕉吗_姨妈期间能吃香蕉吗
- 天天时讯:2023黑龙江七台河市文化广电和旅游局所属事业单位招聘工作人员拟进入体检和考察环节人员公示
- 图片报:多特和E-阿尔瓦雷斯达成加盟协议,转会费约3500万欧
- 高通与中国合作伙伴共赴智能网联汽车新未来
- 金科股份:财聚投资增持410万股股份,一笔公司债本息逾期_世界热推荐
- 今日看点:罗聪鹏:点亮乡村孩子的音乐梦想
- 生态文明与产城一体化的理论与实践_关于生态文明与产城一体化的理论与实践介绍-世界即时
- 基金观察|A股短期承压,中期关注“顺周期”整体机会
- 焦点报道:摊铺机型号大全一览表(摊铺机型号大全)
- 前沿资讯!云南南华县:市场主体增量提质
- 意媒:加斯佩里尼可能今夏告别亚特兰大帅位|快播报
- 手部湿疹汗疱疹怎么治疗_湿疹与汗疱疹有何区别|世界报道
- 大溪地黑珍珠真假鉴别方法_大溪地黑珍珠真假|世界动态
- 焦点讯息:深圳社会保障局电话人工服务_深圳社会保障局
