使用PaddleX对大量图片进行分类(仅包含预测的内容)
# -*- coding: UTF-8 -*-
import os
import cv2
from shutil import copyfile
import numpy as np
import paddlex as pdx
#处理中文路径
import importlib,sys
importlib.reload(sys)
#Paddle加载模型
model = pdx.load_model('./inference_model')
def paddle_predict(source_path,result_path,threshold_value):
#对源路径的图片进行推理
#对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下
isExists=os.path.exists(result_path)#判断目标文件夹是否存在
if not isExists:
os.makedirs(result_path)
for filename in os.listdir(source_path):
try:
im = cv2.imdecode(np.fromfile(source_path+'/'+filename,dtype=np.uint8),-1)
im = im.astype('float32')
result = model.predict(im)
print(result)
isExists=os.path.exists(result_path+'/'+result[0]['category'])#判断分类文件夹是否存在
if not isExists:
os.makedirs(result_path+'/'+result[0]['category'])
if(result[0]['score']>threshold_value):
copyfile(source_path+'/'+filename, result_path+'/'+result[0]['category']+'/'+filename)
#os.remove(source_path+'/'+filename)
except:
print('ERROR:'+filename)
paddle_predict("C:/test","C:/output_new",0.9)首先使用PaddleX进行训练,然后使用该Python程序对大量图片进行分类。
本程序支持中文路径和错误处理功能,能够稳定用于生产用途。
另附上支持对源路径的子文件夹进行遍历的程序
# -*- coding: UTF-8 -*-
import os
import cv2
from shutil import copyfile
import numpy as np
import paddlex as pdx
#处理中文路径
import importlib,sys
importlib.reload(sys)
#Paddle加载模型
model = pdx.load_model('./inference_model')
def paddle_predict(source_path,result_path,threshold_value):
#对源路径的图片进行推理
#对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下
isExists=os.path.exists(result_path)#判断目标文件夹是否存在
if not isExists:
os.makedirs(result_path)
for dirpath, dirnames, filenames in os.walk(source_path):
for filename in filenames:
try:
im = cv2.imdecode(np.fromfile(os.path.join(dirpath, filename),dtype=np.uint8),-1)
im = im.astype('float32')
result = model.predict(im)
print(os.path.join(dirpath, filename))
isExists=os.path.exists(result_path+'/'+result[0]['category'])#判断分类文件夹是否存在
if not isExists:
os.makedirs(result_path+'/'+result[0]['category'])
if(result[0]['score']>threshold_value):
copyfile(os.path.join(dirpath, filename), result_path+'/'+result[0]['category']+'/'+filename)
#os.remove(os.path.join(dirpath, filename))
except:
print('ERROR:'+os.path.join(dirpath, filename))
paddle_predict("C:/test","C:/output_new",0.9)然后为了方便展示效果,进一步优化
# -*- coding: UTF-8 -*-
import os
import cv2
import time
import threading
from shutil import copyfile
import numpy as np
import paddlex as pdx
#处理中文路径
import importlib,sys
importlib.reload(sys)
TEMP_SEC_CALC=0
TEMP_SEC_CALC_ERROR=0
TEMP_SUM=0
TEMP_SUM_DONE=0
model = pdx.load_model('./inference_model') #Paddle加载模型
def time_convert(seconds):
seconds = seconds % (24 * 3600)
hour = seconds // 3600
seconds %= 3600
minutes = seconds // 60
seconds %= 60
return "%02d:%02d:%02d" % (hour, minutes, seconds)
def paddle_predict_calc():
global TEMP_SEC_CALC
global TEMP_SEC_CALC_ERROR
global TEMP_SUM
global TEMP_SUM_DONE
while TEMP_SUM!=TEMP_SUM_DONE:
if TEMP_SEC_CALC:
print('FPS:'+str(TEMP_SEC_CALC)+' Remain:'+time_convert((TEMP_SUM-TEMP_SUM_DONE)/TEMP_SEC_CALC) + ' ERROR:'+str(TEMP_SEC_CALC_ERROR))
TEMP_SEC_CALC = 0
time.sleep(1)
def paddle_predict(source_path,result_path,threshold_value):
#对源路径的图片进行推理
#对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下
global TEMP_SEC_CALC_ERROR
global TEMP_SUM_DONE
global TEMP_SEC_CALC
isExists=os.path.exists(result_path)#判断目标文件夹是否存在
if not isExists:
os.makedirs(result_path)
for dirpath, dirnames, filenames in os.walk(source_path):
for filename in filenames:
try:
im = cv2.imdecode(np.fromfile(os.path.join(dirpath, filename),dtype=np.uint8),-1)
im = im.astype('float32')
result = model.predict(im)
#print(os.path.join(dirpath, filename))
isExists=os.path.exists(result_path+'/'+result[0]['category'])#判断分类文件夹是否存在
if not isExists:
os.makedirs(result_path+'/'+result[0]['category'])
if(result[0]['score']>threshold_value):
copyfile(os.path.join(dirpath, filename), result_path+'/'+result[0]['category']+'/'+filename)
#os.remove(os.path.join(dirpath, filename))
except:
#print('ERROR:'+os.path.join(dirpath, filename))
TEMP_SEC_CALC_ERROR = TEMP_SEC_CALC_ERROR + 1
TEMP_SUM_DONE = TEMP_SUM_DONE + 1
TEMP_SEC_CALC = TEMP_SEC_CALC + 1
def predict(source_path,result_path,threshold_value):
#对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下
global TEMP_SUM
for dirpath, dirnames, filenames in os.walk(source_path):
for filename in filenames:
TEMP_SUM = TEMP_SUM + 1
print('总共发现了:'+str(TEMP_SUM)+'个文件!')
time.sleep(2)
main_func = threading.Thread(target=paddle_predict, args=(source_path,result_path,threshold_value))
calc_func = threading.Thread(target=paddle_predict_calc)
main_func.start()
calc_func.start()
predict("C:/test","C:/output_new1",0.9)运行效果

