|
|
@@ -22,7 +22,6 @@ import json
|
|
|
import struct
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
-
|
|
|
from sentencepiece import SentencePieceProcessor
|
|
|
|
|
|
if len(sys.argv) < 3:
|
|
|
@@ -101,12 +100,28 @@ for p in range(n_parts):
|
|
|
|
|
|
# Is this correct??
|
|
|
for i in range(32000):
|
|
|
- # TODO: this is probably wrong - not sure how this tokenizer works
|
|
|
- text = tokenizer.decode([29889, i]).encode('utf-8')
|
|
|
- # remove the first byte (it's always '.')
|
|
|
- text = text[1:]
|
|
|
- fout.write(struct.pack("i", len(text)))
|
|
|
- fout.write(text)
|
|
|
+ if tokenizer.is_unknown(i):
|
|
|
+ # "<unk>" token (translated as ??)
|
|
|
+ text = " \u2047 ".encode("utf-8")
|
|
|
+ fout.write(struct.pack("i", len(text)))
|
|
|
+ fout.write(text)
|
|
|
+ elif tokenizer.is_control(i):
|
|
|
+ # "<s>"/"</s>" tokens
|
|
|
+ fout.write(struct.pack("i", 0))
|
|
|
+ elif tokenizer.is_byte(i):
|
|
|
+ # "<U+XX>" tokens (which may be invalid UTF-8)
|
|
|
+ piece = tokenizer.id_to_piece(i)
|
|
|
+ if len(piece) != 6:
|
|
|
+ print("Invalid token: " + piece)
|
|
|
+ sys.exit(1)
|
|
|
+ byte_value = int(piece[3:-1], 16)
|
|
|
+ fout.write(struct.pack("i", 1))
|
|
|
+ fout.write(struct.pack("B", byte_value))
|
|
|
+ else:
|
|
|
+ # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.
|
|
|
+ text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
|
|
|
+ fout.write(struct.pack("i", len(text)))
|
|
|
+ fout.write(text)
|
|
|
|
|
|
for k, v in model.items():
|
|
|
name = k
|