diff --git a/python/pokedex_rpi.py b/python/pokedex_rpi.py index 449a764553385aff6c7a8ddc1dd028ef645242c9..91020193dd81bb1ec87e6655e29a0af9935dedf3 100644 --- a/python/pokedex_rpi.py +++ b/python/pokedex_rpi.py @@ -1,9 +1,10 @@ -import cv2 -import numpy as np import json import argparse import os -from hailo_platform.pyhailort import HailoRT +import numpy as np +from picamera2 import Picamera2 +from picamera2.devices.hailo import Hailo +import cv2 # --- Argparse --- parser = argparse.ArgumentParser(description="Pokémon Classifier Inference with Hailo-8") @@ -14,11 +15,9 @@ args = parser.parse_args() if args.model == "1": hef_path = "../models/ResNet50/pokedex_ResNet50.hef" json_path = "../models/ResNet50/class_names.json" - input_shape = (224, 224) elif args.model == "2": hef_path = "../models/Xception/pokedex_Xception.hef" json_path = "../models/Xception/class_names.json" - input_shape = (256, 256) else: raise ValueError("Invalid model selection") @@ -26,57 +25,37 @@ else: with open(json_path, "r") as f: class_names = json.load(f) -# --- Setup device and network --- -device = HailoRT.Device() -hef = HailoRT.Hef(hef_path) -network_group = device.create_hef_group(hef) - -input_info = network_group.get_input_vstream_infos()[0] -output_info = network_group.get_output_vstream_infos()[0] - -# --- Open webcam and capture image --- -cap = cv2.VideoCapture(0) -if not cap.isOpened(): - print("-- Unable to open webcam") - exit(1) - -print("-- Capturing image...") -ret, frame = cap.read() -cap.release() - -if not ret: - print("-- Failed to capture image") - exit(1) - -# --- Try to display the captured image --- -try: - cv2.imshow("Captured Image", frame) - print("-- Press any key to continue...") - cv2.waitKey(0) - cv2.destroyAllWindows() -except cv2.error as e: - print("-- GUI display failed, saving and showing with feh instead...") - output_path = "/tmp/captured.png" - cv2.imwrite(output_path, frame) - os.system(f"feh --fullscreen {output_path}") - -# --- Preprocess image --- -image = cv2.resize(frame, input_shape) -image = image.astype(np.float32) - -# Standard Hailo normalization -image -= [123.68, 116.779, 103.939] - -# NHWC → NCHW -image = np.transpose(image, (2, 0, 1)) # (H, W, C) → (C, H, W) -image = np.expand_dims(image, axis=0) # Add batch dimension → (1, C, H, W) - -# --- Inference --- -with HailoRT.VirtualStreams(input_info, output_info, network_group) as (input_vstreams, output_vstreams): - input_vstreams[0].send(image) - output_data = output_vstreams[0].recv() - -# --- Postprocess --- -predicted_idx = int(np.argmax(output_data)) -predicted_name = class_names[predicted_idx] -print(f"-- Predicted Pokémon: {predicted_name}") +# --- Run inference --- +with Hailo(hef_path) as hailo: + # Get model input shape (e.g., 224x224x3 or 256x256x3) + model_h, model_w, _ = hailo.get_input_shape() + + # Setup and start the camera + picam2 = Picamera2() + main = {'size': (model_w, model_h), 'format': 'RGB888'} + config = picam2.create_preview_configuration(main) + picam2.start(config) + + print("-- Capturing image...") + frame = picam2.capture_array() + + # Optionally display the captured image + try: + cv2.imshow("Captured Image", frame) + print("-- Press any key to continue...") + cv2.waitKey(0) + cv2.destroyAllWindows() + except cv2.error: + print("-- GUI display failed, saving and showing with feh instead...") + output_path = "/tmp/captured.png" + cv2.imwrite(output_path, frame) + os.system(f"feh --fullscreen {output_path}") + + # Run inference + print("-- Running inference...") + inference_results = hailo.run(frame) + + # Postprocess: find predicted class + predicted_idx = int(np.argmax(inference_results)) + predicted_name = class_names[predicted_idx] + print(f"-- Predicted Pokémon: {predicted_name}")