fake_micro_context.cc 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/micro/fake_micro_context.h"
  13. #include "tensorflow/lite/kernels/internal/compatibility.h"
  14. #include "tensorflow/lite/micro/micro_arena_constants.h"
  15. #include "tensorflow/lite/micro/micro_error_reporter.h"
  16. namespace tflite {
  17. using ::tflite::MicroArenaBufferAlignment;
  18. FakeMicroContext::FakeMicroContext(TfLiteTensor* tensors,
  19. SimpleMemoryAllocator* allocator,
  20. MicroGraph* micro_graph)
  21. : MicroContext(nullptr, nullptr, micro_graph),
  22. tensors_(tensors),
  23. allocator_(allocator) {}
  24. TfLiteTensor* FakeMicroContext::GetTensor(int tensor_index) {
  25. return &tensors_[tensor_index];
  26. }
  27. TfLiteEvalTensor* FakeMicroContext::GetEvalTensor(int tensor_index) {
  28. TfLiteEvalTensor* eval_tensor =
  29. reinterpret_cast<TfLiteEvalTensor*>(allocator_->AllocateTemp(
  30. sizeof(TfLiteEvalTensor), alignof(TfLiteEvalTensor)));
  31. TFLITE_DCHECK(eval_tensor != nullptr);
  32. // In unit tests, the TfLiteTensor pointer contains the source of truth for
  33. // buffers and values:
  34. eval_tensor->data = tensors_[tensor_index].data;
  35. eval_tensor->dims = tensors_[tensor_index].dims;
  36. eval_tensor->type = tensors_[tensor_index].type;
  37. return eval_tensor;
  38. }
  39. void* FakeMicroContext::AllocatePersistentBuffer(size_t bytes) {
  40. // FakeMicroContext use SimpleMemoryAllocator, which does not automatically
  41. // apply the buffer alignment like MicroAllocator.
  42. // The buffer alignment is potentially wasteful but allows the
  43. // fake_micro_context to work correctly with optimized kernels.
  44. return allocator_->AllocateFromTail(bytes, MicroArenaBufferAlignment());
  45. }
  46. TfLiteStatus FakeMicroContext::RequestScratchBufferInArena(size_t bytes,
  47. int* buffer_index) {
  48. TFLITE_DCHECK(buffer_index != nullptr);
  49. if (scratch_buffer_count_ == kNumScratchBuffers_) {
  50. MicroPrintf("Exceeded the maximum number of scratch tensors allowed (%d).",
  51. kNumScratchBuffers_);
  52. return kTfLiteError;
  53. }
  54. // For tests, we allocate scratch buffers from the tail and keep them around
  55. // for the lifetime of model. This means that the arena size in the tests will
  56. // be more than what we would have if the scratch buffers could share memory.
  57. scratch_buffers_[scratch_buffer_count_] =
  58. allocator_->AllocateFromTail(bytes, MicroArenaBufferAlignment());
  59. TFLITE_DCHECK(scratch_buffers_[scratch_buffer_count_] != nullptr);
  60. *buffer_index = scratch_buffer_count_++;
  61. return kTfLiteOk;
  62. }
  63. void* FakeMicroContext::GetScratchBuffer(int buffer_index) {
  64. TFLITE_DCHECK(scratch_buffer_count_ <= kNumScratchBuffers_);
  65. if (buffer_index >= scratch_buffer_count_) {
  66. return nullptr;
  67. }
  68. return scratch_buffers_[buffer_index];
  69. }
  70. } // namespace tflite