訂閱
糾錯
加入自媒體

使用 TensorFlow.js 在瀏覽器上進(jìn)行自定義對象檢測

什么是物體檢測?

與許多計算機(jī)視覺認(rèn)知應(yīng)用相比,對象檢測是在圖像和視頻中識別和定位對象的常用技術(shù)之一。顧名思義——“計算機(jī)視覺”,是計算機(jī)獲得類似人類視覺以查看和識別物體的能力。目標(biāo)檢測可以被視為具有一些高級功能的圖像識別。該算法不僅可以識別/識別圖像/視頻中的對象,還可以對它們進(jìn)行定位。換句話說,算法在圖像或視頻幀中的對象周圍創(chuàng)建了一個邊界框。

物體檢測示例

各種物體檢測算法

以下是一些用于對象檢測的流行

R-CNN: 基于區(qū)域的卷積神經(jīng)網(wǎng)絡(luò)

Fast R-CNN: :基于區(qū)域的快速卷積神經(jīng)網(wǎng)絡(luò)

Faster R-CNN: 更快的基于區(qū)域的卷積網(wǎng)絡(luò)YOLO: 只看一次

SSD: 單鏡頭探測器每種算法都有自己的優(yōu)缺點。這些算法如何工作的細(xì)節(jié)超出了本文的范圍。

卷積神經(jīng)網(wǎng)絡(luò)的架構(gòu)

曾經(jīng)晚上放學(xué)回家,打開電視看最喜歡的動畫片的美好時光,可能大家都經(jīng)歷過。相信我們都喜歡看動畫片。那么,如何重溫那些日子呢?

今天,我們將學(xué)習(xí)如何使用 TensorFlow.js 創(chuàng)建端到端的自定義對象檢測 Web 應(yīng)用程序。我們將在自定義數(shù)據(jù)集上訓(xùn)練模型,并將其作為成熟的 Web 應(yīng)用程序部署在瀏覽器上。

如果你對構(gòu)建自己的對象檢測模型感到興奮,還等什么?讓我們深入了解。

本文將創(chuàng)建一個在瀏覽器上實時檢測卡通的模型。隨意選擇你自己的數(shù)據(jù)集,因為整個過程保持不變。

創(chuàng)建數(shù)據(jù)集

第一步是收集要檢測的對象的圖像。比如最喜歡的動畫片是機(jī)器貓,史酷比,米奇 老鼠,憨豆先生和麥昆。這些卡通形象構(gòu)成了這個模型的類。為這五個類中的每一個收集了大約 60 張圖像。這是數(shù)據(jù)集外觀。

記住:如果你給模型喂垃圾,你就會得到垃圾。為了獲得最佳結(jié)果,請確保為模型收集足夠的圖像以從中學(xué)習(xí)特征。

收集到足夠的數(shù)據(jù)后,讓我們繼續(xù)下一步。

標(biāo)記數(shù)據(jù)集

要標(biāo)記數(shù)據(jù)集中的對象,我們需要一個注釋/標(biāo)記工具。有很多注釋工具可以做到這一點,例如 LabelImg、Intel OpenVINO CVAT、VGG Image Annotator 等。

雖然這些都是業(yè)內(nèi)最好的注釋工具,但發(fā)現(xiàn) LabelImg 更容易使用。隨意選擇你喜歡的任何注釋工具,或者直接按照本文進(jìn)行操作。

下面是一個帶注釋的圖像的示例:圍繞感興趣區(qū)域(對象)及其標(biāo)簽名稱的邊界框。

圖片標(biāo)注

對于每個注釋的圖像,將生成一個相應(yīng)的 XML 文件,其中包含元數(shù)據(jù),例如邊界框的坐標(biāo)、類名、圖像名稱、圖像路徑等。

訓(xùn)練模型時需要這些信息。我們稍后會看到那部分。

下面是 XML 注釋文件的外觀示例。

注釋 XML 文件

