簡單交通燈識別Traffic-Light-Classify

整個專案原始碼:GitHub

引言

前面我們講完交通標誌的識別,現在我們開始嘗試來實現交通訊號燈的識別
接下來我們將按照自己的思路來實現並完善整個Project.
在這個專案中,我們使用HSV色彩空間來識別交通燈,可以改善及提高的地方:

  • 可以採用Faster-RCNN或SSD來實現交通燈的識別

首先我們第一步是匯入資料,並在RGB及HSV色彩空間視覺化部分資料。這裡的資料,我們採用MIT自動駕駛課程的圖片,
總共三類:紅綠黃,1187張圖片,其中,723張紅色交通燈圖片,429張綠色交通燈圖片,35張黃色交通燈圖片。

匯入庫

# import some libs
import cv2
import os
import glob
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# Image data directories
IMAGE_DIR_TRAINING = "traffic_light_images/training/"
IMAGE_DIR_TEST = "traffic_light_images/test/"
#load data
def load_dataset(image_dir):
'''
This function loads in images and their labels and places them in a list
image_dir:directions where images stored
'''
im_list =[]
image_types= ['red','yellow','green']
#Iterate through each color folder
for im_type in image_types:
file_lists = glob.glob(os.path.join(image_dir,im_type,'*'))
print(len(file_lists))
for file in file_lists:
im = mpimg.imread(file)
if not im is None:
im_list.append((im,im_type))
return im_list
IMAGE_LIST = load_dataset(IMAGE_DIR_TRAINING)
723
35
429

Visualize the data

這裡視覺化主要實現:

  • 顯示影象
  • 列印出圖片的大小
  • 列印出圖片對應的標籤
_,ax = plt.subplots(1,3,figsize=(5,2))
#red
img_red = IMAGE_LIST[0][0]
ax[0].imshow(img_red)
ax[0].annotate(IMAGE_LIST[0][1],xy=(2,5),color='blue',fontsize='10')
ax[0].axis('off')
ax[0].set_title(img_red.shape,fontsize=10)
#yellow
img_yellow = IMAGE_LIST[730][0]
ax[1].imshow(img_yellow)
ax[1].annotate(IMAGE_LIST[730][1],xy=(2,5),color='blue',fontsize='10')
ax[1].axis('off')
ax[1].set_title(img_yellow.shape,fontsize=10)
#green
img_green = IMAGE_LIST[800][0]
ax[2].imshow(img_green)
ax[2].annotate(IMAGE_LIST[800][1],xy=(2,5),color='blue',fontsize='10')
ax[2].axis('off')
ax[2].set_title(img_green.shape,fontsize=10)
plt.show()

png

PreProcess Data

在匯入了上述資料後,接下來我們需要標準化輸入及輸出

Input

從上圖,我們可以看出,每張圖片的大小並不一樣,我們需要標準化輸入
將每張圖圖片的大小resize成相同的大小,
因為對於分類任務來說,我們需要
在每張圖片上應用相同的演算法,因此標準化影象尤其重要

Output

這裡我們的標籤資料是類別資料:’red’,’yellow’,’green’,因此我們可以利用one_hot方法將類別資料轉換成數值資料

# 標準化輸入影象,這裡我們resize圖片大小為32x32x3,這裡我們也可以對影象進行裁剪、平移、旋轉
def standardize(image_list):
'''
This function takes a rgb image as input and return a standardized version
image_list: image and label
'''
standard_list = []
#Iterate through all the image-label pairs
for item in image_list:
image = item[0]
label = item[1]
# Standardize the input
standardized_im = standardize_input(image)
# Standardize the output(one hot)
one_hot_label = one_hot_encode(label)
# Append the image , and it's one hot encoded label to the full ,processed list of image data
standard_list.append((standardized_im,one_hot_label))
return standard_list
def standardize_input(image):
#Resize all images to be 32x32x3
standard_im = cv2.resize(image,(32,32))
return standard_im
def one_hot_encode(label):
#return the correct encoded label. 
'''
# one_hot_encode("red") should return: [1, 0, 0]
# one_hot_encode("yellow") should return: [0, 1, 0]
# one_hot_encode("green") should return: [0, 0, 1]
'''
if label=='red':
return [1,0,0]
elif label=='yellow':
return [0,1,0]
else:
return [0,0,1]

Test your code

實現完了上述標準化程式碼後,我們需要進一步確定我們的程式碼是正確的,因此接下來我們可以實現一個函式來實現上述程式碼功能的檢驗
用Python搭建自動化測試框架,我們需要組織用例以及測試執行,這裡我們推薦Python的標準庫——unittest。

