|
|
@@ -333,6 +333,7 @@ class TensorNameMap:
|
|
|
tensor_name = tensor_names.get(tensor)
|
|
|
if tensor_name is None:
|
|
|
continue
|
|
|
+ mapping[tensor_name] = (tensor, tensor_name)
|
|
|
for key in keys:
|
|
|
mapping[key] = (tensor, tensor_name)
|
|
|
for bid in range(n_blocks):
|
|
|
@@ -341,11 +342,12 @@ class TensorNameMap:
|
|
|
if tensor_name is None:
|
|
|
continue
|
|
|
tensor_name = tensor_name.format(bid = bid)
|
|
|
+ mapping[tensor_name] = (tensor, tensor_name)
|
|
|
for key in keys:
|
|
|
key = key.format(bid = bid)
|
|
|
mapping[key] = (tensor, tensor_name)
|
|
|
|
|
|
- def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> tuple[MODEL_TENSOR, str] | None:
|
|
|
+ def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
|
|
|
result = self.mapping.get(key)
|
|
|
if result is not None:
|
|
|
return result
|
|
|
@@ -356,13 +358,13 @@ class TensorNameMap:
|
|
|
return (result[0], result[1] + suffix)
|
|
|
return None
|
|
|
|
|
|
- def get_name(self, key: str, try_suffixes: Sequence[str]) -> str | None:
|
|
|
+ def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None:
|
|
|
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
|
|
|
if result is None:
|
|
|
return None
|
|
|
return result[1]
|
|
|
|
|
|
- def get_type(self, key: str, try_suffixes: Sequence[str]) -> MODEL_TENSOR | None:
|
|
|
+ def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None:
|
|
|
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
|
|
|
if result is None:
|
|
|
return None
|