好的,一旦你正確注釋了所有圖像,按照目錄結(jié)構(gòu)的以下方式將數(shù)據(jù)集拆分為訓(xùn)練集和測試集:

數(shù)據(jù)集的目錄結(jié)構(gòu)

在 Google Drive 上上傳數(shù)據(jù)集登

錄你的 Google 帳戶并將壓縮的數(shù)據(jù)集上傳到你的 Google Drive。我們將在模型訓(xùn)練期間獲取此數(shù)據(jù)集。確保數(shù)據(jù)集的上傳沒有因網(wǎng)絡(luò)問題而中斷,并且已完全上傳。

Google Drive 上的數(shù)據(jù)集

在本地機(jī)器上克隆以下存儲庫

https://github.com/NSTiwari/TensorFlow.js-Custom-Object-Detection

此存儲庫包含一個名為:Custom_Object_Detection_using_TensorFlow_js.pynb的 Colab Notebook。

打開 Google Colab 并將此 Colab Notebook上傳到那里,F(xiàn)在,我們將開始實際訓(xùn)練我們的對象檢測模型。

我們正在使用 Google Colab,因此你無需在本地機(jī)器上安裝 TensorFlow 和其他庫,因此我們避免了手動安裝庫的不必要麻煩,如果安裝不當(dāng)可能會出錯。

配置 Google Colab

在 Google Colab 上上傳筆記本后,檢查運行時類型是否設(shè)置為“GPU”。為此,請單擊 Runtime –> Change runtime type.

Google Colab 設(shè)置

在筆記本設(shè)置中,如果硬件加速器設(shè)置為'GPU',如下圖,你就可以開始了。

Google Colab 設(shè)置

如果以上五個步驟都成功完成,那么就開始真正的游戲 —— 模型訓(xùn)練。

模型訓(xùn)練

配置所有必要的訓(xùn)練參數(shù)。

image.png

掛載 Google Drive:

訪問你在第 3 步中存儲在 Google Drive 上的數(shù)據(jù)集。

from google.colab import drive

drive.mount('/content/drive')

安裝 TensorFlow 對象檢測 API:

安裝和設(shè)置 TensorFlow 對象檢測 API、Protobuf 和其他必要的依賴項。

依賴項:

所需的大部分依賴項都預(yù)裝在 Google Colab 中。我們需要安裝的唯一附加包是 TensorFlow.js,它用于將我們訓(xùn)練的模型轉(zhuǎn)換為與網(wǎng)絡(luò)兼容的模型。

協(xié)議緩沖區(qū):

TensorFlow 對象檢測 API 依賴于所謂的協(xié)議緩沖區(qū)(也稱為 protobuf)。Protobuf 是一種描述信息的語言中立方式。這意味著你可以編寫一次 protobuf,然后編譯它以用于其他語言,如 Python、Java 或 C。下面使用的protoc命令正在為 Python 編譯 object_detection/protos 文件夾中的所有協(xié)議緩沖區(qū)。

環(huán)境:

要使用對象檢測 API,我們需要將它與包含用于訓(xùn)練和評估幾個廣泛使用的卷積神經(jīng)網(wǎng)絡(luò) (CNN) 圖像分類模型的代碼的 slim 添加到我們的 PYTHONPATH 中。

image.png

image.png

測試設(shè)置:

運行模型構(gòu)建器測試以驗證是否一切設(shè)置成功。

!python object_detection/builders/model_builder_tf1_test.py

從 Google Drive 復(fù)制數(shù)據(jù)集文件夾:

獲取保存在 Drive 上的圖像和注釋數(shù)據(jù)集。

!unzip /content/drive/MyDrive/TFJS-Custom-Detection -d /content/

%cd /content/

%mkdir data

加載 xml_to_csv.py 文件:

!wget https://raw.githubusercontent.com/NSTiwari/TensorFlow.js-Custom-Object-Detection/master/xml_to_csv.py -P /content/TFJS-Custom-Detection/

