GoogleNet网络训练集和测试集搭建

测试集和训练集都是在之前搭建好的基础上进行修改的,重点记录与之前不同的代码。

还是使用的花分类的数据集进行训练和测试的。

一、训练集

1、搭建网络

设置参数:使用辅助分类器,采用权重初始化

net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)

2、参数输出

之前的模型只有 1 个输出,但由于GoogleNet使用了两个辅助分类器,所以会有 3 个输出。

定义三个输出,分别计算主分类器、辅助分类器1、辅助分类器2的损失函数并相加,最后将损失函数反向传播,使用优化器更新参数模型。 

不单独放代码了,不知道哪里是改动的。图片中红色框中是改动的

整个训练集的代码

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import GoogleNet
import os
import json
import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)


# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:
    json_file.write(json_str)


batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size,
                                              shuffle=False, num_workers=0)

# test_data_iter = iter(validate_loader)
# test_image, test_label = next(test_data_iter)
#
# # 查看图片
# def imshow(img):
#     img = img / 2 + 0.5
#     nping = img.numpy()
#     plt.imshow(np.transpose(nping, (1, 2, 0)))
#     plt.show()
# # print labels
# print(' '.join('%5s' % str(cla_dict[test_label[j].item()]) for j in range(4)))
# # show images
# imshow(utils.make_grid(test_image))

net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()

optimizer = optim.Adam(net.parameters(), lr=0.0003)

best_acc = 0.0
save_path = './GoogleNet.pth'
# best_acc = 0.0
for epoch in range(2):
    # train
    net.train()
    running_loss = 0.0
    t1 = time.perf_counter()
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        logits, aux_logits2, aux_logits1 = net(images.to(device))
        loss0 = loss_function(logits, labels.to(device))
        loss1 = loss_function(aux_logits1, labels.to(device))
        loss2 = loss_function(aux_logits2, labels.to(device))
        loss = loss0 + loss1 * 0.3 + loss2 * 0.3
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        rate = (step+1) / len(train_loader)
        a = "*" * int(rate*50)
        b = "." *int((1-rate)*50)
        print("\rtrain loss: (:3.0f)%[()->:.3f)".format(int(rate * 100), a, b, loss), end="")
    print()
    print(time.perf_counter()-t1)

    net.eval()
    acc = 0.0
    with torch.no_grad():
        for data_test in validate_loader:
            test_images, test_labels = data_test
            outputs = net(test_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == test_labels.to(device)).sum().item()
        accurate_test = acc / val_num
        if accurate_test > best_acc:
            best_acc = accurate_test
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, acc / val_num))
print("Finished Training")

训练完成 

 中间有几次报错,不过在看懂报错后很快改过来了。

二、测试集

载入模型

在创建模型的时候,aux_logits不会构建辅助分类器,但是之前训练的参数会保存。

所以,在载入模型的时候,要设置参数strict=False, 它可以精准匹配当前模型与所需要载入的权重模型的结构。

辅助分类器中的参数全部存放在unexpecte_keys中。

测试集全部代码

 可以自己找图片进行预测看准确率。

import torch
import matplotlib.pyplot as plt
import json
from model import GoogleNet
from PIL import Image
from torchvision import transforms

data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# load image
img = Image.open("8.jpeg")
plt.imshow(img)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)

# read class_indent
try:
    json_file = open('./class_indices,json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = GoogleNet(num_classes=5, aux_logits=False)
model_weight_path = "./GoogleNet.pth"
missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)
model.eval()
with torch.no_grad():
    output = torch.squeeze(model(img))
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

准确率好低,可能是模型训练的还不够吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/555648.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

web--弱口令安全

