Tensorflow分类器项目自定义数据读入的方法介绍(代码示例)

本篇文章给大家带来的内容是关于Tensorflow分类器项目自定义数据读入的方法介绍(代码示例),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。

Tensorflow分类器项目自定义数据读入

在照着tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

import osimport kerasimport matplotlib.pyplot as pltfrom PIL import Imagefrom keras.preprocessing.image import ImageDataGeneratorfrom sklearn.model_selection import train_test_split

登录后复制

图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

IMG_SIZE_X = 30IMG_SIZE_Y = 30

登录后复制

其次确定你图片的方式目录:

image_path = r'D:ProjectsImageClassifierdataset'path = ".data"# 你也可以使用相对路径的方式# image_path =os.path.join(path, "set")

登录后复制

目录下的结构如下:

210901301-5c57a5e09df2b_articlex.png

相应的label.txt如下:

动漫风景美女物语樱花

登录后复制

接下来是接在labels.txt,如下:

label_name = "labels.txt"label_path = os.path.join(path, label_name)class_names = np.loadtxt(label_path, type(""))

登录后复制

这里简便起见,直接利用了numpy的loadtxt函数直接加载。

之后便是正式处理图片数据了,注释就写在里面了:

re_load = Falsere_build = False# re_load = Truere_build = Truedata_name = "data.npz"data_path = os.path.join(path, data_name)model_name = "model.h5"model_path = os.path.join(path, model_name)count = 0# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。if not os.path.exists(data_path) or re_load:    labels = []    images = []    print('Handle images')    # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取    for index, name in enumerate(class_names):        # 这里是拼接后的子目录path        classpath = os.path.join(image_path, name)        # 先判断一下是否是目录        if not os.path.isdir(classpath):            continue        # limit是测试时候用的这里可以去除        limit = 0        for image_name in os.listdir(classpath):            if limit >= max_size:                break            # 这里是拼接后的待处理的图片path            imagepath = os.path.join(classpath, image_name)            count = count + 1            limit = limit + 1            # 利用Image打开图片            img = Image.open(imagepath)            # 缩放到你最初确定要处理的图片分辨率大小            img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))            # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量            img = img.convert("L")            # 转为numpy数组            img = np.array(img)            # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放            img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))            # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引            labels.append([index])            # 添加到images中,最后统一处理            images.append(img)            # 循环中一些状态的输出,可以去除            print("{} class: {} {} limit: {} {}"                  .format(count, index + 1, class_names[index], limit, imagepath))    # 最后一次性将images和labels都转换成numpy数组    npy_data = np.array(images)    npy_labels = np.array(labels)    # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储    np.savez(data_path, x=npy_data, y=npy_labels)    print("Save images by npz")else:    # 如果存在序列化号的数据,便直接读取,提高速度    npy_data = np.load(data_path)["x"]    npy_labels = np.load(data_path)["y"]    print("Load images by npz")image_data = npy_datalabels_data = npy_labels

登录后复制

到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

# 最后一步就是将原始数据分成训练数据和测试数据train_images, test_images, train_labels, test_labels =     train_test_split(image_data, labels_data, test_size=0.2, random_state=6)

登录后复制

这里将相关信息打印的方法也附上:

print("_________________________________________________________________")print("%-28s %-s" % ("Name", "Shape"))print("=================================================================")print("%-28s %-s" % ("Image Data", image_data.shape))print("%-28s %-s" % ("Labels Data", labels_data.shape))print("=================================================================")print('Split train and test data,p=%')print("_________________________________________________________________")print("%-28s %-s" % ("Name", "Shape"))print("=================================================================")print("%-28s %-s" % ("Train Images", train_images.shape))print("%-28s %-s" % ("Test Images", test_images.shape))print("%-28s %-s" % ("Train Labels", train_labels.shape))print("%-28s %-s" % ("Test Labels", test_labels.shape))print("=================================================================")

登录后复制

之后别忘了归一化哟:

print("Normalize images")train_images = train_images / 255.0test_images = test_images / 255.0

登录后复制

最后附上读取自定义数据的完整代码:

