| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- #include "main_functions.h"
- #include "detection_responder.h"
- #include "image_provider.h"
- #include "model_settings.h"
- #include "tensorflow/lite/micro/micro_error_reporter.h"
- #include "tensorflow/lite/micro/micro_interpreter.h"
- #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
- #include "tensorflow/lite/micro/models/person_detect_model_data.h"
- #include "tensorflow/lite/micro/system_setup.h"
- #include "tensorflow/lite/schema/schema_generated.h"
- // Globals, used for compatibility with Arduino-style sketches.
- namespace {
- tflite::ErrorReporter* error_reporter = nullptr;
- const tflite::Model* model = nullptr;
- tflite::MicroInterpreter* interpreter = nullptr;
- TfLiteTensor* input = nullptr;
- // In order to use optimized tensorflow lite kernels, a signed int8_t quantized
- // model is preferred over the legacy unsigned model format. This means that
- // throughout this project, input images must be converted from unisgned to
- // signed format. The easiest and quickest way to convert from unsigned to
- // signed 8-bit integers is to subtract 128 from the unsigned value to get a
- // signed value.
- // An area of memory to use for input, output, and intermediate arrays.
- constexpr int kTensorArenaSize = 136 * 1024;
- static uint8_t tensor_arena[kTensorArenaSize];
- } // namespace
- // The name of this function is important for Arduino compatibility.
- void setup() {
- tflite::InitializeTarget();
- // Set up logging. Google style is to avoid globals or statics because of
- // lifetime uncertainty, but since this has a trivial destructor it's okay.
- // NOLINTNEXTLINE(runtime-global-variables)
- static tflite::MicroErrorReporter micro_error_reporter;
- error_reporter = µ_error_reporter;
- // Map the model into a usable data structure. This doesn't involve any
- // copying or parsing, it's a very lightweight operation.
- model = tflite::GetModel(g_person_detect_model_data);
- if (model->version() != TFLITE_SCHEMA_VERSION) {
- TF_LITE_REPORT_ERROR(error_reporter,
- "Model provided is schema version %d not equal "
- "to supported version %d.",
- model->version(), TFLITE_SCHEMA_VERSION);
- return;
- }
- // Pull in only the operation implementations we need.
- // This relies on a complete list of all the ops needed by this graph.
- // An easier approach is to just use the AllOpsResolver, but this will
- // incur some penalty in code space for op implementations that are not
- // needed by this graph.
- //
- // tflite::AllOpsResolver resolver;
- // NOLINTNEXTLINE(runtime-global-variables)
- static tflite::MicroMutableOpResolver<5> micro_op_resolver;
- micro_op_resolver.AddAveragePool2D();
- micro_op_resolver.AddConv2D();
- micro_op_resolver.AddDepthwiseConv2D();
- micro_op_resolver.AddReshape();
- micro_op_resolver.AddSoftmax();
- // Build an interpreter to run the model with.
- // NOLINTNEXTLINE(runtime-global-variables)
- static tflite::MicroInterpreter static_interpreter(
- model, micro_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
- interpreter = &static_interpreter;
- // Allocate memory from the tensor_arena for the model's tensors.
- TfLiteStatus allocate_status = interpreter->AllocateTensors();
- if (allocate_status != kTfLiteOk) {
- TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
- return;
- }
- // Get information about the memory area to use for the model's input.
- input = interpreter->input(0);
- }
- // The name of this function is important for Arduino compatibility.
- void loop() {
- // Get image from provider.
- if (kTfLiteOk != GetImage(error_reporter, kNumCols, kNumRows, kNumChannels,
- input->data.int8)) {
- TF_LITE_REPORT_ERROR(error_reporter, "Image capture failed.");
- }
- // Run the model on this input and make sure it succeeds.
- if (kTfLiteOk != interpreter->Invoke()) {
- TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed.");
- }
- TfLiteTensor* output = interpreter->output(0);
- // Process the inference results.
- int8_t person_score = output->data.uint8[kPersonIndex];
- int8_t no_person_score = output->data.uint8[kNotAPersonIndex];
- RespondToDetection(error_reporter, person_score, no_person_score);
- }
- /**
- * @brief 进行人形识别
- *
- * @param in int8[96*96] 输入分辨率为96*96的灰度数据,每个像素1byte
- * @param out 输出结果, int8[2]
- * @return int 成功返回0, 否则返回负值
- */
- int do_person_detection(signed char* in, signed char* out) {
- if (input == NULL)
- return -1;
- memcpy(input->data.int8, in, kNumCols * kNumRows * kNumChannels);
- if (kTfLiteOk != interpreter->Invoke()) {
- return -2;
- }
- TfLiteTensor* output = interpreter->output(0);
- memcpy(out, output->data.uint8, 2);
- //int8_t person_score = output->data.uint8[kPersonIndex];
- //int8_t no_person_score = output->data.uint8[kNotAPersonIndex];
- return 0;
- }
|