將XML注釋轉(zhuǎn)換為 CSV 文件:

所有 PascalVOC 標(biāo)簽都轉(zhuǎn)換為 CSV 文件,用于訓(xùn)練和測試數(shù)據(jù)。

%cd /content/

!python TFJS-Custom-Detection/xml_to_csv.py

在數(shù)據(jù)文件夾中創(chuàng)建 labelmap.pbtxt 文件:考慮以下示例:

image.png

創(chuàng)建TFRecord:

下載 generate_tf_record.py 文件。

!wget https://raw.githubusercontent.com/NSTiwari/TensorFlow.js-Custom-Object-Detection/master/generate_tf_records.py -P /content/

!python generate_tf_records.py -l /content/data/labelmap.pbtxt -o data/train.record -i TFJS-Custom-Detection/images -csv TFJS-Custom-Detection/train_labels.csv

!python generate_tf_records.py -l /content/data/labelmap.pbtxt -o data/val.record -i TFJS-Custom-Detection/images -csv TFJS-Custom-Detection/val_labels.csv

導(dǎo)航到models/research目錄:

%cd /content/models/research

下載基本模型:

從頭開始訓(xùn)練模型可能需要大量計算時間。相反,我們選擇在預(yù)訓(xùn)練模型上應(yīng)用遷移學(xué)習(xí)。當(dāng)然,遷移學(xué)習(xí)在很大程度上有助于減少計算和時間。我們將使用的基本模型是非?斓 MobileNet 模型。

image.png

模型配置:

在訓(xùn)練開始之前,我們需要通過指定 labelmap、TFRecord 和 checkpoint 的路徑來配置訓(xùn)練管道。默認(rèn)批量大小為 128,這也需要更改,因為它太大而無法由 Colab 處理。

import re


from google.protobuf import text_format


from object_detection.utils import config_util

from object_detection.utils import label_map_util


pipeline_skeleton = '/content/models/research/object_detection/samples/configs/' + CONFIG_TYPE + '.config'

configs = config_util.get_configs_from_pipeline_file(pipeline_skeleton)

label_map = label_map_util.get_label_map_dict(LABEL_M(jìn)AP_PATH)

num_classes = len(label_map.keys())

meta_arch = configs["model"].WhichOneof("model")

