debug.h 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. #pragma once
  2. #include "common.h"
  3. #include <string>
  4. #include <vector>
  5. #include <regex>
  6. // common debug functions and structs
  7. // Print a tensor's detailed data
  8. // data - the tensor's data in byte format
  9. // type - the tensor's quantization type
  10. // ne - the tensor dimensions array
  11. // nb - the tensor strides array
  12. // n - the number of rows/columns to fully print
  13. template <bool abort_on_nan> void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n);
  14. // Intended to use as callback for ggml_backend_sched_eval_callback
  15. // prints tensors that are processed in the computation graph
  16. // by default prints all tensors, but can be configured by creating a `base_callback_data` instance with
  17. // non-empty filter_patterns. See examples/debug.ccp for possible usage patterns
  18. // The template parameter determins whether an error should be thrown whenever a NaN is encountered
  19. // in a tensor (useful for stopping debug sessions on first erroneous tensor)
  20. // The callback data will be passed as the third parameter (user_data)
  21. template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data);
  22. struct base_callback_data {
  23. std::vector<uint8_t> data;
  24. std::vector<std::regex> tensor_filters;
  25. base_callback_data() = default;
  26. base_callback_data(common_params & params, const std::vector<std::string> & filter_patterns) {
  27. for (const auto & pattern : filter_patterns) {
  28. try {
  29. std::string anchored_pattern = "^" + pattern;
  30. tensor_filters.emplace_back(anchored_pattern, std::regex::optimize);
  31. } catch (const std::regex_error & e) {
  32. throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what());
  33. }
  34. }
  35. params.cb_eval = common_debug_cb_eval<false>;
  36. params.cb_eval_user_data = this;
  37. }
  38. };