Переглянути джерело

py : type-check all Python scripts with Pyright (#8341)

* py : type-check all Python scripts with Pyright

* server-tests : use trailing slash in openai base_url

* server-tests : add more type annotations

* server-tests : strip "chat" from base_url in oai_chat_completions

* server-tests : model metadata is a dict

* ci : disable pip cache in type-check workflow

The cache is not shared between branches, and it's 250MB in size,
so it would become quite a big part of the 10GB cache limit of the repo.

* py : fix new type errors from master branch

* tests : fix test-tokenizer-random.py

Apparently, gcc applies optimisations even when pre-processing,
which confuses pycparser.

* ci : only show warnings and errors in python type-check

The "information" level otherwise has entries
from 'examples/pydantic_models_to_grammar.py',
which could be confusing for someone trying to figure out what failed,
considering that these messages can safely be ignored
even though they look like errors.
compilade 1 рік тому
батько
коміт
3fd62a6b1c
33 змінених файлів з 297 додано та 173 видалено
  1. 16 0
      .devops/nix/package.nix
  2. 38 0
      .github/workflows/python-type-check.yml
  3. 8 12
      convert_hf_to_gguf.py
  4. 2 1
      convert_llama_ggml_to_gguf.py
  5. 12 9
      examples/convert_legacy_llama.py
  6. 1 1
      examples/finetune/convert_finetune_checkpoint_to_gguf.py
  7. 4 1
      examples/json_schema_pydantic_example.py
  8. 7 5
      examples/json_schema_to_grammar.py
  9. 6 4
      examples/llava/convert_image_encoder_to_gguf.py
  10. 7 3
      examples/llava/llava_surgery_v2.py
  11. 19 16
      examples/pydantic_models_to_grammar.py
  12. 4 3
      examples/pydantic_models_to_grammar_examples.py
  13. 7 4
      examples/server/bench/bench.py
  14. 51 49
      examples/server/tests/features/steps/steps.py
  15. 2 2
      examples/server/tests/requirements.txt
  16. 3 1
      examples/server_embd.py
  17. 1 1
      examples/train-text-from-scratch/convert_train_checkpoint_to_gguf.py
  18. 2 2
      ggml/ggml_vk_generate_shaders.py
  19. 3 3
      gguf-py/gguf/gguf_reader.py
  20. 17 14
      gguf-py/gguf/lazy.py
  21. 2 0
      gguf-py/scripts/__init__.py
  22. 3 3
      gguf-py/scripts/gguf_hash.py
  23. 2 0
      gguf-py/scripts/gguf_new_metadata.py
  24. 1 1
      gguf-py/tests/test_gguf.py
  25. 19 1
      pyrightconfig.json
  26. 12 0
      requirements/requirements-all.txt
  27. 2 0
      requirements/requirements-compare-llama-bench.txt
  28. 2 0
      requirements/requirements-pydantic.txt
  29. 1 0
      requirements/requirements-test-tokenizer-random.txt
  30. 6 6
      scripts/check-requirements.sh
  31. 4 4
      scripts/compare-llama-bench.py
  32. 9 7
      scripts/gen-unicode-data.py
  33. 24 20
      tests/test-tokenizer-random.py

+ 16 - 0
.devops/nix/package.nix

@@ -89,6 +89,22 @@ let
       ps.tiktoken
       ps.tiktoken
       ps.torchWithoutCuda
       ps.torchWithoutCuda
       ps.transformers
       ps.transformers
+
+      # server bench
+      ps.matplotlib
+
+      # server tests
+      ps.openai
+      ps.behave
+      ps.prometheus-client
+
+      # for examples/pydantic-models-to-grammar-examples.py
+      ps.docstring-parser
+      ps.pydantic
+
+      # for scripts/compare-llama-bench.py
+      ps.gitpython
+      ps.tabulate
     ]
     ]
   );
   );
 
 

+ 38 - 0
.github/workflows/python-type-check.yml

@@ -0,0 +1,38 @@
+name: Python Type-Check
+
+on:
+  push:
+    paths:
+      - '.github/workflows/python-type-check.yml'
+      - '**.py'
+      - '**/requirements*.txt'
+  pull_request:
+    paths:
+      - '.github/workflows/python-type-check.yml'
+      - '**.py'
+      - '**/requirements*.txt'
+
+concurrency:
+  group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
+  cancel-in-progress: true
+
+jobs:
+  python-type-check:
+    runs-on: ubuntu-latest
+    name: pyright type-check
+    steps:
+      - name: Check out source repository
+        uses: actions/checkout@v4
+      - name: Set up Python environment
+        uses: actions/setup-python@v5
+        with:
+          python-version: "3.11"
+      - name: Install Python dependencies
+        # TODO: use a venv
+        run: pip install -r requirements/requirements-all.txt
+      - name: Type-check with Pyright
+        uses: jakebailey/pyright-action@v2
+        with:
+          version: 1.1.370
+          level: warning
+          warnings: true

+ 8 - 12
convert_hf_to_gguf.py

@@ -265,7 +265,7 @@ class Model:
                     break
                     break
 
 
             for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
             for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
-                data: np.ndarray = data  # type hint
+                data: np.ndarray  # type hint
                 n_dims = len(data.shape)
                 n_dims = len(data.shape)
                 data_dtype = data.dtype
                 data_dtype = data.dtype
                 data_qtype: gguf.GGMLQuantizationType | None = None
                 data_qtype: gguf.GGMLQuantizationType | None = None
@@ -599,10 +599,6 @@ class Model:
 
 
         tokenizer_path = self.dir_model / 'tokenizer.model'
         tokenizer_path = self.dir_model / 'tokenizer.model'
 
 
-        tokens: list[bytes] = []
-        scores: list[float] = []
-        toktypes: list[int] = []
-
         if not tokenizer_path.is_file():
         if not tokenizer_path.is_file():
             raise FileNotFoundError(f"File not found: {tokenizer_path}")
             raise FileNotFoundError(f"File not found: {tokenizer_path}")
 
 
@@ -2120,7 +2116,7 @@ class InternLM2Model(Model):
             logger.error(f'Error: Missing {tokenizer_path}')
             logger.error(f'Error: Missing {tokenizer_path}')
             sys.exit(1)
             sys.exit(1)
 
 
-        sentencepiece_model = model.ModelProto()
+        sentencepiece_model = model.ModelProto()  # pyright: ignore[reportAttributeAccessIssue]
         sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
         sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
         add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
         add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
 
 
@@ -2972,16 +2968,16 @@ class T5Model(Model):
         if not tokenizer_path.is_file():
         if not tokenizer_path.is_file():
             raise FileNotFoundError(f"File not found: {tokenizer_path}")
             raise FileNotFoundError(f"File not found: {tokenizer_path}")
 
 
-        sentencepiece_model = model.ModelProto()
+        sentencepiece_model = model.ModelProto()  # pyright: ignore[reportAttributeAccessIssue]
         sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
         sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
 
 
         # some models like Pile-T5 family use BPE tokenizer instead of Unigram
         # some models like Pile-T5 family use BPE tokenizer instead of Unigram
-        if sentencepiece_model.trainer_spec.model_type == 2: # BPE
+        if sentencepiece_model.trainer_spec.model_type == 2:  # BPE
             # assure the tokenizer model file name is correct
             # assure the tokenizer model file name is correct
             assert tokenizer_path.name == 'tokenizer.model'
             assert tokenizer_path.name == 'tokenizer.model'
             return self._set_vocab_sentencepiece()
             return self._set_vocab_sentencepiece()
         else:
         else:
-            assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
+            assert sentencepiece_model.trainer_spec.model_type == 1  # UNIGRAM
 
 
         add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
         add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
         remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
         remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
@@ -3152,7 +3148,7 @@ class JaisModel(Model):
             # but Jais's PyTorch model simply precalculates the slope values and places them
             # but Jais's PyTorch model simply precalculates the slope values and places them
             # in relative_pes.slopes
             # in relative_pes.slopes
             n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"]))
             n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"]))
-            first_val = float(data_torch._data[0])
+            first_val = float(data_torch[0].item())
             self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2)
             self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2)
 
 
             return tensors
             return tensors
@@ -3186,7 +3182,7 @@ class ChatGLMModel(Model):
     def set_vocab_chatglm3(self):
     def set_vocab_chatglm3(self):
         dir_model = self.dir_model
         dir_model = self.dir_model
         hparams = self.hparams
         hparams = self.hparams
-        tokens: list[bytearray] = []
+        tokens: list[bytes] = []
         toktypes: list[int] = []
         toktypes: list[int] = []
         scores: list[float] = []
         scores: list[float] = []
 
 
