解决tfClassifier训练报错的问题 修正后python 适用于tensorflow2.x python3.x

chanra1n2020年12月09日AI6130
# -*- coding: utf-8 -*-
"""
Created on Sun Dec 29 19:21:08 2019

@原作者: xiuzhang Eastmount CSDN

@修改作者:ChanRa1n

修正问题:TensorFlow版本低,学习速率过高,修正为0.1,准确率达到94%

"""
import os
import glob
import cv2
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

# 定义图片路径
path = 'data/'

#---------------------------------第一步 读取图像-----------------------------------
def read_img(path):
   cate = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
   imgs = []
   labels = []
   fpath = []
   for idx, folder in enumerate(cate):
       # 遍历整个目录判断每个文件是不是符合
       for im in glob.glob(folder + '/*.jpg'):
           #print('reading the images:%s' % (im))
           img = cv2.imread(im)             #调用opencv库读取像素点
           img = cv2.resize(img, (32, 32))  #图像像素大小一致
           imgs.append(img)                 #图像数据
           labels.append(idx)               #图像类标
           fpath.append(path+im)            #图像路径名
           #print(path+im, idx)
           
   return np.asarray(fpath, np.string_), np.asarray(imgs, np.float32), np.asarray(labels, np.int32)

# 读取图像
fpaths, data, label = read_img(path)
print(data.shape)  # (1000, 256, 256, 3)
# 计算有多少类图片
num_classes = len(set(label))
print(num_classes)

# 生成等差数列随机调整图像顺序
num_example = data.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
data = data[arr]
label = label[arr]
fpaths = fpaths[arr]

# 拆分训练集和测试集 80%训练集 20%测试集
ratio = 0.8
s = np.int(num_example * ratio)
x_train = data[:s]
y_train = label[:s]
fpaths_train = fpaths[:s]
x_val = data[s:]
y_val = label[s:]
fpaths_test = fpaths[s:]
print(len(x_train),len(y_train),len(x_val),len(y_val)) #800 800 200 200
print(y_val)


#---------------------------------第二步 建立神经网络-----------------------------------
# 定义Placeholder
xs = tf.placeholder(tf.float32, [None, 32, 32, 3])  #每张图片32*32*3个点
ys = tf.placeholder(tf.int32, [None])               #每个样本有1个输出
# 存放DropOut参数的容器
drop = tf.placeholder(tf.float32)                   #训练时为0.25 测试时为0

# 定义卷积层 conv0
conv0 = tf.layers.conv2d(xs, 20, 5, activation=tf.nn.relu)    #20个卷积核 卷积核大小为5 Relu激活
# 定义max-pooling层 pool0
pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])        #pooling窗口为2x2 步长为2x2
print("Layer0:\n", conv0, pool0)

# 定义卷积层 conv1
conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu) #40个卷积核 卷积核大小为4 Relu激活
# 定义max-pooling层 pool1
pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])        #pooling窗口为2x2 步长为2x2
print("Layer1:\n", conv1, pool1)

# 将3维特征转换为1维向量
flatten = tf.layers.flatten(pool1)

# 全连接层 转换为长度为400的特征向量
fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
print("Layer2:\n", fc)

# 加上DropOut防止过拟合
dropout_fc = tf.layers.dropout(fc, drop)

# 未激活的输出层
logits = tf.layers.dense(dropout_fc, num_classes)
print("Output:\n", logits)

# 定义输出结果
predicted_labels = tf.arg_max(logits, 1)


#---------------------------------第三步 定义损失函数和优化器---------------------------------

# 利用交叉熵定义损失
losses = tf.nn.softmax_cross_entropy_with_logits(
       labels = tf.one_hot(ys, num_classes),       #将input转化为one-hot类型数据输出
       logits = logits)

# 平均损失
mean_loss = tf.reduce_mean(losses)

# 定义优化器 学习效率设置为0.0001
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(losses)


#------------------------------------第四步 模型训练和预测-----------------------------------
# 用于保存和载入模型
saver = tf.train.Saver()
# 训练或预测
#train = True
train = False
# 模型文件路径
model_path = "model/image_model"

