TinyML的Hello World代码分析

虽然之前有用过CMSIS-NN框架进行一些归类问题的学习,但是其实也没系统地分析过主流一些框架的具体做法,越来越多的人找到我希望能写点什么,终于拖到现在打算写一下.

截至发这篇博客,我检出的版本是:8855f56500ff8efa449662a95fe69f24bb78c0a6

主要例子的起始文件是hello_world_test.cc,具体找到tensorflow源码工程中以下代码:

一开始就引入了大量头文件:

//一个允许解释器加载我们的模型所需要使用的操作的类
#include "tensorflow/lite/micro/all_ops_resolver.h"
//我们转换后得到的模型,Flat的,二进制存在数组里的.
#include "tensorflow/lite/micro/examples/hello_world/model.h"
//一个日志用的调试类.
#include "tensorflow/lite/micro/micro_error_reporter.h"
//TensorFlow Lite for Microcontrollers解释器,他会运行我们的模型.
#include "tensorflow/lite/micro/micro_interpreter.h"
//测试框架
#include "tensorflow/lite/micro/testing/micro_test.h"
//定义数据结构schema用,用于理解model数据.
#include "tensorflow/lite/schema/schema_generated.h"

代码下一部分由测试框架代替,即上面引入的头文件中,其中由这么包裹着.

TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
 ......
}
TF_LITE_MICRO_TESTS_END

其中TF_LITE_MICRO_TEST传入的参数LoadModelAndPerformInference是测试的名称,他会喝测试结果一起输出,以便查看测试是通过还是失败,先不看具体代码,运行试试.

make -f tensorflow/lite/micro/tools/make/Makefile  test_hello_world_test

最终可以看到,测试成功.

tensorflow/lite/micro/tools/make/downloads/flatbuffers already exists, skipping the download.
tensorflow/lite/micro/tools/make/downloads/pigweed already exists, skipping the download.
g++ -std=c++11 -fno-rtti -fno-exceptions -fno-threadsafe-statics -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Werror -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter -DLINUX -DTF_LITE_USE_CTIME -I. -Itensorflow/lite/micro/tools/make/downloads/gemmlowp -Itensorflow/lite/micro/tools/make/downloads/flatbuffers/include -Itensorflow/lite/micro/tools/make/downloads/ruy -Itensorflow/lite/micro/tools/make/downloads/kissfft -c tensorflow/lite/micro/examples/hello_world/hello_world_test.cc -o tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/obj/tensorflow/lite/micro/examples/hello_world/hello_world_test.o
g++ -std=c++11 -fno-rtti -fno-exceptions -fno-threadsafe-statics -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Werror -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter -DLINUX -DTF_LITE_USE_CTIME -I. -Itensorflow/lite/micro/tools/make/downloads/gemmlowp -Itensorflow/lite/micro/tools/make/downloads/flatbuffers/include -Itensorflow/lite/micro/tools/make/downloads/ruy -Itensorflow/lite/micro/tools/make/downloads/kissfft -o tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/bin/hello_world_test tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/obj/tensorflow/lite/micro/examples/hello_world/hello_world_test.o tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/obj/tensorflow/lite/micro/examples/hello_world/model.o tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/lib/libtensorflow-microlite.a -Wl,--fatal-warnings -Wl,--gc-sections -lm
tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/bin/hello_world_test '~~~ALL TESTS PASSED~~~' linux
Testing LoadModelAndPerformInference
1/1 tests passed
~~~ALL TESTS PASSED~~~

代码一开始先设置了一个logger,这个logger用起来和printf差不多,紧接着的下面就有一个例子,我们可以试试修改这个逻辑,让他直接打印或者自己模仿写一个来试试.

// Set up logging
tflite::MicroErrorReporter micro_error_reporter;

if (model->version() != TFLITE_SCHEMA_VERSION) {
  TF_LITE_REPORT_ERROR(&micro_error_reporter,
                        "Model provided is schema version %d not equal "
                        "to supported version %d.\n",
                        model->version(), TFLITE_SCHEMA_VERSION);
}

他大致会输出这些内容(实际上由于版本相等,他不会输出!):

Model provided is schema version X not equal to supported version X.

在这个打印之前有一个GetModel的操作,就是从我们的数组里面读取Model,然后生成一个TF Lite的Model的对象,这个Model是从示例代码里面的create_sine_model.ipynb创建的.

