OpenVINO+OpenCV 文本檢測與識別
模型介紹
文本檢測模型
OpenVINO支持場景文字檢測是基于MobileNet的PixelLink模型,該模型有兩個輸出,分別是分割輸出與bounding Boxes輸出,結(jié)構(gòu)如下:
下面是基于VGG16作為backbone實(shí)現(xiàn)的PixelLink的模型結(jié)構(gòu):
輸入格式:1x3x768x1280 BGR彩色圖像
輸出格式:
name: "model/link_logits_/add", [ 1x16x192x32 0] – pixelLink的輸出 name: "model/segm_logits/add", [ 1x2x192x32 0] – 像素分類text/ notext
這里CNN使用類似 VGG16結(jié)構(gòu)提前特征,序列預(yù)測使用 雙向LSTM網(wǎng)絡(luò)。
輸入格式: 1x1x32x120 輸出格式: 30, 1, 37 輸出解釋是基于 CTC貪心解析方式。
代碼演示
01文本檢測
基于PixelLink完成文本檢測,其中加載模型與獲取輸入與輸出層名稱的
代碼實(shí)現(xiàn)如下:
log.info( "Creating Inference Engine")
ie = IECore
dete_net = ie.read_network(model=dete_text_xml, weights=dete_text_bin)
reco_net = ie.read_network(model=reco_text_xml, weights=reco_text_bin)
# 文本檢測網(wǎng)絡(luò), 輸入與輸出格式
log.info( "加載文本檢測網(wǎng)絡(luò),解析輸入與輸出格式...")
input_it = iter(dete_net.input_info)
input_det_blob = next(input_it)
print(input_det_blob)
output_it = iter(dete_net.outputs)
out_det_blob1 = next(output_it)
out_det_blob2 = next(output_it)
# Read andpre-process input images
print(dete_net.input_info[input_det_blob].input_data.shape)
dn, dc, dh, dw = dete_net.input_info[input_det_blob].input_data.shape
# Loading model to the plugin
det_exec_net = ie.load_network(network=dete_net, device_name= "CPU")
print( "out_det_blob1: ", out_det_blob1, "out_det_blob2: ", out_det_blob2)
執(zhí)行推理與解析輸出的
代碼如下:
image = cv.imread( "D:/images/openvino_ocr.jpg")
# image = cv.imread("D:/facedb/tiaoma/1.png")
h, w, c = image.shape
cv.imshow( "input", image)
img_blob = cv.resize(image, (dw, dh))
img_blob = img_blob.transpose( 2, 0, 1)
# Start sync inference
log.info( "Starting inference in synchronous mode")
inf_start1 = time.time
res = det_exec_net.infer(inputs={input_det_blob: [img_blob]})
inf_end1 = time.time - inf_start1
print( "inference time(ms) : %.3f"% (inf_end1 * 1000))
link_logits_ = res[out_det_blob1][ 0]
segm_logits = res[out_det_blob2][ 0]
link_logits_ = link_logits_.transpose( 1, 2, 0)
segm_logits = segm_logits.transpose( 1, 2, 0)
pixel_mask = np.zeros(( 192, 320), dtype=np.uint8)
print(link_logits_.shape, segm_logits.shape)
# 192, 320
forrow inrange( 192):
forcol inrange( 320):
pv1 = segm_logits[row, col, 0]
pv2 = segm_logits[row, col, 1]
ifpv2 > 1.0:
pixel_mask[row, col] = 255
mask = cv.resize(pixel_mask, (w, h))
cv.imshow( "mask", mask) 運(yùn)行結(jié)果如下:
02運(yùn)行結(jié)果:
ie = IECore
reco_net = ie.read_network(model=reco_text_xml, weights=reco_text_bin)
# 文本識別網(wǎng)絡(luò)
log.info( "加載文本識別網(wǎng)絡(luò),解析輸入與輸出格式...")
input_rec_it = iter(reco_net.input_info)
input_rec_blob = next(input_rec_it)
print(input_rec_blob)
output_rec_it = iter(reco_net.outputs)
out_rec_blob = next(output_rec_it)
# Read and pre-process input images
print(reco_net.input_info[input_rec_blob].input_data.shape)
rn, rc, rh, rw = reco_net.input_info[input_rec_blob].input_data.shape
# Loading model to the plugin
rec_exec_net = ie.load_network(network=reco_net, device_name= "CPU")
print( "out_rec_blob1: ", out_rec_blob)
# 文字識別
image = cv.imread( "D:/images/zsxq/ocr3.png")
gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
ret, binary = cv.threshold(gray, 0, 255, cv.THRESH_BINARY_INV | cv.THRESH_OTSU)
se = cv.getStructuringElement(cv.MORPH_RECT, ( 5, 1))
binary = cv.dilate(binary, se)
cv.imshow( "binary", binary)
cv.waitKey( 0)
contours, hireachy = cv.findContours(binary, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
for cnt inrange(len(contours)):
x, y, iw, ih = cv.boundingRect(contours[cnt])
roi = gray[y:y + ih, x:x + iw]
rec_roi = cv.resize(roi, (rw, rh))
rec_roi_blob = np.expand_dims(rec_roi, 0)
# Start sync inference
log.info( "Starting inference in synchronous mode")
inf_start1 = time.time
res = rec_exec_net.infer(inputs={input_rec_blob: [rec_roi_blob]})
inf_end1 = time.time - inf_start1
print( "inference time(ms) : %.3f"% (inf_end1 * 1000))
res = res[out_rec_blob]
txt = greedy_prase_text(res)
cv.putText(image, txt, (x, y), cv.FONT_HERSHEY_PLAIN, 1.0, ( 0, 0, 255), 1, 8)
cv.imshow( "recognition text demo", image)
cv.waitKey( 0)
cv.destroyAllWindows
運(yùn)行結(jié)果如下:
重新整理了一下,解析部分的代碼函數(shù)。不用看公式,看完你會暈倒而且寫不出代碼!
實(shí)現(xiàn)如下:
defctc_soft_max(data): sum = 0; max_val = max(data) index = np.argmax(data) fori inrange(len(data)): sum += np.exp(data[i]- max_val) prob = 1.0/ sum returnindex, prob
defgreedy_prase_text(res): # CTC greedy decode from hereprint(res.shape)# 解析輸出textocrstr = ""prev_pad = False; fori inrange(res.shape[ 0]): ctc = res[i] # 1x13ctc = np.squeeze(ctc, 0) index, prob = ctc_soft_max(ctc)ifdigit_nums[index] == '#': prev_pad = Trueelse: iflen(ocrstr) == 0orprev_pad or(len(ocrstr) > 0anddigit_nums[index] != ocrstr[ -1]): prev_pad = Falseocrstr += digit_nums[index]print(ocrstr)returnocrstr
評論