Opencv

tflite 사전학습 가중치를 사용한 OpenCV 이미지 분류

coding art 2023. 6. 2. 19:50
728x90

이미지 데이터

image_2.jpg
0.04MB

 

라벨값

labels.csv
0.00MB

 

Detection 결과

 

데이터

labels.csv
0.00MB

참조: Object detection with Tensorflow model and OpenCV

https://towardsdatascience.com/object-detection-with-tensorflow-model-and-opencv-d839f3e42849

 

import tensorflow_hub as hub  # 코코데이터로 학습한 가중치 라이브러리를 불러 오는 곳
import cv2
import numpy
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
 

width = 1028
height = 1028

#Load image by Opencv2
img = cv2.imread('image_2.jpg')   # 입력 이미지를 읽어들인다.

#Resize to respect the input_shape
inp = cv2.resize(img, (width , height ))

#Convert img to RGB
rgb = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB)

# Converting to uint8
rgb_tensor = tf.convert_to_tensor(rgb, dtype=tf.uint8)

print(rgb_tensor.shape)  # (1028, 1028, 3)
print(rgb_tensor.numpy())
#Add dims to rgb_tensor
rgb_tensor = tf.expand_dims(rgb_tensor , 0)  # 텐서로 변환, axis=0 추가, (1, 1028, 1028, 3)
print(rgb_tensor.numpy())  # numpy 로 바꾸어 출력


# Loading model directly from TensorFlow Hub
detector = hub.load("https://tfhub.dev/tensorflow/efficientdet/lite2/detection/1")

# Loading csv with labels of classes
labels = pd.read_csv('labels.csv', sep=';', index_col='ID')
labels = labels['OBJECT (2017 REL.)']
print(labels)
# Creating prediction
boxes, scores, classes, num_detections = detector(rgb_tensor)

# Processing outputs
pred_labels = classes.numpy().astype('int')[0] 
pred_labels = [labels[i] for i in pred_labels]
pred_boxes = boxes.numpy()[0].astype('int')
pred_scores = scores.numpy()[0]

# Putting the boxes and labels on the image
for score, (ymin,xmin,ymax,xmax), label in zip(pred_scores, pred_boxes, pred_labels):
    if score < 0.5:
        continue

    score_txt = f'{100 * round(score)}%'
    img_boxes = cv2.rectangle(rgb,(xmin, ymax),(xmax, ymin),(0,255,0),2)      
    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(img_boxes, label,(xmin, ymax-10), font, 1.5, (255,0,0), 2, cv2.LINE_AA)
    cv2.putText(img_boxes,score_txt,(xmax, ymax-10), font, 1.5, (255,0,0), 2, cv2.LINE_AA)

# Result
plt.imshow(img_boxes)