LLM for Unity  v2.4.1
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMCharacter.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.Threading;
7using System.Threading.Tasks;
8using UnityEditor;
9using UnityEngine;
10
11namespace LLMUnity
12{
13 [DefaultExecutionOrder(-2)]
18 public class LLMCharacter : LLMCaller
19 {
23 [LLM] public string save = "";
25 [LLM] public bool saveCache = false;
27 [LLM] public bool debugPrompt = false;
32 [Model] public int numPredict = 256;
34 [ModelAdvanced] public int slot = -1;
36 [ModelAdvanced] public string grammar = null;
38 [ModelAdvanced] public bool cachePrompt = true;
40 [ModelAdvanced] public int seed = 0;
45 [ModelAdvanced, Float(0f, 2f)] public float temperature = 0.2f;
48 [ModelAdvanced, Int(-1, 100)] public int topK = 40;
53 [ModelAdvanced, Float(0f, 1f)] public float topP = 0.9f;
56 [ModelAdvanced, Float(0f, 1f)] public float minP = 0.05f;
59 [ModelAdvanced, Float(0f, 2f)] public float repeatPenalty = 1.1f;
62 [ModelAdvanced, Float(0f, 1f)] public float presencePenalty = 0f;
65 [ModelAdvanced, Float(0f, 1f)] public float frequencyPenalty = 0f;
66
68 [ModelAdvanced, Float(0f, 1f)] public float tfsZ = 1f;
70 [ModelAdvanced, Float(0f, 1f)] public float typicalP = 1f;
72 [ModelAdvanced, Int(0, 2048)] public int repeatLastN = 64;
74 [ModelAdvanced] public bool penalizeNl = true;
77 [ModelAdvanced] public string penaltyPrompt;
79 [ModelAdvanced, Int(0, 2)] public int mirostat = 0;
81 [ModelAdvanced, Float(0f, 10f)] public float mirostatTau = 5f;
83 [ModelAdvanced, Float(0f, 1f)] public float mirostatEta = 0.1f;
85 [ModelAdvanced, Int(0, 10)] public int nProbs = 0;
87 [ModelAdvanced] public bool ignoreEos = false;
88
90 public int nKeep = -1;
92 public List<string> stop = new List<string>();
95 public Dictionary<int, string> logitBias = null;
96
99 [Chat] public bool stream = true;
101 [Chat] public string playerName = "user";
103 [Chat] public string AIName = "assistant";
105 [TextArea(5, 10), Chat] public string prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
107 public bool setNKeepToPrompt = true;
109 public List<ChatMessage> chat = new List<ChatMessage>();
111 public string grammarString;
112
114 protected SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
115 protected string chatTemplate;
116 protected ChatTemplate template = null;
118
128 public override void Awake()
129 {
130 if (!enabled) return;
131 base.Awake();
132 if (!remote)
133 {
134 int slotFromServer = llm.Register(this);
135 if (slot == -1) slot = slotFromServer;
136 }
137 InitGrammar();
138 InitHistory();
139 }
140
141 protected override void OnValidate()
142 {
143 base.OnValidate();
144 if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) LLMUnitySetup.LogError($"The slot needs to be between 0 and {llm.parallelPrompts-1}, or -1 to be automatically set");
145 }
146
147 protected override string NotValidLLMError()
148 {
149 return base.NotValidLLMError() + $", it is an embedding only model";
150 }
151
157 public override bool IsValidLLM(LLM llmSet)
158 {
159 return !llmSet.embeddingsOnly;
160 }
161
162 protected virtual void InitHistory()
163 {
164 ClearChat();
165 _ = LoadHistory();
166 }
167
168 protected virtual async Task LoadHistory()
169 {
170 if (save == "" || !File.Exists(GetJsonSavePath(save))) return;
171 await chatLock.WaitAsync(); // Acquire the lock
172 try
173 {
174 await Load(save);
175 }
176 finally
177 {
178 chatLock.Release(); // Release the lock
179 }
180 }
181
182 protected virtual string GetSavePath(string filename)
183 {
184 return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/');
185 }
186
192 public virtual string GetJsonSavePath(string filename)
193 {
194 return GetSavePath(filename + ".json");
195 }
196
202 public virtual string GetCacheSavePath(string filename)
203 {
204 return GetSavePath(filename + ".cache");
205 }
206
210 public virtual void ClearChat()
211 {
212 chat.Clear();
213 ChatMessage promptMessage = new ChatMessage { role = "system", content = prompt };
214 chat.Add(promptMessage);
215 }
216
222 public virtual void SetPrompt(string newPrompt, bool clearChat = true)
223 {
225 nKeep = -1;
226 if (clearChat) ClearChat();
227 else chat[0] = new ChatMessage { role = "system", content = prompt };
228 }
229
230 protected virtual bool CheckTemplate()
231 {
232 if (template == null)
233 {
234 LLMUnitySetup.LogError("Template not set!");
235 return false;
236 }
237 return true;
238 }
239
240 protected virtual async Task<bool> InitNKeep()
241 {
242 if (setNKeepToPrompt && nKeep == -1)
243 {
244 if (!CheckTemplate()) return false;
245 string systemPrompt = template.ComputePrompt(new List<ChatMessage>(){chat[0]}, playerName, "", false);
246 List<int> tokens = await Tokenize(systemPrompt);
247 if (tokens == null) return false;
248 SetNKeep(tokens);
249 }
250 return true;
251 }
252
253 protected virtual void InitGrammar()
254 {
255 if (grammar != null && grammar != "")
256 {
257 grammarString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammar));
258 }
259 }
260
261 protected virtual void SetNKeep(List<int> tokens)
262 {
263 // set the tokens to keep
264 nKeep = tokens.Count;
265 }
266
271 public virtual async Task LoadTemplate()
272 {
273 string llmTemplate;
274 if (remote)
275 {
277 }
278 else
279 {
280 llmTemplate = llm.GetTemplate();
281 }
282 if (llmTemplate != chatTemplate)
283 {
284 chatTemplate = llmTemplate;
285 template = chatTemplate == null ? null : ChatTemplate.GetTemplate(chatTemplate);
286 nKeep = -1;
287 }
288 }
289
294 public virtual async void SetGrammar(string path)
295 {
296#if UNITY_EDITOR
297 if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path);
298#endif
299 await LLMUnitySetup.AndroidExtractAsset(path, true);
300 grammar = path;
301 InitGrammar();
302 }
303
304 protected virtual List<string> GetStopwords()
305 {
306 if (!CheckTemplate()) return null;
307 List<string> stopAll = new List<string>(template.GetStop(playerName, AIName));
308 if (stop != null) stopAll.AddRange(stop);
309 return stopAll;
310 }
311
312 protected virtual ChatRequest GenerateRequest(string prompt)
313 {
314 // setup the request struct
316 if (debugPrompt) LLMUnitySetup.Log(prompt);
326 chatRequest.stop = GetStopwords();
344 return chatRequest;
345 }
346
352 public virtual void AddMessage(string role, string content)
353 {
354 // add the question / answer to the chat list, update prompt
355 chat.Add(new ChatMessage { role = role, content = content });
356 }
357
362 public virtual void AddPlayerMessage(string content)
363 {
365 }
366
371 public virtual void AddAIMessage(string content)
372 {
374 }
375
376 protected virtual string ChatContent(ChatResult result)
377 {
378 // get content from a chat result received from the endpoint
379 return result.content.Trim();
380 }
381
382 protected virtual string MultiChatContent(MultiChatResult result)
383 {
384 // get content from a chat result received from the endpoint
385 string response = "";
386 foreach (ChatResult resultPart in result.data)
387 {
388 response += resultPart.content;
389 }
390 return response.Trim();
391 }
392
393 protected virtual string SlotContent(SlotResult result)
394 {
395 // get the tokens from a tokenize result received from the endpoint
396 return result.filename;
397 }
398
399 protected virtual string TemplateContent(TemplateResult result)
400 {
401 // get content from a char result received from the endpoint in open AI format
402 return result.template;
403 }
404
405 protected virtual async Task<string> CompletionRequest(string json, Callback<string> callback = null)
406 {
407 string result = "";
408 if (stream)
409 {
410 result = await PostRequest<MultiChatResult, string>(json, "completion", MultiChatContent, callback);
411 }
412 else
413 {
414 result = await PostRequest<ChatResult, string>(json, "completion", ChatContent, callback);
415 }
416 return result;
417 }
418
431 {
432 // handle a chat message by the user
433 // call the callback function while the answer is received
434 // call the completionCallback function when the answer is fully received
436 if (!CheckTemplate()) return null;
437 if (!await InitNKeep()) return null;
438
439 string json;
440 await chatLock.WaitAsync();
441 try
442 {
444 string prompt = template.ComputePrompt(chat, playerName, AIName);
445 json = JsonUtility.ToJson(GenerateRequest(prompt));
446 chat.RemoveAt(chat.Count - 1);
447 }
448 finally
449 {
450 chatLock.Release();
451 }
452
453 string result = await CompletionRequest(json, callback);
454
455 if (addToHistory && result != null)
456 {
457 await chatLock.WaitAsync();
458 try
459 {
462 }
463 finally
464 {
465 chatLock.Release();
466 }
467 if (save != "") _ = Save(save);
468 }
469
470 completionCallback?.Invoke();
471 return result;
472 }
473
484 {
485 // handle a completion request by the user
486 // call the callback function while the answer is received
487 // call the completionCallback function when the answer is fully received
489
490 string json = JsonUtility.ToJson(GenerateRequest(prompt));
491 string result = await CompletionRequest(json, callback);
492 completionCallback?.Invoke();
493 return result;
494 }
495
507 {
509 if (!CheckTemplate()) return;
510 if (!await InitNKeep()) return;
511
512 string prompt = template.ComputePrompt(chat, playerName, AIName);
513 ChatRequest request = GenerateRequest(prompt);
515 string json = JsonUtility.ToJson(request);
516 await CompletionRequest(json);
517 completionCallback?.Invoke();
518 }
519
525 {
526 return await PostRequest<TemplateResult, string>("{}", "template", TemplateContent);
527 }
528
529 protected override void CancelRequestsLocal()
530 {
531 if (slot >= 0) llm.CancelRequest(slot);
532 }
533
534 protected virtual async Task<string> Slot(string filepath, string action)
535 {
540 string json = JsonUtility.ToJson(slotRequest);
541 return await PostRequest<SlotResult, string>(json, "slots", SlotContent);
542 }
543
549 public virtual async Task<string> Save(string filename)
550 {
551 string filepath = GetJsonSavePath(filename);
552 string dirname = Path.GetDirectoryName(filepath);
553 if (!Directory.Exists(dirname)) Directory.CreateDirectory(dirname);
554 string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat.GetRange(1, chat.Count - 1) });
555 File.WriteAllText(filepath, json);
556
557 string cachepath = GetCacheSavePath(filename);
558 if (remote || !saveCache) return null;
559 string result = await Slot(cachepath, "save");
560 return result;
561 }
562
568 public virtual async Task<string> Load(string filename)
569 {
570 string filepath = GetJsonSavePath(filename);
571 if (!File.Exists(filepath))
572 {
573 LLMUnitySetup.LogError($"File {filepath} does not exist.");
574 return null;
575 }
576 string json = File.ReadAllText(filepath);
578 ClearChat();
579 chat.AddRange(chatHistory);
580 LLMUnitySetup.Log($"Loaded {filepath}");
581
582 string cachepath = GetCacheSavePath(filename);
583 if (remote || !saveCache || !File.Exists(GetSavePath(cachepath))) return null;
584 string result = await Slot(cachepath, "restore");
585 return result;
586 }
587
588 protected override async Task<Ret> PostRequestLocal<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
589 {
590 if (endpoint != "completion") return await base.PostRequestLocal(json, endpoint, getContent, callback);
591
592 while (!llm.failed && !llm.started) await Task.Yield();
593
594 string callResult = null;
595 bool callbackCalled = false;
596 if (llm.embeddingsOnly) LLMUnitySetup.LogError("The LLM can't be used for completion, only for embeddings");
597 else
598 {
600 if (stream && callback != null)
601 {
602 if (typeof(Ret) == typeof(string))
603 {
605 {
607 };
608 }
609 else
610 {
611 LLMUnitySetup.LogError($"wrong callback type, should be string");
612 }
613 callbackCalled = true;
614 }
616 }
617
619 if (!callbackCalled) callback?.Invoke(result);
620 return result;
621 }
622 }
623
625 [Serializable]
626 public class ChatListWrapper
627 {
628 public List<ChatMessage> chat;
629 }
631}
Class implementing the skeleton of a chat template.
static ChatTemplate GetTemplate(string template)
Creates the chat template based on the provided chat template name.
Class implementing calling of LLM functions (local and remote).
Definition LLMCaller.cs:17
virtual async Task< List< int > > Tokenize(string query, Callback< List< int > > callback=null)
Tokenises the provided query.
Definition LLMCaller.cs:343
bool remote
toggle to use remote LLM server or local LLM
Definition LLMCaller.cs:21
Class implementing the LLM characters.
bool cachePrompt
option to cache the prompt as it is being created by the chat to avoid reprocessing the entire prompt...
int slot
specify which slot of the server to use for computation (affects caching)
virtual async Task< string > Chat(string query, Callback< string > callback=null, EmptyCallback completionCallback=null, bool addToHistory=true)
Chat functionality of the LLM. It calls the LLM completion based on the provided query including the ...
List< string > stop
stopwords to stop the LLM in addition to the default stopwords from the chat template.
float topP
top-p sampling (1.0 = disabled). The top p value controls the cumulative probability of generated tok...
virtual async Task LoadTemplate()
Loads the chat template of the LLMCharacter.
string AIName
the name of the AI
int nProbs
if greater than 0, the response also contains the probabilities of top N tokens for each generated to...
bool ignoreEos
ignore end of stream token and continue generating.
float mirostatTau
set the Mirostat target entropy, parameter tau.
int numPredict
number of tokens to predict (-1 = infinity, -2 = until context filled). This is the amount of tokens ...
string prompt
a description of the AI role. This defines the LLMCharacter system prompt
float temperature
LLM temperature, lower values give more deterministic answers. The temperature setting adjusts how ra...
override void Awake()
The Unity Awake function that initializes the state before the application starts....
float presencePenalty
repeated token presence penalty (0.0 = disabled). Positive values penalize new tokens based on whethe...
virtual async Task< string > AskTemplate()
Asks the LLM for the chat template to use.
string playerName
the name of the player
virtual string GetCacheSavePath(string filename)
Allows to get the save path of the LLM cache based on the provided filename or relative path.
float mirostatEta
set the Mirostat learning rate, parameter eta.
float minP
minimum probability for a token to be used. The probability is defined relative to the probability of...
float typicalP
enable locally typical sampling with parameter p (1.0 = disabled).
virtual async Task< string > Load(string filename)
Load the chat history and cache from the provided filename / relative path.
int nKeep
number of tokens to retain from the prompt when the model runs out of context (-1 = LLMCharacter prom...
virtual async void SetGrammar(string path)
Sets the grammar file of the LLMCharacter.
bool penalizeNl
penalize newline tokens when applying the repeat penalty.
bool debugPrompt
select to log the constructed prompt the Unity Editor.
bool saveCache
toggle to save the LLM cache. This speeds up the prompt calculation but also requires ~100MB of space...
int topK
top-k sampling (0 = disabled). The top k value controls the top k most probable tokens at each step o...
string grammar
grammar file used for the LLM in .cbnf format (relative to the Assets/StreamingAssets folder)
virtual async Task< string > Save(string filename)
Saves the chat history and cache to the provided filename / relative path.
virtual void AddAIMessage(string content)
Allows to add a AI message in the chat history.
string penaltyPrompt
prompt for the purpose of the penalty evaluation. Can be either null, a string or an array of numbers...
override bool IsValidLLM(LLM llmSet)
Checks if a LLM is valid for the LLMCaller.
virtual void AddPlayerMessage(string content)
Allows to add a player message in the chat history.
float repeatPenalty
control the repetition of token sequences in the generated text. The penalty is applied to repeated t...
virtual string GetJsonSavePath(string filename)
Allows to get the save path of the chat history based on the provided filename or relative path.
bool stream
option to receive the reply from the model as it is produced (recommended!). If it is not selected,...
float frequencyPenalty
repeated token frequency penalty (0.0 = disabled). Positive values penalize new tokens based on their...
virtual async Task Warmup(EmptyCallback completionCallback=null)
Allow to warm-up a model by processing the prompt. The prompt processing will be cached (if cacheProm...
int mirostat
enable Mirostat sampling, controlling perplexity during text generation (0 = disabled,...
bool setNKeepToPrompt
option to set the number of tokens to retain from the prompt (nKeep) based on the LLMCharacter system...
int repeatLastN
last n tokens to consider for penalizing repetition (0 = disabled, -1 = ctx-size).
int seed
seed for reproducibility. For random results every time set to -1.
float tfsZ
enable tail free sampling with parameter z (1.0 = disabled).
string grammarString
the grammar to use
Dictionary< int, string > logitBias
the logit bias option allows to manually adjust the likelihood of specific tokens appearing in the ge...
virtual async Task< string > Complete(string prompt, Callback< string > callback=null, EmptyCallback completionCallback=null)
Pure completion functionality of the LLM. It calls the LLM completion based solely on the provided pr...
virtual void AddMessage(string role, string content)
Allows to add a message in the chat history.
virtual void ClearChat()
Clear the chat of the LLMCharacter.
List< ChatMessage > chat
the chat history as list of chat messages
string save
file to save the chat history. The file is saved only for Chat calls with addToHistory set to true....
virtual void SetPrompt(string newPrompt, bool clearChat=true)
Set the system prompt for the LLMCharacter.
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
Definition LLM.cs:19
string GetTemplate()
Returns the chat template of the LLM.
Definition LLM.cs:387
void CancelRequest(int id_slot)
Allows to cancel the requests in a specific slot of the LLM.
Definition LLM.cs:785
int parallelPrompts
number of prompts that can happen in parallel (-1 = number of LLMCaller objects)
Definition LLM.cs:35
bool started
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
Definition LLM.cs:44
async Task< string > Completion(string json, Callback< string > streamCallback=null)
Allows to use the chat and completion functionality of the LLM.
Definition LLM.cs:767
int Register(LLMCaller llmCaller)
Registers a local LLMCaller object. This allows to bind the LLMCaller "client" to a specific slot of ...
Definition LLM.cs:557
bool failed
Boolean set to true if the server has failed to start.
Definition LLM.cs:46