Running deep learning models on microcontrollers is not exactly a run-of-the-mill task due to resource limitations of microcontrollers. Such systems are typically designed for low-power, low-cost embedded systems with minimal processing power and memory. Machine learning algorithms, especially deep learning models, often require significant computational resources and memory.
One of our clients is Nomo Smart Care:
The Nomo system is for caregivers who want to make sure a loved one is OK. The Nomo system uses sensors, not cameras, to monitor in-home motion. Data from sensors is sent to the Nomo mobile app and allows you, or a circle of trusted caregivers, to check in on your loved one from anywhere, any time.
Besides various other sensors, Nomo can utilize microphones to try to paint a picture of what is going on in the home of a care recipient. For example, if fire alarm is going off, this is probably something a caregiver should know about. Since Nomo cares deeply about privacy, no audio recordings are allowed to exit users’ homes. Hence, audio classification must be performed on edge devices. In particular, the target system is based on ESP32-PICO meaning that we are working with the following constraints:
- 240MHz 32-bit CPU, and
- 2MB PSRAM Nomo can utilize.
Luckily, LiteRT for Microcontrollers (formerly known as TensorFlow Lite for Microcontrollers) has been ported to ESP32 architecture allowing us to run basic machine learning models on low-resource devices.
This article illustrates a typical workflow for getting AI models running on microcontrollers. This is demonstrated by training and deploying a trivial model that recognized handwritten digits from the MNIST dataset. The process is identical for CNN models we use for audio classification.
Model training
We use Keras to train our TensorFlow models. MNIST images are grayscale, 28 pixels wide and 28 pixels tall, so Input
layer shape is (28, 28, 1)
. There are 10 digits, so the neural network has 10 outputs. We use simple CNN in this example:
import numpy as np
import tensorflow as tf
def get_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(28, 28, 1), name="input"),
tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
tf.keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10, activation="softmax", name="output"),
])
model.compile(optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["sparse_categorical_accuracy"])
return model
This model has approximately 260K parameters resulting in approximately 1KB model size. This size will be reduced after training using post-training quantization methods. The following code prepares MNIST dataset and trains the model:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# Convert pixel color values from integer [0, 255] to float [0.0, 1.0].
x_train = x_train.astype(np.float32) / 255.0
x_test = x_test.astype(np.float32) / 255.0
# Keras dataset has inputs of shape (28, 28) so a grayscale dimension is
# added resulting in shape (28, 28, 1).
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)
model = get_model()
model.fit(x_train, y_train, validation_split=0.2, epochs=10)
Finally, to make the model suitable for running on low-power edge devices, the model needs to be quantized and converted to LiteRT format.
Post-training quantization includes general techniques to reduce CPU and hardware accelerator latency, processing, power, and model size with little degradation in model accuracy.
This example uses full-integer quantization which converts 32-bit floating point numbers to the nearest unsigned 8-bit integers.
def convert_to_tflite_and_quantize(model, representative_dataset_generator): converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset_generator converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] return converter.convert() def create_representative_dataset_generator(x_train): def representative_dataset_generator(): for x in x_train: yield {'input': np.expand_dims(x, axis=0) } return representative_dataset_generator representative_dataset_generator = create_representative_dataset_generator( x_train) tflite_model = convert_to_tflite_and_quantize( model, representative_dataset_generator) with open('model.tflite', "wb") as f: f.write(tflite_model)
Running the three listings of Python code produces
model.tflite
file that can be loaded by C++ code and used to recognize digits. The quantized model is approximately 260KB in size.
Performing inference in C++
This tutorial describes performing inference in C++ in great detail. In the tutorial, the model is converted to unsigned char
array by running:
xxd -i model.tflite > model_data.c
This produces a C file with the unsigned char
array that should be compiled with the rest of the project sources in order to embed the model data into the binary. Alternatively, the model.tflite
file can be loaded as a binary file at runtime. The full code is similar to tflite-micro hello world example so there is no use in repeating the code here. Important difference is that for running our MNIST model, the tflite::MicroMutableOpResolver
object needs to have the following operations:
using MnistExampleOperationResolver = tflite::MicroMutableOpResolver<6>;
TfLiteStatus register_operations(MnistExampleOperationResolver& resolver) {
TF_LITE_ENSURE_STATUS(resolver.AddFullyConnected());
TF_LITE_ENSURE_STATUS(resolver.AddConv2D());
TF_LITE_ENSURE_STATUS(resolver.AddMaxPool2D());
TF_LITE_ENSURE_STATUS(resolver.AddSoftmax());
TF_LITE_ENSURE_STATUS(resolver.AddMean());
TF_LITE_ENSURE_STATUS(resolver.AddQuantize());
return kTfLiteOk;
}
Additionally, the tensor arena size should be larger so we set it to 100KB. Finally, the quantization parameters need to be extracted from the model so that the input images can be quantized before feeding them into the model, and so that the model output can be de-quantized.
struct QuantizationData final { float scale; int zero_point; };
QuantizationData get_quantization_data(const TfLiteQuantization& quant) {
if (quant.type != kTfLiteAffineQuantization) {
throw std::runtime_error{"no quantization"};
}
const auto* const quantization =
static_cast<const TfLiteAffineQuantization*>(quant.params);
const auto* const scales = quant->scale;
const auto* const zero_points = quant->zero_point;
if (quant->quantized_dimension != 0 ||
scales->size != 1 ||
zero_points->size != 1) {
throw std::runtime_error{"unexpected quantization parameters"};
}
return QuantizationData{scales->data[0], zero_points->data[0]};
}
This function throws an exception if no quantization is present so only quantized models are supported. Additionally, only quantization on a first dimension with a single zero point and a single scale value is supported. At the time of writing this article, this should be the case for all TensorFlow models quantized and converted using the Python code provided above.
The last thing to do is to load an image, quantize it, perform inference, and de-quantize the model output. We use Boost GIL in our example to load an image and get a constant view over the image data which we feed into the model input tensor.
std::array<float, 10> predict(boost::gil::gray8c_view_t image,
tflite::MicroInterpreter& interpreter) {
auto* const input = interpreter.typed_input_tensor<std::int8_t>(0);
const auto* const output =
interpreter.typed_output_tensor<std::int8_t>(0);
const auto input_quant =
get_quantization_data(interpreter->input(0)->quantization);
const auto output_quant =
get_quantization_data(interpreter->output(0)->quantization);
std::transform(image.begin(), image.end(), input,
[input_quant](boost::gil::gray8_pixel_t pixel) -> std::int8_t {
const auto pixel_as_float = pixel / 255.0f;
const auto [scale, zero_point] = input_quantization_data;
return static_cast<std::int8_t>(
static_cast<int>(std::round(pixel_as_float / scale))
+ zero_point
);
});
if (interpreter.Invoke() != kTfLiteOk) {
throw std::runtime_error{"interpreter invoke failed"};
}
std::array<float, 10> result;
std::transform(output, output + 10, result.begin(),
[output_quantization_data](std::int8_t value) -> float {
const auto [scale, zero_point] = output_quantization_data;
return scale * static_cast<float>(value - zero_point);
});
return result;
});
The first std::transform
call iterates over image pixels, converts each pixel value from integer interval [0, 255]
to floating point interval [0.0, 1.0]
, quantizes the value, and feeds it into the model. After invoking the interpreter, the second std::transform
call iterates reads the model outputs, de-quantizes them, and stores them into the result
array. The result
array contains the probabilities of each digit so the recognized image could be retrieved by using std::max_element
call:
const auto result = predict(image, interpreter);
std::cout << "Digit: "
<< std::distance(result.cbegin(),
std::max_element(result.cbegin(),
result.cend()));
Conclusion
While still a bit unorthodox, doing AI on microcontrollers is feasible. Standard C++17 runs on many embedded platforms making LiteRT for Microcontrollers very portable. With support for many common neural network operations, a variety of AI tasks can be solved on edge devices, reducing communication and processing costs for application developers and enhancing user privacy.