diff options
Diffstat (limited to 'examples/server/webui/src/utils/app.context.tsx')
-rw-r--r-- | examples/server/webui/src/utils/app.context.tsx | 393 |
1 files changed, 393 insertions, 0 deletions
diff --git a/examples/server/webui/src/utils/app.context.tsx b/examples/server/webui/src/utils/app.context.tsx new file mode 100644 index 00000000..378e24dd --- /dev/null +++ b/examples/server/webui/src/utils/app.context.tsx @@ -0,0 +1,393 @@ +import React, { createContext, useContext, useEffect, useState } from 'react'; +import { + APIMessage, + CanvasData, + Conversation, + Message, + PendingMessage, + ViewingChat, +} from './types'; +import StorageUtils from './storage'; +import { + filterThoughtFromMsgs, + normalizeMsgsForAPI, + getSSEStreamAsync, +} from './misc'; +import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config'; +import { matchPath, useLocation, useNavigate } from 'react-router'; + +class Timer { + static timercount = 1; +} +interface AppContextValue { + // conversations and messages + viewingChat: ViewingChat | null; + pendingMessages: Record<Conversation['id'], PendingMessage>; + isGenerating: (convId: string) => boolean; + sendMessage: ( + convId: string | null, + leafNodeId: Message['id'] | null, + content: string, + extra: Message['extra'], + onChunk: CallbackGeneratedChunk + ) => Promise<boolean>; + stopGenerating: (convId: string) => void; + replaceMessageAndGenerate: ( + convId: string, + parentNodeId: Message['id'], // the parent node of the message to be replaced + content: string | null, + extra: Message['extra'], + onChunk: CallbackGeneratedChunk + ) => Promise<void>; + + // canvas + canvasData: CanvasData | null; + setCanvasData: (data: CanvasData | null) => void; + + // config + config: typeof CONFIG_DEFAULT; + saveConfig: (config: typeof CONFIG_DEFAULT) => void; + showSettings: boolean; + setShowSettings: (show: boolean) => void; +} + +// this callback is used for scrolling to the bottom of the chat and switching to the last node +export type CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => void; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const AppContext = createContext<AppContextValue>({} as any); + +const getViewingChat = async (convId: string): Promise<ViewingChat | null> => { + const conv = await StorageUtils.getOneConversation(convId); + if (!conv) return null; + return { + conv: conv, + // all messages from all branches, not filtered by last node + messages: await StorageUtils.getMessages(convId), + }; +}; + +export const AppContextProvider = ({ + children, +}: { + children: React.ReactElement; +}) => { + const { pathname } = useLocation(); + const navigate = useNavigate(); + const params = matchPath('/chat/:convId', pathname); + const convId = params?.params?.convId; + + const [viewingChat, setViewingChat] = useState<ViewingChat | null>(null); + const [pendingMessages, setPendingMessages] = useState< + Record<Conversation['id'], PendingMessage> + >({}); + const [aborts, setAborts] = useState< + Record<Conversation['id'], AbortController> + >({}); + const [config, setConfig] = useState(StorageUtils.getConfig()); + const [canvasData, setCanvasData] = useState<CanvasData | null>(null); + const [showSettings, setShowSettings] = useState(false); + + // handle change when the convId from URL is changed + useEffect(() => { + // also reset the canvas data + setCanvasData(null); + const handleConversationChange = async (changedConvId: string) => { + if (changedConvId !== convId) return; + setViewingChat(await getViewingChat(changedConvId)); + }; + StorageUtils.onConversationChanged(handleConversationChange); + getViewingChat(convId ?? '').then(setViewingChat); + return () => { + StorageUtils.offConversationChanged(handleConversationChange); + }; + }, [convId]); + + const setPending = (convId: string, pendingMsg: PendingMessage | null) => { + // if pendingMsg is null, remove the key from the object + if (!pendingMsg) { + setPendingMessages((prev) => { + const newState = { ...prev }; + delete newState[convId]; + return newState; + }); + } else { + setPendingMessages((prev) => ({ ...prev, [convId]: pendingMsg })); + } + }; + + const setAbort = (convId: string, controller: AbortController | null) => { + if (!controller) { + setAborts((prev) => { + const newState = { ...prev }; + delete newState[convId]; + return newState; + }); + } else { + setAborts((prev) => ({ ...prev, [convId]: controller })); + } + }; + + //////////////////////////////////////////////////////////////////////// + // public functions + + const isGenerating = (convId: string) => !!pendingMessages[convId]; + + const generateMessage = async ( + convId: string, + leafNodeId: Message['id'], + onChunk: CallbackGeneratedChunk + ) => { + if (isGenerating(convId)) return; + + const config = StorageUtils.getConfig(); + const currConversation = await StorageUtils.getOneConversation(convId); + if (!currConversation) { + throw new Error('Current conversation is not found'); + } + + const currMessages = StorageUtils.filterByLeafNodeId( + await StorageUtils.getMessages(convId), + leafNodeId, + false + ); + const abortController = new AbortController(); + setAbort(convId, abortController); + + if (!currMessages) { + throw new Error('Current messages are not found'); + } + + const pendingId = Date.now() + Timer.timercount + 1; + Timer.timercount=Timer.timercount+2; + let pendingMsg: PendingMessage = { + id: pendingId, + convId, + type: 'text', + timestamp: pendingId, + role: 'assistant', + content: null, + parent: leafNodeId, + children: [], + }; + setPending(convId, pendingMsg); + + try { + // prepare messages for API + let messages: APIMessage[] = [ + ...(config.systemMessage.length === 0 + ? [] + : [{ role: 'system', content: config.systemMessage } as APIMessage]), + ...normalizeMsgsForAPI(currMessages), + ]; + if (config.excludeThoughtOnReq) { + messages = filterThoughtFromMsgs(messages); + } + if (isDev) console.log({ messages }); + + // prepare params + const params = { + messages, + stream: true, + cache_prompt: true, + samplers: config.samplers, + temperature: config.temperature, + dynatemp_range: config.dynatemp_range, + dynatemp_exponent: config.dynatemp_exponent, + top_k: config.top_k, + top_p: config.top_p, + min_p: config.min_p, + typical_p: config.typical_p, + xtc_probability: config.xtc_probability, + xtc_threshold: config.xtc_threshold, + repeat_last_n: config.repeat_last_n, + repeat_penalty: config.repeat_penalty, + presence_penalty: config.presence_penalty, + frequency_penalty: config.frequency_penalty, + dry_multiplier: config.dry_multiplier, + dry_base: config.dry_base, + dry_allowed_length: config.dry_allowed_length, + dry_penalty_last_n: config.dry_penalty_last_n, + max_tokens: config.max_tokens, + timings_per_token: !!config.showTokensPerSecond, + ...(config.custom.length ? JSON.parse(config.custom) : {}), + }; + + // send request + const fetchResponse = await fetch(`${BASE_URL}/v1/chat/completions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...(config.apiKey + ? { Authorization: `Bearer ${config.apiKey}` } + : {}), + }, + body: JSON.stringify(params), + signal: abortController.signal, + }); + if (fetchResponse.status !== 200) { + const body = await fetchResponse.json(); + throw new Error(body?.error?.message || 'Unknown error'); + } + const chunks = getSSEStreamAsync(fetchResponse); + for await (const chunk of chunks) { + // const stop = chunk.stop; + if (chunk.error) { + throw new Error(chunk.error?.message || 'Unknown error'); + } + const addedContent = chunk.choices[0].delta.content; + const lastContent = pendingMsg.content || ''; + if (addedContent) { + pendingMsg = { + ...pendingMsg, + content: lastContent + addedContent, + }; + } + const timings = chunk.timings; + if (timings && config.showTokensPerSecond) { + // only extract what's really needed, to save some space + pendingMsg.timings = { + prompt_n: timings.prompt_n, + prompt_ms: timings.prompt_ms, + predicted_n: timings.predicted_n, + predicted_ms: timings.predicted_ms, + }; + } + setPending(convId, pendingMsg); + onChunk(); // don't need to switch node for pending message + } + } catch (err) { + setPending(convId, null); + if ((err as Error).name === 'AbortError') { + // user stopped the generation via stopGeneration() function + // we can safely ignore this error + } else { + console.error(err); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + alert((err as any)?.message ?? 'Unknown error'); + //throw err; // rethrow + } + } + finally { + if (pendingMsg.content !== null) { + await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId); + } + } + setPending(convId, null); + onChunk(pendingId); // trigger scroll to bottom and switch to the last node + }; + + const sendMessage = async ( + convId: string | null, + leafNodeId: Message['id'] | null, + content: string, + extra: Message['extra'], + onChunk: CallbackGeneratedChunk + ): Promise<boolean> => { + if (isGenerating(convId ?? '') || content.trim().length === 0) return false; + + if (convId === null || convId.length === 0 || leafNodeId === null) { + const conv = await StorageUtils.createConversation( + content.substring(0, 256) + ); + convId = conv.id; + leafNodeId = conv.currNode; + // if user is creating a new conversation, redirect to the new conversation + navigate(`/chat/${convId}`); + } + + const now = Date.now()+Timer.timercount; + Timer.timercount=Timer.timercount + 2; + const currMsgId = now; + StorageUtils.appendMsg( + { + id: currMsgId, + timestamp: now, + type: 'text', + convId, + role: 'user', + content, + extra, + parent: leafNodeId, + children: [], + }, + leafNodeId + ); + onChunk(currMsgId); + + try { + await generateMessage(convId, currMsgId, onChunk); + return true; + } catch (_) { + // TODO: rollback + } + return false; + }; + + const stopGenerating = (convId: string) => { + setPending(convId, null); + aborts[convId]?.abort(); + }; + + // if content is undefined, we remove last assistant message + const replaceMessageAndGenerate = async ( + convId: string, + parentNodeId: Message['id'], // the parent node of the message to be replaced + content: string | null, + extra: Message['extra'], + onChunk: CallbackGeneratedChunk + ) => { + if (isGenerating(convId)) return; + + if (content !== null) { + const now = Date.now(); + const currMsgId = now; + StorageUtils.appendMsg( + { + id: currMsgId, + timestamp: now, + type: 'text', + convId, + role: 'user', + content, + extra, + parent: parentNodeId, + children: [], + }, + parentNodeId + ); + parentNodeId = currMsgId; + } + onChunk(parentNodeId); + + await generateMessage(convId, parentNodeId, onChunk); + }; + + const saveConfig = (config: typeof CONFIG_DEFAULT) => { + StorageUtils.setConfig(config); + setConfig(config); + }; + + return ( + <AppContext.Provider + value={{ + isGenerating, + viewingChat, + pendingMessages, + sendMessage, + stopGenerating, + replaceMessageAndGenerate, + canvasData, + setCanvasData, + config, + saveConfig, + showSettings, + setShowSettings, + }} + > + {children} + </AppContext.Provider> + ); +}; + +export const useAppContext = () => useContext(AppContext); |