@@ -3335,7 +3331,7 @@ class ChatGLMModel(Model):
         special_vocab.add_to_gguf(self.gguf_writer)
         special_vocab.add_to_gguf(self.gguf_writer)
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name(self.hparams.get("_name_or_path").split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b
+        self.gguf_writer.add_name(self.hparams["_name_or_path"].split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b
         n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
         n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
         n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
         n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
         n_head_kv = self.hparams.get("multi_query_group_num", n_head)
         n_head_kv = self.hparams.get("multi_query_group_num", n_head)

+ 2 - 1
convert_llama_ggml_to_gguf.py

@@ -354,7 +354,8 @@ class GGMLToGGUF:
 
 
 
 
 def handle_metadata(cfg, hp):
 def handle_metadata(cfg, hp):
-    import convert
+    import examples.convert_legacy_llama as convert
+
     assert cfg.model_metadata_dir.is_dir(), 'Metadata dir is not a directory'
     assert cfg.model_metadata_dir.is_dir(), 'Metadata dir is not a directory'
     hf_config_path   = cfg.model_metadata_dir / "config.json"
     hf_config_path   = cfg.model_metadata_dir / "config.json"
     orig_config_path = cfg.model_metadata_dir / "params.json"
     orig_config_path = cfg.model_metadata_dir / "params.json"

+ 12 - 9
examples/convert_legacy_llama.py

@@ -353,7 +353,7 @@ class Metadata:
     version: Optional[str] = None
     version: Optional[str] = None
     url: Optional[str] = None
     url: Optional[str] = None
     description: Optional[str] = None
     description: Optional[str] = None
-    licence: Optional[str] = None
+    license: Optional[str] = None
     source_url: Optional[str] = None
     source_url: Optional[str] = None
     source_hf_repo: Optional[str] = None
     source_hf_repo: Optional[str] = None
 
 
@@ -492,12 +492,13 @@ class LazyTensor:
 
 
 LazyModel: TypeAlias = 'dict[str, LazyTensor]'
 LazyModel: TypeAlias = 'dict[str, LazyTensor]'
 
 
+ModelFormat: TypeAlias = Literal['ggml', 'torch', 'safetensors', 'none']
 
 
 @dataclass
 @dataclass
 class ModelPlus:
 class ModelPlus:
     model: LazyModel
     model: LazyModel
     paths: list[Path]  # Where this was read from.
     paths: list[Path]  # Where this was read from.
-    format: Literal['ggml', 'torch', 'safetensors', 'none']
+    format: ModelFormat
     vocab: BaseVocab | None  # For GGML models (which have vocab built in), the vocab.
     vocab: BaseVocab | None  # For GGML models (which have vocab built in), the vocab.
 
 
 
 
@@ -536,7 +537,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel:
 
 
 
 
 def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
 def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
-    formats = set(mp.format for mp in models_plus)
+    formats: set[ModelFormat] = set(mp.format for mp in models_plus)
     assert len(formats) == 1, "different formats?"
     assert len(formats) == 1, "different formats?"
     format = formats.pop()
     format = formats.pop()
     paths = [path for mp in models_plus for path in mp.paths]
     paths = [path for mp in models_plus for path in mp.paths]
@@ -555,7 +556,7 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
     else:
     else:
         model = merge_sharded([mp.model for mp in models_plus])
         model = merge_sharded([mp.model for mp in models_plus])
 
 
-    return ModelPlus(model, paths, format, vocab)  # pytype: disable=wrong-arg-types
+    return ModelPlus(model, paths, format, vocab)
 
 
 
 
 def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
 def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
@@ -805,7 +806,7 @@ class OutputFile:
     def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
     def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
         self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
         self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
 
 
-    def add_meta_model(self, params: Params, metadata: Metadata) -> None:
+    def add_meta_model(self, params: Params, metadata: Metadata | None) -> None:
         # Metadata About The Model And Its Provenence
         # Metadata About The Model And Its Provenence
         name = "LLaMA"
         name = "LLaMA"
         if metadata is not None and metadata.name is not None:
         if metadata is not None and metadata.name is not None:
@@ -827,8 +828,8 @@ class OutputFile:
                 self.gguf.add_url(metadata.url)
                 self.gguf.add_url(metadata.url)
             if metadata.description is not None:
             if metadata.description is not None:
                 self.gguf.add_description(metadata.description)
                 self.gguf.add_description(metadata.description)
-            if metadata.licence is not None:
-                self.gguf.add_licence(metadata.licence)
+            if metadata.license is not None:
+                self.gguf.add_licence(metadata.license)
             if metadata.source_url is not None:
             if metadata.source_url is not None:
                 self.gguf.add_source_url(metadata.source_url)
                 self.gguf.add_source_url(metadata.source_url)
             if metadata.source_hf_repo is not None:
             if metadata.source_hf_repo is not None:
@@ -943,7 +944,7 @@ class OutputFile:
     @staticmethod
     @staticmethod
     def write_vocab_only(
     def write_vocab_only(
         fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
         fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
-        endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata = None,
+        endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None,
     ) -> None:
     ) -> None:
         check_vocab_size(params, vocab, pad_vocab=pad_vocab)
         check_vocab_size(params, vocab, pad_vocab=pad_vocab)
 
 
@@ -977,7 +978,7 @@ class OutputFile:
         fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
         fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
         concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
         concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
         pad_vocab: bool = False,
         pad_vocab: bool = False,
-        metadata: Metadata = None,
+        metadata: Metadata | None = None,
     ) -> None:
     ) -> None:
         check_vocab_size(params, vocab, pad_vocab=pad_vocab)
         check_vocab_size(params, vocab, pad_vocab=pad_vocab)
 
 
@@ -1396,6 +1397,8 @@ def main(args_in: list[str] | None = None) -> None:
     if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
     if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
         vocab = model_plus.vocab
         vocab = model_plus.vocab
 
 
+    assert params is not None
+
     logger.info(f"Vocab info: {vocab}")
     logger.info(f"Vocab info: {vocab}")
     logger.info(f"Special vocab info: {special_vocab}")
     logger.info(f"Special vocab info: {special_vocab}")
     model   = model_plus.model
     model   = model_plus.model

+ 1 - 1
examples/finetune/convert_finetune_checkpoint_to_gguf.py

@@ -74,7 +74,7 @@ class Tensor:
             if len(self.ne) == 0:
             if len(self.ne) == 0:
                 self.nbytes = 0
                 self.nbytes = 0
             else:
             else:
-                self.nbytes = int(np.product(self.ne)) * 4
+                self.nbytes = int(np.prod(self.ne)) * 4
         else:
         else:
             raise ValueError(f"Unhandled data type '{self.dtype}'")
             raise ValueError(f"Unhandled data type '{self.dtype}'")
 
 

+ 4 - 1
examples/json_schema_pydantic_example.py

@@ -3,7 +3,7 @@
 #! pip install pydantic
 #! pip install pydantic
 #! python json_schema_pydantic_example.py
 #! python json_schema_pydantic_example.py
 
 
-from pydantic import BaseModel, Extra, TypeAdapter
+from pydantic import BaseModel, Field, TypeAdapter
 from annotated_types import MinLen
 from annotated_types import MinLen
 from typing import Annotated, List, Optional
 from typing import Annotated, List, Optional
 import json, requests
 import json, requests
@@ -17,6 +17,9 @@ if True:
 
 
         The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below)
         The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below)
         '''
         '''
+        response_format = None
+        type_adapter = None
+
         if response_model:
         if response_model:
             type_adapter = TypeAdapter(response_model)
             type_adapter = TypeAdapter(response_model)
             schema = type_adapter.json_schema()
             schema = type_adapter.json_schema()

+ 7 - 5
examples/json_schema_to_grammar.py

@@ -1,4 +1,6 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
+from __future__ import annotations
+
 import argparse
 import argparse
 import itertools
 import itertools
 import json
 import json
@@ -188,7 +190,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
     raise RuntimeError("At least one of min_value or max_value must be set")
     raise RuntimeError("At least one of min_value or max_value must be set")
 
 
 class BuiltinRule:
 class BuiltinRule:
-    def __init__(self, content: str, deps: list = None):
+    def __init__(self, content: str, deps: list | None = None):
         self.content = content
         self.content = content
         self.deps = deps or []
         self.deps = deps or []
 
 
@@ -248,7 +250,7 @@ class SchemaConverter:
 
 
     def _format_literal(self, literal):
     def _format_literal(self, literal):
         escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
         escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
-            lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
+            lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
         )
         )
         return f'"{escaped}"'
         return f'"{escaped}"'
 
 
@@ -403,11 +405,11 @@ class SchemaConverter:
         i = 0
         i = 0
         length = len(pattern)
         length = len(pattern)
 
 
-        def to_rule(s: Tuple[str, bool]) -> str:
+        def to_rule(s: tuple[str, bool]) -> str:
             (txt, is_literal) = s
             (txt, is_literal) = s
             return "\"" + txt + "\"" if is_literal else txt
             return "\"" + txt + "\"" if is_literal else txt
 
 
-        def transform() -> Tuple[str, bool]:
+        def transform() -> tuple[str, bool]:
             '''
             '''
                 Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
                 Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
             '''
             '''
@@ -420,7 +422,7 @@ class SchemaConverter:
             # We only need a flat structure here to apply repetition operators to the last item, and
             # We only need a flat structure here to apply repetition operators to the last item, and
             # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
             # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
             # (GBNF's syntax is luckily very close to regular expressions!)
             # (GBNF's syntax is luckily very close to regular expressions!)
-            seq: list[Tuple[str, bool]] = []
+            seq: list[tuple[str, bool]] = []
 
 
             def get_dot():
             def get_dot():
                 if self._dotall:
                 if self._dotall:

+ 6 - 4
examples/llava/convert_image_encoder_to_gguf.py

@@ -185,6 +185,8 @@ else:
     fout.add_description("two-tower CLIP model")
     fout.add_description("two-tower CLIP model")
 
 
 if has_text_encoder:
 if has_text_encoder:
