ChatScreen.tsx 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import { useEffect, useMemo, useRef, useState } from 'react';
  2. import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
  3. import ChatMessage from './ChatMessage';
  4. import { CanvasType, Message, PendingMessage } from '../utils/types';
  5. import { classNames, throttle } from '../utils/misc';
  6. import CanvasPyInterpreter from './CanvasPyInterpreter';
  7. import StorageUtils from '../utils/storage';
  8. import { useVSCodeContext } from '../utils/llama-vscode';
  9. /**
  10. * A message display is a message node with additional information for rendering.
  11. * For example, siblings of the message node are stored as their last node (aka leaf node).
  12. */
  13. export interface MessageDisplay {
  14. msg: Message | PendingMessage;
  15. siblingLeafNodeIds: Message['id'][];
  16. siblingCurrIdx: number;
  17. isPending?: boolean;
  18. }
  19. function getListMessageDisplay(
  20. msgs: Readonly<Message[]>,
  21. leafNodeId: Message['id']
  22. ): MessageDisplay[] {
  23. const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
  24. const res: MessageDisplay[] = [];
  25. const nodeMap = new Map<Message['id'], Message>();
  26. for (const msg of msgs) {
  27. nodeMap.set(msg.id, msg);
  28. }
  29. // find leaf node from a message node
  30. const findLeafNode = (msgId: Message['id']): Message['id'] => {
  31. let currNode: Message | undefined = nodeMap.get(msgId);
  32. while (currNode) {
  33. if (currNode.children.length === 0) break;
  34. currNode = nodeMap.get(currNode.children.at(-1) ?? -1);
  35. }
  36. return currNode?.id ?? -1;
  37. };
  38. // traverse the current nodes
  39. for (const msg of currNodes) {
  40. const parentNode = nodeMap.get(msg.parent ?? -1);
  41. if (!parentNode) continue;
  42. const siblings = parentNode.children;
  43. if (msg.type !== 'root') {
  44. res.push({
  45. msg,
  46. siblingLeafNodeIds: siblings.map(findLeafNode),
  47. siblingCurrIdx: siblings.indexOf(msg.id),
  48. });
  49. }
  50. }
  51. return res;
  52. }
  53. const scrollToBottom = throttle(
  54. (requiresNearBottom: boolean, delay: number = 80) => {
  55. const mainScrollElem = document.getElementById('main-scroll');
  56. if (!mainScrollElem) return;
  57. const spaceToBottom =
  58. mainScrollElem.scrollHeight -
  59. mainScrollElem.scrollTop -
  60. mainScrollElem.clientHeight;
  61. if (!requiresNearBottom || spaceToBottom < 50) {
  62. setTimeout(
  63. () => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
  64. delay
  65. );
  66. }
  67. },
  68. 80
  69. );
  70. export default function ChatScreen() {
  71. const {
  72. viewingChat,
  73. sendMessage,
  74. isGenerating,
  75. stopGenerating,
  76. pendingMessages,
  77. canvasData,
  78. replaceMessageAndGenerate,
  79. } = useAppContext();
  80. const [inputMsg, setInputMsg] = useState('');
  81. const inputRef = useRef<HTMLTextAreaElement>(null);
  82. const { extraContext, clearExtraContext } = useVSCodeContext(
  83. inputRef,
  84. setInputMsg
  85. );
  86. // TODO: improve this when we have "upload file" feature
  87. const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
  88. // keep track of leaf node for rendering
  89. const [currNodeId, setCurrNodeId] = useState<number>(-1);
  90. const messages: MessageDisplay[] = useMemo(() => {
  91. if (!viewingChat) return [];
  92. else return getListMessageDisplay(viewingChat.messages, currNodeId);
  93. }, [currNodeId, viewingChat]);
  94. const currConvId = viewingChat?.conv.id ?? null;
  95. const pendingMsg: PendingMessage | undefined =
  96. pendingMessages[currConvId ?? ''];
  97. useEffect(() => {
  98. // reset to latest node when conversation changes
  99. setCurrNodeId(-1);
  100. // scroll to bottom when conversation changes
  101. scrollToBottom(false, 1);
  102. }, [currConvId]);
  103. const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
  104. if (currLeafNodeId) {
  105. setCurrNodeId(currLeafNodeId);
  106. }
  107. scrollToBottom(true);
  108. };
  109. const sendNewMessage = async () => {
  110. if (inputMsg.trim().length === 0 || isGenerating(currConvId ?? '')) return;
  111. const lastInpMsg = inputMsg;
  112. setInputMsg('');
  113. scrollToBottom(false);
  114. setCurrNodeId(-1);
  115. // get the last message node
  116. const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
  117. if (
  118. !(await sendMessage(
  119. currConvId,
  120. lastMsgNodeId,
  121. inputMsg,
  122. currExtra,
  123. onChunk
  124. ))
  125. ) {
  126. // restore the input message if failed
  127. setInputMsg(lastInpMsg);
  128. }
  129. // OK
  130. clearExtraContext();
  131. };
  132. const handleEditMessage = async (msg: Message, content: string) => {
  133. if (!viewingChat) return;
  134. setCurrNodeId(msg.id);
  135. scrollToBottom(false);
  136. await replaceMessageAndGenerate(
  137. viewingChat.conv.id,
  138. msg.parent,
  139. content,
  140. msg.extra,
  141. onChunk
  142. );
  143. setCurrNodeId(-1);
  144. scrollToBottom(false);
  145. };
  146. const handleRegenerateMessage = async (msg: Message) => {
  147. if (!viewingChat) return;
  148. setCurrNodeId(msg.parent);
  149. scrollToBottom(false);
  150. await replaceMessageAndGenerate(
  151. viewingChat.conv.id,
  152. msg.parent,
  153. null,
  154. msg.extra,
  155. onChunk
  156. );
  157. setCurrNodeId(-1);
  158. scrollToBottom(false);
  159. };
  160. const hasCanvas = !!canvasData;
  161. // due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
  162. const pendingMsgDisplay: MessageDisplay[] =
  163. pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
  164. ? [
  165. {
  166. msg: pendingMsg,
  167. siblingLeafNodeIds: [],
  168. siblingCurrIdx: 0,
  169. isPending: true,
  170. },
  171. ]
  172. : [];
  173. return (
  174. <div
  175. className={classNames({
  176. 'grid lg:gap-8 grow transition-[300ms]': true,
  177. 'grid-cols-[1fr_0fr] lg:grid-cols-[1fr_1fr]': hasCanvas, // adapted for mobile
  178. 'grid-cols-[1fr_0fr]': !hasCanvas,
  179. })}
  180. >
  181. <div
  182. className={classNames({
  183. 'flex flex-col w-full max-w-[900px] mx-auto': true,
  184. 'hidden lg:flex': hasCanvas, // adapted for mobile
  185. flex: !hasCanvas,
  186. })}
  187. >
  188. {/* chat messages */}
  189. <div id="messages-list" className="grow">
  190. <div className="mt-auto flex justify-center">
  191. {/* placeholder to shift the message to the bottom */}
  192. {viewingChat ? '' : 'Send a message to start'}
  193. </div>
  194. {[...messages, ...pendingMsgDisplay].map((msg) => (
  195. <ChatMessage
  196. key={msg.msg.id}
  197. msg={msg.msg}
  198. siblingLeafNodeIds={msg.siblingLeafNodeIds}
  199. siblingCurrIdx={msg.siblingCurrIdx}
  200. onRegenerateMessage={handleRegenerateMessage}
  201. onEditMessage={handleEditMessage}
  202. onChangeSibling={setCurrNodeId}
  203. />
  204. ))}
  205. </div>
  206. {/* chat input */}
  207. <div className="flex flex-row items-center pt-8 pb-6 sticky bottom-0 bg-base-100">
  208. <textarea
  209. className="textarea textarea-bordered w-full"
  210. placeholder="Type a message (Shift+Enter to add a new line)"
  211. ref={inputRef}
  212. value={inputMsg}
  213. onChange={(e) => setInputMsg(e.target.value)}
  214. onKeyDown={(e) => {
  215. if (e.nativeEvent.isComposing || e.keyCode === 229) return;
  216. if (e.key === 'Enter' && e.shiftKey) return;
  217. if (e.key === 'Enter' && !e.shiftKey) {
  218. e.preventDefault();
  219. sendNewMessage();
  220. }
  221. }}
  222. id="msg-input"
  223. dir="auto"
  224. ></textarea>
  225. {isGenerating(currConvId ?? '') ? (
  226. <button
  227. className="btn btn-neutral ml-2"
  228. onClick={() => stopGenerating(currConvId ?? '')}
  229. >
  230. Stop
  231. </button>
  232. ) : (
  233. <button
  234. className="btn btn-primary ml-2"
  235. onClick={sendNewMessage}
  236. disabled={inputMsg.trim().length === 0}
  237. >
  238. Send
  239. </button>
  240. )}
  241. </div>
  242. </div>
  243. <div className="w-full sticky top-[7em] h-[calc(100vh-9em)]">
  244. {canvasData?.type === CanvasType.PY_INTERPRETER && (
  245. <CanvasPyInterpreter />
  246. )}
  247. </div>
  248. </div>
  249. );
  250. }