为了方便对网页图片进行判断,修改了一个精简版的
# -*- coding: UTF-8 -*-
import cv2
import paddlex as pdx
# 处理中文路径
import importlib
import sys
importlib.reload(sys)
Paddle_Func = pdx.load_model('./inference_model') # Paddle加载模型
def Paddle_Url_Predit(Pic_Url):
try:
cap = cv2.VideoCapture(Pic_Url)
if(cap.isOpened()):
ret, im = cap.read()
im = im.astype('float32')
result = Paddle_Func.predict(im)
print(result)
else:
print('Download Failure!')
except:
print('Unknown Error!')
Paddle_Url_Predit('图片url地址')进一步添加Socket传输图片的url地址
服务端:
# -*- coding: UTF-8 -*-
import cv2
import paddlex as pdx
# 处理中文路径
import importlib
import sys
importlib.reload(sys)
Paddle_Func = pdx.load_model('./inference_model') # Paddle加载模型
def Paddle_Url_Predit(Pic_Url):
try:
cap = cv2.VideoCapture(Pic_Url)
if(cap.isOpened()):
ret, im = cap.read()
im = im.astype('float32')
result = Paddle_Func.predict(im)
return(result)
else:
return('Download Failure!')
except:
return('Unknown Error!')
import os
import stat
import socket
# 创建服务器端套接字
sk = socket.socket()
sk.bind(('127.0.0.1', 8898))
sk.listen()
conn, addr = sk.accept()
while True:
ret = conn.recv(1024)
# 打印客户端信息
Socket_rst = Paddle_Url_Predit(ret.decode('utf-8'))
print(Socket_rst)
try:
conn.send(bytes(str(Socket_rst), encoding='utf-8'))
except:
print('Connect Error!')
# 关闭客户端链接
conn.close()
# 关闭服务器套接字
sk.close()客户端(Python,参考自https://zhuanlan.zhihu.com/p/279968757):
import socket
# 创建客户端套接字
sk = socket.socket()
# 尝试连接服务器
sk.connect(('127.0.0.1',8898))
while True:
# 信息发送
info = input('>>>')
sk.send(bytes(info,encoding='utf-8'))
# 信息接收
ret = sk.recv(1024)
# 结束会话
if ret == b'bye':
sk.send(b'bye')
break
# 信息打印
print(ret.decode('utf-8'))
# 关闭客户端套接字
sk.close()客户端(PHP):
待续