+    assert t_hparams is not None
+    assert tokens is not None
     # text_model hparams
     # text_model hparams
     fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
     fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
     fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
     fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
@@ -259,8 +261,8 @@ if has_vision_encoder:
 
 
 
 
     if processor is not None:
     if processor is not None:
-        image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
-        image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
+        image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean  # pyright: ignore[reportAttributeAccessIssue]
+        image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std  # pyright: ignore[reportAttributeAccessIssue]
     else:
     else:
         image_mean = args.image_mean if args.image_mean is not None else default_image_mean
         image_mean = args.image_mean if args.image_mean is not None else default_image_mean
         image_std = args.image_std if args.image_std is not None else default_image_std
         image_std = args.image_std if args.image_std is not None else default_image_std
@@ -272,7 +274,7 @@ fout.add_bool("clip.use_gelu", use_gelu)
 
 
 
 
 if has_llava_projector:
 if has_llava_projector:
-    model.vision_model.encoder.layers.pop(-1)
+    model.vision_model.encoder.layers.pop(-1)  # pyright: ignore[reportAttributeAccessIssue]
     projector = torch.load(args.llava_projector)
     projector = torch.load(args.llava_projector)
     for name, data in projector.items():
     for name, data in projector.items():
         name = get_tensor_name(name)
         name = get_tensor_name(name)
@@ -286,7 +288,7 @@ if has_llava_projector:
 
 
     print("Projector tensors added\n")
     print("Projector tensors added\n")
 
 
-state_dict = model.state_dict()
+state_dict = model.state_dict()  # pyright: ignore[reportAttributeAccessIssue]
 for name, data in state_dict.items():
 for name, data in state_dict.items():
     if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
     if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
         # we don't need this
         # we don't need this

+ 7 - 3
examples/llava/llava_surgery_v2.py

@@ -2,7 +2,9 @@ import argparse
 import glob
 import glob
 import os
 import os
 import torch
 import torch
-from safetensors.torch import load as safe_load, save as safe_save, safe_open, save_file
+from safetensors import safe_open
+from safetensors.torch import save_file
+from typing import Any, ContextManager, cast
 
 
 # Function to determine if file is a SafeTensor file
 # Function to determine if file is a SafeTensor file
 def is_safetensor_file(file_path):
 def is_safetensor_file(file_path):
@@ -13,7 +15,7 @@ def is_safetensor_file(file_path):
 def load_model(file_path):
 def load_model(file_path):
     if is_safetensor_file(file_path):
     if is_safetensor_file(file_path):
         tensors = {}
         tensors = {}
-        with safe_open(file_path, framework="pt", device="cpu") as f:
+        with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
             for key in f.keys():
             for key in f.keys():
                 tensors[key] = f.get_tensor(key).clone()
                 tensors[key] = f.get_tensor(key).clone()
                 # output shape
                 # output shape
@@ -134,7 +136,7 @@ if len(mm_tensors) == 0:
     if last_checkpoint is not None:
     if last_checkpoint is not None:
         for k, v in last_checkpoint.items():
         for k, v in last_checkpoint.items():
             print(k)
             print(k)
-    print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.")
+    print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
     print("No tensors found. Is this a LLaVA model?")
     print("No tensors found. Is this a LLaVA model?")
     exit()
     exit()
 
 
@@ -143,8 +145,10 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
 # projector = {name: checkpoint.[name].float() for name in mm_tensors}
 # projector = {name: checkpoint.[name].float() for name in mm_tensors}
 projector = {}
 projector = {}
 for name in mm_tensors:
 for name in mm_tensors:
+    assert last_checkpoint is not None
     projector[name] = last_checkpoint[name].float()
     projector[name] = last_checkpoint[name].float()
 for name in first_mm_tensors:
 for name in first_mm_tensors:
+    assert first_checkpoint is not None
     projector[name] = first_checkpoint[name].float()
     projector[name] = first_checkpoint[name].float()
 
 
 if len(projector) > 0:
 if len(projector) > 0:

+ 19 - 16
examples/pydantic_models_to_grammar.py

@@ -6,10 +6,10 @@ import re
 from copy import copy
 from copy import copy
 from enum import Enum
 from enum import Enum
 from inspect import getdoc, isclass
 from inspect import getdoc, isclass
-from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
+from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin
 
 
 from docstring_parser import parse
 from docstring_parser import parse
-from pydantic import BaseModel, Field, create_model
+from pydantic import BaseModel, create_model
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from types import GenericAlias
     from types import GenericAlias
@@ -17,6 +17,9 @@ else:
     # python 3.8 compat
     # python 3.8 compat
     from typing import _GenericAlias as GenericAlias
     from typing import _GenericAlias as GenericAlias
 
 
+# TODO: fix this
+# pyright: reportAttributeAccessIssue=information
+
 
 
 class PydanticDataType(Enum):
 class PydanticDataType(Enum):
     """
     """
@@ -234,8 +237,9 @@ def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None
 
 
     # Define the integer part rule
     # Define the integer part rule
     integer_part_rule = (
     integer_part_rule = (
-        "integer-part" + (f"-max{max_digit}" if max_digit is not None else "") + (
-        f"-min{min_digit}" if min_digit is not None else "")
+        "integer-part"
+        + (f"-max{max_digit}" if max_digit is not None else "")
+        + (f"-min{min_digit}" if min_digit is not None else "")
     )
     )
 
 
     # Define the fractional part rule based on precision constraints
     # Define the fractional part rule based on precision constraints
@@ -458,7 +462,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
     if not issubclass(model, BaseModel):
     if not issubclass(model, BaseModel):
         # For non-Pydantic classes, generate model_fields from __annotations__ or __init__
         # For non-Pydantic classes, generate model_fields from __annotations__ or __init__
         if hasattr(model, "__annotations__") and model.__annotations__:
         if hasattr(model, "__annotations__") and model.__annotations__:
-            model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()}
+            model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()}  # pyright: ignore[reportGeneralTypeIssues]
         else:
         else:
             init_signature = inspect.signature(model.__init__)
             init_signature = inspect.signature(model.__init__)
             parameters = init_signature.parameters
             parameters = init_signature.parameters
@@ -680,7 +684,7 @@ def generate_markdown_documentation(
         str: Generated text documentation.
         str: Generated text documentation.
     """
     """
     documentation = ""
     documentation = ""
-    pyd_models = [(model, True) for model in pydantic_models]
+    pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models]
     for model, add_prefix in pyd_models:
     for model, add_prefix in pyd_models:
         if add_prefix:
         if add_prefix:
             documentation += f"{model_prefix}: {model.__name__}\n"
             documentation += f"{model_prefix}: {model.__name__}\n"
