15
2022
06

使用PaddleX对大量图片进行分类(仅包含预测的内容)

# -*- coding: UTF-8 -*-

import os
import cv2
from shutil import copyfile
import numpy as np
import paddlex as pdx

#处理中文路径
import importlib,sys
importlib.reload(sys)

#Paddle加载模型
model = pdx.load_model('./inference_model')

def paddle_predict(source_path,result_path,threshold_value):
    #对源路径的图片进行推理
    #对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下
    isExists=os.path.exists(result_path)#判断目标文件夹是否存在
    if not isExists:
        os.makedirs(result_path)
    for filename in os.listdir(source_path):
        try:
            im = cv2.imdecode(np.fromfile(source_path+'/'+filename,dtype=np.uint8),-1)
            im = im.astype('float32')
            result = model.predict(im)
            print(result)
            isExists=os.path.exists(result_path+'/'+result[0]['category'])#判断分类文件夹是否存在
            if not isExists:
                os.makedirs(result_path+'/'+result[0]['category'])
            if(result[0]['score']>threshold_value):
                copyfile(source_path+'/'+filename, result_path+'/'+result[0]['category']+'/'+filename)
                #os.remove(source_path+'/'+filename)
        except:
            print('ERROR:'+filename)

paddle_predict("C:/test","C:/output_new",0.9)

首先使用PaddleX进行训练,然后使用该Python程序对大量图片进行分类。

本程序支持中文路径和错误处理功能,能够稳定用于生产用途。

另附上支持对源路径的子文件夹进行遍历的程序

# -*- coding: UTF-8 -*-

import os
import cv2
from shutil import copyfile
import numpy as np
import paddlex as pdx

#处理中文路径
import importlib,sys
importlib.reload(sys)

#Paddle加载模型
model = pdx.load_model('./inference_model')

def paddle_predict(source_path,result_path,threshold_value):
    #对源路径的图片进行推理
    #对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下
    isExists=os.path.exists(result_path)#判断目标文件夹是否存在
    if not isExists:
        os.makedirs(result_path)

    for dirpath, dirnames, filenames in os.walk(source_path):
        for filename in filenames:
            try:
                im = cv2.imdecode(np.fromfile(os.path.join(dirpath, filename),dtype=np.uint8),-1)
                im = im.astype('float32')
                result = model.predict(im)
                print(os.path.join(dirpath, filename))
                isExists=os.path.exists(result_path+'/'+result[0]['category'])#判断分类文件夹是否存在
                if not isExists:
                    os.makedirs(result_path+'/'+result[0]['category'])
                if(result[0]['score']>threshold_value):
                    copyfile(os.path.join(dirpath, filename), result_path+'/'+result[0]['category']+'/'+filename)
                    #os.remove(os.path.join(dirpath, filename))
            except:
                print('ERROR:'+os.path.join(dirpath, filename))

paddle_predict("C:/test","C:/output_new",0.9)

然后为了方便展示效果,进一步优化

# -*- coding: UTF-8 -*-

import os
import cv2
import time
import threading
from shutil import copyfile
import numpy as np
import paddlex as pdx

#处理中文路径
import importlib,sys
importlib.reload(sys)

TEMP_SEC_CALC=0
TEMP_SEC_CALC_ERROR=0
TEMP_SUM=0
TEMP_SUM_DONE=0

model = pdx.load_model('./inference_model') #Paddle加载模型

def time_convert(seconds):
    seconds = seconds % (24 * 3600)
    hour = seconds // 3600
    seconds %= 3600
    minutes = seconds // 60
    seconds %= 60
    return "%02d:%02d:%02d" % (hour, minutes, seconds)

def paddle_predict_calc():
    global TEMP_SEC_CALC
    global TEMP_SEC_CALC_ERROR
    global TEMP_SUM
    global TEMP_SUM_DONE
    while TEMP_SUM!=TEMP_SUM_DONE:
        if TEMP_SEC_CALC:
            print('FPS:'+str(TEMP_SEC_CALC)+' Remain:'+time_convert((TEMP_SUM-TEMP_SUM_DONE)/TEMP_SEC_CALC) + ' ERROR:'+str(TEMP_SEC_CALC_ERROR))
        TEMP_SEC_CALC = 0
        time.sleep(1)