with tf.device('GPU:0'):
   with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
       if train:
           print("训练模式")
           # 训练初始化参数
           sess.run(tf.global_variables_initializer())
           # 定义输入和Label以填充容器 训练时dropout为0.25
           train_feed_dict = {
                   xs: x_train,
                   ys: y_train,
                   drop: 0.1
           }
           # 训练学习1000次
           for step in range(1000):
               _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
               if step % 50 == 0:  #每隔50次输出一次结果
                   print("step = {}\t mean loss = {}".format(step, mean_loss_val))
           # 保存模型
           saver.save(sess, model_path)
           print("训练结束,保存模型到{}".format(model_path))
       else:
           print("测试模式")
           # 测试载入参数
           saver.restore(sess, model_path)
           print("从{}载入模型".format(model_path))
           # label和名称的对照关系
           label_name_dict = {
               0: "人类",
               1: "沙滩",
               2: "建筑",
               3: "公交",
               4: "恐龙",
               5: "大象",
               6: "花朵",
               7: "野马",
               8: "雪山",
               9: "美食"
           }
           # 定义输入和Label以填充容器 测试时dropout为0
           test_feed_dict = {
               xs: x_val,
               ys: y_val,
               drop: 0
           }
           
           # 真实label与模型预测label
           predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
           for fpath, real_label, predicted_label in zip(fpaths_test, y_val, predicted_labels_val):
               # 将label id转换为label名
               real_label_name = label_name_dict[real_label]
               predicted_label_name = label_name_dict[predicted_label]
               print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))
           # 评价结果
           print("正确预测个数:", sum(y_val==predicted_labels_val))
           print("准确度为:", 1.0*sum(y_val==predicted_labels_val) / len(y_val))
               


