|
@@ -1548,11 +1548,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|
|
{"-fa", "--flash-attn"}, "FA",
|
|
{"-fa", "--flash-attn"}, "FA",
|
|
|
string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)),
|
|
string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)),
|
|
|
[](common_params & params, const std::string & value) {
|
|
[](common_params & params, const std::string & value) {
|
|
|
- if (value == "on" || value == "enabled") {
|
|
|
|
|
|
|
+ if (value == "on" || value == "enabled" || value == "1") {
|
|
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
|
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
|
|
|
- } else if (value == "off" || value == "disabled") {
|
|
|
|
|
|
|
+ } else if (value == "off" || value == "disabled" || value == "0") {
|
|
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
|
|
- } else if (value == "auto") {
|
|
|
|
|
|
|
+ } else if (value == "auto" || value == "-1") {
|
|
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
|
|
} else {
|
|
} else {
|
|
|
throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
|
|
throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
|