接下来代码创建定义各种东西,首先是创建一个主要的操作类,然后tensor_arena是TF Lite的工作内存,在单片机中应该用mallloc之类管理,他应该多大这个很难确定,一般就是先设定一个较大的数,然后逐步缩小,确定一个稳定且较为节约的数值,最后把这些东西连接起来成为一个interpreter,即字面意思:解释器.

  // This pulls in all the operation implementations we need
  tflite::AllOpsResolver resolver;

  constexpr int kTensorArenaSize = 2000;
  uint8_t tensor_arena[kTensorArenaSize];

  // Build an interpreter to run the model with
  tflite::MicroInterpreter interpreter(model, resolver, tensor_arena,
                                       kTensorArenaSize, &micro_error_reporter);
  // Allocate memory from the tensor_arena for the model's tensors
  TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);

接下来要把一些输入的内容放进去,先申请一个输入空间,然后确定申请到的不是NULL(就像内存分配一样,能分配到空间才行啊),一旦申请成功后,模型就已经成功加载(因为之前声明解释器时候已经确定了模型),而后面几个EQ是断言这个模型的大小规格,也就是说,实际上整个申请只有interpreter.input(0)这么一句.

  // Obtain a pointer to the model's input tensor
  TfLiteTensor* input = interpreter.input(0);

  // Make sure the input has the properties we expect
  TF_LITE_MICRO_EXPECT_NE(nullptr, input);
  // The property "dims" tells us the tensor's shape. It has one element for
  // each dimension. Our input is a 2D tensor containing 1 element, so "dims"
  // should have size 2.
  TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
  // The value of each element gives the length of the corresponding tensor.
  // We should expect two single element tensors (one is contained within the
  // other).
  TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
  TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
  // The input is an 8 bit integer value
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, input->type);

这么就可以开始进行推断了,这里分别三段是获取输入的量化参数,将输入的浮点数量化为整数(为了优化速度),将量化的输入放在模型的输入张量中,然后运行模型.

  // Get the input quantization parameters
  float input_scale = input->params.scale;
  int input_zero_point = input->params.zero_point;

  // Quantize the input from floating-point to integer
  int8_t x_quantized = x / input_scale + input_zero_point;
  // Place the quantized input in the model's input tensor
  input->data.int8[0] = x_quantized;

  // Run the model and check that it succeeds
  TfLiteStatus invoke_status = interpreter.Invoke();
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);

然后读取其输出,然后断言其数据正确性,再把输出还原成float类型.

  // Obtain a pointer to the output tensor and make sure it has the
  // properties we expect. It should be the same as the input tensor.
  TfLiteTensor* output = interpreter.output(0);
  TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
  TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
  TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[1]);
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, output->type);

  // Get the output quantization parameters
  float output_scale = output->params.scale;
  int output_zero_point = output->params.zero_point;

  // Obtain the quantized output from model's output tensor
  int8_t y_pred_quantized = output->data.int8[0];
  // Dequantize the output from integer to floating-point
  float y_pred = (y_pred_quantized - output_zero_point) * output_scale;

然后测试误差在不在范围内,后续几个测试都是这个意思.

  float epsilon = 0.05f;
  TF_LITE_MICRO_EXPECT_NEAR(y_true, y_pred, epsilon);

  // Run inference on several more values and confirm the expected outputs
  x = 1.f;
  y_true = sin(x);
  input->data.int8[0] = x / input_scale + input_zero_point;
  interpreter.Invoke();
  y_pred = (output->data.int8[0] - output_zero_point) * output_scale;
  TF_LITE_MICRO_EXPECT_NEAR(y_true, y_pred, epsilon);

这里Invoke有很多个用途,模型输入用了他(输入就会产生输出,所以输出部分看不到~),用来推断数据也用了他.

如果修改输入模型的参数,或者让误差变得更严格(超过模型本身能力),就会出现错误.

除了这个代码,还有很多其他文件夹里面包含了不同微控制器用的代码.

比如看到ESP的代码里只有这些.

从头文件能看出来,其他内容从main_functions.cc开始,这个文件一开始有很多熟悉的include,不用展开应该都能猜到具体意思了.

#include "tensorflow/lite/micro/examples/hello_world/main_functions.h"

#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/examples/hello_world/constants.h"
#include "tensorflow/lite/micro/examples/hello_world/model.h"
#include "tensorflow/lite/micro/examples/hello_world/output_handler.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"

然后定义了全局使用的一些变量,这些变量名应该都很熟悉,在我们的测试例子里就是同样的名字,唯一的新鲜事物是inference_count,他指示这个程序进行了多少次推断.

// Globals, used for compatibility with Arduino-style sketches.
namespace {
tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* input = nullptr;
TfLiteTensor* output = nullptr;
int inference_count = 0;

constexpr int kTensorArenaSize = 2000;
uint8_t tensor_arena[kTensorArenaSize];
}  // namespace

玩过Arduino应该都不陌生,一开始有个setup函数,然后有个loop函数.setup函数只执行一次,loop函数一直执行.