import unittest
from IPython.display import Markdown,display
# Helper function for printing markdown text(text in color/bold/etc)
def printmd(string):
display(Markdown(string))
# Print a test falied message,given an error
def print_fail():
printmd('**<span style=="color: red;">Test Failed</span>**')
def print_pass():
printmd('**<span style="color:green;">Test Passed</span>**')
# A class holding all tests
class Tests(unittest.TestCase):
#Tests the 'one_hot_encode' function,which is passed in as an argument
def test_one_hot(self,one_hot_function):
#test that the generate onr-hot lables match the expected one-hot label
#for all three cases(red,yellow,green)
try:
self.assertEqual([1,0,0],one_hot_function('red'))
self.assertEqual([0,1,0],one_hot_function('yellow'))
self.assertEqual([0,0,1],one_hot_function('green'))
#enter exception
except self.failureException as e:
#print out an error message
print_fail()
print('Your function did not return the excepted one-hot label')
print('\n' str(e))
return
print_pass()
#Test if ay misclassified images are red but mistakenly classifed as green
def test_red_aa_green(self,misclassified_images):
#Loop through each misclassified image and the labels
for im,predicted_label,true_label in misclassified_images:
#check if the iamge is one of a red light
if(true_label==[1,0,0]):
try:
self.assertNotEqual(true_label,[0,1,0])
except self.failureException as e:
print_fail()
print('Warning:A red light is classified as green.')
print('\n' str(e))
return
print_pass()
tests = Tests()
tests.test_one_hot(one_hot_encode)

Test Passed

Standardized_Train_List = standardize(IMAGE_LIST)

Feature Extraction

在這裡我們將使用色彩空間、形狀分析及特徵構造

RGB to HSV

#Visualize
image_num = 0
test_im = Standardized_Train_List[image_num][0]
test_label = Standardized_Train_List[image_num][1]
#convert to hsv
hsv = cv2.cvtColor(test_im, cv2.COLOR_RGB2HSV)
# Print image label
print('Label [red, yellow, green]: '   str(test_label))
h = hsv[:,:,0]
s = hsv[:,:,1]
v = hsv[:,:,2]
# Plot the original image and the three channels
_, ax = plt.subplots(1, 4, figsize=(20,10))
ax[0].set_title('Standardized image')
ax[0].imshow(test_im)
ax[1].set_title('H channel')
ax[1].imshow(h, cmap='gray')
ax[2].set_title('S channel')
ax[2].imshow(s, cmap='gray')
ax[3].set_title('V channel')
ax[3].imshow(v, cmap='gray')
Label [red, yellow, green]: [1, 0, 0]
<matplotlib.image.AxesImage at 0x7fb49ad71f28>

png

# create feature
'''
HSV即色相、飽和度、明度(英語:Hue, Saturation, Value),又稱HSB,其中B即英語:Brightness。
色相(H)是色彩的基本屬性,就是平常所說的顏色名稱,如紅色、黃色等。
飽和度(S)是指色彩的純度,越高色彩越純,低則逐漸變灰,取0-100%的數值。
明度(V),亮度(L),取0-100%。
'''
def create_feature(rgb_image):
'''
Basic brightness feature
rgb_image : a rgb_image
'''
hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
sum_brightness = np.sum(hsv[:,:,2])
area = 32*32
avg_brightness = sum_brightness / area#Find the average
return avg_brightness
def high_saturation_pixels(rgb_image,threshold=80):
'''
Returns average red and green content from high saturation pixels
Usually, the traffic light contained the highest saturation pixels in the image.
The threshold was experimentally determined to be 80
'''
high_sat_pixels = []
hsv = cv2.cvtColor(rgb,cv2.COLOR_RGB2HSV)
for i in range(32):
for j in range(32):
if hsv[i][j][1] > threshold:
high_sat_pixels.append(rgb_image[i][j])
if not high_sat_pixels:
return highest_sat_pixel(rgb_image)
sum_red = 0
sum_green = 0
for pixel in high_sat_pixels:
sum_red =pixel[0]
sum_green =pixel[1]
# use sum() instead of manually adding them up
avg_red = sum_red / len(high_sat_pixels)
avg_green = sum_green / len(high_sat_pixels)*0.8
return avg_red,avg_green
def highest_sat_pixel(rgb_image):
'''
Finds the highest saturation pixels, and checks if it has a higher green
or a higher red content
'''
hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
s = hsv[:,:,1]
x,y = (np.unravel_index(np.argmax(s),s.shape))
if rgb_image[x,y,0] > rgb_image[x,y,1]*0.9:
return 1,0 #red has a higher content
return 0,1

Test dataset

接下來我們匯入測試集來看看,上述方法的測試精度
上述方法我們實現了:
1.求平均的brightness
2.求red及green的色彩飽和度
有人或許會提出疑問,為啥沒有進行yellow的判斷,因此我們作出以下的改善
reference url

這裡部分閾值,我們直接參考WIKI上的資料:
這裡寫圖片描述