def paddle_predict(source_path,result_path,threshold_value):
    #对源路径的图片进行推理
    #对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下
    global TEMP_SEC_CALC_ERROR
    global TEMP_SUM_DONE
    global TEMP_SEC_CALC
    isExists=os.path.exists(result_path)#判断目标文件夹是否存在
    if not isExists:
        os.makedirs(result_path)
    for dirpath, dirnames, filenames in os.walk(source_path):
        for filename in filenames:
            try:
                im = cv2.imdecode(np.fromfile(os.path.join(dirpath, filename),dtype=np.uint8),-1)
                im = im.astype('float32')
                result = model.predict(im)
                #print(os.path.join(dirpath, filename))
                isExists=os.path.exists(result_path+'/'+result[0]['category'])#判断分类文件夹是否存在
                if not isExists:
                    os.makedirs(result_path+'/'+result[0]['category'])
                if(result[0]['score']>threshold_value):
                    copyfile(os.path.join(dirpath, filename), result_path+'/'+result[0]['category']+'/'+filename)
                    #os.remove(os.path.join(dirpath, filename))  
            except:
                #print('ERROR:'+os.path.join(dirpath, filename))
                TEMP_SEC_CALC_ERROR = TEMP_SEC_CALC_ERROR + 1
            TEMP_SUM_DONE = TEMP_SUM_DONE + 1
            TEMP_SEC_CALC = TEMP_SEC_CALC + 1

def predict(source_path,result_path,threshold_value):
    #对来自参数一目录下的图片进行分类,按分类名保存在参数二的目录下,当预测的准确度大于参数三时,保存该图片到参数二对应目录下
    global TEMP_SUM
    for dirpath, dirnames, filenames in os.walk(source_path):
        for filename in filenames:
            TEMP_SUM = TEMP_SUM + 1
    print('总共发现了:'+str(TEMP_SUM)+'个文件!')
    time.sleep(2)

    main_func = threading.Thread(target=paddle_predict, args=(source_path,result_path,threshold_value))
    calc_func = threading.Thread(target=paddle_predict_calc)
    main_func.start()
    calc_func.start()

predict("C:/test","C:/output_new1",0.9)

运行效果

image.png

为了方便对网页图片进行判断,修改了一个精简版的

# -*- coding: UTF-8 -*-

import cv2
import paddlex as pdx

# 处理中文路径
import importlib
import sys
importlib.reload(sys)

Paddle_Func = pdx.load_model('./inference_model')  # Paddle加载模型

def Paddle_Url_Predit(Pic_Url):
    try:
        cap = cv2.VideoCapture(Pic_Url)
        if(cap.isOpened()):
            ret, im = cap.read()
            im = im.astype('float32')
            result = Paddle_Func.predict(im)
            print(result)
        else:
            print('Download Failure!')
    except:
        print('Unknown Error!')

Paddle_Url_Predit('图片url地址')

进一步添加Socket传输图片的url地址

服务端:

# -*- coding: UTF-8 -*-

import cv2
import paddlex as pdx

# 处理中文路径
import importlib
import sys
importlib.reload(sys)

Paddle_Func = pdx.load_model('./inference_model')  # Paddle加载模型

def Paddle_Url_Predit(Pic_Url):
    try:
        cap = cv2.VideoCapture(Pic_Url)
        if(cap.isOpened()):
            ret, im = cap.read()
            im = im.astype('float32')
            result = Paddle_Func.predict(im)
            return(result)
        else:
            return('Download Failure!')
    except:
        return('Unknown Error!')

import os
import stat
import socket

# 创建服务器端套接字
sk = socket.socket()
sk.bind(('127.0.0.1', 8898))
sk.listen()
conn, addr = sk.accept()
while True:
    ret = conn.recv(1024)
    # 打印客户端信息
    Socket_rst = Paddle_Url_Predit(ret.decode('utf-8'))
    print(Socket_rst)
    try:
        conn.send(bytes(str(Socket_rst), encoding='utf-8'))
    except:
        print('Connect Error!')
# 关闭客户端链接
conn.close()
# 关闭服务器套接字
sk.close()

客户端(Python,参考自https://zhuanlan.zhihu.com/p/279968757):

import socket
# 创建客户端套接字
sk = socket.socket()          
# 尝试连接服务器
sk.connect(('127.0.0.1',8898))
while True:
    # 信息发送
    info = input('>>>')
    sk.send(bytes(info,encoding='utf-8'))
    # 信息接收
    ret = sk.recv(1024)
    # 结束会话
    if ret == b'bye':
        sk.send(b'bye')
        break
    # 信息打印
    print(ret.decode('utf-8'))
# 关闭客户端套接字
sk.close()

客户端(PHP):

待续


« 上一篇

返回顶部
请先 登录 再评论,若不是会员请先 注册