override_dict = {

 'model.{}.num_classes'.format(meta_arch): num_classes,

 'train_config.batch_size': 24,

 'train_input_path': TRAIN_RECORD_PATH,

 'eval_input_path': VAL_RECORD_PATH,

 'train_config.fine_tune_checkpoint': os.path.join(CHECKPOINT_PATH, 'model.ckpt'),

 'label_map_path': LABEL_M(jìn)AP_PATH


configs = config_util.merge_external_params_with_configs(configs, kwargs_dict=override_dict)

pipeline_config = config_util.create_pipeline_proto_from_configs(configs)

config_util.save_pipeline_config(pipeline_config, DATA_PATH)

開始訓(xùn)練:

運行下面的單元格以開始訓(xùn)練模型。通過調(diào)用model_main腳本并將以下參數(shù)傳遞給它來調(diào)用訓(xùn)練

· 我們創(chuàng)建的pipeline.config 的位置。

· 我們想要保存模型的位置。

· 我們想要訓(xùn)練模型的步驟數(shù)(訓(xùn)練時間越長,學(xué)習(xí)的潛力就越大)。

· 評估步驟的數(shù)量(或測試模型的頻率)讓我們了解模型的表現(xiàn)。

!rm -rf $OUTPUT_PATH

!python -m object_detection.model_main
       --pipeline_config_path=$DATA_PATH/pipeline.config
       --model_dir=$OUTPUT_PATH
       --num_train_steps=$NUM_TRAIN_STEPS
       --num_eval_steps=100

導(dǎo)出推理圖:

每 500 個訓(xùn)練步驟后生成檢查點。每個檢查點都是你的模型在該訓(xùn)練點的快照。

如果由于某種原因訓(xùn)練因網(wǎng)絡(luò)或電源故障而崩潰,那么你可以從最后一個檢查點繼續(xù)訓(xùn)練,而不是從頭開始。

import os

import re

regex = re.compile(r"model.ckpt-([0-9]+).index")

numbers = [int(regex.search(f).group(1)) for f in os.listdir(OUTPUT_PATH) if regex.search(f)]

TRAINED_CHECKPOINT_PREFIX = os.path.join(OUTPUT_PATH, 'model.ckpt-{}'.format(max(numbers)))


print(f'Using {TRAINED_CHECKPOINT_PREFIX}')

!rm -rf $EXPORTED_PATH

!python -m object_detection.export_inference_graph 

 --pipeline_config_path=$DATA_PATH/pipeline.config 

 --trained_checkpoint_prefix=$TRAINED_CHECKPOINT_PREFIX 

 --output_directory=$EXPORTED_PATH

測試模型:

現(xiàn)在,讓我們在一些圖像上測試模型。請記住,該模型僅訓(xùn)練了 500 步。所以,準(zhǔn)確度可能不會那么高。運行下面的單元格來親自測試模型并了解模型的訓(xùn)練效果。

注意:有時,此命令不運行,可以嘗試重新運行它。此外,嘗試將模型訓(xùn)練 5,000 步,看看準(zhǔn)確性如何變化。

from IPython.display import display, Javascript, Image

from google.colab.output import eval_js

from base64 import b64decode

import tensorflow as tf

# Use javascipt to take a photo.

def take_photo(filename, quality=0.8):

    js = Javascript('''

     async function takePhoto(quality) {

     const div = document.createElement('div');

     const capture = document.createElement('button');

     capture.textContent = 'Capture';

     div.a(chǎn)ppendChild(capture);


const video = document.createElement('video');

     video.style.display = 'block';

     const stream = await navigator.mediaDevices.getUserMedia({video: true});

document.body.a(chǎn)ppendChild(div);

     div.a(chǎn)ppendChild(video);

     video.srcObject = stream;

     await video.play();

// Resize the output to fit the video element.

     google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);

// Wait for Capture to be clicked.

     await new Promise((resolve) => capture.onclick = resolve);

const canvas = document.createElement('canvas');

     canvas.width = video.videoWidth;

     canvas.height = video.videoHeight;

     canvas.getContext('2d').drawImage(video, 0, 0);

     stream.getVideoTracks()[0].stop();

     div.remove();

     return canvas.toDataURL('image/jpeg', quality);

   }

   ''')

  display(js)

  data = eval_js('takePhoto({})'.format(quality))

  binary = b64decode(data.split(',')[1])

  with open(filename, 'wb') as f:

     f.write(binary)

  return filename


try:

 take_photo('/content/photo.jpg')

except Exception as err:

 # Errors will be thrown if the user does not have a webcam or if they do not

 # grant the page permission to access it.

 print(str(err))

# Use the captured photo to make predictions

%matplotlib inline

import os

import numpy as np

from matplotlib import pyplot as plt

from PIL import Image as PImage

from object_detection.utils import visualization_utils as vis_util

from object_detection.utils import label_map_util

# Load the labels

category_index = label_map_util.create_category_index_from_labelmap(LABEL

MAP_PATH, use_display_name=True)

# Load the model

path_to_frozen_graph = os.path.join(EXPORTED_PATH, 'frozen_inference_graph.pb')

detection_graph = tf.Graph()

with detection_graph.a(chǎn)s_default():

 od_graph_def = tf.GraphDef()

 with tf.gfile.GFile(path_to_frozen_graph, 'rb') as fid:

   serialized_graph = fid.read()

   od_graph_def.ParseFromString(serialized_graph)

   tf.import_graph_def(od_graph_def, name='')

with detection_graph.a(chǎn)s_default():

 with tf.Session(graph=detection_graph) as sess:

   # Definite input and output Tensors for detection_graph

   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')

   # Each box represents a part of the image where a particular object was detected.

   detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')

   # Each score represent how level of confidence for each of the objects.

   # Score is shown on the result image, together with the class label.

   detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')

   detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')

   num_detections = detection_graph.get_tensor_by_name('num_detections:0')

   image = PImage.open('/content/photo.jpg')

   # the array based representation of the image will be used later in order to prepare the

   # result image with boxes and labels on it.

   (im_width, im_h(yuǎn)eight) = image.size

   image_np = np.a(chǎn)rray(image.getdata()).reshape((im_h(yuǎn)eight, im_width, 3)).a(chǎn)stype(np.uint8)

   # Expand dimensions since the model expects images to have shape: [1, None, None, 3]

   image_np_expanded = np.expand_dims(image_np, axis=0)

   # Actual detection.

   (boxes, scores, classes, num) = sess.run(

       [detection_boxes, detection_scores, detection_classes, num_detections],

       feed_dict={image_tensor: image_np_expanded})

   # Visualization of the results of a detection.

   vis_util.visualize_boxes_and_labels_on_image_array(

      image_np,

       np.squeeze(boxes),

       np.squeeze(classes).a(chǎn)stype(np.int32),

       np.squeeze(scores),

       category_index,

       use_normalized_coordinates=True,

       line_thickness=8)

   plt.figure(figsize=(12, 8))

   plt.imshow(image_np)

將模型轉(zhuǎn)換為 TFJS:

我們導(dǎo)出的模型適用于 Python。但是,要將其部署在 Web 瀏覽器上,我們需要將其轉(zhuǎn)換為 TensorFlow.js,以便兼容直接在瀏覽器上運行

此外,該模型僅將對象檢測為label_map.pbtxt. 因此,我們還需要為所有可以映射到 ID 的標(biāo)簽創(chuàng)建一個 JSON 列表。

image.png

下載模型:

現(xiàn)在可以下載 TFJS 模型了。

注意:有時,此命令不會運行或會引發(fā)錯誤。請嘗試再次運行它。

你還可以通過右鍵單擊左側(cè)邊欄文件檢查器中的 model_web.zip 文件來下載模型。

from google.colab import files

files.download('/content/model_web.zip')

如果你順利到達(dá)這里,恭喜你,你已經(jīng)成功地訓(xùn)練了模型。

使用 TensorFlow.js 在 Web 應(yīng)用程序上部署模型。下載 TFJS 模型后,復(fù)制TensorFlow.js-Custom-Object-Detection/React_Web_App/public目錄中的 model_web 文件夾 。

現(xiàn)在,運行以下命令:

cd TensorFlow.js-Custom-Object-Detection/React_Web_App

npm install

npm start

現(xiàn)在,最后在你的 Web 瀏覽器上打開localhost:3000并親自測試模型。

TF.js 模型的對象檢測輸出

因此,恭喜你使用 TensorFlow 創(chuàng)建了端到端的自定義對象檢測模型,并將其部署在使用 TensorFlow.js 的 Web 應(yīng)用程序上。

       原文標(biāo)題 : 使用 TensorFlow.js 在瀏覽器上進(jìn)行自定義對象檢測

聲明: 本文由入駐維科號的作者撰寫,觀點僅代表作者本人,不代表OFweek立場。如有侵權(quán)或其他問題,請聯(lián)系舉報。

發(fā)表評論

0條評論,0人參與

請輸入評論內(nèi)容...

請輸入評論/評論長度6~500個字

您提交的評論過于頻繁,請輸入驗證碼繼續(xù)

暫無評論

暫無評論

人工智能 獵頭職位 更多
掃碼關(guān)注公眾號
OFweek人工智能網(wǎng)
獲取更多精彩內(nèi)容
文章糾錯
x
*文字標(biāo)題:
*糾錯內(nèi)容:
聯(lián)系郵箱:
*驗 證 碼:

粵公網(wǎng)安備 44030502002758號