-
michael.divia authoredmichael.divia authored
pokedex_rpi.py 1.84 KiB
import cv2
import numpy as np
import json
from hailo_platform.pyhailort import HailoRT
import argparse
# --- WHAT ? ---
parser = argparse.ArgumentParser(description="WHAT ?!")
parser.add_argument("--model", choices=["1", "2"], required=True, help="1 = ResNet50, 2 = Xception")
args = parser.parse_args()
# Paths
if args.model == "1":
hef_path = "../models/ResNet50/pokedex_ResNet50.hef"
json_path = "../models/ResNet50/class_names.json"
size=(224,224)
elif args.model == "2":
hef_path = "../models/Xception/pokedex_Xception.hef"
json_path = "../models/Xception/class_names.json"
size=(256,256)
# Load class names
with open(json_path, "r") as f:
class_names = json.load(f)
device = HailoRT.Device()
hef = HailoRT.Hef(hef_path)
configured_network_group = device.create_hef_group(hef)
input_vstream_info = configured_network_group.get_input_vstream_infos()[0]
output_vstream_info = configured_network_group.get_output_vstream_infos()[0]
# --- Open webcam ---
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("-- Unable to open webcam")
exit()
print("-- Taking picture...")
ret, frame = cap.read()
cap.release()
if not ret:
print("-- Failed to capture image")
exit()
# --- Preprocess image ---
image = cv2.resize(frame, size)
image = image.astype(np.float32) / 255.0 # Normalize to [0, 1]
image = np.expand_dims(image, axis=0) # Add batch dimension
image = np.transpose(image, (0, 3, 1, 2)) # NHWC ? NCHW if required (check your model)
# --- Inference ---
with HailoRT.VirtualStreams(input_vstream_info, output_vstream_info, configured_network_group) as (input_vstreams, output_vstreams):
input_vstreams[0].send(image)
output_data = output_vstreams[0].recv()
# --- Postprocess ---
predicted_idx = int(np.argmax(output_data))
predicted_name = class_names[predicted_idx]
print(f"-- Predicted Pokémon: {predicted_name}")