ChatScreen.tsx 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. import { useEffect, useMemo, useRef, useState } from 'react';
  2. import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
  3. import ChatMessage from './ChatMessage';
  4. import { CanvasType, Message, PendingMessage } from '../utils/types';
  5. import { classNames, cleanCurrentUrl } from '../utils/misc';
  6. import CanvasPyInterpreter from './CanvasPyInterpreter';
  7. import StorageUtils from '../utils/storage';
  8. import { useVSCodeContext } from '../utils/llama-vscode';
  9. import { useChatTextarea, ChatTextareaApi } from './useChatTextarea.ts';
  10. import {
  11. ArrowUpIcon,
  12. StopIcon,
  13. PaperClipIcon,
  14. } from '@heroicons/react/24/solid';
  15. import {
  16. ChatExtraContextApi,
  17. useChatExtraContext,
  18. } from './useChatExtraContext.tsx';
  19. import Dropzone from 'react-dropzone';
  20. import toast from 'react-hot-toast';
  21. import ChatInputExtraContextItem from './ChatInputExtraContextItem.tsx';
  22. import { scrollToBottom, useChatScroll } from './useChatScroll.tsx';
  23. /**
  24. * A message display is a message node with additional information for rendering.
  25. * For example, siblings of the message node are stored as their last node (aka leaf node).
  26. */
  27. export interface MessageDisplay {
  28. msg: Message | PendingMessage;
  29. siblingLeafNodeIds: Message['id'][];
  30. siblingCurrIdx: number;
  31. isPending?: boolean;
  32. }
  33. /**
  34. * If the current URL contains "?m=...", prefill the message input with the value.
  35. * If the current URL contains "?q=...", prefill and SEND the message.
  36. */
  37. const prefilledMsg = {
  38. content() {
  39. const url = new URL(window.location.href);
  40. return url.searchParams.get('m') ?? url.searchParams.get('q') ?? '';
  41. },
  42. shouldSend() {
  43. const url = new URL(window.location.href);
  44. return url.searchParams.has('q');
  45. },
  46. clear() {
  47. cleanCurrentUrl(['m', 'q']);
  48. },
  49. };
  50. function getListMessageDisplay(
  51. msgs: Readonly<Message[]>,
  52. leafNodeId: Message['id']
  53. ): MessageDisplay[] {
  54. const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
  55. const res: MessageDisplay[] = [];
  56. const nodeMap = new Map<Message['id'], Message>();
  57. for (const msg of msgs) {
  58. nodeMap.set(msg.id, msg);
  59. }
  60. // find leaf node from a message node
  61. const findLeafNode = (msgId: Message['id']): Message['id'] => {
  62. let currNode: Message | undefined = nodeMap.get(msgId);
  63. while (currNode) {
  64. if (currNode.children.length === 0) break;
  65. currNode = nodeMap.get(currNode.children.at(-1) ?? -1);
  66. }
  67. return currNode?.id ?? -1;
  68. };
  69. // traverse the current nodes
  70. for (const msg of currNodes) {
  71. const parentNode = nodeMap.get(msg.parent ?? -1);
  72. if (!parentNode) continue;
  73. const siblings = parentNode.children;
  74. if (msg.type !== 'root') {
  75. res.push({
  76. msg,
  77. siblingLeafNodeIds: siblings.map(findLeafNode),
  78. siblingCurrIdx: siblings.indexOf(msg.id),
  79. });
  80. }
  81. }
  82. return res;
  83. }
  84. export default function ChatScreen() {
  85. const {
  86. viewingChat,
  87. sendMessage,
  88. isGenerating,
  89. stopGenerating,
  90. pendingMessages,
  91. canvasData,
  92. replaceMessageAndGenerate,
  93. } = useAppContext();
  94. const textarea: ChatTextareaApi = useChatTextarea(prefilledMsg.content());
  95. const extraContext = useChatExtraContext();
  96. useVSCodeContext(textarea, extraContext);
  97. const msgListRef = useRef<HTMLDivElement>(null);
  98. useChatScroll(msgListRef);
  99. // keep track of leaf node for rendering
  100. const [currNodeId, setCurrNodeId] = useState<number>(-1);
  101. const messages: MessageDisplay[] = useMemo(() => {
  102. if (!viewingChat) return [];
  103. else return getListMessageDisplay(viewingChat.messages, currNodeId);
  104. }, [currNodeId, viewingChat]);
  105. const currConvId = viewingChat?.conv.id ?? null;
  106. const pendingMsg: PendingMessage | undefined =
  107. pendingMessages[currConvId ?? ''];
  108. useEffect(() => {
  109. // reset to latest node when conversation changes
  110. setCurrNodeId(-1);
  111. // scroll to bottom when conversation changes
  112. scrollToBottom(false, 1);
  113. }, [currConvId]);
  114. const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
  115. if (currLeafNodeId) {
  116. setCurrNodeId(currLeafNodeId);
  117. }
  118. // useChatScroll will handle the auto scroll
  119. };
  120. const sendNewMessage = async () => {
  121. const lastInpMsg = textarea.value();
  122. if (lastInpMsg.trim().length === 0 || isGenerating(currConvId ?? '')) {
  123. toast.error('Please enter a message');
  124. return;
  125. }
  126. textarea.setValue('');
  127. scrollToBottom(false);
  128. setCurrNodeId(-1);
  129. // get the last message node
  130. const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
  131. if (
  132. !(await sendMessage(
  133. currConvId,
  134. lastMsgNodeId,
  135. lastInpMsg,
  136. extraContext.items,
  137. onChunk
  138. ))
  139. ) {
  140. // restore the input message if failed
  141. textarea.setValue(lastInpMsg);
  142. }
  143. // OK
  144. extraContext.clearItems();
  145. };
  146. // for vscode context
  147. textarea.refOnSubmit.current = sendNewMessage;
  148. const handleEditMessage = async (msg: Message, content: string) => {
  149. if (!viewingChat) return;
  150. setCurrNodeId(msg.id);
  151. scrollToBottom(false);
  152. await replaceMessageAndGenerate(
  153. viewingChat.conv.id,
  154. msg.parent,
  155. content,
  156. msg.extra,
  157. onChunk
  158. );
  159. setCurrNodeId(-1);
  160. scrollToBottom(false);
  161. };
  162. const handleRegenerateMessage = async (msg: Message) => {
  163. if (!viewingChat) return;
  164. setCurrNodeId(msg.parent);
  165. scrollToBottom(false);
  166. await replaceMessageAndGenerate(
  167. viewingChat.conv.id,
  168. msg.parent,
  169. null,
  170. msg.extra,
  171. onChunk
  172. );
  173. setCurrNodeId(-1);
  174. scrollToBottom(false);
  175. };
  176. const hasCanvas = !!canvasData;
  177. useEffect(() => {
  178. if (prefilledMsg.shouldSend()) {
  179. // send the prefilled message if needed
  180. sendNewMessage();
  181. } else {
  182. // otherwise, focus on the input
  183. textarea.focus();
  184. }
  185. prefilledMsg.clear();
  186. // no need to keep track of sendNewMessage
  187. // eslint-disable-next-line react-hooks/exhaustive-deps
  188. }, [textarea.ref]);
  189. // due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
  190. const pendingMsgDisplay: MessageDisplay[] =
  191. pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
  192. ? [
  193. {
  194. msg: pendingMsg,
  195. siblingLeafNodeIds: [],
  196. siblingCurrIdx: 0,
  197. isPending: true,
  198. },
  199. ]
  200. : [];
  201. return (
  202. <div
  203. className={classNames({
  204. 'grid lg:gap-8 grow transition-[300ms]': true,
  205. 'grid-cols-[1fr_0fr] lg:grid-cols-[1fr_1fr]': hasCanvas, // adapted for mobile
  206. 'grid-cols-[1fr_0fr]': !hasCanvas,
  207. })}
  208. >
  209. <div
  210. className={classNames({
  211. 'flex flex-col w-full max-w-[900px] mx-auto': true,
  212. 'hidden lg:flex': hasCanvas, // adapted for mobile
  213. flex: !hasCanvas,
  214. })}
  215. >
  216. {/* chat messages */}
  217. <div id="messages-list" className="grow" ref={msgListRef}>
  218. <div className="mt-auto flex flex-col items-center">
  219. {/* placeholder to shift the message to the bottom */}
  220. {viewingChat ? (
  221. ''
  222. ) : (
  223. <>
  224. <div className="mb-4">Send a message to start</div>
  225. <ServerInfo />
  226. </>
  227. )}
  228. </div>
  229. {[...messages, ...pendingMsgDisplay].map((msg) => (
  230. <ChatMessage
  231. key={msg.msg.id}
  232. msg={msg.msg}
  233. siblingLeafNodeIds={msg.siblingLeafNodeIds}
  234. siblingCurrIdx={msg.siblingCurrIdx}
  235. onRegenerateMessage={handleRegenerateMessage}
  236. onEditMessage={handleEditMessage}
  237. onChangeSibling={setCurrNodeId}
  238. isPending={msg.isPending}
  239. />
  240. ))}
  241. </div>
  242. {/* chat input */}
  243. <ChatInput
  244. textarea={textarea}
  245. extraContext={extraContext}
  246. onSend={sendNewMessage}
  247. onStop={() => stopGenerating(currConvId ?? '')}
  248. isGenerating={isGenerating(currConvId ?? '')}
  249. />
  250. </div>
  251. <div className="w-full sticky top-[7em] h-[calc(100vh-9em)]">
  252. {canvasData?.type === CanvasType.PY_INTERPRETER && (
  253. <CanvasPyInterpreter />
  254. )}
  255. </div>
  256. </div>
  257. );
  258. }
  259. function ServerInfo() {
  260. const { serverProps } = useAppContext();
  261. return (
  262. <div className="card card-sm shadow-sm border-1 border-base-content/20 text-base-content/70 mb-6">
  263. <div className="card-body">
  264. <b>Server Info</b>
  265. <p>
  266. <b>Model</b>: {serverProps?.model_path?.split(/(\\|\/)/).pop()}
  267. <br />
  268. <b>Build</b>: {serverProps?.build_info}
  269. <br />
  270. </p>
  271. </div>
  272. </div>
  273. );
  274. }
  275. function ChatInput({
  276. textarea,
  277. extraContext,
  278. onSend,
  279. onStop,
  280. isGenerating,
  281. }: {
  282. textarea: ChatTextareaApi;
  283. extraContext: ChatExtraContextApi;
  284. onSend: () => void;
  285. onStop: () => void;
  286. isGenerating: boolean;
  287. }) {
  288. const [isDrag, setIsDrag] = useState(false);
  289. return (
  290. <div
  291. className={classNames({
  292. 'flex items-end pt-8 pb-6 sticky bottom-0 bg-base-100': true,
  293. 'opacity-50': isDrag, // simply visual feedback to inform user that the file will be accepted
  294. })}
  295. >
  296. <Dropzone
  297. noClick
  298. onDrop={(files: File[]) => {
  299. setIsDrag(false);
  300. extraContext.onFileAdded(files);
  301. }}
  302. onDragEnter={() => setIsDrag(true)}
  303. onDragLeave={() => setIsDrag(false)}
  304. multiple={true}
  305. >
  306. {({ getRootProps, getInputProps }) => (
  307. <div
  308. className="flex flex-col rounded-xl border-1 border-base-content/30 p-3 w-full"
  309. {...getRootProps()}
  310. >
  311. {!isGenerating && (
  312. <ChatInputExtraContextItem
  313. items={extraContext.items}
  314. removeItem={extraContext.removeItem}
  315. />
  316. )}
  317. <div className="flex flex-row w-full">
  318. <textarea
  319. // Default (mobile): Enable vertical resize, overflow auto for scrolling if needed
  320. // Large screens (lg:): Disable manual resize, apply max-height for autosize limit
  321. className="text-md outline-none border-none w-full resize-vertical lg:resize-none lg:max-h-48 lg:overflow-y-auto" // Adjust lg:max-h-48 as needed (e.g., lg:max-h-60)
  322. placeholder="Type a message (Shift+Enter to add a new line)"
  323. ref={textarea.ref}
  324. onInput={textarea.onInput} // Hook's input handler (will only resize height on lg+ screens)
  325. onKeyDown={(e) => {
  326. if (e.nativeEvent.isComposing || e.keyCode === 229) return;
  327. if (e.key === 'Enter' && !e.shiftKey) {
  328. e.preventDefault();
  329. onSend();
  330. }
  331. }}
  332. id="msg-input"
  333. dir="auto"
  334. // Set a base height of 2 rows for mobile views
  335. // On lg+ screens, the hook will calculate and set the initial height anyway
  336. rows={2}
  337. ></textarea>
  338. {/* buttons area */}
  339. <div className="flex flex-row gap-2 ml-2">
  340. <label
  341. htmlFor="file-upload"
  342. className={classNames({
  343. 'btn w-8 h-8 p-0 rounded-full': true,
  344. 'btn-disabled': isGenerating,
  345. })}
  346. >
  347. <PaperClipIcon className="h-5 w-5" />
  348. </label>
  349. <input
  350. id="file-upload"
  351. type="file"
  352. className="hidden"
  353. disabled={isGenerating}
  354. {...getInputProps()}
  355. hidden
  356. />
  357. {isGenerating ? (
  358. <button
  359. className="btn btn-neutral w-8 h-8 p-0 rounded-full"
  360. onClick={onStop}
  361. >
  362. <StopIcon className="h-5 w-5" />
  363. </button>
  364. ) : (
  365. <button
  366. className="btn btn-primary w-8 h-8 p-0 rounded-full"
  367. onClick={onSend}
  368. >
  369. <ArrowUpIcon className="h-5 w-5" />
  370. </button>
  371. )}
  372. </div>
  373. </div>
  374. </div>
  375. )}
  376. </Dropzone>
  377. </div>
  378. );
  379. }