字典(一种是产品初始化的密码,一种是改变的密码 对爆破密码进行加密 先这个 对账号和密码同时爆破 设置两个要用这个模式 ssh,rdp远程终端 linux的用户名为root,windows为administrator 这就被爆破了 zip,word文件猜解

【Python系列】非异步方法调用异步方法

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

Linux http协议与实现http服务器

目录 一、HTTP与URL 1、HTTP协议 2、URL 3、URL编码 4、报文与报头 报文(Message) 报头(Header) 二、HTTP(超文本传输协议)的内部运作机理 请求部分: 响应部分: 三、实现…

实验二: ping命令的使用

1.实验环境 同实验案例一环境 2.需求描述 熟悉ping命令的用法并熟悉ping命令的各种参数 3.推荐步骤 分别ping一个存在的和不存在的P 地址,观察返回的信息分别测试ping命令的相关参数 4.实验步骤 4.1、分别ping一个存在的和不存在的IP 地址 C:\>ping 192.1…

C语言如何使⽤指针?

一、问题 指针变量在初始化以后就可以使⽤和参与操作了,那么就要⽤到对指针变量最常⽤的两个操作符——> * 和 & 。 二、解答 这⾥⼜要提到始终贯穿着指针的⼀个符号“ * ”,但是这⾥的“ * ”是作为指针运算符使⽤的,叫做取内…

8个十分受小企业欢迎的电子商务网站制作工具,自助建站,一键生成免安装

如果您的小型企业需要一个网站,通常会考虑使用自助建站系统来构建与自己业务有关的电子商务站或小程序商城。 在这个时代,电子商务不仅迅速成为许多网站上的热门功能,而且几乎成为许多企业生存的必要条件,因为越来越多的人将大部分…

【C++】list的介绍及使用说明

目录 00.引言 01.list的介绍 模版类 独立节点存储 list的使用 1.构造函数 2.迭代器的使用 分类 运用 3.容量管理 empty(): size(): 4.元素访问 5.增删查改 00.引言 我们学习数据结构时,学过两个基本的数据结构:顺序表和链表。顺…

华为机考入门python3--(16)牛客16-购物单最大满意度

分类:动态规划,组合,最大值,装箱问题 知识点: 生成递减数 100, 90, 80, ..., 0 range(100, -1, -10) 访问列表的下标key for key, value in enumerate(my_list): 动态规划-捆绑装箱问题 a. 把有捆绑约束的物…

【技术变现之道】如何打造IT行业的超级个体?

前言 在当今的数字化时代,IT行业蓬勃发展,为具备技术专长的个人提供了无限的可能性。想要成为IT行业的超级个体,实现知识与技能的变现吗?以下是一些高效途径,助你一臂之力! 1. 独立接单外包 1&#xff09…

Flask项目在Pycharm中设置局域网访问

打开PyCharm导入本应用。点击Run标签中的Edit Configurations 其中Target type选择Script path,Target填入本项目中app.py的路径,Additional optional填入--host0.0.0.0(不要有空格)。 再重新运行项目,会观察到除了原本的http://127.0.0.1:50…

上线流程及操作

上节回顾 1 搜索功能-前端:搜索框,搜索结果页面-后端:一种类型课程-APIResponse(actual_courseres.data.get(results),free_course[],light_course[])-搜索,如果数据量很大,直接使用mysql,效率非常低--》E…

分类预测 | Matlab实现PSO-LSSVM粒子群算法优化最小二乘支持向量机数据分类预测

分类预测 | Matlab实现PSO-LSSVM粒子群算法优化最小二乘支持向量机数据分类预测 目录 分类预测 | Matlab实现PSO-LSSVM粒子群算法优化最小二乘支持向量机数据分类预测分类效果基本介绍程序设计参考资料 分类效果 基本介绍 1.Matlab实现PSO-LSSVM粒子群算法优化最小二乘支持向量…

FebHost:注册.CA域名的企业有什么限制?

在加拿大,只要满足加拿大互联网注册管理局的“加拿大注册要求”,任何类型的企业都可以注册.CA域名。这些要求的目的是为了确保.CA域名空间作为一个重要的公共资源得到合理的使用和开发,以促进所有加拿大人的社会和经济发展。 以下是一些主要…

0418WeCross搭建 + Caliper测试TPS

1. 基本信息 虚拟机名称:Pure-Ununtu18.04 WeCross位置:/root/wecross-demo 2. 搭建并启动WeCross 参考官方指导文档 https://wecross.readthedocs.io/zh-cn/v1.2.0/docs/tutorial/demo/demo.html 访问WeCross网页管理平台 http://localhost:8250/s/…

嵌入式科普(15)小米su7成本分析和拆解之智驶、座舱分析

目录 一、概述 二、小米su7成本分析 2.1 整车成本构成 2.2 三电系统 2.3 车身与底盘 2.3 智能网联 2.4 内外饰 三、小米su7拆解之智驶、座舱分析 3.1 主要芯片 3.2 智能驾驶&智能座舱 四、NXP S32K324汽车通用微控制器 嵌入式科普(15)小米su7成本分析和拆解之智…

问答营销之官方号问答推广技巧

问答营销作为一种网络推广的重要手段,受到各大品牌企业的关注。实战中,问答营销有新起提问再回答和直接回复老问题两种形式,一般做企业官方号问答营销都是选择后者。这里小马识途营销顾问详细解析下开展老问题回复营销的思路和步骤。 一、分析…

2024最新大厂C++面试真题合集,玩转互联网公司面试!

小米C 1. 进程和线程的区别 进程是操作系统分配资源和调度的独立单位,拥有自己的地址空间和系统资源。线程是进程内部的执行单元,共享属于相同进程的资源,但是执行切换代价更小。进程间相互独立,稳定性较高;线程间共…

C++修炼之路之反向迭代器和非模板参数,模板特化,分离编译

目录 前言 一:反向迭代器 二:非类型模板参数 三:模板的特化 四:模板的分离编译 五:模板的优点与缺点 接下来的日子会顺顺利利,万事胜意,生活明朗-----------林辞忧 前言 在vector&am…

代码随想录第40天|343. 整数拆分

343. 整数拆分 343. 整数拆分 - 力扣(LeetCode) 代码随想录 (programmercarl.com) 动态规划,本题关键在于理解递推公式!| LeetCode:343. 整数拆分_哔哩哔哩_bilibili 给定一个正整数 n ,将其拆分为 k 个 正…

2024-4-18 群讨论:Java Agent,JFR 与 JIT 的一些讨论

以下来自本人拉的一个关于 Java 技术的讨论群。关注公众号:hashcon,私信进群拉你 命令行中带 -XX:StartFlightRecording 启动,同时带 -javaagent,那么谁先启动?jfr能采集到agent启动前后资源消耗情况不? 不…
最新文章