|
|
@@ -75,6 +75,7 @@
|
|
|
#include <forward_list>
|
|
|
#include <fstream>
|
|
|
#include <functional>
|
|
|
+#include <future>
|
|
|
#include <initializer_list>
|
|
|
#include <locale>
|
|
|
#include <map>
|
|
|
@@ -2985,6 +2986,7 @@ struct llama_model_loader {
|
|
|
size_t n_bytes = 0;
|
|
|
|
|
|
bool use_mmap = false;
|
|
|
+ bool check_tensors;
|
|
|
|
|
|
llama_files files;
|
|
|
llama_ftype ftype;
|
|
|
@@ -3018,7 +3020,7 @@ struct llama_model_loader {
|
|
|
std::string arch_name;
|
|
|
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
|
|
|
|
|
|
- llama_model_loader(const std::string & fname, bool use_mmap, const struct llama_model_kv_override * param_overrides_p) {
|
|
|
+ llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
|
|
|
int trace = 0;
|
|
|
if (getenv("LLAMA_TRACE")) {
|
|
|
trace = atoi(getenv("LLAMA_TRACE"));
|
|
|
@@ -3223,6 +3225,7 @@ struct llama_model_loader {
|
|
|
}
|
|
|
|
|
|
this->use_mmap = use_mmap;
|
|
|
+ this->check_tensors = check_tensors;
|
|
|
}
|
|
|
|
|
|
~llama_model_loader() {
|
|
|
@@ -3481,6 +3484,10 @@ struct llama_model_loader {
|
|
|
file->seek(w.offs, SEEK_SET);
|
|
|
file->read_raw(cur->data, ggml_nbytes(cur));
|
|
|
}
|
|
|
+
|
|
|
+ if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) {
|
|
|
+ throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
size_t size_done = 0;
|
|
|
@@ -3497,6 +3504,8 @@ struct llama_model_loader {
|
|
|
GGML_ASSERT(size_data != 0 && "call init_mappings() first");
|
|
|
|
|
|
std::vector<no_init<uint8_t>> read_buf;
|
|
|
+ std::vector<std::future<std::pair<ggml_tensor *, bool>>> validation_result;
|
|
|
+
|
|
|
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
|
|
|
const auto * weight = get_weight(ggml_get_name(cur));
|
|
|
if (weight == nullptr) {
|
|
|
@@ -3518,37 +3527,66 @@ struct llama_model_loader {
|
|
|
if (bufs_mmap.count(weight->idx)) {
|
|
|
buf_mmap = bufs_mmap.at(weight->idx);
|
|
|
}
|
|
|
+ uint8_t * data = (uint8_t *) mapping->addr + weight->offs;
|
|
|
+
|
|
|
+ if (check_tensors) {
|
|
|
+ validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] {
|
|
|
+ return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size));
|
|
|
+ }));
|
|
|
+ }
|
|
|
+
|
|
|
GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated
|
|
|
if (buf_mmap && cur->data == nullptr) {
|
|
|
- ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *) mapping->addr + weight->offs);
|
|
|
+ ggml_backend_tensor_alloc(buf_mmap, cur, data);
|
|
|
if (lmlocks) {
|
|
|
const auto & lmlock = lmlocks->at(weight->idx);
|
|
|
- lmlock->grow_to(weight->offs + ggml_nbytes(cur));
|
|
|
+ lmlock->grow_to(weight->offs + n_size);
|
|
|
}
|
|
|
|
|
|
auto & mmap_used = mmaps_used[weight->idx];
|
|
|
mmap_used.first = std::min(mmap_used.first, weight->offs);
|
|
|
mmap_used.second = std::max(mmap_used.second, weight->offs + n_size);
|
|
|
} else {
|
|
|
- ggml_backend_tensor_set(cur, (uint8_t *) mapping->addr + weight->offs, 0, n_size);
|
|
|
+ ggml_backend_tensor_set(cur, data, 0, n_size);
|
|
|
}
|
|
|
} else {
|
|
|
GGML_ASSERT(weight->idx < files.size());
|
|
|
const auto & file = files.at(weight->idx);
|
|
|
if (ggml_backend_buffer_is_host(cur->buffer)) {
|
|
|
file->seek(weight->offs, SEEK_SET);
|
|
|
- file->read_raw(cur->data, ggml_nbytes(cur));
|
|
|
+ file->read_raw(cur->data, n_size);
|
|
|
+ if (check_tensors) {
|
|
|
+ validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
|
|
|
+ return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
|
|
|
+ }));
|
|
|
+ }
|
|
|
} else {
|
|
|
- read_buf.resize(ggml_nbytes(cur));
|
|
|
+ read_buf.resize(n_size);
|
|
|
file->seek(weight->offs, SEEK_SET);
|
|
|
- file->read_raw(read_buf.data(), ggml_nbytes(cur));
|
|
|
+ file->read_raw(read_buf.data(), n_size);
|
|
|
ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
|
|
|
+ if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
|
|
|
+ throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
size_done += n_size;
|
|
|
}
|
|
|
|
|
|
+ // check validation results
|
|
|
+ bool validation_failed = false;
|
|
|
+ for (auto & future : validation_result) {
|
|
|
+ auto result = future.get();
|
|
|
+ if (!result.second) {
|
|
|
+ LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first));
|
|
|
+ validation_failed = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (validation_failed) {
|
|
|
+ throw std::runtime_error("found tensors with invalid data");
|
|
|
+ }
|
|
|
+
|
|
|
// check if this is the last call and do final cleanup
|
|
|
if (size_done >= size_data) {
|
|
|
// unmap offloaded tensors and metadata
|
|
|
@@ -5975,7 +6013,7 @@ static bool llm_load_tensors(
|
|
|
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
|
|
|
static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
|
|
|
try {
|
|
|
- llama_model_loader ml(fname, params.use_mmap, params.kv_overrides);
|
|
|
+ llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
|
|
|
|
|
|
model.hparams.vocab_only = params.vocab_only;
|
|
|
|
|
|
@@ -14360,14 +14398,20 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
|
|
|
}
|
|
|
|
|
|
static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
|
|
|
- std::mutex mutex;
|
|
|
- int64_t counter = 0;
|
|
|
- size_t new_size = 0;
|
|
|
if (nthread < 2) {
|
|
|
// single-thread
|
|
|
- return ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
|
|
|
+ size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
|
|
|
+ if (!ggml_validate_row_data(new_type, new_data, new_size)) {
|
|
|
+ throw std::runtime_error("quantized data validation failed");
|
|
|
+ }
|
|
|
+ return new_size;
|
|
|
}
|
|
|
- auto compute = [&mutex, &counter, &new_size, new_type, f32_data, new_data, chunk_size,
|
|
|
+
|
|
|
+ std::mutex mutex;
|
|
|
+ int64_t counter = 0;
|
|
|
+ size_t new_size = 0;
|
|
|
+ bool valid = true;
|
|
|
+ auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
|
|
|
nrows, n_per_row, imatrix]() {
|
|
|
const int64_t nrows_per_chunk = chunk_size / n_per_row;
|
|
|
size_t local_size = 0;
|
|
|
@@ -14382,7 +14426,17 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa
|
|
|
}
|
|
|
lock.unlock();
|
|
|
const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
|
|
|
- local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
|
|
|
+ size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
|
|
|
+ local_size += this_size;
|
|
|
+
|
|
|
+ // validate the quantized data
|
|
|
+ const size_t row_size = ggml_row_size(new_type, n_per_row);
|
|
|
+ void * this_data = (char *) new_data + first_row * row_size;
|
|
|
+ if (!ggml_validate_row_data(new_type, this_data, this_size)) {
|
|
|
+ std::unique_lock<std::mutex> lock(mutex);
|
|
|
+ valid = false;
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
};
|
|
|
for (int it = 0; it < nthread - 1; ++it) {
|
|
|
@@ -14391,6 +14445,9 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa
|
|
|
compute();
|
|
|
for (auto & w : workers) { w.join(); }
|
|
|
workers.clear();
|
|
|
+ if (!valid) {
|
|
|
+ throw std::runtime_error("quantized data validation failed");
|
|
|
+ }
|
|
|
return new_size;
|
|
|
}
|
|
|
|
|
|
@@ -14453,7 +14510,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|
|
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
|
|
|
kv_overrides = v->data();
|
|
|
}
|
|
|
- llama_model_loader ml(fname_inp, use_mmap, kv_overrides);
|
|
|
+ llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
|
|
|
ml.init_mappings(false); // no prefetching
|
|
|
|
|
|
llama_model model;
|
|
|
@@ -14814,7 +14871,7 @@ static int llama_apply_lora_from_file_internal(
|
|
|
std::unique_ptr<llama_model_loader> ml;
|
|
|
if (path_base_model) {
|
|
|
LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
|
|
|
- ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ nullptr));
|
|
|
+ ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
|
|
|
ml->init_mappings(/*prefetch*/ false); // no prefetching
|
|
|
}
|
|
|
|
|
|
@@ -15073,6 +15130,7 @@ struct llama_model_params llama_model_default_params() {
|
|
|
/*.vocab_only =*/ false,
|
|
|
/*.use_mmap =*/ true,
|
|
|
/*.use_mlock =*/ false,
|
|
|
+ /*.check_tensors =*/ false,
|
|
|
};
|
|
|
|
|
|
#ifdef GGML_USE_METAL
|