LLM for Unity  v2.4.2
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMCharacter.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.Threading;
7using System.Threading.Tasks;
8using UnityEditor;
9using UnityEngine;
10
11namespace LLMUnity
12{
13 [DefaultExecutionOrder(-2)]
18 public class LLMCharacter : LLMCaller
19 {
22 [Tooltip("file to save the chat history. The file will be saved within the persistentDataPath directory.")]
23 [LLM] public string save = "";
25 [Tooltip("save the LLM cache. Speeds up the prompt calculation when reloading from history but also requires ~100MB of space per character.")]
26 [LLM] public bool saveCache = false;
28 [Tooltip("log the constructed prompt the Unity Editor.")]
29 [LLM] public bool debugPrompt = false;
31 [Tooltip("maximum number of tokens that the LLM will predict (-1 = infinity, -2 = until context filled).")]
32 [Model] public int numPredict = 256;
34 [Tooltip("slot of the server to use for computation (affects caching)")]
35 [ModelAdvanced] public int slot = -1;
37 [Tooltip("grammar file used for the LLMCharacter (.gbnf format)")]
38 [ModelAdvanced] public string grammar = null;
40 [Tooltip("cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!)")]
41 [ModelAdvanced] public bool cachePrompt = true;
43 [Tooltip("seed for reproducibility (-1 = no reproducibility).")]
44 [ModelAdvanced] public int seed = 0;
46 [Tooltip("LLM temperature, lower values give more deterministic answers.")]
47 [ModelAdvanced, Float(0f, 2f)] public float temperature = 0.2f;
51 [Tooltip("Top-k sampling selects the next token only from the top k most likely predicted tokens (0 = disabled). Higher values lead to more diverse text, while lower value will generate more focused and conservative text. ")]
52 [ModelAdvanced, Int(-1, 100)] public int topK = 40;
56 [Tooltip("Top-p sampling selects the next token from a subset of tokens that together have a cumulative probability of at least p (1.0 = disabled). Higher values lead to more diverse text, while lower value will generate more focused and conservative text. ")]
57 [ModelAdvanced, Float(0f, 1f)] public float topP = 0.9f;
59 [Tooltip("minimum probability for a token to be used.")]
60 [ModelAdvanced, Float(0f, 1f)] public float minP = 0.05f;
62 [Tooltip("Penalty based on repeated tokens to control the repetition of token sequences in the generated text.")]
63 [ModelAdvanced, Float(0f, 2f)] public float repeatPenalty = 1.1f;
65 [Tooltip("Penalty based on token presence in previous responses to control the repetition of token sequences in the generated text. (0.0 = disabled).")]
66 [ModelAdvanced, Float(0f, 1f)] public float presencePenalty = 0f;
68 [Tooltip("Penalty based on token frequency in previous responses to control the repetition of token sequences in the generated text. (0.0 = disabled).")]
69 [ModelAdvanced, Float(0f, 1f)] public float frequencyPenalty = 0f;
71 [Tooltip("enable locally typical sampling (1.0 = disabled). Higher values will promote more contextually coherent tokens, while lower values will promote more diverse tokens.")]
72 [ModelAdvanced, Float(0f, 1f)] public float typicalP = 1f;
74 [Tooltip("last n tokens to consider for penalizing repetition (0 = disabled, -1 = ctx-size).")]
75 [ModelAdvanced, Int(0, 2048)] public int repeatLastN = 64;
77 [Tooltip("penalize newline tokens when applying the repeat penalty.")]
78 [ModelAdvanced] public bool penalizeNl = true;
80 [Tooltip("prompt for the purpose of the penalty evaluation. Can be either null, a string or an array of numbers representing tokens (null/'' = use original prompt)")]
81 [ModelAdvanced] public string penaltyPrompt;
83 [Tooltip("enable Mirostat sampling, controlling perplexity during text generation (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).")]
84 [ModelAdvanced, Int(0, 2)] public int mirostat = 0;
86 [Tooltip("The Mirostat target entropy (tau) controls the balance between coherence and diversity in the generated text.")]
87 [ModelAdvanced, Float(0f, 10f)] public float mirostatTau = 5f;
89 [Tooltip("The Mirostat learning rate (eta) controls how quickly the algorithm responds to feedback from the generated text.")]
90 [ModelAdvanced, Float(0f, 1f)] public float mirostatEta = 0.1f;
92 [Tooltip("if greater than 0, the response also contains the probabilities of top N tokens for each generated token.")]
93 [ModelAdvanced, Int(0, 10)] public int nProbs = 0;
95 [Tooltip("ignore end of stream token and continue generating.")]
96 [ModelAdvanced] public bool ignoreEos = false;
98 [Tooltip("number of tokens to retain from the prompt when the model runs out of context (-1 = LLMCharacter prompt tokens if setNKeepToPrompt is set to true).")]
99 public int nKeep = -1;
101 [Tooltip("stopwords to stop the LLM in addition to the default stopwords from the chat template.")]
102 public List<string> stop = new List<string>();
105 [Tooltip("the logit bias option allows to manually adjust the likelihood of specific tokens appearing in the generated text. By providing a token ID and a positive or negative bias value, you can increase or decrease the probability of that token being generated.")]
106 public Dictionary<int, string> logitBias = null;
109 [Tooltip("Receive the reply from the model as it is produced (recommended!). If not selected, the full reply from the model is received in one go")]
110 [Chat] public bool stream = true;
112 [Tooltip("the name of the player")]
113 [Chat] public string playerName = "user";
115 [Tooltip("the name of the AI")]
116 [Chat] public string AIName = "assistant";
118 [Tooltip("a description of the AI role (system prompt)")]
119 [TextArea(5, 10), Chat] public string prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
121 [Tooltip("set the number of tokens to always retain from the prompt (nKeep) based on the LLMCharacter system prompt")]
122 public bool setNKeepToPrompt = true;
124 [Tooltip("the chat history as list of chat messages")]
125 public List<ChatMessage> chat = new List<ChatMessage>();
127 [Tooltip("the grammar to use")]
128 public string grammarString;
129
131 protected SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
132 protected string chatTemplate;
133 protected ChatTemplate template = null;
135
145 public override void Awake()
146 {
147 if (!enabled) return;
148 base.Awake();
149 if (!remote)
150 {
151 int slotFromServer = llm.Register(this);
152 if (slot == -1) slot = slotFromServer;
153 }
154 InitGrammar();
155 InitHistory();
156 }
157
158 protected override void OnValidate()
159 {
160 base.OnValidate();
161 if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) LLMUnitySetup.LogError($"The slot needs to be between 0 and {llm.parallelPrompts-1}, or -1 to be automatically set");
162 }
163
164 protected override string NotValidLLMError()
165 {
166 return base.NotValidLLMError() + $", it is an embedding only model";
167 }
168
174 public override bool IsValidLLM(LLM llmSet)
175 {
176 return !llmSet.embeddingsOnly;
177 }
178
179 protected virtual void InitHistory()
180 {
181 ClearChat();
182 _ = LoadHistory();
183 }
184
185 protected virtual async Task LoadHistory()
186 {
187 if (save == "" || !File.Exists(GetJsonSavePath(save))) return;
188 await chatLock.WaitAsync(); // Acquire the lock
189 try
190 {
191 await Load(save);
192 }
193 finally
194 {
195 chatLock.Release(); // Release the lock
196 }
197 }
198
199 protected virtual string GetSavePath(string filename)
200 {
201 return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/');
202 }
203
209 public virtual string GetJsonSavePath(string filename)
210 {
211 return GetSavePath(filename + ".json");
212 }
213
219 public virtual string GetCacheSavePath(string filename)
220 {
221 return GetSavePath(filename + ".cache");
222 }
223
227 public virtual void ClearChat()
228 {
229 chat.Clear();
230 ChatMessage promptMessage = new ChatMessage { role = "system", content = prompt };
231 chat.Add(promptMessage);
232 }
233
239 public virtual void SetPrompt(string newPrompt, bool clearChat = true)
240 {
242 nKeep = -1;
243 if (clearChat) ClearChat();
244 else chat[0] = new ChatMessage { role = "system", content = prompt };
245 }
246
247 protected virtual bool CheckTemplate()
248 {
249 if (template == null)
250 {
251 LLMUnitySetup.LogError("Template not set!");
252 return false;
253 }
254 return true;
255 }
256
257 protected virtual async Task<bool> InitNKeep()
258 {
259 if (setNKeepToPrompt && nKeep == -1)
260 {
261 if (!CheckTemplate()) return false;
262 string systemPrompt = template.ComputePrompt(new List<ChatMessage>(){chat[0]}, playerName, "", false);
263 List<int> tokens = await Tokenize(systemPrompt);
264 if (tokens == null) return false;
265 SetNKeep(tokens);
266 }
267 return true;
268 }
269
270 protected virtual void InitGrammar()
271 {
272 if (grammar != null && grammar != "")
273 {
274 grammarString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammar));
275 }
276 }
277
278 protected virtual void SetNKeep(List<int> tokens)
279 {
280 // set the tokens to keep
281 nKeep = tokens.Count;
282 }
283
288 public virtual async Task LoadTemplate()
289 {
290 string llmTemplate;
291 if (remote)
292 {
294 }
295 else
296 {
297 llmTemplate = llm.GetTemplate();
298 }
299 if (llmTemplate != chatTemplate)
300 {
301 chatTemplate = llmTemplate;
302 template = chatTemplate == null ? null : ChatTemplate.GetTemplate(chatTemplate);
303 nKeep = -1;
304 }
305 }
306
311 public virtual async void SetGrammar(string path)
312 {
313#if UNITY_EDITOR
314 if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path);
315#endif
316 await LLMUnitySetup.AndroidExtractAsset(path, true);
317 grammar = path;
318 InitGrammar();
319 }
320
321 protected virtual List<string> GetStopwords()
322 {
323 if (!CheckTemplate()) return null;
324 List<string> stopAll = new List<string>(template.GetStop(playerName, AIName));
325 if (stop != null) stopAll.AddRange(stop);
326 return stopAll;
327 }
328
329 protected virtual ChatRequest GenerateRequest(string prompt)
330 {
331 // setup the request struct
333 if (debugPrompt) LLMUnitySetup.Log(prompt);
343 chatRequest.stop = GetStopwords();
360 return chatRequest;
361 }
362
368 public virtual void AddMessage(string role, string content)
369 {
370 // add the question / answer to the chat list, update prompt
371 chat.Add(new ChatMessage { role = role, content = content });
372 }
373
378 public virtual void AddPlayerMessage(string content)
379 {
381 }
382
387 public virtual void AddAIMessage(string content)
388 {
390 }
391
392 protected virtual string ChatContent(ChatResult result)
393 {
394 // get content from a chat result received from the endpoint
395 return result.content.Trim();
396 }
397
398 protected virtual string MultiChatContent(MultiChatResult result)
399 {
400 // get content from a chat result received from the endpoint
401 string response = "";
402 foreach (ChatResult resultPart in result.data)
403 {
404 response += resultPart.content;
405 }
406 return response.Trim();
407 }
408
409 protected virtual string SlotContent(SlotResult result)
410 {
411 // get the tokens from a tokenize result received from the endpoint
412 return result.filename;
413 }
414
415 protected virtual string TemplateContent(TemplateResult result)
416 {
417 // get content from a char result received from the endpoint in open AI format
418 return result.template;
419 }
420
421 protected virtual async Task<string> CompletionRequest(string json, Callback<string> callback = null)
422 {
423 string result = "";
424 if (stream)
425 {
426 result = await PostRequest<MultiChatResult, string>(json, "completion", MultiChatContent, callback);
427 }
428 else
429 {
430 result = await PostRequest<ChatResult, string>(json, "completion", ChatContent, callback);
431 }
432 return result;
433 }
434
435 protected async Task<ChatRequest> PromptWithQuery(string query)
436 {
437 ChatRequest result = default;
438 await chatLock.WaitAsync();
439 try
440 {
442 string prompt = template.ComputePrompt(chat, playerName, AIName);
443 result = GenerateRequest(prompt);
444 chat.RemoveAt(chat.Count - 1);
445 }
446 finally
447 {
448 chatLock.Release();
449 }
450 return result;
451 }
452
465 {
466 // handle a chat message by the user
467 // call the callback function while the answer is received
468 // call the completionCallback function when the answer is fully received
470 if (!CheckTemplate()) return null;
471 if (!await InitNKeep()) return null;
472
473 string json = JsonUtility.ToJson(await PromptWithQuery(query));
474 string result = await CompletionRequest(json, callback);
475
476 if (addToHistory && result != null)
477 {
478 await chatLock.WaitAsync();
479 try
480 {
483 }
484 finally
485 {
486 chatLock.Release();
487 }
488 if (save != "") _ = Save(save);
489 }
490
491 completionCallback?.Invoke();
492 return result;
493 }
494
505 {
506 // handle a completion request by the user
507 // call the callback function while the answer is received
508 // call the completionCallback function when the answer is fully received
510
511 string json = JsonUtility.ToJson(GenerateRequest(prompt));
512 string result = await CompletionRequest(json, callback);
513 completionCallback?.Invoke();
514 return result;
515 }
516
525 {
527 }
528
538 public virtual async Task Warmup(string query, EmptyCallback completionCallback = null)
539 {
541 if (!CheckTemplate()) return;
542 if (!await InitNKeep()) return;
543
545 if (String.IsNullOrEmpty(query))
546 {
547 string prompt = template.ComputePrompt(chat, playerName, AIName);
548 request = GenerateRequest(prompt);
549 }
550 else
551 {
552 request = await PromptWithQuery(query);
553 }
554
556 string json = JsonUtility.ToJson(request);
557 await CompletionRequest(json);
558 completionCallback?.Invoke();
559 }
560
566 {
567 return await PostRequest<TemplateResult, string>("{}", "template", TemplateContent);
568 }
569
570 protected override void CancelRequestsLocal()
571 {
572 if (slot >= 0) llm.CancelRequest(slot);
573 }
574
575 protected virtual async Task<string> Slot(string filepath, string action)
576 {
581 string json = JsonUtility.ToJson(slotRequest);
582 return await PostRequest<SlotResult, string>(json, "slots", SlotContent);
583 }
584
590 public virtual async Task<string> Save(string filename)
591 {
592 string filepath = GetJsonSavePath(filename);
593 string dirname = Path.GetDirectoryName(filepath);
594 if (!Directory.Exists(dirname)) Directory.CreateDirectory(dirname);
595 string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat.GetRange(1, chat.Count - 1) });
596 File.WriteAllText(filepath, json);
597
598 string cachepath = GetCacheSavePath(filename);
599 if (remote || !saveCache) return null;
600 string result = await Slot(cachepath, "save");
601 return result;
602 }
603
609 public virtual async Task<string> Load(string filename)
610 {
611 string filepath = GetJsonSavePath(filename);
612 if (!File.Exists(filepath))
613 {
614 LLMUnitySetup.LogError($"File {filepath} does not exist.");
615 return null;
616 }
617 string json = File.ReadAllText(filepath);
619 ClearChat();
620 chat.AddRange(chatHistory);
621 LLMUnitySetup.Log($"Loaded {filepath}");
622
623 string cachepath = GetCacheSavePath(filename);
624 if (remote || !saveCache || !File.Exists(GetSavePath(cachepath))) return null;
625 string result = await Slot(cachepath, "restore");
626 return result;
627 }
628
629 protected override async Task<Ret> PostRequestLocal<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
630 {
631 if (endpoint != "completion") return await base.PostRequestLocal(json, endpoint, getContent, callback);
632
633 while (!llm.failed && !llm.started) await Task.Yield();
634
635 string callResult = null;
636 bool callbackCalled = false;
637 if (llm.embeddingsOnly) LLMUnitySetup.LogError("The LLM can't be used for completion, only for embeddings");
638 else
639 {
641 if (stream && callback != null)
642 {
643 if (typeof(Ret) == typeof(string))
644 {
646 {
648 };
649 }
650 else
651 {
652 LLMUnitySetup.LogError($"wrong callback type, should be string");
653 }
654 callbackCalled = true;
655 }
657 }
658
660 if (!callbackCalled) callback?.Invoke(result);
661 return result;
662 }
663 }
664
666 [Serializable]
667 public class ChatListWrapper
668 {
669 public List<ChatMessage> chat;
670 }
672}
Class implementing the skeleton of a chat template.
static ChatTemplate GetTemplate(string template)
Creates the chat template based on the provided chat template name.
Class implementing calling of LLM functions (local and remote).
Definition LLMCaller.cs:17
virtual async Task< List< int > > Tokenize(string query, Callback< List< int > > callback=null)
Tokenises the provided query.
Definition LLMCaller.cs:348
bool remote
use remote LLM server
Definition LLMCaller.cs:23
Class implementing the LLM characters.
bool cachePrompt
cache the processed prompt to avoid reprocessing the entire prompt every time (default: true,...
int slot
slot of the server to use for computation (affects caching)
virtual async Task< string > Chat(string query, Callback< string > callback=null, EmptyCallback completionCallback=null, bool addToHistory=true)
Chat functionality of the LLM. It calls the LLM completion based on the provided query including the ...
List< string > stop
stopwords to stop the LLM in addition to the default stopwords from the chat template.
float topP
Top-p sampling selects the next token from a subset of tokens that together have a cumulative probabi...
virtual async Task LoadTemplate()
Loads the chat template of the LLMCharacter.
string AIName
the name of the AI
int nProbs
if greater than 0, the response also contains the probabilities of top N tokens for each generated to...
bool ignoreEos
ignore end of stream token and continue generating.
float mirostatTau
The Mirostat target entropy (tau) controls the balance between coherence and diversity in the generat...
int numPredict
maximum number of tokens that the LLM will predict (-1 = infinity, -2 = until context filled).
string prompt
a description of the AI role (system prompt)
float temperature
LLM temperature, lower values give more deterministic answers.
override void Awake()
The Unity Awake function that initializes the state before the application starts....
float presencePenalty
Penalty based on token presence in previous responses to control the repetition of token sequences in...
virtual async Task< string > AskTemplate()
Asks the LLM for the chat template to use.
string playerName
the name of the player
virtual string GetCacheSavePath(string filename)
Allows to get the save path of the LLM cache based on the provided filename or relative path.
float mirostatEta
The Mirostat learning rate (eta) controls how quickly the algorithm responds to feedback from the gen...
float minP
minimum probability for a token to be used.
float typicalP
enable locally typical sampling (1.0 = disabled). Higher values will promote more contextually cohere...
virtual async Task< string > Load(string filename)
Load the chat history and cache from the provided filename / relative path.
int nKeep
number of tokens to retain from the prompt when the model runs out of context (-1 = LLMCharacter prom...
virtual async void SetGrammar(string path)
Sets the grammar file of the LLMCharacter.
bool penalizeNl
penalize newline tokens when applying the repeat penalty.
bool debugPrompt
log the constructed prompt the Unity Editor.
bool saveCache
save the LLM cache. Speeds up the prompt calculation when reloading from history but also requires ~1...
int topK
Top-k sampling selects the next token only from the top k most likely predicted tokens (0 = disabled)...
string grammar
grammar file used for the LLMCharacter (.gbnf format)
virtual async Task< string > Save(string filename)
Saves the chat history and cache to the provided filename / relative path.
virtual void AddAIMessage(string content)
Allows to add a AI message in the chat history.
string penaltyPrompt
prompt for the purpose of the penalty evaluation. Can be either null, a string or an array of numbers...
override bool IsValidLLM(LLM llmSet)
Checks if a LLM is valid for the LLMCaller.
virtual void AddPlayerMessage(string content)
Allows to add a player message in the chat history.
float repeatPenalty
Penalty based on repeated tokens to control the repetition of token sequences in the generated text.
virtual string GetJsonSavePath(string filename)
Allows to get the save path of the chat history based on the provided filename or relative path.
bool stream
Receive the reply from the model as it is produced (recommended!). If not selected,...
float frequencyPenalty
Penalty based on token frequency in previous responses to control the repetition of token sequences i...
virtual async Task Warmup(EmptyCallback completionCallback=null)
Allow to warm-up a model by processing the system prompt. The prompt processing will be cached (if ca...
int mirostat
enable Mirostat sampling, controlling perplexity during text generation (0 = disabled,...
bool setNKeepToPrompt
set the number of tokens to always retain from the prompt (nKeep) based on the LLMCharacter system pr...
int repeatLastN
last n tokens to consider for penalizing repetition (0 = disabled, -1 = ctx-size).
int seed
seed for reproducibility (-1 = no reproducibility).
string grammarString
the grammar to use
Dictionary< int, string > logitBias
the logit bias option allows to manually adjust the likelihood of specific tokens appearing in the ge...
virtual async Task< string > Complete(string prompt, Callback< string > callback=null, EmptyCallback completionCallback=null)
Pure completion functionality of the LLM. It calls the LLM completion based solely on the provided pr...
virtual void AddMessage(string role, string content)
Allows to add a message in the chat history.
virtual async Task Warmup(string query, EmptyCallback completionCallback=null)
Allow to warm-up a model by processing the provided prompt without adding it to history....
virtual void ClearChat()
Clear the chat of the LLMCharacter.
List< ChatMessage > chat
the chat history as list of chat messages
string save
file to save the chat history. The file will be saved within the persistentDataPath directory.
virtual void SetPrompt(string newPrompt, bool clearChat=true)
Set the system prompt for the LLMCharacter.
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
Definition LLM.cs:19
string GetTemplate()
Returns the chat template of the LLM.
Definition LLM.cs:399
void CancelRequest(int id_slot)
Allows to cancel the requests in a specific slot of the LLM.
Definition LLM.cs:797
int parallelPrompts
number of prompts that can happen in parallel (-1 = number of LLMCaller objects)
Definition LLM.cs:41
bool started
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
Definition LLM.cs:53
async Task< string > Completion(string json, Callback< string > streamCallback=null)
Allows to use the chat and completion functionality of the LLM.
Definition LLM.cs:779
int Register(LLMCaller llmCaller)
Registers a local LLMCaller object. This allows to bind the LLMCaller "client" to a specific slot of ...
Definition LLM.cs:569
bool failed
Boolean set to true if the server has failed to start.
Definition LLM.cs:55