ChatScreen.tsx 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import { useEffect, useMemo, useRef, useState } from 'react';
  2. import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
  3. import ChatMessage from './ChatMessage';
  4. import { CanvasType, Message, PendingMessage } from '../utils/types';
  5. import { classNames, cleanCurrentUrl, throttle } from '../utils/misc';
  6. import CanvasPyInterpreter from './CanvasPyInterpreter';
  7. import StorageUtils from '../utils/storage';
  8. import { useVSCodeContext } from '../utils/llama-vscode';
  9. /**
  10. * A message display is a message node with additional information for rendering.
  11. * For example, siblings of the message node are stored as their last node (aka leaf node).
  12. */
  13. export interface MessageDisplay {
  14. msg: Message | PendingMessage;
  15. siblingLeafNodeIds: Message['id'][];
  16. siblingCurrIdx: number;
  17. isPending?: boolean;
  18. }
  19. /**
  20. * If the current URL contains "?m=...", prefill the message input with the value.
  21. * If the current URL contains "?q=...", prefill and SEND the message.
  22. */
  23. const prefilledMsg = {
  24. content() {
  25. const url = new URL(window.location.href);
  26. return url.searchParams.get('m') ?? url.searchParams.get('q') ?? '';
  27. },
  28. shouldSend() {
  29. const url = new URL(window.location.href);
  30. return url.searchParams.has('q');
  31. },
  32. clear() {
  33. cleanCurrentUrl(['m', 'q']);
  34. },
  35. };
  36. function getListMessageDisplay(
  37. msgs: Readonly<Message[]>,
  38. leafNodeId: Message['id']
  39. ): MessageDisplay[] {
  40. const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
  41. const res: MessageDisplay[] = [];
  42. const nodeMap = new Map<Message['id'], Message>();
  43. for (const msg of msgs) {
  44. nodeMap.set(msg.id, msg);
  45. }
  46. // find leaf node from a message node
  47. const findLeafNode = (msgId: Message['id']): Message['id'] => {
  48. let currNode: Message | undefined = nodeMap.get(msgId);
  49. while (currNode) {
  50. if (currNode.children.length === 0) break;
  51. currNode = nodeMap.get(currNode.children.at(-1) ?? -1);
  52. }
  53. return currNode?.id ?? -1;
  54. };
  55. // traverse the current nodes
  56. for (const msg of currNodes) {
  57. const parentNode = nodeMap.get(msg.parent ?? -1);
  58. if (!parentNode) continue;
  59. const siblings = parentNode.children;
  60. if (msg.type !== 'root') {
  61. res.push({
  62. msg,
  63. siblingLeafNodeIds: siblings.map(findLeafNode),
  64. siblingCurrIdx: siblings.indexOf(msg.id),
  65. });
  66. }
  67. }
  68. return res;
  69. }
  70. const scrollToBottom = throttle(
  71. (requiresNearBottom: boolean, delay: number = 80) => {
  72. const mainScrollElem = document.getElementById('main-scroll');
  73. if (!mainScrollElem) return;
  74. const spaceToBottom =
  75. mainScrollElem.scrollHeight -
  76. mainScrollElem.scrollTop -
  77. mainScrollElem.clientHeight;
  78. if (!requiresNearBottom || spaceToBottom < 50) {
  79. setTimeout(
  80. () => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
  81. delay
  82. );
  83. }
  84. },
  85. 80
  86. );
  87. export default function ChatScreen() {
  88. const {
  89. viewingChat,
  90. sendMessage,
  91. isGenerating,
  92. stopGenerating,
  93. pendingMessages,
  94. canvasData,
  95. replaceMessageAndGenerate,
  96. } = useAppContext();
  97. const textarea = useOptimizedTextarea(prefilledMsg.content());
  98. const { extraContext, clearExtraContext } = useVSCodeContext(textarea);
  99. // TODO: improve this when we have "upload file" feature
  100. const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
  101. // keep track of leaf node for rendering
  102. const [currNodeId, setCurrNodeId] = useState<number>(-1);
  103. const messages: MessageDisplay[] = useMemo(() => {
  104. if (!viewingChat) return [];
  105. else return getListMessageDisplay(viewingChat.messages, currNodeId);
  106. }, [currNodeId, viewingChat]);
  107. const currConvId = viewingChat?.conv.id ?? null;
  108. const pendingMsg: PendingMessage | undefined =
  109. pendingMessages[currConvId ?? ''];
  110. useEffect(() => {
  111. // reset to latest node when conversation changes
  112. setCurrNodeId(-1);
  113. // scroll to bottom when conversation changes
  114. scrollToBottom(false, 1);
  115. }, [currConvId]);
  116. const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
  117. if (currLeafNodeId) {
  118. setCurrNodeId(currLeafNodeId);
  119. }
  120. scrollToBottom(true);
  121. };
  122. const sendNewMessage = async () => {
  123. const lastInpMsg = textarea.value();
  124. if (lastInpMsg.trim().length === 0 || isGenerating(currConvId ?? ''))
  125. return;
  126. textarea.setValue('');
  127. scrollToBottom(false);
  128. setCurrNodeId(-1);
  129. // get the last message node
  130. const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
  131. if (
  132. !(await sendMessage(
  133. currConvId,
  134. lastMsgNodeId,
  135. lastInpMsg,
  136. currExtra,
  137. onChunk
  138. ))
  139. ) {
  140. // restore the input message if failed
  141. textarea.setValue(lastInpMsg);
  142. }
  143. // OK
  144. clearExtraContext();
  145. };
  146. const handleEditMessage = async (msg: Message, content: string) => {
  147. if (!viewingChat) return;
  148. setCurrNodeId(msg.id);
  149. scrollToBottom(false);
  150. await replaceMessageAndGenerate(
  151. viewingChat.conv.id,
  152. msg.parent,
  153. content,
  154. msg.extra,
  155. onChunk
  156. );
  157. setCurrNodeId(-1);
  158. scrollToBottom(false);
  159. };
  160. const handleRegenerateMessage = async (msg: Message) => {
  161. if (!viewingChat) return;
  162. setCurrNodeId(msg.parent);
  163. scrollToBottom(false);
  164. await replaceMessageAndGenerate(
  165. viewingChat.conv.id,
  166. msg.parent,
  167. null,
  168. msg.extra,
  169. onChunk
  170. );
  171. setCurrNodeId(-1);
  172. scrollToBottom(false);
  173. };
  174. const hasCanvas = !!canvasData;
  175. useEffect(() => {
  176. if (prefilledMsg.shouldSend()) {
  177. // send the prefilled message if needed
  178. sendNewMessage();
  179. } else {
  180. // otherwise, focus on the input
  181. textarea.focus();
  182. }
  183. prefilledMsg.clear();
  184. // no need to keep track of sendNewMessage
  185. // eslint-disable-next-line react-hooks/exhaustive-deps
  186. }, [textarea.ref]);
  187. // due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
  188. const pendingMsgDisplay: MessageDisplay[] =
  189. pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
  190. ? [
  191. {
  192. msg: pendingMsg,
  193. siblingLeafNodeIds: [],
  194. siblingCurrIdx: 0,
  195. isPending: true,
  196. },
  197. ]
  198. : [];
  199. return (
  200. <div
  201. className={classNames({
  202. 'grid lg:gap-8 grow transition-[300ms]': true,
  203. 'grid-cols-[1fr_0fr] lg:grid-cols-[1fr_1fr]': hasCanvas, // adapted for mobile
  204. 'grid-cols-[1fr_0fr]': !hasCanvas,
  205. })}
  206. >
  207. <div
  208. className={classNames({
  209. 'flex flex-col w-full max-w-[900px] mx-auto': true,
  210. 'hidden lg:flex': hasCanvas, // adapted for mobile
  211. flex: !hasCanvas,
  212. })}
  213. >
  214. {/* chat messages */}
  215. <div id="messages-list" className="grow">
  216. <div className="mt-auto flex justify-center">
  217. {/* placeholder to shift the message to the bottom */}
  218. {viewingChat ? '' : 'Send a message to start'}
  219. </div>
  220. {[...messages, ...pendingMsgDisplay].map((msg) => (
  221. <ChatMessage
  222. key={msg.msg.id}
  223. msg={msg.msg}
  224. siblingLeafNodeIds={msg.siblingLeafNodeIds}
  225. siblingCurrIdx={msg.siblingCurrIdx}
  226. onRegenerateMessage={handleRegenerateMessage}
  227. onEditMessage={handleEditMessage}
  228. onChangeSibling={setCurrNodeId}
  229. />
  230. ))}
  231. </div>
  232. {/* chat input */}
  233. <div className="flex flex-row items-center pt-8 pb-6 sticky bottom-0 bg-base-100">
  234. <textarea
  235. className="textarea textarea-bordered w-full"
  236. placeholder="Type a message (Shift+Enter to add a new line)"
  237. ref={textarea.ref}
  238. onKeyDown={(e) => {
  239. if (e.nativeEvent.isComposing || e.keyCode === 229) return;
  240. if (e.key === 'Enter' && e.shiftKey) return;
  241. if (e.key === 'Enter' && !e.shiftKey) {
  242. e.preventDefault();
  243. sendNewMessage();
  244. }
  245. }}
  246. id="msg-input"
  247. dir="auto"
  248. ></textarea>
  249. {isGenerating(currConvId ?? '') ? (
  250. <button
  251. className="btn btn-neutral ml-2"
  252. onClick={() => stopGenerating(currConvId ?? '')}
  253. >
  254. Stop
  255. </button>
  256. ) : (
  257. <button className="btn btn-primary ml-2" onClick={sendNewMessage}>
  258. Send
  259. </button>
  260. )}
  261. </div>
  262. </div>
  263. <div className="w-full sticky top-[7em] h-[calc(100vh-9em)]">
  264. {canvasData?.type === CanvasType.PY_INTERPRETER && (
  265. <CanvasPyInterpreter />
  266. )}
  267. </div>
  268. </div>
  269. );
  270. }
  271. export interface OptimizedTextareaValue {
  272. value: () => string;
  273. setValue: (value: string) => void;
  274. focus: () => void;
  275. ref: React.RefObject<HTMLTextAreaElement>;
  276. }
  277. // This is a workaround to prevent the textarea from re-rendering when the inner content changes
  278. // See https://github.com/ggml-org/llama.cpp/pull/12299
  279. function useOptimizedTextarea(initValue: string): OptimizedTextareaValue {
  280. const [savedInitValue, setSavedInitValue] = useState<string>(initValue);
  281. const textareaRef = useRef<HTMLTextAreaElement>(null);
  282. useEffect(() => {
  283. if (textareaRef.current && savedInitValue) {
  284. textareaRef.current.value = savedInitValue;
  285. setSavedInitValue('');
  286. }
  287. }, [textareaRef, savedInitValue, setSavedInitValue]);
  288. return {
  289. value: () => {
  290. return textareaRef.current?.value ?? savedInitValue;
  291. },
  292. setValue: (value: string) => {
  293. if (textareaRef.current) {
  294. textareaRef.current.value = value;
  295. }
  296. },
  297. focus: () => {
  298. if (textareaRef.current) {
  299. // focus and move the cursor to the end
  300. textareaRef.current.focus();
  301. textareaRef.current.selectionStart = textareaRef.current.value.length;
  302. }
  303. },
  304. ref: textareaRef,
  305. };
  306. }