|
|
@@ -36,8 +36,10 @@ class SentencePieceTokenTypes(IntEnum):
|
|
|
UNUSED = 5
|
|
|
BYTE = 6
|
|
|
|
|
|
+
|
|
|
AnyModel = TypeVar("AnyModel", bound="type[Model]")
|
|
|
|
|
|
+
|
|
|
class Model(ABC):
|
|
|
_model_classes: dict[str, type[Model]] = {}
|
|
|
|
|
|
@@ -187,6 +189,7 @@ class Model(ABC):
|
|
|
@classmethod
|
|
|
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
|
|
|
assert names
|
|
|
+
|
|
|
def func(modelcls: type[Model]):
|
|
|
for name in names:
|
|
|
cls._model_classes[name] = modelcls
|