setup还是熟悉的代码,有点疑惑的是他在setup过程就访问了输出?不是的,其实只是给output这个ptr分配一下内存~

// The name of this function is important for Arduino compatibility.
void setup() {
  tflite::InitializeTarget();

  // Set up logging. Google style is to avoid globals or statics because of
  // lifetime uncertainty, but since this has a trivial destructor it's okay.
  // NOLINTNEXTLINE(runtime-global-variables)
  static tflite::MicroErrorReporter micro_error_reporter;
  error_reporter = &micro_error_reporter;

  // Map the model into a usable data structure. This doesn't involve any
  // copying or parsing, it's a very lightweight operation.
  model = tflite::GetModel(g_model);
  if (model->version() != TFLITE_SCHEMA_VERSION) {
    TF_LITE_REPORT_ERROR(error_reporter,
                         "Model provided is schema version %d not equal "
                         "to supported version %d.",
                         model->version(), TFLITE_SCHEMA_VERSION);
    return;
  }

  // This pulls in all the operation implementations we need.
  // NOLINTNEXTLINE(runtime-global-variables)
  static tflite::AllOpsResolver resolver;

  // Build an interpreter to run the model with.
  static tflite::MicroInterpreter static_interpreter(
      model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
  interpreter = &static_interpreter;

  // Allocate memory from the tensor_arena for the model's tensors.
  TfLiteStatus allocate_status = interpreter->AllocateTensors();
  if (allocate_status != kTfLiteOk) {
    TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
    return;
  }

  // Obtain pointers to the model's input and output tensors.
  input = interpreter->input(0);
  output = interpreter->output(0);

  // Keep track of how many inferences we have performed.
  inference_count = 0;
}

loop代码里也相对的熟悉,其中kXrange,kInferencesPerCycle是定义在常量里的,kXrange是2pi,kInferencesPerCycle是限制inference_count的每个周期的最大推理次数,一旦到达整个周期,将会回到0继续推理(因为这是个周期正弦),要记住,这里依然是x输入得到y,代码和测试样例里差不多,除了HandleOutput这个额外的.

// The name of this function is important for Arduino compatibility.
void loop() {
  // Calculate an x value to feed into the model. We compare the current
  // inference_count to the number of inferences per cycle to determine
  // our position within the range of possible x values the model was
  // trained on, and use this to calculate a value.
  float position = static_cast<float>(inference_count) /
                   static_cast<float>(kInferencesPerCycle);
  float x = position * kXrange;

  // Quantize the input from floating-point to integer
  int8_t x_quantized = x / input->params.scale + input->params.zero_point;
  // Place the quantized input in the model's input tensor
  input->data.int8[0] = x_quantized;

  // Run inference, and report any error
  TfLiteStatus invoke_status = interpreter->Invoke();
  if (invoke_status != kTfLiteOk) {
    TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x: %f\n",
                         static_cast<double>(x));
    return;
  }

  // Obtain the quantized output from model's output tensor
  int8_t y_quantized = output->data.int8[0];
  // Dequantize the output from integer to floating-point
  float y = (y_quantized - output->params.zero_point) * output->params.scale;

  // Output the results. A custom HandleOutput function can be implemented
  // for each supported hardware target.
  HandleOutput(error_reporter, x, y);

  // Increment the inference_counter, and reset it if we have reached
  // the total number per cycle
  inference_count += 1;
  if (inference_count >= kInferencesPerCycle) inference_count = 0;
}

HandleOutput即关联到硬件上的输出,看看Arduino例子的Handle是这么实现的.

// Animates a dot across the screen to represent the current x and y values
void HandleOutput(tflite::ErrorReporter* error_reporter, float x_value,
                  float y_value) {
  // Do this only once
  if (!initialized) {
    // Set the LED pin to output
    pinMode(led, OUTPUT);
    initialized = true;
  }

  // Calculate the brightness of the LED such that y=-1 is fully off
  // and y=1 is fully on. The LED's brightness can range from 0-255.
  int brightness = (int)(127.5f * (y_value + 1));

  // Set the brightness of the LED. If the specified pin does not support PWM,
  // this will result in the LED being on when y > 127, off otherwise.
  analogWrite(led, brightness);

  // Log the current brightness value for display in the Arduino plotter
  TF_LITE_REPORT_ERROR(error_reporter, "%d\n", brightness);
}

OK,这样有一个基于机器学习(逼格高)的呼吸灯,现在还没说怎么让他在实际硬件中跑,所以,我们可以这么测试.

 make -f tensorflow/lite/micro/tools/make/Makefile hello_world

执行其生成的文件即可,但是输出的数据是以2的幂形式输出的,具体自己手动换算一下就得了.

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注