Skip to content
Snippets Groups Projects
Commit 6cb09929 authored by michael.divia's avatar michael.divia
Browse files

MayBe

parent a4d04f4e
No related branches found
No related tags found
No related merge requests found
import cv2
import numpy as np
import json import json
import argparse import argparse
import os import os
from hailo_platform.pyhailort import HailoRT import numpy as np
from picamera2 import Picamera2
from picamera2.devices.hailo import Hailo
import cv2
# --- Argparse --- # --- Argparse ---
parser = argparse.ArgumentParser(description="Pokémon Classifier Inference with Hailo-8") parser = argparse.ArgumentParser(description="Pokémon Classifier Inference with Hailo-8")
...@@ -14,11 +15,9 @@ args = parser.parse_args() ...@@ -14,11 +15,9 @@ args = parser.parse_args()
if args.model == "1": if args.model == "1":
hef_path = "../models/ResNet50/pokedex_ResNet50.hef" hef_path = "../models/ResNet50/pokedex_ResNet50.hef"
json_path = "../models/ResNet50/class_names.json" json_path = "../models/ResNet50/class_names.json"
input_shape = (224, 224)
elif args.model == "2": elif args.model == "2":
hef_path = "../models/Xception/pokedex_Xception.hef" hef_path = "../models/Xception/pokedex_Xception.hef"
json_path = "../models/Xception/class_names.json" json_path = "../models/Xception/class_names.json"
input_shape = (256, 256)
else: else:
raise ValueError("Invalid model selection") raise ValueError("Invalid model selection")
...@@ -26,57 +25,37 @@ else: ...@@ -26,57 +25,37 @@ else:
with open(json_path, "r") as f: with open(json_path, "r") as f:
class_names = json.load(f) class_names = json.load(f)
# --- Setup device and network --- # --- Run inference ---
device = HailoRT.Device() with Hailo(hef_path) as hailo:
hef = HailoRT.Hef(hef_path) # Get model input shape (e.g., 224x224x3 or 256x256x3)
network_group = device.create_hef_group(hef) model_h, model_w, _ = hailo.get_input_shape()
input_info = network_group.get_input_vstream_infos()[0] # Setup and start the camera
output_info = network_group.get_output_vstream_infos()[0] picam2 = Picamera2()
main = {'size': (model_w, model_h), 'format': 'RGB888'}
# --- Open webcam and capture image --- config = picam2.create_preview_configuration(main)
cap = cv2.VideoCapture(0) picam2.start(config)
if not cap.isOpened():
print("-- Unable to open webcam") print("-- Capturing image...")
exit(1) frame = picam2.capture_array()
print("-- Capturing image...") # Optionally display the captured image
ret, frame = cap.read() try:
cap.release() cv2.imshow("Captured Image", frame)
print("-- Press any key to continue...")
if not ret: cv2.waitKey(0)
print("-- Failed to capture image") cv2.destroyAllWindows()
exit(1) except cv2.error:
print("-- GUI display failed, saving and showing with feh instead...")
# --- Try to display the captured image --- output_path = "/tmp/captured.png"
try: cv2.imwrite(output_path, frame)
cv2.imshow("Captured Image", frame) os.system(f"feh --fullscreen {output_path}")
print("-- Press any key to continue...")
cv2.waitKey(0) # Run inference
cv2.destroyAllWindows() print("-- Running inference...")
except cv2.error as e: inference_results = hailo.run(frame)
print("-- GUI display failed, saving and showing with feh instead...")
output_path = "/tmp/captured.png" # Postprocess: find predicted class
cv2.imwrite(output_path, frame) predicted_idx = int(np.argmax(inference_results))
os.system(f"feh --fullscreen {output_path}") predicted_name = class_names[predicted_idx]
print(f"-- Predicted Pokémon: {predicted_name}")
# --- Preprocess image ---
image = cv2.resize(frame, input_shape)
image = image.astype(np.float32)
# Standard Hailo normalization
image -= [123.68, 116.779, 103.939]
# NHWC → NCHW
image = np.transpose(image, (2, 0, 1)) # (H, W, C) → (C, H, W)
image = np.expand_dims(image, axis=0) # Add batch dimension → (1, C, H, W)
# --- Inference ---
with HailoRT.VirtualStreams(input_info, output_info, network_group) as (input_vstreams, output_vstreams):
input_vstreams[0].send(image)
output_data = output_vstreams[0].recv()
# --- Postprocess ---
predicted_idx = int(np.argmax(output_data))
predicted_name = class_names[predicted_idx]
print(f"-- Predicted Pokémon: {predicted_name}")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment