Skip to content
Snippets Groups Projects
pokedex_rpi.py 1.84 KiB
import cv2
import numpy as np
import json
from hailo_platform.pyhailort import HailoRT
import argparse

# --- WHAT ? ---
parser = argparse.ArgumentParser(description="WHAT ?!")
parser.add_argument("--model", choices=["1", "2"], required=True, help="1 = ResNet50, 2 = Xception")
args = parser.parse_args()

# Paths
if args.model == "1":
    hef_path = "../models/ResNet50/pokedex_ResNet50.hef"
    json_path = "../models/ResNet50/class_names.json"
    size=(224,224)
elif args.model == "2":
    hef_path = "../models/Xception/pokedex_Xception.hef"
    json_path = "../models/Xception/class_names.json"
    size=(256,256)

# Load class names
with open(json_path, "r") as f:
    class_names = json.load(f)

device = HailoRT.Device()
hef = HailoRT.Hef(hef_path)
configured_network_group = device.create_hef_group(hef)
input_vstream_info = configured_network_group.get_input_vstream_infos()[0]
output_vstream_info = configured_network_group.get_output_vstream_infos()[0]

# --- Open webcam ---
cap = cv2.VideoCapture(0)
if not cap.isOpened():
    print("-- Unable to open webcam")
    exit()

print("-- Taking picture...")
ret, frame = cap.read()
cap.release()

if not ret:
    print("-- Failed to capture image")
    exit()

# --- Preprocess image ---
image = cv2.resize(frame, size)
image = image.astype(np.float32) / 255.0  # Normalize to [0, 1]
image = np.expand_dims(image, axis=0)     # Add batch dimension
image = np.transpose(image, (0, 3, 1, 2))  # NHWC ? NCHW if required (check your model)

# --- Inference ---
with HailoRT.VirtualStreams(input_vstream_info, output_vstream_info, configured_network_group) as (input_vstreams, output_vstreams):
    input_vstreams[0].send(image)
    output_data = output_vstreams[0].recv()

# --- Postprocess ---
predicted_idx = int(np.argmax(output_data))
predicted_name = class_names[predicted_idx]
print(f"-- Predicted Pokémon: {predicted_name}")