app.context.tsx 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. import React, { createContext, useContext, useEffect, useState } from 'react';
  2. import {
  3. APIMessage,
  4. CanvasData,
  5. Conversation,
  6. Message,
  7. PendingMessage,
  8. ViewingChat,
  9. } from './types';
  10. import StorageUtils from './storage';
  11. import {
  12. filterThoughtFromMsgs,
  13. normalizeMsgsForAPI,
  14. getSSEStreamAsync,
  15. } from './misc';
  16. import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config';
  17. import { matchPath, useLocation, useNavigate } from 'react-router';
  18. interface AppContextValue {
  19. // conversations and messages
  20. viewingChat: ViewingChat | null;
  21. pendingMessages: Record<Conversation['id'], PendingMessage>;
  22. isGenerating: (convId: string) => boolean;
  23. sendMessage: (
  24. convId: string | null,
  25. leafNodeId: Message['id'] | null,
  26. content: string,
  27. extra: Message['extra'],
  28. onChunk: CallbackGeneratedChunk
  29. ) => Promise<boolean>;
  30. stopGenerating: (convId: string) => void;
  31. replaceMessageAndGenerate: (
  32. convId: string,
  33. parentNodeId: Message['id'], // the parent node of the message to be replaced
  34. content: string | null,
  35. extra: Message['extra'],
  36. onChunk: CallbackGeneratedChunk
  37. ) => Promise<void>;
  38. // canvas
  39. canvasData: CanvasData | null;
  40. setCanvasData: (data: CanvasData | null) => void;
  41. // config
  42. config: typeof CONFIG_DEFAULT;
  43. saveConfig: (config: typeof CONFIG_DEFAULT) => void;
  44. showSettings: boolean;
  45. setShowSettings: (show: boolean) => void;
  46. }
  47. // this callback is used for scrolling to the bottom of the chat and switching to the last node
  48. export type CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => void;
  49. // eslint-disable-next-line @typescript-eslint/no-explicit-any
  50. const AppContext = createContext<AppContextValue>({} as any);
  51. const getViewingChat = async (convId: string): Promise<ViewingChat | null> => {
  52. const conv = await StorageUtils.getOneConversation(convId);
  53. if (!conv) return null;
  54. return {
  55. conv: conv,
  56. // all messages from all branches, not filtered by last node
  57. messages: await StorageUtils.getMessages(convId),
  58. };
  59. };
  60. export const AppContextProvider = ({
  61. children,
  62. }: {
  63. children: React.ReactElement;
  64. }) => {
  65. const { pathname } = useLocation();
  66. const navigate = useNavigate();
  67. const params = matchPath('/chat/:convId', pathname);
  68. const convId = params?.params?.convId;
  69. const [viewingChat, setViewingChat] = useState<ViewingChat | null>(null);
  70. const [pendingMessages, setPendingMessages] = useState<
  71. Record<Conversation['id'], PendingMessage>
  72. >({});
  73. const [aborts, setAborts] = useState<
  74. Record<Conversation['id'], AbortController>
  75. >({});
  76. const [config, setConfig] = useState(StorageUtils.getConfig());
  77. const [canvasData, setCanvasData] = useState<CanvasData | null>(null);
  78. const [showSettings, setShowSettings] = useState(false);
  79. // handle change when the convId from URL is changed
  80. useEffect(() => {
  81. // also reset the canvas data
  82. setCanvasData(null);
  83. const handleConversationChange = async (changedConvId: string) => {
  84. if (changedConvId !== convId) return;
  85. setViewingChat(await getViewingChat(changedConvId));
  86. };
  87. StorageUtils.onConversationChanged(handleConversationChange);
  88. getViewingChat(convId ?? '').then(setViewingChat);
  89. return () => {
  90. StorageUtils.offConversationChanged(handleConversationChange);
  91. };
  92. }, [convId]);
  93. const setPending = (convId: string, pendingMsg: PendingMessage | null) => {
  94. // if pendingMsg is null, remove the key from the object
  95. if (!pendingMsg) {
  96. setPendingMessages((prev) => {
  97. const newState = { ...prev };
  98. delete newState[convId];
  99. return newState;
  100. });
  101. } else {
  102. setPendingMessages((prev) => ({ ...prev, [convId]: pendingMsg }));
  103. }
  104. };
  105. const setAbort = (convId: string, controller: AbortController | null) => {
  106. if (!controller) {
  107. setAborts((prev) => {
  108. const newState = { ...prev };
  109. delete newState[convId];
  110. return newState;
  111. });
  112. } else {
  113. setAborts((prev) => ({ ...prev, [convId]: controller }));
  114. }
  115. };
  116. ////////////////////////////////////////////////////////////////////////
  117. // public functions
  118. const isGenerating = (convId: string) => !!pendingMessages[convId];
  119. const generateMessage = async (
  120. convId: string,
  121. leafNodeId: Message['id'],
  122. onChunk: CallbackGeneratedChunk
  123. ) => {
  124. if (isGenerating(convId)) return;
  125. const config = StorageUtils.getConfig();
  126. const currConversation = await StorageUtils.getOneConversation(convId);
  127. if (!currConversation) {
  128. throw new Error('Current conversation is not found');
  129. }
  130. const currMessages = StorageUtils.filterByLeafNodeId(
  131. await StorageUtils.getMessages(convId),
  132. leafNodeId,
  133. false
  134. );
  135. const abortController = new AbortController();
  136. setAbort(convId, abortController);
  137. if (!currMessages) {
  138. throw new Error('Current messages are not found');
  139. }
  140. const pendingId = Date.now() + 1;
  141. let pendingMsg: PendingMessage = {
  142. id: pendingId,
  143. convId,
  144. type: 'text',
  145. timestamp: pendingId,
  146. role: 'assistant',
  147. content: null,
  148. parent: leafNodeId,
  149. children: [],
  150. };
  151. setPending(convId, pendingMsg);
  152. try {
  153. // prepare messages for API
  154. let messages: APIMessage[] = [
  155. ...(config.systemMessage.length === 0
  156. ? []
  157. : [{ role: 'system', content: config.systemMessage } as APIMessage]),
  158. ...normalizeMsgsForAPI(currMessages),
  159. ];
  160. if (config.excludeThoughtOnReq) {
  161. messages = filterThoughtFromMsgs(messages);
  162. }
  163. if (isDev) console.log({ messages });
  164. // prepare params
  165. const params = {
  166. messages,
  167. stream: true,
  168. cache_prompt: true,
  169. samplers: config.samplers,
  170. temperature: config.temperature,
  171. dynatemp_range: config.dynatemp_range,
  172. dynatemp_exponent: config.dynatemp_exponent,
  173. top_k: config.top_k,
  174. top_p: config.top_p,
  175. min_p: config.min_p,
  176. typical_p: config.typical_p,
  177. xtc_probability: config.xtc_probability,
  178. xtc_threshold: config.xtc_threshold,
  179. repeat_last_n: config.repeat_last_n,
  180. repeat_penalty: config.repeat_penalty,
  181. presence_penalty: config.presence_penalty,
  182. frequency_penalty: config.frequency_penalty,
  183. dry_multiplier: config.dry_multiplier,
  184. dry_base: config.dry_base,
  185. dry_allowed_length: config.dry_allowed_length,
  186. dry_penalty_last_n: config.dry_penalty_last_n,
  187. max_tokens: config.max_tokens,
  188. timings_per_token: !!config.showTokensPerSecond,
  189. ...(config.custom.length ? JSON.parse(config.custom) : {}),
  190. };
  191. // send request
  192. const fetchResponse = await fetch(`${BASE_URL}/v1/chat/completions`, {
  193. method: 'POST',
  194. headers: {
  195. 'Content-Type': 'application/json',
  196. ...(config.apiKey
  197. ? { Authorization: `Bearer ${config.apiKey}` }
  198. : {}),
  199. },
  200. body: JSON.stringify(params),
  201. signal: abortController.signal,
  202. });
  203. if (fetchResponse.status !== 200) {
  204. const body = await fetchResponse.json();
  205. throw new Error(body?.error?.message || 'Unknown error');
  206. }
  207. const chunks = getSSEStreamAsync(fetchResponse);
  208. for await (const chunk of chunks) {
  209. // const stop = chunk.stop;
  210. if (chunk.error) {
  211. throw new Error(chunk.error?.message || 'Unknown error');
  212. }
  213. const addedContent = chunk.choices[0].delta.content;
  214. const lastContent = pendingMsg.content || '';
  215. if (addedContent) {
  216. pendingMsg = {
  217. ...pendingMsg,
  218. content: lastContent + addedContent,
  219. };
  220. }
  221. const timings = chunk.timings;
  222. if (timings && config.showTokensPerSecond) {
  223. // only extract what's really needed, to save some space
  224. pendingMsg.timings = {
  225. prompt_n: timings.prompt_n,
  226. prompt_ms: timings.prompt_ms,
  227. predicted_n: timings.predicted_n,
  228. predicted_ms: timings.predicted_ms,
  229. };
  230. }
  231. setPending(convId, pendingMsg);
  232. onChunk(); // don't need to switch node for pending message
  233. }
  234. } catch (err) {
  235. setPending(convId, null);
  236. if ((err as Error).name === 'AbortError') {
  237. // user stopped the generation via stopGeneration() function
  238. // we can safely ignore this error
  239. } else {
  240. console.error(err);
  241. // eslint-disable-next-line @typescript-eslint/no-explicit-any
  242. alert((err as any)?.message ?? 'Unknown error');
  243. throw err; // rethrow
  244. }
  245. }
  246. if (pendingMsg.content !== null) {
  247. await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId);
  248. }
  249. setPending(convId, null);
  250. onChunk(pendingId); // trigger scroll to bottom and switch to the last node
  251. };
  252. const sendMessage = async (
  253. convId: string | null,
  254. leafNodeId: Message['id'] | null,
  255. content: string,
  256. extra: Message['extra'],
  257. onChunk: CallbackGeneratedChunk
  258. ): Promise<boolean> => {
  259. if (isGenerating(convId ?? '') || content.trim().length === 0) return false;
  260. if (convId === null || convId.length === 0 || leafNodeId === null) {
  261. const conv = await StorageUtils.createConversation(
  262. content.substring(0, 256)
  263. );
  264. convId = conv.id;
  265. leafNodeId = conv.currNode;
  266. // if user is creating a new conversation, redirect to the new conversation
  267. navigate(`/chat/${convId}`);
  268. }
  269. const now = Date.now();
  270. const currMsgId = now;
  271. StorageUtils.appendMsg(
  272. {
  273. id: currMsgId,
  274. timestamp: now,
  275. type: 'text',
  276. convId,
  277. role: 'user',
  278. content,
  279. extra,
  280. parent: leafNodeId,
  281. children: [],
  282. },
  283. leafNodeId
  284. );
  285. onChunk(currMsgId);
  286. try {
  287. await generateMessage(convId, currMsgId, onChunk);
  288. return true;
  289. } catch (_) {
  290. // TODO: rollback
  291. }
  292. return false;
  293. };
  294. const stopGenerating = (convId: string) => {
  295. setPending(convId, null);
  296. aborts[convId]?.abort();
  297. };
  298. // if content is undefined, we remove last assistant message
  299. const replaceMessageAndGenerate = async (
  300. convId: string,
  301. parentNodeId: Message['id'], // the parent node of the message to be replaced
  302. content: string | null,
  303. extra: Message['extra'],
  304. onChunk: CallbackGeneratedChunk
  305. ) => {
  306. if (isGenerating(convId)) return;
  307. if (content !== null) {
  308. const now = Date.now();
  309. const currMsgId = now;
  310. StorageUtils.appendMsg(
  311. {
  312. id: currMsgId,
  313. timestamp: now,
  314. type: 'text',
  315. convId,
  316. role: 'user',
  317. content,
  318. extra,
  319. parent: parentNodeId,
  320. children: [],
  321. },
  322. parentNodeId
  323. );
  324. parentNodeId = currMsgId;
  325. }
  326. onChunk(parentNodeId);
  327. await generateMessage(convId, parentNodeId, onChunk);
  328. };
  329. const saveConfig = (config: typeof CONFIG_DEFAULT) => {
  330. StorageUtils.setConfig(config);
  331. setConfig(config);
  332. };
  333. return (
  334. <AppContext.Provider
  335. value={{
  336. isGenerating,
  337. viewingChat,
  338. pendingMessages,
  339. sendMessage,
  340. stopGenerating,
  341. replaceMessageAndGenerate,
  342. canvasData,
  343. setCanvasData,
  344. config,
  345. saveConfig,
  346. showSettings,
  347. setShowSettings,
  348. }}
  349. >
  350. {children}
  351. </AppContext.Provider>
  352. );
  353. };
  354. export const useAppContext = () => useContext(AppContext);