ChatScreen.tsx 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import { useEffect, useMemo, useState } from 'react';
  2. import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
  3. import ChatMessage from './ChatMessage';
  4. import { CanvasType, Message, PendingMessage } from '../utils/types';
  5. import { classNames, cleanCurrentUrl, throttle } from '../utils/misc';
  6. import CanvasPyInterpreter from './CanvasPyInterpreter';
  7. import StorageUtils from '../utils/storage';
  8. import { useVSCodeContext } from '../utils/llama-vscode';
  9. import { useChatTextarea, ChatTextareaApi } from './useChatTextarea.ts';
  10. /**
  11. * A message display is a message node with additional information for rendering.
  12. * For example, siblings of the message node are stored as their last node (aka leaf node).
  13. */
  14. export interface MessageDisplay {
  15. msg: Message | PendingMessage;
  16. siblingLeafNodeIds: Message['id'][];
  17. siblingCurrIdx: number;
  18. isPending?: boolean;
  19. }
  20. /**
  21. * If the current URL contains "?m=...", prefill the message input with the value.
  22. * If the current URL contains "?q=...", prefill and SEND the message.
  23. */
  24. const prefilledMsg = {
  25. content() {
  26. const url = new URL(window.location.href);
  27. return url.searchParams.get('m') ?? url.searchParams.get('q') ?? '';
  28. },
  29. shouldSend() {
  30. const url = new URL(window.location.href);
  31. return url.searchParams.has('q');
  32. },
  33. clear() {
  34. cleanCurrentUrl(['m', 'q']);
  35. },
  36. };
  37. function getListMessageDisplay(
  38. msgs: Readonly<Message[]>,
  39. leafNodeId: Message['id']
  40. ): MessageDisplay[] {
  41. const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
  42. const res: MessageDisplay[] = [];
  43. const nodeMap = new Map<Message['id'], Message>();
  44. for (const msg of msgs) {
  45. nodeMap.set(msg.id, msg);
  46. }
  47. // find leaf node from a message node
  48. const findLeafNode = (msgId: Message['id']): Message['id'] => {
  49. let currNode: Message | undefined = nodeMap.get(msgId);
  50. while (currNode) {
  51. if (currNode.children.length === 0) break;
  52. currNode = nodeMap.get(currNode.children.at(-1) ?? -1);
  53. }
  54. return currNode?.id ?? -1;
  55. };
  56. // traverse the current nodes
  57. for (const msg of currNodes) {
  58. const parentNode = nodeMap.get(msg.parent ?? -1);
  59. if (!parentNode) continue;
  60. const siblings = parentNode.children;
  61. if (msg.type !== 'root') {
  62. res.push({
  63. msg,
  64. siblingLeafNodeIds: siblings.map(findLeafNode),
  65. siblingCurrIdx: siblings.indexOf(msg.id),
  66. });
  67. }
  68. }
  69. return res;
  70. }
  71. const scrollToBottom = throttle(
  72. (requiresNearBottom: boolean, delay: number = 80) => {
  73. const mainScrollElem = document.getElementById('main-scroll');
  74. if (!mainScrollElem) return;
  75. const spaceToBottom =
  76. mainScrollElem.scrollHeight -
  77. mainScrollElem.scrollTop -
  78. mainScrollElem.clientHeight;
  79. if (!requiresNearBottom || spaceToBottom < 50) {
  80. setTimeout(
  81. () => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
  82. delay
  83. );
  84. }
  85. },
  86. 80
  87. );
  88. export default function ChatScreen() {
  89. const {
  90. viewingChat,
  91. sendMessage,
  92. isGenerating,
  93. stopGenerating,
  94. pendingMessages,
  95. canvasData,
  96. replaceMessageAndGenerate,
  97. } = useAppContext();
  98. const textarea: ChatTextareaApi = useChatTextarea(prefilledMsg.content());
  99. const { extraContext, clearExtraContext } = useVSCodeContext(textarea);
  100. // TODO: improve this when we have "upload file" feature
  101. const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
  102. // keep track of leaf node for rendering
  103. const [currNodeId, setCurrNodeId] = useState<number>(-1);
  104. const messages: MessageDisplay[] = useMemo(() => {
  105. if (!viewingChat) return [];
  106. else return getListMessageDisplay(viewingChat.messages, currNodeId);
  107. }, [currNodeId, viewingChat]);
  108. const currConvId = viewingChat?.conv.id ?? null;
  109. const pendingMsg: PendingMessage | undefined =
  110. pendingMessages[currConvId ?? ''];
  111. useEffect(() => {
  112. // reset to latest node when conversation changes
  113. setCurrNodeId(-1);
  114. // scroll to bottom when conversation changes
  115. scrollToBottom(false, 1);
  116. }, [currConvId]);
  117. const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
  118. if (currLeafNodeId) {
  119. setCurrNodeId(currLeafNodeId);
  120. }
  121. scrollToBottom(true);
  122. };
  123. const sendNewMessage = async () => {
  124. const lastInpMsg = textarea.value();
  125. if (lastInpMsg.trim().length === 0 || isGenerating(currConvId ?? ''))
  126. return;
  127. textarea.setValue('');
  128. scrollToBottom(false);
  129. setCurrNodeId(-1);
  130. // get the last message node
  131. const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
  132. if (
  133. !(await sendMessage(
  134. currConvId,
  135. lastMsgNodeId,
  136. lastInpMsg,
  137. currExtra,
  138. onChunk
  139. ))
  140. ) {
  141. // restore the input message if failed
  142. textarea.setValue(lastInpMsg);
  143. }
  144. // OK
  145. clearExtraContext();
  146. };
  147. // for vscode context
  148. textarea.refOnSubmit.current = sendNewMessage;
  149. const handleEditMessage = async (msg: Message, content: string) => {
  150. if (!viewingChat) return;
  151. setCurrNodeId(msg.id);
  152. scrollToBottom(false);
  153. await replaceMessageAndGenerate(
  154. viewingChat.conv.id,
  155. msg.parent,
  156. content,
  157. msg.extra,
  158. onChunk
  159. );
  160. setCurrNodeId(-1);
  161. scrollToBottom(false);
  162. };
  163. const handleRegenerateMessage = async (msg: Message) => {
  164. if (!viewingChat) return;
  165. setCurrNodeId(msg.parent);
  166. scrollToBottom(false);
  167. await replaceMessageAndGenerate(
  168. viewingChat.conv.id,
  169. msg.parent,
  170. null,
  171. msg.extra,
  172. onChunk
  173. );
  174. setCurrNodeId(-1);
  175. scrollToBottom(false);
  176. };
  177. const hasCanvas = !!canvasData;
  178. useEffect(() => {
  179. if (prefilledMsg.shouldSend()) {
  180. // send the prefilled message if needed
  181. sendNewMessage();
  182. } else {
  183. // otherwise, focus on the input
  184. textarea.focus();
  185. }
  186. prefilledMsg.clear();
  187. // no need to keep track of sendNewMessage
  188. // eslint-disable-next-line react-hooks/exhaustive-deps
  189. }, [textarea.ref]);
  190. // due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
  191. const pendingMsgDisplay: MessageDisplay[] =
  192. pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
  193. ? [
  194. {
  195. msg: pendingMsg,
  196. siblingLeafNodeIds: [],
  197. siblingCurrIdx: 0,
  198. isPending: true,
  199. },
  200. ]
  201. : [];
  202. return (
  203. <div
  204. className={classNames({
  205. 'grid lg:gap-8 grow transition-[300ms]': true,
  206. 'grid-cols-[1fr_0fr] lg:grid-cols-[1fr_1fr]': hasCanvas, // adapted for mobile
  207. 'grid-cols-[1fr_0fr]': !hasCanvas,
  208. })}
  209. >
  210. <div
  211. className={classNames({
  212. 'flex flex-col w-full max-w-[900px] mx-auto': true,
  213. 'hidden lg:flex': hasCanvas, // adapted for mobile
  214. flex: !hasCanvas,
  215. })}
  216. >
  217. {/* chat messages */}
  218. <div id="messages-list" className="grow">
  219. <div className="mt-auto flex justify-center">
  220. {/* placeholder to shift the message to the bottom */}
  221. {viewingChat ? '' : 'Send a message to start'}
  222. </div>
  223. {[...messages, ...pendingMsgDisplay].map((msg) => (
  224. <ChatMessage
  225. key={msg.msg.id}
  226. msg={msg.msg}
  227. siblingLeafNodeIds={msg.siblingLeafNodeIds}
  228. siblingCurrIdx={msg.siblingCurrIdx}
  229. onRegenerateMessage={handleRegenerateMessage}
  230. onEditMessage={handleEditMessage}
  231. onChangeSibling={setCurrNodeId}
  232. />
  233. ))}
  234. </div>
  235. {/* chat input */}
  236. <div className="flex flex-row items-end pt-8 pb-6 sticky bottom-0 bg-base-100">
  237. <textarea
  238. // Default (mobile): Enable vertical resize, overflow auto for scrolling if needed
  239. // Large screens (lg:): Disable manual resize, apply max-height for autosize limit
  240. className="textarea textarea-bordered w-full resize-vertical lg:resize-none lg:max-h-48 lg:overflow-y-auto" // Adjust lg:max-h-48 as needed (e.g., lg:max-h-60)
  241. placeholder="Type a message (Shift+Enter to add a new line)"
  242. ref={textarea.ref}
  243. onInput={textarea.onInput} // Hook's input handler (will only resize height on lg+ screens)
  244. onKeyDown={(e) => {
  245. if (e.nativeEvent.isComposing || e.keyCode === 229) return;
  246. if (e.key === 'Enter' && !e.shiftKey) {
  247. e.preventDefault();
  248. sendNewMessage();
  249. }
  250. }}
  251. id="msg-input"
  252. dir="auto"
  253. // Set a base height of 2 rows for mobile views
  254. // On lg+ screens, the hook will calculate and set the initial height anyway
  255. rows={2}
  256. ></textarea>
  257. {isGenerating(currConvId ?? '') ? (
  258. <button
  259. className="btn btn-neutral ml-2"
  260. onClick={() => stopGenerating(currConvId ?? '')}
  261. >
  262. Stop
  263. </button>
  264. ) : (
  265. <button className="btn btn-primary ml-2" onClick={sendNewMessage}>
  266. Send
  267. </button>
  268. )}
  269. </div>
  270. </div>
  271. <div className="w-full sticky top-[7em] h-[calc(100vh-9em)]">
  272. {canvasData?.type === CanvasType.PY_INTERPRETER && (
  273. <CanvasPyInterpreter />
  274. )}
  275. </div>
  276. </div>
  277. );
  278. }