import osimport kerasimport matplotlib.pyplot as pltfrom PIL import Imagefrom keras.layers import *from keras.models import *from keras.optimizers import Adamfrom keras.preprocessing.image import ImageDataGeneratorfrom sklearn.model_selection import train_test_splitos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'# 支持中文plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号re_load = Falsere_build = False# re_load = Truere_build = Trueepochs = 50batch_size = 5count = 0max_size = 2000000000IMG_SIZE_X = 30IMG_SIZE_Y = 30np.random.seed(9277)image_path = r'D:ProjectsImageClassifierdataset'path = ".data"data_name = "data.npz"data_path = os.path.join(path, data_name)model_name = "model.h5"model_path = os.path.join(path, model_name)label_name = "labels.txt"label_path = os.path.join(path, label_name)class_names = np.loadtxt(label_path, type(""))print('Load class names')if not os.path.exists(data_path) or re_load:    labels = []    images = []    print('Handle images')    for index, name in enumerate(class_names):        classpath = os.path.join(image_path, name)        if not os.path.isdir(classpath):            continue        limit = 0        for image_name in os.listdir(classpath):            if limit >= max_size:                break            imagepath = os.path.join(classpath, image_name)            count = count + 1            limit = limit + 1            img = Image.open(imagepath)            img = img.resize((30, 30))            img = img.convert("L")            img = np.array(img)            img = np.reshape(img, (1, 30, 30))            # img = skimage.io.imread(imagepath, as_grey=True)            # if img.shape[2] != 3:            #     print("{} shape is {}".format(image_name, img.shape))            #     continue            # data = transform.resize(img, (IMG_SIZE_X, IMG_SIZE_Y))            labels.append([index])            images.append(img)            print("{} class: {} {} limit: {} {}"                  .format(count, index + 1, class_names[index], limit, imagepath))    npy_data = np.array(images)    npy_labels = np.array(labels)    np.savez(data_path, x=npy_data, y=npy_labels)    print("Save images by npz")else:    npy_data = np.load(data_path)["x"]    npy_labels = np.load(data_path)["y"]    print("Load images by npz")image_data = npy_datalabels_data = npy_labelsprint("_________________________________________________________________")print("%-28s %-s" % ("Name", "Shape"))print("=================================================================")print("%-28s %-s" % ("Image Data", image_data.shape))print("%-28s %-s" % ("Labels Data", labels_data.shape))print("=================================================================")train_images, test_images, train_labels, test_labels =     train_test_split(image_data, labels_data, test_size=0.2, random_state=6)print('Split train and test data,p=%')print("_________________________________________________________________")print("%-28s %-s" % ("Name", "Shape"))print("=================================================================")print("%-28s %-s" % ("Train Images", train_images.shape))print("%-28s %-s" % ("Test Images", test_images.shape))print("%-28s %-s" % ("Train Labels", train_labels.shape))print("%-28s %-s" % ("Test Labels", test_labels.shape))print("=================================================================")# 归一化# 我们将这些值缩小到 0 到 1 之间,然后将其馈送到神经网络模型。为此,将图像组件的数据类型从整数转换为浮点数,然后除以 255。以下是预处理图像的函数:# 务必要以相同的方式对训练集和测试集进行预处理:print("Normalize images")train_images = train_images / 255.0test_images = test_images / 255.0

登录后复制

以上就是Tensorflow分类器项目自定义数据读入的方法介绍(代码示例)的详细内容,更多请关注【创想鸟】其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至253000106@qq.com举报,一经查实,本站将立刻删除。

发布者:PHP中文网,转转请注明出处:https://www.chuangxiangniao.com/p/2534036.html

(0)
上一篇 2025年3月5日 21:25:07
下一篇 2025年2月17日 23:36:12

AD推荐 黄金广告位招租... 更多推荐

相关推荐

  • 怎么找到黑客的联系方式?

    如果你想要找到黑客的联系方式,那么你可能面临以下难题:黑客往往会隐藏他们的身份,并且他们的联系方式很难被发现。php小编草莓在这里为你提供了一份指南,旨在帮助你找到黑客的联系方式。在本指南中,我们将介绍一些常见的黑客使用的联系方式,并提供一…

    2025年3月5日
    200
  • Python正则表达式和re库的相关内容介绍(代码示例)

    本篇文章给大家带来的内容是关于python正则表达式和re库的相关内容介绍(代码示例),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。 正则表达式是定义搜索模式的字符序列。通常这种模式被字符串搜索算法用于字符串上的“查找”或…

    编程技术 2025年3月5日
    200
  • Python中浮点型的基本内容介绍(代码示例)

    本篇文章给大家带来的内容是关于python中浮点型的基本内容介绍(代码示例),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。 1.浮点数的介绍 float(浮点型)是Python基本数据类型中的一种,Python的浮点数类似…

    编程技术 2025年3月5日
    200
  • Python中整型的基本介绍(代码示例)

    本篇文章给大家带来的内容是关于python中整型的基本介绍(代码示例),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。 Python中有以下几个基本的数据类型: 整数 int 字符串 str 浮点数 float 立即学习“P…

    编程技术 2025年3月5日
    200
  • Python中的命名空间和范围

    在python中,每个包、模块、类、函数和方法函数都拥有一个“名称空间”,其中解析了变量名称。下面本篇文章就来带大家认识一下python中的命名空间和范围,希望对大家有所帮助。 什么是命名空间: 命名空间是一个系统,用于确保程序中的所有名称…

    2025年3月5日
    200
  • 如何使用Python压缩/解压缩zip文件?(代码示例)

    在批量交换大文件和多个文件时,使用zip文件是非常方便的。下面本篇文章就来带大家认识解一下zip文件,介绍使用python压缩或解压缩zip文件的方法,希望对大家有所帮助。【视频教程推荐:python教程】 什么是zip文件? zip文件是…

    2025年3月5日
    200
  • Python和Go之间的区别是什么?

    python和go都是用于编写web应用程序的强大的高级编程语言,它们之间有什么区别吗?下面本篇文章就来带大家认识一下python和go语言,简单比较一下python和go,让大家了解python和go之间的区别有哪些,希望对大家有所帮助。…

    2025年3月5日
    200
  • python中re模块与正则表达式的介绍(附代码)

    本篇文章给大家带来的内容是关于python中re模块与正则表达式的介绍(附代码),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。 正则表达式(英语:Regular Expression,在代码中常简写为regex、regex…

    编程技术 2025年3月5日
    200
  • Python中Pandas读取修改excel操作攻略(代码示例)

    本篇文章给大家带来的内容是关于python中pandas读取修改excel操作攻略(代码示例),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。 环境:python 3.6.8 以某米赛尔号举个例子吧: 立即学习“Python…

    2025年3月5日 编程技术
    200
  • Python线程中定位与销毁的详细介绍(附示例)

    本篇文章给大家带来的内容是关于python线程中定位与销毁的详细介绍(附示例),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。 开工前我就觉得有什么不太对劲,感觉要背锅。这可不,上班第三天就捅锅了。 我们有个了不起的后台程序…

    2025年3月5日 编程技术
    200

发表回复

登录后才能评论