embd_input.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #!/usr/bin/env python3
  2. import ctypes
  3. from ctypes import cdll, c_char_p, c_void_p, POINTER, c_float, c_int
  4. import numpy as np
  5. import os
  6. libc = cdll.LoadLibrary("./libembdinput.so")
  7. libc.sampling.restype=c_char_p
  8. libc.create_mymodel.restype=c_void_p
  9. libc.eval_string.argtypes=[c_void_p, c_char_p]
  10. libc.sampling.argtypes=[c_void_p]
  11. libc.eval_float.argtypes=[c_void_p, POINTER(c_float), c_int]
  12. class MyModel:
  13. def __init__(self, args):
  14. argc = len(args)
  15. c_str = [c_char_p(i.encode()) for i in args]
  16. args_c = (c_char_p * argc)(*c_str)
  17. self.model = c_void_p(libc.create_mymodel(argc, args_c))
  18. self.max_tgt_len = 512
  19. self.print_string_eval = True
  20. def __del__(self):
  21. libc.free_mymodel(self.model)
  22. def eval_float(self, x):
  23. libc.eval_float(self.model, x.astype(np.float32).ctypes.data_as(POINTER(c_float)), x.shape[1])
  24. def eval_string(self, x):
  25. libc.eval_string(self.model, x.encode()) # c_char_p(x.encode()))
  26. if self.print_string_eval:
  27. print(x)
  28. def eval_token(self, x):
  29. libc.eval_id(self.model, x)
  30. def sampling(self):
  31. s = libc.sampling(self.model)
  32. return s
  33. def stream_generate(self, end="</s>"):
  34. ret = b""
  35. end = end.encode()
  36. for _ in range(self.max_tgt_len):
  37. tmp = self.sampling()
  38. ret += tmp
  39. yield tmp
  40. if ret.endswith(end):
  41. break
  42. def generate_with_print(self, end="</s>"):
  43. ret = b""
  44. for i in self.stream_generate(end=end):
  45. ret += i
  46. print(i.decode(errors="replace"), end="", flush=True)
  47. print("")
  48. return ret.decode(errors="replace")
  49. def generate(self, end="</s>"):
  50. text = b"".join(self.stream_generate(end=end))
  51. return text.decode(errors="replace")
  52. if __name__ == "__main__":
  53. model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
  54. model.eval_string("""user: what is the color of the flag of UN?""")
  55. x = np.random.random((5120,10))# , dtype=np.float32)
  56. model.eval_float(x)
  57. model.eval_string("""assistant:""")
  58. for i in model.generate():
  59. print(i.decode(errors="replace"), end="", flush=True)