Categories
Examples

Teachable Machines tflite model Python Test

If you had trained and tested your tflite model with teachablemachines you now want to test that in your python code. In this example code we are assuming the classification model and image processing task. We used the OpenCV and tensorflow model and load the tflite model with tf.lite.interpreter.

Here is our complete code

#this code is working with teachable machine learning tflite models 

import cv2
import numpy as np
import tensorflow as tf
#import tflite_runtime.interpreter as tflite

class_labels = ['Note50', 'Note100', 'Empty']  # Replace with your model's class labels



# Load the TFLite model
model_path = "vww_96_grayscale_quantized.tflite"
interpreter = tf.lite.Interpreter(model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_width, input_height = input_details[0]['shape'][1:3]
input_shape = input_details[0]['shape'][1:3]
print(input_shape)
# Open a video capture object (0 for the default camera)
cap = cv2.VideoCapture(0)

# Define the class names (if doing image classification)
class_names = ['Note50', 'Note100', 'Empty']  # Replace with your class names

while True:
    # Capture a frame from the video feed
    ret, frame = cap.read()
    
    if not ret:
        break
    
    # Preprocess the frame
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    preprocessed_frame = cv2.resize(gray_frame, (input_width, input_height))
    preprocessed_frame = np.expand_dims(preprocessed_frame, axis=0)
    preprocessed_frame = np.expand_dims(preprocessed_frame, axis=3)
    preprocessed_frame = preprocessed_frame.astype('float32') / 255.0  # Normalize the frame to [0, 1]

    # Run inference on the preprocessed frame
    interpreter.set_tensor(input_details[0]['index'], preprocessed_frame)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])
    

    # Process the output
    predicted_class = int(np.argmax(output[0]))
    confidence = output[0][np.argmax(output[0])]
    print(class_names[predicted_class],confidence)

    # Draw bounding boxes for object detection if applicable
    # This depends on the type of model you've trained with Teachable Machines.
    # If you have an object detection model, you can draw bounding boxes around detected objects.

    # Example code to draw a bounding box:
    # x, y, width, height = [int(i) for i in output[2][0]]
    # cv2.rectangle(frame, (x, y), (x + width, y + height), (0, 255, 0), 2)

    # Display the result on the frame
    cv2.putText(frame, f"Class: {class_names[predicted_class]}, Confidence: {confidence:.2f}", (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
    
    # Show the frame
    cv2.imshow("Live Video", frame)

    # Break the loop if 'q' is pressed
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release the video capture object and close all windows
cap.release()
cv2.destroyAllWindows()
Code language: PHP (php)

This code is working fine and here is example output.

remeber this code works in grayscale and 96x96x1 image because the training is done with esp32 cam pictures. Here is the code which I took from this page and it works good to take the sample images with esp32 cam.

// 26_Collect_Images.ino
#define MAX_RESOLUTION_XGA 1

/**
 * Run a development HTTP server to capture images
 * for TinyML tasks
 */

#include "esp32cam.h"
#include "esp32cam/http/FomoImageCollectionServer.h"

using namespace Eloquent::Esp32cam;


Cam cam;
Http::FOMO::CollectImagesServer http(cam);


void setup() {
    Serial.begin(115200);
    delay(3000);
    Serial.println("Init");

    /**
     * Replace with your camera model.
     * Available: aithinker, m5, m5wide, wrover, eye, ttgoLCD
     */
    cam.aithinker();
    cam.highQuality();
    cam.highestSaturation();
    cam.xga();

    while (!cam.begin())
        Serial.println(cam.getErrorMessage());

    // replace with your SSID and PASSWORD
    while (!cam.connect("SSID", "PASSWORD"))
        Serial.println(cam.getErrorMessage());

    while (!http.begin())
        Serial.println(http.getErrorMessage());

    Serial.println(http.getWelcomeMessage());
    cam.mDNS("esp32cam");
}


void loop() {
    http.handle();
}Code language: PHP (php)

By Abdul Rehman

My name is Abdul Rehman and I love to do Reasearch in Embedded Systems, Artificial Intelligence, Computer Vision and Engineering related fields. With 10+ years of experience in Research and Development field in Embedded systems I touched lot of technologies including Web development, and Mobile Application development. Now with the help of Social Presence, I like to share my knowledge and to document everything I learned and still learning.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.