|
|
@@ -17,6 +17,7 @@
|
|
|
#include <cmath>
|
|
|
#include <functional>
|
|
|
#include <map>
|
|
|
+#include <regex>
|
|
|
#include <sstream>
|
|
|
#include <stdexcept>
|
|
|
|
|
|
@@ -378,9 +379,12 @@ struct llama_model::impl {
|
|
|
layer_dev dev_input = {};
|
|
|
layer_dev dev_output = {};
|
|
|
std::vector<layer_dev> dev_layer;
|
|
|
+
|
|
|
+ bool has_tensor_overrides;
|
|
|
};
|
|
|
|
|
|
llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) {
|
|
|
+ pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern;
|
|
|
}
|
|
|
|
|
|
llama_model::~llama_model() {}
|
|
|
@@ -1571,9 +1575,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
|
GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
|
|
|
}
|
|
|
|
|
|
- ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list);
|
|
|
+ ggml_backend_buffer_type_t buft = nullptr;
|
|
|
+
|
|
|
+ // check overrides
|
|
|
+ if (ml.tensor_buft_overrides) {
|
|
|
+ std::string tensor_name = tn.str();
|
|
|
+ for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
|
|
|
+ std::regex pattern(overrides->pattern);
|
|
|
+ if (std::regex_search(tensor_name, pattern)) {
|
|
|
+ LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
|
|
|
+ buft = overrides->buft;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
if (!buft) {
|
|
|
- throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
|
|
|
+ buft = select_weight_buft(hparams, t_meta, op, *buft_list);
|
|
|
+ if (!buft) {
|
|
|
+ throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// avoid using a host buffer when using mmap
|
|
|
@@ -4151,6 +4172,10 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
|
|
|
});
|
|
|
}
|
|
|
|
|
|
+bool llama_model::has_tensor_overrides() const {
|
|
|
+ return pimpl->has_tensor_overrides;
|
|
|
+}
|
|
|
+
|
|
|
const ggml_tensor * llama_model::get_tensor(const char * name) const {
|
|
|
auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(),
|
|
|
[name](const std::pair<std::string, ggml_tensor *> & it) {
|
|
|
@@ -12319,6 +12344,7 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
|
llama_model_params llama_model_default_params() {
|
|
|
llama_model_params result = {
|
|
|
/*.devices =*/ nullptr,
|
|
|
+ /*.tensor_buft_overrides =*/ nullptr,
|
|
|
/*.n_gpu_layers =*/ 0,
|
|
|
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
|
|
|
/*.main_gpu =*/ 0,
|