def estimate_label(rgb_image,display=False):
'''
rgb_image:Standardized RGB image
'''
return red_green_yellow(rgb_image,display)
def findNoneZero(rgb_image):
rows,cols,_ = rgb_image.shape
counter = 0
for row in range(rows):
for col in range(cols):
pixels = rgb_image[row,col]
if sum(pixels)!=0:
counter = counter 1
return counter
def red_green_yellow(rgb_image,display):
'''
Determines the red , green and yellow content in each image using HSV and experimentally
determined thresholds. Returns a Classification based on the values
'''
hsv = cv2.cvtColor(rgb_image,cv2.COLOR_RGB2HSV)
sum_saturation = np.sum(hsv[:,:,1])# Sum the brightness values
area = 32*32
avg_saturation = sum_saturation / area #find average
sat_low = int(avg_saturation*1.3)#均值的1.3倍,工程經驗
val_low = 140
#Green
lower_green = np.array([70,sat_low,val_low])
upper_green = np.array([100,255,255])
green_mask = cv2.inRange(hsv,lower_green,upper_green)
green_result = cv2.bitwise_and(rgb_image,rgb_image,mask = green_mask)
#Yellow
lower_yellow = np.array([10,sat_low,val_low])
upper_yellow = np.array([60,255,255])
yellow_mask = cv2.inRange(hsv,lower_yellow,upper_yellow)
yellow_result = cv2.bitwise_and(rgb_image,rgb_image,mask=yellow_mask)
# Red 
lower_red = np.array([150,sat_low,val_low])
upper_red = np.array([180,255,255])
red_mask = cv2.inRange(hsv,lower_red,upper_red)
red_result = cv2.bitwise_and(rgb_image,rgb_image,mask = red_mask)
if display==True:
_,ax = plt.subplots(1,5,figsize=(20,10))
ax[0].set_title('rgb image')
ax[0].imshow(rgb_image)
ax[1].set_title('red result')
ax[1].imshow(red_result)
ax[2].set_title('yellow result')
ax[2].imshow(yellow_result)
ax[3].set_title('green result')
ax[3].imshow(green_result)
ax[4].set_title('hsv image')
ax[4].imshow(hsv)
plt.show()
sum_green = findNoneZero(green_result)
sum_red = findNoneZero(red_result)
sum_yellow = findNoneZero(yellow_result)
if sum_red >= sum_yellow and sum_red>=sum_green:
return [1,0,0]#Red
if sum_yellow>=sum_green:
return [0,1,0]#yellow
return [0,0,1]#green

Test

接下來我們選擇三張圖片來看看測試效果

img_red,img_yellow,img_green

img_test = [(img_red,'red'),(img_yellow,'yellow'),(img_green,'green')]
standardtest = standardize(img_test)
for img in standardtest:
predicted_label = estimate_label(img[0],display = True)
print('Predict label :',predicted_label)
print('True label:',img[1])

png

Predict label : [1, 0, 0]
True label: [1, 0, 0]

png

Predict label : [0, 1, 0]
True label: [0, 1, 0]

png

Predict label : [0, 0, 1]
True label: [0, 0, 1]
# Using the load_dataset function in helpers.py
# Load test data
TEST_IMAGE_LIST = load_dataset(IMAGE_DIR_TEST)
# Standardize the test data
STANDARDIZED_TEST_LIST = standardize(TEST_IMAGE_LIST)
# Shuffle the standardized test data
random.shuffle(STANDARDIZED_TEST_LIST)
181
9
107

Determine the Accuracy

接下來我們來看看咱們演算法在測試集上的準確率。下面我們實現的程式碼儲存所有的被錯分的圖片以及它們被預測的結果及真實標籤。
這些資料被儲存在MISCLASSIFIED.

# COnstructs a list of misclassfied iamges given a list of test images and their labels
# This will throw an assertionerror if labels are not standardized(one hot encode)
def get_misclassified_images(test_images,display=False):
misclassified_images_labels = []
#Iterate through all the test images
#Classify each image  and compare to the true label
for image in test_images:
# Get true data
im = image[0]
true_label = image[1]
assert (len(true_label)==3),'This true_label is not the excepted length (3).'
#Get predicted label from your classifier
predicted_label = estimate_label(im,display=False)
assert(len(predicted_label)==3),'This predicted_label is not the excepted length (3).'
#compare true and predicted labels
if(predicted_label!=true_label):
#if these labels are ot equal, the image  has been misclassified
misclassified_images_labels.append((im,predicted_label,true_label))
# return the list of misclassified [image,predicted_label,true_label] values
return misclassified_images_labels
# Find all misclassified images in a given test set
MISCLASSIFIED = get_misclassified_images(STANDARDIZED_TEST_LIST,display=False)
#Accuracy calcuations
total = len(STANDARDIZED_TEST_LIST)
num_correct = total-len(MISCLASSIFIED)
accuracy = num_correct / total
print('Accuracy:' str(accuracy))
print('Number of misclassfied images = ' str(len(MISCLASSIFIED)) ' out of ' str(total))
Accuracy:0.9797979797979798
Number of misclassfied images = 6 out of 297