Model validation#

# Imports
import os
import cv2
import glob
import pickle
import pandas as pd
import numpy as np
from utils import *
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder 
# Load data
df = pd.read_csv("../data/color_data.csv")

# Encode categorical labels
labelencoder = LabelEncoder() 
df['Class'] = labelencoder.fit_transform(df['Class'])

# Load model
filename = '../data/model.sav'
model = pickle.load(open(filename, 'rb'))
# Accuracy
correct = 0

# Loop over every image
for i in glob.glob('../data/resistor_images/*jpg'):

	# Read image
	image = cv2.imread(i)

	# Read label
	label = i.split('_')[-1][0:3]

	# Extract color band contours
	bands = extract_color_bands(image)

	# Iterate over first three contours
	prediction = ''
	for j, band in enumerate(bands):

		# Predict
		pred = model.predict([band])

		# Convert to class
		prediction += labelencoder.inverse_transform(pred)[0]

	# Draw text
	if len(prediction) == 3:

		# Accuracy
		if prediction == label: correct += 1

		# Plot text
		cv2.putText(image, text=prediction + " - " + decode(prediction), org=(150, 250), fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=3, color=(0, 255, 0),thickness=3)
	
	# Show
	if correct % 20 == 0:
		plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
		plt.show()
../_images/da8b7ddbd54984fcd921499ce9f12696fbb2e4178c3526130d92c51e27d13324.png ../_images/4384d8b4b73045684ad499a181ea6449b9e96eef8ba5d8f9e39254a78dde63dd.png ../_images/c0936efcbaab87bb64a688679d1ef445a0b422829d92c254be6005b64220a56e.png ../_images/32c4495788898bf1fee7f321befc582e08ef662eec3235ed42d31df8f6967807.png ../_images/3b8345c5bed94fa27bff17f0af9e54660e24fbd408180a8862dfc6dfc7b4fc49.png ../_images/7bd27286f316a2ca898ef42f135f1de905770d0dbf4409e413944dc601b1e3fc.png ../_images/a196a9cc05f54572a82f46982c2528515c0d688721492e91ff3c000430aeb44a.png ../_images/640925efd11ccbddf748129d38ed583ba4956fbc970073834dd9d7f2bb7583a6.png ../_images/936812fbf23a9a27b32eb8e6de3769be770e5f605aec3c518f074c0520218d6c.png ../_images/8366fb02966c4f97e9b133884bf0593a558507a02d9ea62f7082b84d47eb1ef1.png
# Compute accuracy
print("Accuracy: " + str(correct/len(os.listdir("../data/resistor_images"))*100) + ' %')
Accuracy: 100.0 %