"""
从model/image_model载入模型
b'data/data/6\\605.jpg' 花朵 => 花朵
b'data/data/3\\379.jpg' 公交 => 公交
b'data/data/0\\92.jpg'  人类 => 人类
b'data/data/4\\487.jpg' 恐龙 => 恐龙
b'data/data/4\\448.jpg' 恐龙 => 恐龙
b'data/data/3\\387.jpg' 公交 => 公交
b'data/data/7\\733.jpg' 野马 => 野马
b'data/data/8\\810.jpg' 雪山 => 雪山
b'data/data/5\\542.jpg' 大象 => 大象
b'data/data/0\\71.jpg'  人类 => 人类
b'data/data/2\\260.jpg' 建筑 => 建筑
b'data/data/3\\348.jpg' 公交 => 公交
b'data/data/7\\738.jpg' 野马 => 野马
b'data/data/8\\828.jpg' 雪山 => 雪山
b'data/data/8\\844.jpg' 雪山 => 雪山
b'data/data/1\\156.jpg' 沙滩 => 沙滩
b'data/data/2\\272.jpg' 建筑 => 建筑
b'data/data/3\\399.jpg' 公交 => 公交
b'data/data/9\\999.jpg' 美食 => 公交
b'data/data/3\\311.jpg' 公交 => 公交
b'data/data/3\\367.jpg' 公交 => 公交
b'data/data/1\\124.jpg' 沙滩 => 沙滩
b'data/data/0\\98.jpg'  人类 => 人类
b'data/data/7\\783.jpg' 野马 => 野马
b'data/data/1\\195.jpg' 沙滩 => 沙滩
b'data/data/6\\672.jpg' 花朵 => 花朵
b'data/data/9\\940.jpg' 美食 => 美食
b'data/data/3\\329.jpg' 公交 => 公交
b'data/data/3\\393.jpg' 公交 => 公交
b'data/data/0\\99.jpg'  人类 => 人类
b'data/data/4\\485.jpg' 恐龙 => 恐龙
b'data/data/4\\472.jpg' 恐龙 => 恐龙
b'data/data/3\\332.jpg' 公交 => 公交
b'data/data/6\\644.jpg' 花朵 => 花朵
b'data/data/2\\213.jpg' 建筑 => 建筑
b'data/data/7\\748.jpg' 野马 => 野马
b'data/data/4\\496.jpg' 恐龙 => 恐龙
b'data/data/5\\555.jpg' 大象 => 大象
b'data/data/2\\283.jpg' 建筑 => 美食
b'data/data/4\\491.jpg' 恐龙 => 恐龙
b'data/data/9\\970.jpg' 美食 => 美食
b'data/data/5\\562.jpg' 大象 => 大象
b'data/data/5\\506.jpg' 大象 => 大象
b'data/data/4\\497.jpg' 恐龙 => 恐龙
b'data/data/9\\910.jpg' 美食 => 美食
b'data/data/4\\402.jpg' 恐龙 => 恐龙
b'data/data/8\\802.jpg' 雪山 => 雪山
b'data/data/5\\531.jpg' 大象 => 大象
b'data/data/3\\389.jpg' 公交 => 公交
b'data/data/1\\143.jpg' 沙滩 => 沙滩
b'data/data/3\\327.jpg' 公交 => 公交
b'data/data/9\\958.jpg' 美食 => 美食
b'data/data/2\\297.jpg' 建筑 => 建筑
b'data/data/9\\918.jpg' 美食 => 美食
b'data/data/2\\208.jpg' 建筑 => 建筑
b'data/data/0\\15.jpg'  人类 => 人类
b'data/data/4\\443.jpg' 恐龙 => 恐龙
b'data/data/2\\231.jpg' 建筑 => 建筑
b'data/data/4\\475.jpg' 恐龙 => 恐龙
b'data/data/2\\217.jpg' 建筑 => 建筑
b'data/data/2\\206.jpg' 建筑 => 公交
b'data/data/4\\478.jpg' 恐龙 => 恐龙
b'data/data/8\\852.jpg' 雪山 => 雪山
b'data/data/2\\268.jpg' 建筑 => 建筑
b'data/data/1\\198.jpg' 沙滩 => 沙滩
b'data/data/0\\41.jpg'  人类 => 人类
b'data/data/5\\510.jpg' 大象 => 大象
b'data/data/7\\774.jpg' 野马 => 野马
b'data/data/9\\915.jpg' 美食 => 美食
b'data/data/8\\809.jpg' 雪山 => 雪山
b'data/data/4\\442.jpg' 恐龙 => 恐龙
b'data/data/8\\881.jpg' 雪山 => 沙滩
b'data/data/7\\720.jpg' 野马 => 野马
b'data/data/8\\811.jpg' 雪山 => 雪山
b'data/data/5\\560.jpg' 大象 => 大象
b'data/data/9\\956.jpg' 美食 => 美食
b'data/data/3\\326.jpg' 公交 => 公交
b'data/data/7\\747.jpg' 野马 => 野马
b'data/data/3\\310.jpg' 公交 => 公交
b'data/data/1\\160.jpg' 沙滩 => 野马
b'data/data/3\\394.jpg' 公交 => 公交
b'data/data/6\\613.jpg' 花朵 => 花朵
b'data/data/2\\263.jpg' 建筑 => 建筑
b'data/data/2\\259.jpg' 建筑 => 建筑
b'data/data/6\\607.jpg' 花朵 => 花朵
b'data/data/1\\118.jpg' 沙滩 => 沙滩
b'data/data/8\\826.jpg' 雪山 => 雪山
b'data/data/3\\314.jpg' 公交 => 公交
b'data/data/1\\166.jpg' 沙滩 => 沙滩
b'data/data/5\\577.jpg' 大象 => 大象
b'data/data/7\\701.jpg' 野马 => 野马
b'data/data/6\\609.jpg' 花朵 => 花朵
b'data/data/0\\42.jpg'  人类 => 人类
b'data/data/3\\334.jpg' 公交 => 公交
b'data/data/9\\965.jpg' 美食 => 美食
b'data/data/2\\294.jpg' 建筑 => 建筑
b'data/data/8\\898.jpg' 雪山 => 雪山
b'data/data/4\\476.jpg' 恐龙 => 恐龙
b'data/data/9\\960.jpg' 美食 => 美食
b'data/data/0\\0.jpg'   人类 => 人类
b'data/data/5\\535.jpg' 大象 => 花朵
b'data/data/3\\338.jpg' 公交 => 公交
b'data/data/1\\129.jpg' 沙滩 => 沙滩
b'data/data/3\\385.jpg' 公交 => 公交
b'data/data/4\\492.jpg' 恐龙 => 恐龙
b'data/data/8\\805.jpg' 雪山 => 雪山
b'data/data/2\\266.jpg' 建筑 => 建筑
b'data/data/5\\581.jpg' 大象 => 大象
b'data/data/0\\28.jpg'  人类 => 建筑
b'data/data/0\\69.jpg'  人类 => 人类
b'data/data/2\\242.jpg' 建筑 => 建筑
b'data/data/7\\759.jpg' 野马 => 野马
b'data/data/1\\121.jpg' 沙滩 => 沙滩
b'data/data/5\\519.jpg' 大象 => 建筑
b'data/data/4\\425.jpg' 恐龙 => 恐龙
b'data/data/5\\599.jpg' 大象 => 大象
b'data/data/8\\833.jpg' 雪山 => 雪山
b'data/data/4\\450.jpg' 恐龙 => 恐龙
b'data/data/1\\116.jpg' 沙滩 => 沙滩
b'data/data/4\\412.jpg' 恐龙 => 恐龙
b'data/data/3\\395.jpg' 公交 => 公交
b'data/data/0\\36.jpg'  人类 => 人类
b'data/data/6\\685.jpg' 花朵 => 花朵
b'data/data/6\\670.jpg' 花朵 => 花朵
b'data/data/4\\477.jpg' 恐龙 => 恐龙
b'data/data/7\\791.jpg' 野马 => 野马
b'data/data/5\\550.jpg' 大象 => 大象
b'data/data/8\\851.jpg' 雪山 => 雪山
b'data/data/8\\867.jpg' 雪山 => 雪山
b'data/data/5\\558.jpg' 大象 => 大象
b'data/data/4\\466.jpg' 恐龙 => 恐龙
b'data/data/5\\585.jpg' 大象 => 大象
b'data/data/3\\301.jpg' 公交 => 公交
b'data/data/8\\871.jpg' 雪山 => 沙滩
b'data/data/1\\183.jpg' 沙滩 => 沙滩
b'data/data/6\\601.jpg' 花朵 => 花朵
b'data/data/6\\637.jpg' 花朵 => 花朵
b'data/data/0\\95.jpg'  人类 => 人类
b'data/data/3\\339.jpg' 公交 => 公交
b'data/data/9\\968.jpg' 美食 => 美食
b'data/data/5\\508.jpg' 大象 => 大象
b'data/data/1\\167.jpg' 沙滩 => 沙滩
b'data/data/7\\746.jpg' 野马 => 野马
b'data/data/8\\819.jpg' 雪山 => 雪山
b'data/data/7\\726.jpg' 野马 => 野马
b'data/data/1\\172.jpg' 沙滩 => 沙滩
b'data/data/5\\574.jpg' 大象 => 大象
b'data/data/2\\293.jpg' 建筑 => 建筑
b'data/data/7\\754.jpg' 野马 => 野马
b'data/data/6\\658.jpg' 花朵 => 花朵
b'data/data/2\\211.jpg' 建筑 => 大象
b'data/data/5\\502.jpg' 大象 => 大象
b'data/data/0\\80.jpg'  人类 => 人类
b'data/data/6\\647.jpg' 花朵 => 花朵
b'data/data/5\\579.jpg' 大象 => 大象
b'data/data/2\\276.jpg' 建筑 => 建筑
b'data/data/5\\505.jpg' 大象 => 大象
b'data/data/4\\489.jpg' 恐龙 => 恐龙
b'data/data/6\\659.jpg' 花朵 => 花朵
b'data/data/2\\207.jpg' 建筑 => 公交
b'data/data/1\\150.jpg' 沙滩 => 沙滩
b'data/data/1\\133.jpg' 沙滩 => 沙滩
b'data/data/5\\521.jpg' 大象 => 大象
b'data/data/0\\88.jpg'  人类 => 人类
b'data/data/9\\944.jpg' 美食 => 美食
b'data/data/4\\424.jpg' 恐龙 => 恐龙
b'data/data/7\\718.jpg' 野马 => 野马
b'data/data/4\\409.jpg' 恐龙 => 恐龙
b'data/data/0\\19.jpg'  人类 => 人类
b'data/data/2\\265.jpg' 建筑 => 建筑
b'data/data/3\\312.jpg' 公交 => 公交
b'data/data/3\\352.jpg' 公交 => 公交
b'data/data/5\\559.jpg' 大象 => 大象
b'data/data/3\\335.jpg' 公交 => 公交
b'data/data/2\\212.jpg' 建筑 => 建筑
b'data/data/0\\82.jpg'  人类 => 人类
b'data/data/1\\168.jpg' 沙滩 => 沙滩
b'data/data/5\\545.jpg' 大象 => 大象
b'data/data/7\\700.jpg' 野马 => 野马
b'data/data/3\\309.jpg' 公交 => 公交
b'data/data/7\\787.jpg' 野马 => 野马
b'data/data/9\\957.jpg' 美食 => 美食
b'data/data/3\\363.jpg' 公交 => 公交
b'data/data/8\\839.jpg' 雪山 => 雪山
b'data/data/8\\896.jpg' 雪山 => 雪山
b'data/data/3\\306.jpg' 公交 => 公交
b'data/data/5\\565.jpg' 大象 => 大象
b'data/data/8\\885.jpg' 雪山 => 建筑
b'data/data/0\\26.jpg'  人类 => 人类
b'data/data/3\\380.jpg' 公交 => 公交
b'data/data/5\\552.jpg' 大象 => 大象
b'data/data/1\\175.jpg' 沙滩 => 沙滩
b'data/data/6\\646.jpg' 花朵 => 花朵
b'data/data/6\\610.jpg' 花朵 => 花朵
b'data/data/0\\48.jpg'  人类 => 人类
b'data/data/0\\32.jpg'  人类 => 人类
b'data/data/1\\154.jpg' 沙滩 => 沙滩
b'data/data/6\\633.jpg' 花朵 => 花朵
b'data/data/9\\954.jpg' 美食 => 美食
b'data/data/3\\316.jpg' 公交 => 公交
正确预测个数: 188
准确度为: 0.94
"""


请先 登录 再评论,若不是会员请先 注册