@@ -700,7 +704,7 @@ def generate_markdown_documentation(
             # Indenting the fields section
             # Indenting the fields section
             documentation += f"  {fields_prefix}:\n"
             documentation += f"  {fields_prefix}:\n"
         else:
         else:
-            documentation += f"  Fields:\n"
+            documentation += f"  Fields:\n"  # noqa: F541
         if isclass(model) and issubclass(model, BaseModel):
         if isclass(model) and issubclass(model, BaseModel):
             for name, field_type in model.__annotations__.items():
             for name, field_type in model.__annotations__.items():
                 # if name == "markdown_code_block":
                 # if name == "markdown_code_block":
@@ -778,7 +782,7 @@ def generate_field_markdown(
         return field_text
         return field_text
 
 
     if field_description != "":
     if field_description != "":
-        field_text += f"        Description: " + field_description + "\n"
+        field_text += f"        Description: {field_description}\n"
 
 
     # Check for and include field-specific examples if available
     # Check for and include field-specific examples if available
     if hasattr(model, "Config") and hasattr(model.Config,
     if hasattr(model, "Config") and hasattr(model.Config,
@@ -833,7 +837,7 @@ def generate_text_documentation(
         str: Generated text documentation.
         str: Generated text documentation.
     """
     """
     documentation = ""
     documentation = ""
-    pyd_models = [(model, True) for model in pydantic_models]
+    pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models]
     for model, add_prefix in pyd_models:
     for model, add_prefix in pyd_models:
         if add_prefix:
         if add_prefix:
             documentation += f"{model_prefix}: {model.__name__}\n"
             documentation += f"{model_prefix}: {model.__name__}\n"
@@ -1164,7 +1168,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
         dynamic_fields[param.name] = (
         dynamic_fields[param.name] = (
             param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
             param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
     # Creating the dynamic model
     # Creating the dynamic model
-    dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)  # type: ignore[call-overload]
+    dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
 
 
     for name, param_doc in param_docs:
     for name, param_doc in param_docs:
         dynamic_model.model_fields[name].description = param_doc.description
         dynamic_model.model_fields[name].description = param_doc.description
@@ -1228,9 +1232,6 @@ def map_grammar_names_to_pydantic_model_class(pydantic_model_list):
     return output
     return output
 
 
 
 
-from enum import Enum
-
-
 def json_schema_to_python_types(schema):
 def json_schema_to_python_types(schema):
     type_map = {
     type_map = {
         "any": Any,
         "any": Any,
@@ -1275,7 +1276,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
                     if items != {}:
                     if items != {}:
                         array = {"properties": items}
                         array = {"properties": items}
                         array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
                         array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
-                        fields[field_name] = (List[array_type], ...)  # type: ignore[valid-type]
+                        fields[field_name] = (List[array_type], ...)
                     else:
                     else:
                         fields[field_name] = (list, ...)
                         fields[field_name] = (list, ...)
                 elif field_type == "object":
                 elif field_type == "object":
@@ -1285,7 +1286,8 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
                     required = field_data.get("enum", [])
                     required = field_data.get("enum", [])
                     for key, field in fields.items():
                     for key, field in fields.items():
                         if key not in required:
                         if key not in required:
-                            fields[key] = (Optional[fields[key][0]], ...)
+                            optional_type = fields[key][0]
+                            fields[key] = (Optional[optional_type], ...)
                 else:
                 else:
                     field_type = json_schema_to_python_types(field_type)
                     field_type = json_schema_to_python_types(field_type)
                     fields[field_name] = (field_type, ...)
                     fields[field_name] = (field_type, ...)
@@ -1305,6 +1307,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
         required = dictionary.get("required", [])
         required = dictionary.get("required", [])
         for key, field in fields.items():
         for key, field in fields.items():
             if key not in required:
             if key not in required:
-                fields[key] = (Optional[fields[key][0]], ...)
+                optional_type = fields[key][0]
+                fields[key] = (Optional[optional_type], ...)
     custom_model = create_model(model_name, **fields)
     custom_model = create_model(model_name, **fields)
     return custom_model
     return custom_model

+ 4 - 3
examples/pydantic_models_to_grammar_examples.py

@@ -1,6 +1,7 @@
 # Function calling example using pydantic models.
 # Function calling example using pydantic models.
+from __future__ import annotations
+
 import datetime
 import datetime
-import importlib
 import json
 import json
 from enum import Enum
 from enum import Enum
 from typing import Optional, Union
 from typing import Optional, Union
@@ -215,9 +216,9 @@ for call in json_data:
     if call["function"] == "Calculator":
     if call["function"] == "Calculator":
         print(Calculator(**call["params"]).run())
         print(Calculator(**call["params"]).run())
     elif call["function"] == "get_current_datetime":
     elif call["function"] == "get_current_datetime":
-        print(current_datetime_model(**call["params"]).run())
+        print(current_datetime_model(**call["params"]).run())  # pyright: ignore[reportAttributeAccessIssue]
     elif call["function"] == "get_current_weather":
     elif call["function"] == "get_current_weather":
-        print(current_weather_tool_model(**call["params"]).run())
+        print(current_weather_tool_model(**call["params"]).run())  # pyright: ignore[reportAttributeAccessIssue]
 # Should output something like this:
 # Should output something like this:
 # 2024-01-14 13:36:06
 # 2024-01-14 13:36:06
 # {"location": "London", "temperature": "42", "unit": "celsius"}
 # {"location": "London", "temperature": "42", "unit": "celsius"}

+ 7 - 4
examples/server/bench/bench.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import argparse
 import argparse
 import json
 import json
 import os
 import os
@@ -59,10 +61,11 @@ def main(args_in: list[str] | None = None) -> None:
         sys.exit(1)
         sys.exit(1)
 
 
     # start the benchmark
     # start the benchmark
+    iterations = 0
+    data = {}
     try:
     try:
         start_benchmark(args)
         start_benchmark(args)
 
 
-        iterations = 0
         with open("results.github.env", 'w') as github_env:
         with open("results.github.env", 'w') as github_env:
             # parse output
             # parse output
             with open('k6-results.json', 'r') as bench_results:
             with open('k6-results.json', 'r') as bench_results:
@@ -129,7 +132,7 @@ def main(args_in: list[str] | None = None) -> None:
                 timestamps, metric_values = zip(*values)
                 timestamps, metric_values = zip(*values)
                 metric_values = [float(value) for value in metric_values]
                 metric_values = [float(value) for value in metric_values]
                 prometheus_metrics[metric] = metric_values
                 prometheus_metrics[metric] = metric_values
-                timestamps_dt = [datetime.fromtimestamp(int(ts)) for ts in timestamps]
+                timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps]
                 plt.figure(figsize=(16, 10), dpi=80)
                 plt.figure(figsize=(16, 10), dpi=80)
                 plt.plot(timestamps_dt, metric_values, label=metric)
                 plt.plot(timestamps_dt, metric_values, label=metric)
                 plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
                 plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
@@ -156,7 +159,7 @@ def main(args_in: list[str] | None = None) -> None:
                 plt.close()
                 plt.close()
 
 
                 # Mermaid format in case images upload failed
                 # Mermaid format in case images upload failed
-                with (open(f"{metric}.mermaid", 'w') as mermaid_f):
+                with open(f"{metric}.mermaid", 'w') as mermaid_f:
                     mermaid = (
                     mermaid = (
                     f"""---
                     f"""---
 config:
 config:
@@ -278,7 +281,7 @@ def start_server_background(args):
     }
     }
     server_process = subprocess.Popen(
     server_process = subprocess.Popen(
         args,
         args,
-        **pkwargs)
+        **pkwargs)  # pyright: ignore[reportArgumentType, reportCallIssue]
 
 
     def server_log(in_stream, out_stream):
     def server_log(in_stream, out_stream):
         for line in iter(in_stream.readline, b''):
         for line in iter(in_stream.readline, b''):

+ 51 - 49
examples/server/tests/features/steps/steps.py

@@ -1,5 +1,4 @@
 import asyncio
 import asyncio
-import collections
 import json
 import json
 import os
 import os
 import re
 import re
@@ -8,19 +7,23 @@ import subprocess
 import sys
 import sys
 import threading
 import threading
 import time
 import time
+from collections.abc import Sequence
 from contextlib import closing
 from contextlib import closing
 from re import RegexFlag
 from re import RegexFlag
+from typing import Any, Literal, cast
 
 
 import aiohttp
 import aiohttp
 import numpy as np
 import numpy as np
 import openai
 import openai
-from behave import step
+from openai.types.chat import ChatCompletionChunk
+from behave import step  # pyright: ignore[reportAttributeAccessIssue]
 from behave.api.async_step import async_run_until_complete
 from behave.api.async_step import async_run_until_complete
 from prometheus_client import parser
 from prometheus_client import parser
 
 
+# pyright: reportRedeclaration=false
 
 
 @step("a server listening on {server_fqdn}:{server_port}")
 @step("a server listening on {server_fqdn}:{server_port}")
-def step_server_config(context, server_fqdn, server_port):
+def step_server_config(context, server_fqdn: str, server_port: str):
     context.server_fqdn = server_fqdn
     context.server_fqdn = server_fqdn
     context.server_port = int(server_port)
     context.server_port = int(server_port)
     context.n_threads = None
     context.n_threads = None
@@ -74,34 +77,34 @@ def step_server_config(context, server_fqdn, server_port):
 
 
 
 
 @step('a model file {hf_file} from HF repo {hf_repo}')
 @step('a model file {hf_file} from HF repo {hf_repo}')
-def step_download_hf_model(context, hf_file, hf_repo):
+def step_download_hf_model(context, hf_file: str, hf_repo: str):
     context.model_hf_repo = hf_repo
     context.model_hf_repo = hf_repo
     context.model_hf_file = hf_file
     context.model_hf_file = hf_file
     context.model_file = os.path.basename(hf_file)
     context.model_file = os.path.basename(hf_file)
 
 
 
 
 @step('a model file {model_file}')
 @step('a model file {model_file}')
-def step_model_file(context, model_file):
+def step_model_file(context, model_file: str):
     context.model_file = model_file
     context.model_file = model_file
 
 
 
 
 @step('a model url {model_url}')
 @step('a model url {model_url}')
-def step_model_url(context, model_url):
+def step_model_url(context, model_url: str):
     context.model_url = model_url
     context.model_url = model_url
 
 
 
 
 @step('a model alias {model_alias}')
 @step('a model alias {model_alias}')
-def step_model_alias(context, model_alias):
+def step_model_alias(context, model_alias: str):
     context.model_alias = model_alias
     context.model_alias = model_alias
 
 
 
 
 @step('{seed:d} as server seed')
 @step('{seed:d} as server seed')
-def step_seed(context, seed):
+def step_seed(context, seed: int):
     context.server_seed = seed
     context.server_seed = seed
 
 
 
 
 @step('{ngl:d} GPU offloaded layers')
 @step('{ngl:d} GPU offloaded layers')
-def step_n_gpu_layer(context, ngl):
+def step_n_gpu_layer(context, ngl: int):
     if 'N_GPU_LAYERS' in os.environ:
     if 'N_GPU_LAYERS' in os.environ:
         new_ngl = int(os.environ['N_GPU_LAYERS'])
         new_ngl = int(os.environ['N_GPU_LAYERS'])
         if context.debug:
         if context.debug:
@@ -111,37 +114,37 @@ def step_n_gpu_layer(context, ngl):
 
 
 
 
 @step('{n_threads:d} threads')
 @step('{n_threads:d} threads')
-def step_n_threads(context, n_threads):
+def step_n_threads(context, n_threads: int):
     context.n_thread = n_threads
     context.n_thread = n_threads
 
 
 
 
 @step('{draft:d} as draft')
 @step('{draft:d} as draft')
-def step_draft(context, draft):
+def step_draft(context, draft: int):
     context.draft = draft
     context.draft = draft
 
 
 
 
 @step('{n_ctx:d} KV cache size')
 @step('{n_ctx:d} KV cache size')
-def step_n_ctx(context, n_ctx):
+def step_n_ctx(context, n_ctx: int):
     context.n_ctx = n_ctx
     context.n_ctx = n_ctx
 
 
 
 
 @step('{n_slots:d} slots')
 @step('{n_slots:d} slots')
-def step_n_slots(context, n_slots):
+def step_n_slots(context, n_slots: int):
     context.n_slots = n_slots
     context.n_slots = n_slots
 
 
 
 
 @step('{n_predict:d} server max tokens to predict')
 @step('{n_predict:d} server max tokens to predict')
-def step_server_n_predict(context, n_predict):
+def step_server_n_predict(context, n_predict: int):
     context.n_server_predict = n_predict
     context.n_server_predict = n_predict
 
 
 
 
 @step('{slot_save_path} as slot save path')
 @step('{slot_save_path} as slot save path')
-def step_slot_save_path(context, slot_save_path):
+def step_slot_save_path(context, slot_save_path: str):
     context.slot_save_path = slot_save_path
     context.slot_save_path = slot_save_path
 
 
 
 
 @step('using slot id {id_slot:d}')
 @step('using slot id {id_slot:d}')
-def step_id_slot(context, id_slot):
+def step_id_slot(context, id_slot: int):
     context.id_slot = id_slot
     context.id_slot = id_slot
 
 
 
 
@@ -191,7 +194,7 @@ def step_start_server(context):
 
 
 @step("the server is {expecting_status}")
 @step("the server is {expecting_status}")
 @async_run_until_complete
 @async_run_until_complete
-async def step_wait_for_the_server_to_be_started(context, expecting_status):
+async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str):
     match expecting_status:
     match expecting_status:
         case 'healthy':
         case 'healthy':
             await wait_for_health_status(context, context.base_url, 200, 'ok',
             await wait_for_health_status(context, context.base_url, 200, 'ok',
@@ -221,7 +224,7 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
 
 
 @step('all slots are {expected_slot_status_string}')
 @step('all slots are {expected_slot_status_string}')
 @async_run_until_complete
 @async_run_until_complete
-async def step_all_slots_status(context, expected_slot_status_string):
+async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):
     match expected_slot_status_string:
     match expected_slot_status_string:
         case 'idle':
         case 'idle':
             expected_slot_status = 0
             expected_slot_status = 0
@@ -237,7 +240,7 @@ async def step_all_slots_status(context, expected_slot_status_string):
 
 
 @step('a completion request with {api_error} api error')
 @step('a completion request with {api_error} api error')
 @async_run_until_complete
 @async_run_until_complete
-async def step_request_completion(context, api_error):
+async def step_request_completion(context, api_error: Literal['raised'] | str):
     expect_api_error = api_error == 'raised'
     expect_api_error = api_error == 'raised'
     seeds = await completions_seed(context, num_seeds=1)
     seeds = await completions_seed(context, num_seeds=1)
     completion = await request_completion(context.prompts.pop(),
     completion = await request_completion(context.prompts.pop(),
@@ -777,8 +780,8 @@ def step_assert_metric_value(context, metric_name, metric_value):
 def step_available_models(context):
 def step_available_models(context):
     # openai client always expects an api_key
     # openai client always expects an api_key
     openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
     openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
-    openai.api_base = f'{context.base_url}/v1'
-    context.models = openai.Model.list().data
+    openai.base_url = f'{context.base_url}/v1/'
+    context.models = openai.models.list().data
 
 
 
 
 @step('{n_model:d} models are supported')
 @step('{n_model:d} models are supported')
@@ -789,7 +792,7 @@ def step_supported_models(context, n_model):
 
 
 
 
 @step('model {i_model:d} is {param} {preposition} {param_value}')
 @step('model {i_model:d} is {param} {preposition} {param_value}')
-def step_supported_models(context, i_model, param, preposition, param_value):
+def step_supported_models(context, i_model: int, param: Literal['identified', 'trained'] | str, preposition: str, param_value: str):
     assert i_model < len(context.models)
     assert i_model < len(context.models)
     model = context.models[i_model]
     model = context.models[i_model]
 
 
@@ -798,7 +801,7 @@ def step_supported_models(context, i_model, param, preposition, param_value):
         case 'identified':
         case 'identified':
             value = model.id
             value = model.id
         case 'trained':
         case 'trained':
-            value = str(model.meta.n_ctx_train)
+            value = str(model.meta["n_ctx_train"])
         case _:
         case _:
             assert False, "param {param} not supported"
             assert False, "param {param} not supported"
     assert param_value == value, f"model param {param} {value} != {param_value}"
     assert param_value == value, f"model param {param} {value} != {param_value}"
@@ -810,6 +813,7 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
         print(f"starting {context.n_prompts} concurrent completion requests...")
         print(f"starting {context.n_prompts} concurrent completion requests...")
     assert context.n_prompts > 0
     assert context.n_prompts > 0
     seeds = await completions_seed(context)
     seeds = await completions_seed(context)
+    assert seeds is not None
     for prompt_no in range(context.n_prompts):
     for prompt_no in range(context.n_prompts):
         shifted_args = [context.prompts.pop(), seeds[prompt_no], *args]
         shifted_args = [context.prompts.pop(), seeds[prompt_no], *args]
         context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
         context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
@@ -861,7 +865,7 @@ async def request_completion(prompt,
                              id_slot=None,
                              id_slot=None,
                              expect_api_error=None,
                              expect_api_error=None,
                              user_api_key=None,
                              user_api_key=None,
-                             temperature=None):
+                             temperature=None) -> int | dict[str, Any]:
     if debug:
     if debug:
         print(f"Sending completion request: {prompt}")
         print(f"Sending completion request: {prompt}")
     origin = "my.super.domain"
     origin = "my.super.domain"
@@ -899,8 +903,8 @@ async def request_completion(prompt,
 async def oai_chat_completions(user_prompt,
 async def oai_chat_completions(user_prompt,
                                seed,
                                seed,
                                system_prompt,
                                system_prompt,
-                               base_url,
-                               base_path,
+                               base_url: str,
+                               base_path: str,
                                async_client,
                                async_client,
                                debug=False,
                                debug=False,
                                temperature=None,
                                temperature=None,
@@ -909,7 +913,7 @@ async def oai_chat_completions(user_prompt,
                                enable_streaming=None,
                                enable_streaming=None,
                                response_format=None,
                                response_format=None,
                                user_api_key=None,
                                user_api_key=None,
-                               expect_api_error=None):
+                               expect_api_error=None) -> int | dict[str, Any]:
     if debug:
     if debug:
         print(f"Sending OAI Chat completions request: {user_prompt}")
         print(f"Sending OAI Chat completions request: {user_prompt}")
     # openai client always expects an api key
     # openai client always expects an api key
@@ -989,32 +993,35 @@ async def oai_chat_completions(user_prompt,
     else:
     else:
         try:
         try:
             openai.api_key = user_api_key
             openai.api_key = user_api_key
-            openai.api_base = f'{base_url}{base_path}'
-            chat_completion = openai.Completion.create(
+            openai.base_url = f'{base_url}{base_path.removesuffix("chat")}'
+            assert model is not None
+            chat_completion = openai.chat.completions.create(
                 messages=payload['messages'],
                 messages=payload['messages'],
                 model=model,
                 model=model,
                 max_tokens=n_predict,
                 max_tokens=n_predict,
                 stream=enable_streaming,
                 stream=enable_streaming,
-                response_format=payload.get('response_format'),
+                response_format=payload.get('response_format') or openai.NOT_GIVEN,
                 seed=seed,
                 seed=seed,
                 temperature=payload['temperature']
                 temperature=payload['temperature']
             )
             )
-        except openai.error.AuthenticationError as e:
+        except openai.AuthenticationError as e:
             if expect_api_error is not None and expect_api_error:
             if expect_api_error is not None and expect_api_error:
                 return 401
                 return 401
             else:
             else:
                 assert False, f'error raised: {e}'
                 assert False, f'error raised: {e}'
 
 
         if enable_streaming:
         if enable_streaming:
+            chat_completion = cast(openai.Stream[ChatCompletionChunk], chat_completion)
             for chunk in chat_completion:
             for chunk in chat_completion:
                 assert len(chunk.choices) == 1
                 assert len(chunk.choices) == 1
                 delta = chunk.choices[0].delta
                 delta = chunk.choices[0].delta
-                if 'content' in delta:
-                    completion_response['content'] += delta['content']
+                if delta.content is not None:
+                    completion_response['content'] += delta.content
                     completion_response['timings']['predicted_n'] += 1
                     completion_response['timings']['predicted_n'] += 1
                 completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
                 completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
         else:
         else:
             assert len(chat_completion.choices) == 1
             assert len(chat_completion.choices) == 1
+            assert chat_completion.usage is not None
             completion_response = {
             completion_response = {
                 'content': chat_completion.choices[0].message.content,
                 'content': chat_completion.choices[0].message.content,
                 'timings': {
                 'timings': {
@@ -1028,7 +1035,7 @@ async def oai_chat_completions(user_prompt,
     return completion_response
     return completion_response
 
 
 
 
-async def request_embedding(content, seed, base_url=None):
+async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
     async with aiohttp.ClientSession() as session:
     async with aiohttp.ClientSession() as session:
         async with session.post(f'{base_url}/embedding',
         async with session.post(f'{base_url}/embedding',
                                 json={
                                 json={
@@ -1041,7 +1048,7 @@ async def request_embedding(content, seed, base_url=None):
 
 
 async def request_oai_embeddings(input, seed,
 async def request_oai_embeddings(input, seed,
                                  base_url=None, user_api_key=None,
                                  base_url=None, user_api_key=None,
-                                 model=None, async_client=False):
+                                 model=None, async_client=False) -> list[list[float]]:
     # openai client always expects an api_key
     # openai client always expects an api_key
     user_api_key = user_api_key if user_api_key is not None else 'nope'
     user_api_key = user_api_key if user_api_key is not None else 'nope'
     if async_client:
     if async_client:
@@ -1063,7 +1070,7 @@ async def request_oai_embeddings(input, seed,
                 response_json = await response.json()
                 response_json = await response.json()
                 assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
                 assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
                 assert response_json['object'] == 'list'
                 assert response_json['object'] == 'list'
-                if isinstance(input, collections.abc.Sequence):
+                if isinstance(input, Sequence):
                     embeddings = []
                     embeddings = []
                     for an_oai_embeddings in response_json['data']:
                     for an_oai_embeddings in response_json['data']:
                         embeddings.append(an_oai_embeddings['embedding'])
                         embeddings.append(an_oai_embeddings['embedding'])
@@ -1072,19 +1079,14 @@ async def request_oai_embeddings(input, seed,
                 return embeddings
                 return embeddings
     else:
     else:
         openai.api_key = user_api_key
         openai.api_key = user_api_key
-        openai.api_base = f'{base_url}/v1'
-        oai_embeddings = openai.Embedding.create(
+        openai.base_url = f'{base_url}/v1/'
+        assert model is not None
+        oai_embeddings = openai.embeddings.create(
             model=model,
             model=model,
             input=input,
             input=input,
         )
         )
 
 
-        if isinstance(input, collections.abc.Sequence):
-            embeddings = []
-            for an_oai_embeddings in oai_embeddings.data:
-                embeddings.append(an_oai_embeddings.embedding)
-        else:
-            embeddings = [oai_embeddings.data.embedding]
-        return embeddings
+        return [e.embedding for e in oai_embeddings.data]
 
 
 
 
 def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
 def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
@@ -1122,7 +1124,7 @@ def assert_all_predictions_equal(completion_responses):
             if i == j:
             if i == j:
                 continue
                 continue
             content_j = response_j['content']
             content_j = response_j['content']
-        assert content_i == content_j, "contents not equal"
+            assert content_i == content_j, "contents not equal"
 
 
 
 
 def assert_all_predictions_different(completion_responses):
 def assert_all_predictions_different(completion_responses):
@@ -1136,7 +1138,7 @@ def assert_all_predictions_different(completion_responses):
             if i == j:
             if i == j:
                 continue
                 continue
             content_j = response_j['content']
             content_j = response_j['content']
-        assert content_i != content_j, "contents not different"
+            assert content_i != content_j, "contents not different"
 
 
 
 
 def assert_all_token_probabilities_equal(completion_responses):
 def assert_all_token_probabilities_equal(completion_responses):
@@ -1153,7 +1155,7 @@ def assert_all_token_probabilities_equal(completion_responses):
                 if i == j:
                 if i == j:
                     continue
                     continue
                 probs_j = response_j['completion_probabilities'][pos]['probs']
                 probs_j = response_j['completion_probabilities'][pos]['probs']
-            assert probs_i == probs_j, "contents not equal"
+                assert probs_i == probs_j, "contents not equal"
 
 
 
 
 async def gather_tasks_results(context):
 async def gather_tasks_results(context):
@@ -1343,7 +1345,7 @@ def start_server_background(context):
     }
     }
     context.server_process = subprocess.Popen(
     context.server_process = subprocess.Popen(
         [str(arg) for arg in [context.server_path, *server_args]],
         [str(arg) for arg in [context.server_path, *server_args]],
-        **pkwargs)
+        **pkwargs)  # pyright: ignore[reportArgumentType, reportCallIssue]
 
 
     def server_log(in_stream, out_stream):
     def server_log(in_stream, out_stream):
         for line in iter(in_stream.readline, b''):
         for line in iter(in_stream.readline, b''):

+ 2 - 2
examples/server/tests/requirements.txt

@@ -1,6 +1,6 @@
 aiohttp~=3.9.3
 aiohttp~=3.9.3
 behave~=1.2.6
 behave~=1.2.6
 huggingface_hub~=0.20.3
 huggingface_hub~=0.20.3
-numpy~=1.24.4
-openai~=0.25.0
+numpy~=1.26.4
+openai~=1.30.3
 prometheus-client~=0.20.0
 prometheus-client~=0.20.0

+ 3 - 1
examples/server_embd.py

@@ -1,13 +1,15 @@
 import asyncio
 import asyncio
+import asyncio.threads
 import requests
 import requests
 import numpy as np
 import numpy as np
 
 
+
 n = 8
 n = 8
 
 
 result = []
 result = []
 
 
 async def requests_post_async(*args, **kwargs):
 async def requests_post_async(*args, **kwargs):
-    return await asyncio.to_thread(requests.post, *args, **kwargs)
+    return await asyncio.threads.to_thread(requests.post, *args, **kwargs)
 
 
 async def main():
 async def main():
     model_url = "http://127.0.0.1:6900"
     model_url = "http://127.0.0.1:6900"

+ 1 - 1
examples/train-text-from-scratch/convert_train_checkpoint_to_gguf.py

@@ -66,7 +66,7 @@ class Tensor:
             if len(self.ne) == 0:
             if len(self.ne) == 0:
                 self.nbytes = 0
                 self.nbytes = 0
             else:
             else:
-                self.nbytes = int(np.product(self.ne)) * 4
+                self.nbytes = int(np.prod(self.ne)) * 4
         else:
         else:
             raise ValueError(f"Unhandled data type '{self.dtype}'")
             raise ValueError(f"Unhandled data type '{self.dtype}'")
 
 

+ 2 - 2
ggml/ggml_vk_generate_shaders.py

@@ -99,6 +99,8 @@ async def main():
 
 
     tasks = []
     tasks = []
 
 
+    base_dict = {"FLOAT_TYPE": "float"}
+
     for fp16 in (False, True):
     for fp16 in (False, True):
         # MUL_MAT
         # MUL_MAT
         matmul_shaders(tasks, fp16, False)
         matmul_shaders(tasks, fp16, False)
@@ -106,8 +108,6 @@ async def main():
         matmul_shaders(tasks, fp16, True)
         matmul_shaders(tasks, fp16, True)
 
 
     for tname in type_names:
     for tname in type_names:
-        base_dict = {"FLOAT_TYPE": "float"}
-
         # mul mat vec
         # mul mat vec
         data_a_key = f"DATA_A_{tname.upper()}"
         data_a_key = f"DATA_A_{tname.upper()}"
         shader = f"mul_mat_vec_{tname}.comp" if tname.endswith("_k") else "mul_mat_vec.comp"
         shader = f"mul_mat_vec_{tname}.comp" if tname.endswith("_k") else "mul_mat_vec.comp"

+ 3 - 3
gguf-py/gguf/gguf_reader.py

@@ -67,7 +67,7 @@ class ReaderTensor(NamedTuple):
 
 
 class GGUFReader:
 class GGUFReader:
     # I - same as host, S - swapped
     # I - same as host, S - swapped
-    byte_order: Literal['I'] | Literal['S'] = 'I'
+    byte_order: Literal['I', 'S'] = 'I'
     alignment: int = GGUF_DEFAULT_ALIGNMENT
     alignment: int = GGUF_DEFAULT_ALIGNMENT
     data_offset: int
     data_offset: int
 
 
@@ -86,7 +86,7 @@ class GGUFReader:
         GGUFValueType.BOOL:    np.bool_,
         GGUFValueType.BOOL:    np.bool_,
     }
     }
 
 
-    def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
+    def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
         self.data = np.memmap(path, mode = mode)
         self.data = np.memmap(path, mode = mode)
         offs = 0
         offs = 0
 
 
@@ -140,7 +140,7 @@ class GGUFReader:
         return self.tensors[idx]
         return self.tensors[idx]
 
 
     def _get(
     def _get(
-        self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
+        self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
     ) -> npt.NDArray[Any]:
     ) -> npt.NDArray[Any]:
         count = int(count)
         count = int(count)
         itemsize = int(np.empty([], dtype = dtype).itemsize)
         itemsize = int(np.empty([], dtype = dtype).itemsize)

+ 17 - 14
gguf-py/gguf/lazy.py

@@ -16,16 +16,16 @@ logger = logging.getLogger(__name__)
 class LazyMeta(ABCMeta):
 class LazyMeta(ABCMeta):
 
 
     def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
     def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
-        def __getattr__(self, __name: str) -> Any:
-            meta_attr = getattr(self._meta, __name)
+        def __getattr__(self, name: str) -> Any:
+            meta_attr = getattr(self._meta, name)
             if callable(meta_attr):
             if callable(meta_attr):
                 return type(self)._wrap_fn(
                 return type(self)._wrap_fn(
-                    (lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
+                    (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
                     use_self=self,
                     use_self=self,
                 )
                 )
             elif isinstance(meta_attr, self._tensor_type):
             elif isinstance(meta_attr, self._tensor_type):
                 # e.g. self.T with torch.Tensor should still be wrapped
                 # e.g. self.T with torch.Tensor should still be wrapped
-                return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
+                return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
             else:
             else:
                 # no need to wrap non-tensor properties,
                 # no need to wrap non-tensor properties,
                 # and they likely don't depend on the actual contents of the tensor
                 # and they likely don't depend on the actual contents of the tensor
@@ -141,19 +141,21 @@ class LazyBase(ABC, metaclass=LazyMeta):
                         res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
                         res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
 
 
             if isinstance(res, cls._tensor_type):
             if isinstance(res, cls._tensor_type):
-                def collect_replace(t: LazyBase):
-                    if collect_replace.shared_lazy is None:
-                        collect_replace.shared_lazy = t._lazy
-                    else:
-                        collect_replace.shared_lazy.extend(t._lazy)
-                        t._lazy = collect_replace.shared_lazy
+                class CollectSharedLazy:
+                    # emulating a static variable
+                    shared_lazy: None | deque[LazyBase] = None
 
 
-                # emulating a static variable
-                collect_replace.shared_lazy = None
+                    @staticmethod
+                    def collect_replace(t: LazyBase):
+                        if CollectSharedLazy.shared_lazy is None:
+                            CollectSharedLazy.shared_lazy = t._lazy
+                        else:
+                            CollectSharedLazy.shared_lazy.extend(t._lazy)
+                            t._lazy = CollectSharedLazy.shared_lazy
 
 
-                LazyBase._recurse_apply(args, collect_replace)
+                LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
 
 
-                shared_lazy = collect_replace.shared_lazy
+                shared_lazy = CollectSharedLazy.shared_lazy
 
 
                 return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
                 return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
             else:
             else:
@@ -184,6 +186,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
                 lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
                 lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
                 lt._data = lt._func(lt._args)
                 lt._data = lt._func(lt._args)
                 # sanity check
                 # sanity check
+                assert lt._data is not None
                 assert lt._data.dtype == lt._meta.dtype
                 assert lt._data.dtype == lt._meta.dtype
                 assert lt._data.shape == lt._meta.shape
                 assert lt._data.shape == lt._meta.shape
 
 

+ 2 - 0
gguf-py/scripts/__init__.py

@@ -1,3 +1,5 @@
+# pyright: reportUnusedImport=false
+
 from .gguf_convert_endian import main as gguf_convert_endian_entrypoint
 from .gguf_convert_endian import main as gguf_convert_endian_entrypoint
 from .gguf_dump import main as gguf_dump_entrypoint
 from .gguf_dump import main as gguf_dump_entrypoint
 from .gguf_set_metadata import main as gguf_set_metadata_entrypoint
 from .gguf_set_metadata import main as gguf_set_metadata_entrypoint

+ 3 - 3
gguf-py/scripts/gguf_hash.py

@@ -63,9 +63,9 @@ def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar) -> None:
         bar.update(sum_weights_in_tensor)
         bar.update(sum_weights_in_tensor)
 
 
         sha1_layer = hashlib.sha1()
         sha1_layer = hashlib.sha1()
-        sha1_layer.update(tensor.data)
-        sha1.update(tensor.data)
-        uuidv5_sha1.update(tensor.data)
+        sha1_layer.update(tensor.data.data)
+        sha1.update(tensor.data.data)
+        uuidv5_sha1.update(tensor.data.data)
         print("sha1    {0}  {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
         print("sha1    {0}  {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
 
 
     # Flush Hash Progress Bar
     # Flush Hash Progress Bar

+ 2 - 0
gguf-py/scripts/gguf_new_metadata.py

@@ -1,4 +1,6 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
+from __future__ import annotations
+
 import logging
 import logging
 import argparse
 import argparse
 import os
 import os

+ 1 - 1
gguf-py/tests/test_gguf.py

@@ -1,4 +1,4 @@
-import gguf  # noqa: F401
+import gguf  # noqa: F401  # pyright: ignore[reportUnusedImport]
 
 
 # TODO: add tests
 # TODO: add tests
 
 

+ 19 - 1
pyrightconfig.json

@@ -1,3 +1,21 @@
 {
 {
   "extraPaths": ["gguf-py"],
   "extraPaths": ["gguf-py"],
-}
+  "pythonVersion": "3.9",
+  "pythonPlatform": "All",
+  "reportUnusedImport": "warning",
+  "reportDuplicateImport": "error",
+  "reportDeprecated": "warning",
+  "reportUnnecessaryTypeIgnoreComment": "warning",
+  "executionEnvironments": [
+    {
+      // TODO: make this version override work correctly
+      "root": "gguf-py",
+      "pythonVersion": "3.8",
+    },
+    {
+      // uses match expressions in steps.py
+      "root": "examples/server/tests",
+      "pythonVersion": "3.10",
+    },
+  ],
+ }

+ 12 - 0
requirements/requirements-all.txt

@@ -0,0 +1,12 @@
+-r ../examples/llava/requirements.txt
+-r ../examples/server/bench/requirements.txt
+-r ../examples/server/tests/requirements.txt
+
+-r ./requirements-compare-llama-bench.txt
+-r ./requirements-pydantic.txt
+-r ./requirements-test-tokenizer-random.txt
+
+-r ./requirements-convert_hf_to_gguf.txt
+-r ./requirements-convert_hf_to_gguf_update.txt
+-r ./requirements-convert_legacy_llama.txt
+-r ./requirements-convert_llama_ggml_to_gguf.txt

+ 2 - 0
requirements/requirements-compare-llama-bench.txt

@@ -0,0 +1,2 @@
+tabulate~=0.9.0
+GitPython~=3.1.43

+ 2 - 0
requirements/requirements-pydantic.txt

@@ -0,0 +1,2 @@
+docstring_parser~=0.15
+pydantic~=2.6.3

+ 1 - 0
requirements/requirements-test-tokenizer-random.txt

@@ -0,0 +1 @@
+cffi~=1.16.0

+ 6 - 6
scripts/check-requirements.sh

@@ -108,6 +108,11 @@ check_convert_script() {
         fatal "$py missing requirements. Expected: $reqs"
         fatal "$py missing requirements. Expected: $reqs"
     fi
     fi
 
 
+    # Check that all sub-requirements are added to top-level requirements.txt
+    if ! grep -qF "$reqs" requirements.txt; then
+        fatal "$reqs needs to be added to requirements.txt"
+    fi
+
     local venv="$workdir/$pyname-venv"
     local venv="$workdir/$pyname-venv"
     python3 -m venv "$venv"
     python3 -m venv "$venv"
 
 
@@ -134,12 +139,7 @@ EOF
 
 
 readonly ignore_eq_eq='check_requirements: ignore "=="'
 readonly ignore_eq_eq='check_requirements: ignore "=="'
 
 
-for req in "$reqs_dir"/*; do
-    # Check that all sub-requirements are added to top-level requirements.txt
-    if ! grep -qF "$req" requirements.txt; then
-        fatal "$req needs to be added to requirements.txt"
-    fi
-
+for req in */**/requirements*.txt; do
     # Make sure exact release versions aren't being pinned in the requirements
     # Make sure exact release versions aren't being pinned in the requirements
     # Filters out the ignore string
     # Filters out the ignore string
     if grep -vF "$ignore_eq_eq" "$req" | grep -q '=='; then
     if grep -vF "$ignore_eq_eq" "$req" | grep -q '=='; then

+ 4 - 4
scripts/compare-llama-bench.py

@@ -123,13 +123,13 @@ builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
 
 
 try:
 try:
     repo = git.Repo(".", search_parent_directories=True)
     repo = git.Repo(".", search_parent_directories=True)
-except git.exc.InvalidGitRepositoryError:
+except git.InvalidGitRepositoryError:
     repo = None
     repo = None
 
 
 
 
-def find_parent_in_data(commit):
+def find_parent_in_data(commit: git.Commit):
     """Helper function to find the most recent parent measured in number of commits for which there is data."""
     """Helper function to find the most recent parent measured in number of commits for which there is data."""
-    heap = [(0, commit)]
+    heap: list[tuple[int, git.Commit]] = [(0, commit)]
     seen_hexsha8 = set()
     seen_hexsha8 = set()
     while heap:
     while heap:
         depth, current_commit = heapq.heappop(heap)
         depth, current_commit = heapq.heappop(heap)
@@ -144,7 +144,7 @@ def find_parent_in_data(commit):
     return None
     return None
 
 
 
 
-def get_all_parent_hexsha8s(commit):
+def get_all_parent_hexsha8s(commit: git.Commit):
     """Helper function to recursively get hexsha8 values for all parents of a commit."""
     """Helper function to recursively get hexsha8 values for all parents of a commit."""
     unvisited = [commit]
     unvisited = [commit]
     visited   = []
     visited   = []

+ 9 - 7
scripts/gen-unicode-data.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import array
 import array
 import unicodedata
 import unicodedata
 import requests
 import requests
@@ -133,7 +135,7 @@ table_nfd.sort()
 
 
 
 
 # group ranges with same flags
 # group ranges with same flags
-ranges_flags = [(0, codepoint_flags[0])]  # start, flags
+ranges_flags: list[tuple[int, int]] = [(0, codepoint_flags[0])]  # start, flags
 for codepoint, flags in enumerate(codepoint_flags):
 for codepoint, flags in enumerate(codepoint_flags):
     if flags != ranges_flags[-1][1]:
     if flags != ranges_flags[-1][1]:
         ranges_flags.append((codepoint, flags))
         ranges_flags.append((codepoint, flags))
@@ -141,11 +143,11 @@ ranges_flags.append((MAX_CODEPOINTS, 0x0000))
 
 
 
 
 # group ranges with same nfd
 # group ranges with same nfd
-ranges_nfd = [(0, 0, 0)]  # start, last, nfd
+ranges_nfd: list[tuple[int, int, int]] = [(0, 0, 0)]  # start, last, nfd
 for codepoint, norm in table_nfd:
 for codepoint, norm in table_nfd:
     start = ranges_nfd[-1][0]
     start = ranges_nfd[-1][0]
     if ranges_nfd[-1] != (start, codepoint - 1, norm):
     if ranges_nfd[-1] != (start, codepoint - 1, norm):
-        ranges_nfd.append(None)
+        ranges_nfd.append(None)  # type: ignore[arg-type]  # dummy, will be replaced below
         start = codepoint
         start = codepoint
     ranges_nfd[-1] = (start, codepoint, norm)
     ranges_nfd[-1] = (start, codepoint, norm)
 
 
@@ -179,13 +181,13 @@ for codepoint in table_whitespace:
 out("};\n")
 out("};\n")
 
 
 out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
 out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
-for tuple in table_lowercase:
-    out("{0x%06X, 0x%06X}," % tuple)
+for tuple_lw in table_lowercase:
+    out("{0x%06X, 0x%06X}," % tuple_lw)
 out("};\n")
 out("};\n")
 
 
 out("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
 out("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
-for tuple in table_uppercase:
-    out("{0x%06X, 0x%06X}," % tuple)
+for tuple_up in table_uppercase:
+    out("{0x%06X, 0x%06X}," % tuple_up)
 out("};\n")
 out("};\n")
 
 
 out("const std::vector<range_nfd> unicode_ranges_nfd = {  // start, last, nfd")
 out("const std::vector<range_nfd> unicode_ranges_nfd = {  // start, last, nfd")

+ 24 - 20
tests/test-tokenizer-random.py

@@ -6,6 +6,8 @@
 #   python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
 #   python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
 #
 #
 
 
+from __future__ import annotations
+
 import time
 import time
 import logging
 import logging
 import argparse
 import argparse
@@ -13,7 +15,9 @@ import subprocess
 import random
 import random
 import unicodedata
 import unicodedata
 
 
-from typing import Iterator
+from pathlib import Path
+from typing import Any, Iterator, cast
+from typing_extensions import Buffer
 
 
 import cffi
 import cffi
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
@@ -28,15 +32,15 @@ class LibLlama:
     DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"]
     DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"]
     DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so"  # CMakeLists.txt: BUILD_SHARED_LIBS ON
     DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so"  # CMakeLists.txt: BUILD_SHARED_LIBS ON
 
 
-    def __init__(self, path_llama_h: str = None, path_includes: list[str] = [], path_libllama: str = None):
+    def __init__(self, path_llama_h: str | None = None, path_includes: list[str] = [], path_libllama: str | None = None):
         path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
         path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
         path_includes = path_includes or self.DEFAULT_PATH_INCLUDES
         path_includes = path_includes or self.DEFAULT_PATH_INCLUDES
         path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
         path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
         (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_includes, path_libllama)
         (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_includes, path_libllama)
         self.lib.llama_backend_init()
         self.lib.llama_backend_init()
 
 
-    def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str):
-        cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="]
+    def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str) -> tuple[cffi.FFI, Any]:
+        cmd = ["gcc", "-O0", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="]
         cmd += ["-I" + path for path in path_includes] + [path_llama_h]
         cmd += ["-I" + path for path in path_includes] + [path_llama_h]
         res = subprocess.run(cmd, stdout=subprocess.PIPE)
         res = subprocess.run(cmd, stdout=subprocess.PIPE)
         assert (res.returncode == 0)
         assert (res.returncode == 0)
@@ -68,7 +72,7 @@ class LibLlama:
 class LibLlamaModel:
 class LibLlamaModel:
 
 
     def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
     def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
-        self.lib = libllama.lib
+        self.lib: Any = libllama.lib
         self.ffi = libllama.ffi
         self.ffi = libllama.ffi
         if isinstance(mparams, dict):
         if isinstance(mparams, dict):
             mparams = libllama.model_default_params(**mparams)
             mparams = libllama.model_default_params(**mparams)
@@ -94,11 +98,11 @@ class LibLlamaModel:
         self.lib = None
         self.lib = None
 
 
     def tokenize(self, text: str, add_special: bool = False, parse_special: bool = False) -> list[int]:
     def tokenize(self, text: str, add_special: bool = False, parse_special: bool = False) -> list[int]:
-        text = text.encode("utf-8")
-        num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special)
+        encoded_text: bytes = text.encode("utf-8")
+        num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
         while num < 0 and len(self.token_ids) < (16 << 20):
         while num < 0 and len(self.token_ids) < (16 << 20):
             self.token_ids = self.ffi.new("llama_token[]", -2 * num)
             self.token_ids = self.ffi.new("llama_token[]", -2 * num)
-            num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special)
+            num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
         return list(self.token_ids[0:num])
         return list(self.token_ids[0:num])
 
 
     def detokenize(self, ids: list[int], remove_special: bool = False, unparse_special: bool = False) -> str:
     def detokenize(self, ids: list[int], remove_special: bool = False, unparse_special: bool = False) -> str:
@@ -110,7 +114,7 @@ class LibLlamaModel:
         while num < 0 and len(self.text_buff) < (16 << 20):
         while num < 0 and len(self.text_buff) < (16 << 20):
             self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
             self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
             num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
             num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
-        return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace")  # replace errors with '\uFFFD'
+        return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace")  # replace errors with '\uFFFD'
 
 
 
 
 class Tokenizer:
 class Tokenizer:
@@ -152,7 +156,7 @@ class TokenizerGroundtruth (Tokenizer):
 
 
 class TokenizerLlamaCpp (Tokenizer):
 class TokenizerLlamaCpp (Tokenizer):
 
 
-    libllama: LibLlama = None
+    libllama: LibLlama | None = None
 
 
     def __init__(self, vocab_file: str):
     def __init__(self, vocab_file: str):
         if not self.libllama:
         if not self.libllama:
@@ -404,7 +408,7 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
 
 
 def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
 def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
 
 
-    def find_first_mismatch(ids1: list[int], ids2: list[int]):
+    def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
         for i, (a, b) in enumerate(zip(ids1, ids2)):
         for i, (a, b) in enumerate(zip(ids1, ids2)):
             if a != b:
             if a != b:
                 return i
                 return i
@@ -433,7 +437,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
     decode_errors = 0
     decode_errors = 0
     MAX_ERRORS = 10
     MAX_ERRORS = 10
 
 
-    logger.info("%s: %s" % (generator.__name__, "ini"))
+    logger.info("%s: %s" % (generator.__qualname__, "ini"))
     for text in generator:
     for text in generator:
         # print(repr(text), text.encode())
         # print(repr(text), text.encode())
         # print(repr(text), hex(ord(text[0])), text.encode())
         # print(repr(text), hex(ord(text[0])), text.encode())
@@ -472,13 +476,13 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
             break
             break
 
 
     t_total = time.perf_counter() - t_start
     t_total = time.perf_counter() - t_start
-    logger.info(f"{generator.__name__}: end,  {t_encode1=:.3f} {t_encode2=:.3f}  {t_decode1=:.3f} {t_decode2=:.3f}  {t_total=:.3f}")
+    logger.info(f"{generator.__qualname__}: end,  {t_encode1=:.3f} {t_encode2=:.3f}  {t_decode1=:.3f} {t_decode2=:.3f}  {t_total=:.3f}")
 
 
 
 
-def main(argv: list[str] = None):
+def main(argv: list[str] | None = None):
     parser = argparse.ArgumentParser()
     parser = argparse.ArgumentParser()
-    parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
-    parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
+    parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
+    parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file")
     parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
     parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
     args = parser.parse_args(argv)
     args = parser.parse_args(argv)
 
 
@@ -520,7 +524,7 @@ if __name__ == "__main__":
         format   = "%(levelname)s %(message)s",
         format   = "%(levelname)s %(message)s",
     )
     )
 
 
-    path_tokenizers   = "./models/tokenizers/"
+    path_tokenizers   = Path("./models/tokenizers/")
     path_vocab_format = "./models/ggml-vocab-%s.gguf"
     path_vocab_format = "./models/ggml-vocab-%s.gguf"
 
 
     tokenizers = [
     tokenizers = [
@@ -556,6 +560,6 @@ if __name__ == "__main__":
     for tokenizer in tokenizers:
     for tokenizer in tokenizers:
         logger.info("-" * 50)
         logger.info("-" * 50)
         logger.info(f"TOKENIZER: '{tokenizer}'")
         logger.info(f"TOKENIZER: '{tokenizer}'")
-        vocab_file = path_vocab_format % tokenizer
-        dir_tokenizer = path_tokenizers + "/" + tokenizer
-        main([vocab_file, dir_tokenizer, "--verbose"])
+        vocab_file = Path(path_vocab_format % tokenizer)
+        dir_tokenizer = path_tokenizers / tokenizer
+        main([str(vocab_file), str(dir_tokenizer), "--verbose"])