LLM for Unity  v2.5.2
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMCharacter.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.Threading;
7using System.Threading.Tasks;
8using UnityEditor;
9using UnityEngine;
10
11namespace LLMUnity
12{
13 [DefaultExecutionOrder(-2)]
18 public class LLMCharacter : LLMCaller
19 {
22 [Tooltip("file to save the chat history. The file will be saved within the persistentDataPath directory.")]
23 [LLM] public string save = "";
25 [Tooltip("save the LLM cache. Speeds up the prompt calculation when reloading from history but also requires ~100MB of space per character.")]
26 [LLM] public bool saveCache = false;
28 [Tooltip("log the constructed prompt the Unity Editor.")]
29 [LLM] public bool debugPrompt = false;
31 [Tooltip("maximum number of tokens that the LLM will predict (-1 = infinity).")]
32 [Model] public int numPredict = -1;
34 [Tooltip("slot of the server to use for computation (affects caching)")]
35 [ModelAdvanced] public int slot = -1;
37 [Tooltip("grammar file used for the LLMCharacter (.gbnf format)")]
38 [ModelAdvanced] public string grammar = null;
40 [Tooltip("grammar file used for the LLMCharacter (.json format)")]
41 [ModelAdvanced] public string grammarJSON = null;
43 [Tooltip("cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!)")]
44 [ModelAdvanced] public bool cachePrompt = true;
46 [Tooltip("seed for reproducibility (-1 = no reproducibility).")]
47 [ModelAdvanced] public int seed = 0;
49 [Tooltip("LLM temperature, lower values give more deterministic answers.")]
50 [ModelAdvanced, Float(0f, 2f)] public float temperature = 0.2f;
54 [Tooltip("Top-k sampling selects the next token only from the top k most likely predicted tokens (0 = disabled). Higher values lead to more diverse text, while lower value will generate more focused and conservative text. ")]
55 [ModelAdvanced, Int(-1, 100)] public int topK = 40;
59 [Tooltip("Top-p sampling selects the next token from a subset of tokens that together have a cumulative probability of at least p (1.0 = disabled). Higher values lead to more diverse text, while lower value will generate more focused and conservative text. ")]
60 [ModelAdvanced, Float(0f, 1f)] public float topP = 0.9f;
62 [Tooltip("minimum probability for a token to be used.")]
63 [ModelAdvanced, Float(0f, 1f)] public float minP = 0.05f;
65 [Tooltip("Penalty based on repeated tokens to control the repetition of token sequences in the generated text.")]
66 [ModelAdvanced, Float(0f, 2f)] public float repeatPenalty = 1.1f;
68 [Tooltip("Penalty based on token presence in previous responses to control the repetition of token sequences in the generated text. (0.0 = disabled).")]
69 [ModelAdvanced, Float(0f, 1f)] public float presencePenalty = 0f;
71 [Tooltip("Penalty based on token frequency in previous responses to control the repetition of token sequences in the generated text. (0.0 = disabled).")]
72 [ModelAdvanced, Float(0f, 1f)] public float frequencyPenalty = 0f;
74 [Tooltip("enable locally typical sampling (1.0 = disabled). Higher values will promote more contextually coherent tokens, while lower values will promote more diverse tokens.")]
75 [ModelAdvanced, Float(0f, 1f)] public float typicalP = 1f;
77 [Tooltip("last n tokens to consider for penalizing repetition (0 = disabled, -1 = ctx-size).")]
78 [ModelAdvanced, Int(0, 2048)] public int repeatLastN = 64;
80 [Tooltip("penalize newline tokens when applying the repeat penalty.")]
81 [ModelAdvanced] public bool penalizeNl = true;
83 [Tooltip("prompt for the purpose of the penalty evaluation. Can be either null, a string or an array of numbers representing tokens (null/'' = use original prompt)")]
84 [ModelAdvanced] public string penaltyPrompt;
86 [Tooltip("enable Mirostat sampling, controlling perplexity during text generation (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).")]
87 [ModelAdvanced, Int(0, 2)] public int mirostat = 0;
89 [Tooltip("The Mirostat target entropy (tau) controls the balance between coherence and diversity in the generated text.")]
90 [ModelAdvanced, Float(0f, 10f)] public float mirostatTau = 5f;
92 [Tooltip("The Mirostat learning rate (eta) controls how quickly the algorithm responds to feedback from the generated text.")]
93 [ModelAdvanced, Float(0f, 1f)] public float mirostatEta = 0.1f;
95 [Tooltip("if greater than 0, the response also contains the probabilities of top N tokens for each generated token.")]
96 [ModelAdvanced, Int(0, 10)] public int nProbs = 0;
98 [Tooltip("ignore end of stream token and continue generating.")]
99 [ModelAdvanced] public bool ignoreEos = false;
101 [Tooltip("number of tokens to retain from the prompt when the model runs out of context (-1 = LLMCharacter prompt tokens if setNKeepToPrompt is set to true).")]
102 public int nKeep = -1;
104 [Tooltip("stopwords to stop the LLM in addition to the default stopwords from the chat template.")]
105 public List<string> stop = new List<string>();
108 [Tooltip("the logit bias option allows to manually adjust the likelihood of specific tokens appearing in the generated text. By providing a token ID and a positive or negative bias value, you can increase or decrease the probability of that token being generated.")]
109 public Dictionary<int, string> logitBias = null;
112 [Tooltip("Receive the reply from the model as it is produced (recommended!). If not selected, the full reply from the model is received in one go")]
113 [Chat] public bool stream = true;
115 [Tooltip("the name of the player")]
116 [Chat] public string playerName = "user";
118 [Tooltip("the name of the AI")]
119 [Chat] public string AIName = "assistant";
121 [Tooltip("a description of the AI role (system prompt)")]
122 [TextArea(5, 10), Chat] public string prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
124 [Tooltip("set the number of tokens to always retain from the prompt (nKeep) based on the LLMCharacter system prompt")]
125 public bool setNKeepToPrompt = true;
127 [Tooltip("the chat history as list of chat messages")]
128 public List<ChatMessage> chat = new List<ChatMessage>();
130 [Tooltip("the grammar to use")]
131 public string grammarString;
133 [Tooltip("the grammar to use")]
134 public string grammarJSONString;
135
137 protected SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
138 protected string chatTemplate;
139 protected ChatTemplate template = null;
141
151 public override void Awake()
152 {
153 if (!enabled) return;
154 base.Awake();
155 if (!remote)
156 {
157 int slotFromServer = llm.Register(this);
158 if (slot == -1) slot = slotFromServer;
159 }
160 InitGrammar();
161 InitHistory();
162 }
163
164 protected override void OnValidate()
165 {
166 base.OnValidate();
167 if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) LLMUnitySetup.LogError($"The slot needs to be between 0 and {llm.parallelPrompts-1}, or -1 to be automatically set");
168 }
169
170 protected override string NotValidLLMError()
171 {
172 return base.NotValidLLMError() + $", it is an embedding only model";
173 }
174
180 public override bool IsValidLLM(LLM llmSet)
181 {
182 return !llmSet.embeddingsOnly;
183 }
184
185 protected virtual void InitHistory()
186 {
187 ClearChat();
188 _ = LoadHistory();
189 }
190
191 protected virtual async Task LoadHistory()
192 {
193 if (save == "" || !File.Exists(GetJsonSavePath(save))) return;
194 await chatLock.WaitAsync(); // Acquire the lock
195 try
196 {
197 await Load(save);
198 }
199 finally
200 {
201 chatLock.Release(); // Release the lock
202 }
203 }
204
205 protected virtual string GetSavePath(string filename)
206 {
207 return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/');
208 }
209
215 public virtual string GetJsonSavePath(string filename)
216 {
217 return GetSavePath(filename + ".json");
218 }
219
225 public virtual string GetCacheSavePath(string filename)
226 {
227 return GetSavePath(filename + ".cache");
228 }
229
233 public virtual void ClearChat()
234 {
235 chat.Clear();
236 ChatMessage promptMessage = new ChatMessage { role = "system", content = prompt };
237 chat.Add(promptMessage);
238 }
239
245 public virtual void SetPrompt(string newPrompt, bool clearChat = true)
246 {
248 nKeep = -1;
249 if (clearChat) ClearChat();
250 else chat[0] = new ChatMessage { role = "system", content = prompt };
251 }
252
253 protected virtual bool CheckTemplate()
254 {
255 if (template == null)
256 {
257 LLMUnitySetup.LogError("Template not set!");
258 return false;
259 }
260 return true;
261 }
262
263 protected virtual async Task<bool> InitNKeep()
264 {
265 if (setNKeepToPrompt && nKeep == -1)
266 {
267 if (!CheckTemplate()) return false;
268 string systemPrompt = template.ComputePrompt(new List<ChatMessage>(){chat[0]}, playerName, "", false);
269 List<int> tokens = await Tokenize(systemPrompt);
270 if (tokens == null) return false;
271 SetNKeep(tokens);
272 }
273 return true;
274 }
275
276 protected virtual void InitGrammar()
277 {
278 grammarString = "";
280 if (!String.IsNullOrEmpty(grammar))
281 {
282 grammarString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammar));
283 if (!String.IsNullOrEmpty(grammarJSON))
284 LLMUnitySetup.LogWarning("Both GBNF and JSON grammars are set, only the GBNF will be used");
285 }
286 else if (!String.IsNullOrEmpty(grammarJSON))
287 {
288 grammarJSONString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammarJSON));
289 }
290 }
291
292 protected virtual void SetNKeep(List<int> tokens)
293 {
294 // set the tokens to keep
295 nKeep = tokens.Count;
296 }
297
302 public virtual async Task LoadTemplate()
303 {
304 string llmTemplate;
305 if (remote)
306 {
308 }
309 else
310 {
311 llmTemplate = llm.GetTemplate();
312 }
313 if (llmTemplate != chatTemplate)
314 {
315 chatTemplate = llmTemplate;
316 template = chatTemplate == null ? null : ChatTemplate.GetTemplate(chatTemplate);
317 nKeep = -1;
318 }
319 }
320
325 public virtual async Task SetGrammarFile(string path, bool gnbf)
326 {
327#if UNITY_EDITOR
328 if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path);
329#endif
330 await LLMUnitySetup.AndroidExtractAsset(path, true);
331 if (gnbf) grammar = path;
332 else grammarJSON = path;
333 InitGrammar();
334 }
335
340 public virtual async Task SetGrammar(string path)
341 {
342 await SetGrammarFile(path, true);
343 }
344
349 public virtual async Task SetJSONGrammar(string path)
350 {
351 await SetGrammarFile(path, false);
352 }
353
354 protected virtual List<string> GetStopwords()
355 {
356 if (!CheckTemplate()) return null;
357 List<string> stopAll = new List<string>(template.GetStop(playerName, AIName));
358 if (stop != null) stopAll.AddRange(stop);
359 return stopAll;
360 }
361
362 protected virtual ChatRequest GenerateRequest(string prompt)
363 {
364 // setup the request struct
366 if (debugPrompt) LLMUnitySetup.Log(prompt);
376 chatRequest.stop = GetStopwords();
394 return chatRequest;
395 }
396
402 public virtual void AddMessage(string role, string content)
403 {
404 // add the question / answer to the chat list, update prompt
405 chat.Add(new ChatMessage { role = role, content = content });
406 }
407
412 public virtual void AddPlayerMessage(string content)
413 {
415 }
416
421 public virtual void AddAIMessage(string content)
422 {
424 }
425
426 protected virtual string ChatContent(ChatResult result)
427 {
428 // get content from a chat result received from the endpoint
429 return result.content.Trim();
430 }
431
432 protected virtual string MultiChatContent(MultiChatResult result)
433 {
434 // get content from a chat result received from the endpoint
435 string response = "";
436 foreach (ChatResult resultPart in result.data)
437 {
438 response += resultPart.content;
439 }
440 return response.Trim();
441 }
442
443 protected virtual string SlotContent(SlotResult result)
444 {
445 // get the tokens from a tokenize result received from the endpoint
446 return result.filename;
447 }
448
449 protected virtual string TemplateContent(TemplateResult result)
450 {
451 // get content from a char result received from the endpoint in open AI format
452 return result.template;
453 }
454
455 protected virtual string ChatRequestToJson(ChatRequest request)
456 {
457 string json = JsonUtility.ToJson(request);
458 int grammarIndex = json.LastIndexOf('}');
459 if (!String.IsNullOrEmpty(request.grammar))
460 {
463 int start = grammarToJSON.IndexOf(":\"") + 2;
464 int end = grammarToJSON.LastIndexOf("\"");
465 string grammarSerialised = grammarToJSON.Substring(start, end - start);
466 json = json.Insert(grammarIndex, $",\"grammar\": \"{grammarSerialised}\"");
467 }
468 else if (!String.IsNullOrEmpty(request.json_schema))
469 {
470 json = json.Insert(grammarIndex, $",\"json_schema\":{request.json_schema}");
471 }
472 return json;
473 }
474
475 protected virtual async Task<string> CompletionRequest(ChatRequest request, Callback<string> callback = null)
476 {
477 string json = ChatRequestToJson(request);
478 string result = "";
479 if (stream)
480 {
481 result = await PostRequest<MultiChatResult, string>(json, "completion", MultiChatContent, callback);
482 }
483 else
484 {
485 result = await PostRequest<ChatResult, string>(json, "completion", ChatContent, callback);
486 }
487 return result;
488 }
489
490 protected async Task<ChatRequest> PromptWithQuery(string query)
491 {
492 ChatRequest result = default;
493 await chatLock.WaitAsync();
494 try
495 {
497 string prompt = template.ComputePrompt(chat, playerName, AIName);
498 result = GenerateRequest(prompt);
499 chat.RemoveAt(chat.Count - 1);
500 }
501 finally
502 {
503 chatLock.Release();
504 }
505 return result;
506 }
507
520 {
521 // handle a chat message by the user
522 // call the callback function while the answer is received
523 // call the completionCallback function when the answer is fully received
525 if (!CheckTemplate()) return null;
526 if (!await InitNKeep()) return null;
527
528 ChatRequest request = await PromptWithQuery(query);
529 string result = await CompletionRequest(request, callback);
530
531 if (addToHistory && result != null)
532 {
533 await chatLock.WaitAsync();
534 try
535 {
538 }
539 finally
540 {
541 chatLock.Release();
542 }
543 if (save != "") _ = Save(save);
544 }
545
546 completionCallback?.Invoke();
547 return result;
548 }
549
560 {
561 // handle a completion request by the user
562 // call the callback function while the answer is received
563 // call the completionCallback function when the answer is fully received
565
566 ChatRequest request = GenerateRequest(prompt);
567 string result = await CompletionRequest(request, callback);
568 completionCallback?.Invoke();
569 return result;
570 }
571
580 {
582 }
583
593 public virtual async Task Warmup(string query, EmptyCallback completionCallback = null)
594 {
596 if (!CheckTemplate()) return;
597 if (!await InitNKeep()) return;
598
600 if (String.IsNullOrEmpty(query))
601 {
602 string prompt = template.ComputePrompt(chat, playerName, AIName, false);
603 request = GenerateRequest(prompt);
604 }
605 else
606 {
607 request = await PromptWithQuery(query);
608 }
609
611 await CompletionRequest(request);
612 completionCallback?.Invoke();
613 }
614
620 {
621 return await PostRequest<TemplateResult, string>("{}", "template", TemplateContent);
622 }
623
624 protected override void CancelRequestsLocal()
625 {
626 if (slot >= 0) llm.CancelRequest(slot);
627 }
628
629 protected virtual async Task<string> Slot(string filepath, string action)
630 {
635 string json = JsonUtility.ToJson(slotRequest);
636 return await PostRequest<SlotResult, string>(json, "slots", SlotContent);
637 }
638
644 public virtual async Task<string> Save(string filename)
645 {
646 string filepath = GetJsonSavePath(filename);
647 string dirname = Path.GetDirectoryName(filepath);
648 if (!Directory.Exists(dirname)) Directory.CreateDirectory(dirname);
649 string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat.GetRange(1, chat.Count - 1) });
650 File.WriteAllText(filepath, json);
651
652 string cachepath = GetCacheSavePath(filename);
653 if (remote || !saveCache) return null;
654 string result = await Slot(cachepath, "save");
655 return result;
656 }
657
663 public virtual async Task<string> Load(string filename)
664 {
665 string filepath = GetJsonSavePath(filename);
666 if (!File.Exists(filepath))
667 {
668 LLMUnitySetup.LogError($"File {filepath} does not exist.");
669 return null;
670 }
671 string json = File.ReadAllText(filepath);
673 ClearChat();
674 chat.AddRange(chatHistory);
675 LLMUnitySetup.Log($"Loaded {filepath}");
676
677 string cachepath = GetCacheSavePath(filename);
678 if (remote || !saveCache || !File.Exists(GetSavePath(cachepath))) return null;
679 string result = await Slot(cachepath, "restore");
680 return result;
681 }
682
683 protected override async Task<Ret> PostRequestLocal<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
684 {
685 if (endpoint != "completion") return await base.PostRequestLocal(json, endpoint, getContent, callback);
686
687 while (!llm.failed && !llm.started) await Task.Yield();
688
689 string callResult = null;
690 bool callbackCalled = false;
691 if (llm.embeddingsOnly) LLMUnitySetup.LogError("The LLM can't be used for completion, only for embeddings");
692 else
693 {
695 if (stream && callback != null)
696 {
697 if (typeof(Ret) == typeof(string))
698 {
700 {
702 };
703 }
704 else
705 {
706 LLMUnitySetup.LogError($"wrong callback type, should be string");
707 }
708 callbackCalled = true;
709 }
710 callResult = await llm.Completion(json, callbackString);
711 }
712
714 if (!callbackCalled) callback?.Invoke(result);
715 return result;
716 }
717 }
718
720 [Serializable]
721 public class ChatListWrapper
722 {
723 public List<ChatMessage> chat;
724 }
726}
Class implementing the skeleton of a chat template.
static ChatTemplate GetTemplate(string template)
Creates the chat template based on the provided chat template name.
Class implementing the LLM characters.
bool cachePrompt
cache the processed prompt to avoid reprocessing the entire prompt every time (default: true,...
int slot
slot of the server to use for computation (affects caching)
virtual async Task SetJSONGrammar(string path)
Sets the grammar file of the LLMCharacter (JSON schema)
virtual async Task< string > Chat(string query, Callback< string > callback=null, EmptyCallback completionCallback=null, bool addToHistory=true)
Chat functionality of the LLM. It calls the LLM completion based on the provided query including the ...
string grammarJSONString
the grammar to use
List< string > stop
stopwords to stop the LLM in addition to the default stopwords from the chat template.
float topP
Top-p sampling selects the next token from a subset of tokens that together have a cumulative probabi...
virtual async Task LoadTemplate()
Loads the chat template of the LLMCharacter.
string AIName
the name of the AI
int nProbs
if greater than 0, the response also contains the probabilities of top N tokens for each generated to...
bool ignoreEos
ignore end of stream token and continue generating.
float mirostatTau
The Mirostat target entropy (tau) controls the balance between coherence and diversity in the generat...
int numPredict
maximum number of tokens that the LLM will predict (-1 = infinity).
string prompt
a description of the AI role (system prompt)
float temperature
LLM temperature, lower values give more deterministic answers.
override void Awake()
The Unity Awake function that initializes the state before the application starts....
float presencePenalty
Penalty based on token presence in previous responses to control the repetition of token sequences in...
virtual async Task< string > AskTemplate()
Asks the LLM for the chat template to use.
string playerName
the name of the player
virtual string GetCacheSavePath(string filename)
Allows to get the save path of the LLM cache based on the provided filename or relative path.
float mirostatEta
The Mirostat learning rate (eta) controls how quickly the algorithm responds to feedback from the gen...
float minP
minimum probability for a token to be used.
float typicalP
enable locally typical sampling (1.0 = disabled). Higher values will promote more contextually cohere...
virtual async Task< string > Load(string filename)
Load the chat history and cache from the provided filename / relative path.
int nKeep
number of tokens to retain from the prompt when the model runs out of context (-1 = LLMCharacter prom...
bool penalizeNl
penalize newline tokens when applying the repeat penalty.
bool debugPrompt
log the constructed prompt the Unity Editor.
string grammarJSON
grammar file used for the LLMCharacter (.json format)
bool saveCache
save the LLM cache. Speeds up the prompt calculation when reloading from history but also requires ~1...
int topK
Top-k sampling selects the next token only from the top k most likely predicted tokens (0 = disabled)...
virtual async Task SetGrammar(string path)
Sets the grammar file of the LLMCharacter (GBNF)
string grammar
grammar file used for the LLMCharacter (.gbnf format)
virtual async Task< string > Save(string filename)
Saves the chat history and cache to the provided filename / relative path.
virtual void AddAIMessage(string content)
Allows to add a AI message in the chat history.
string penaltyPrompt
prompt for the purpose of the penalty evaluation. Can be either null, a string or an array of numbers...
override bool IsValidLLM(LLM llmSet)
Checks if a LLM is valid for the LLMCaller.
virtual void AddPlayerMessage(string content)
Allows to add a player message in the chat history.
float repeatPenalty
Penalty based on repeated tokens to control the repetition of token sequences in the generated text.
virtual string GetJsonSavePath(string filename)
Allows to get the save path of the chat history based on the provided filename or relative path.
bool stream
Receive the reply from the model as it is produced (recommended!). If not selected,...
virtual async Task SetGrammarFile(string path, bool gnbf)
Sets the grammar file of the LLMCharacter.
float frequencyPenalty
Penalty based on token frequency in previous responses to control the repetition of token sequences i...
virtual async Task Warmup(EmptyCallback completionCallback=null)
Allow to warm-up a model by processing the system prompt. The prompt processing will be cached (if ca...
int mirostat
enable Mirostat sampling, controlling perplexity during text generation (0 = disabled,...
bool setNKeepToPrompt
set the number of tokens to always retain from the prompt (nKeep) based on the LLMCharacter system pr...
int repeatLastN
last n tokens to consider for penalizing repetition (0 = disabled, -1 = ctx-size).
int seed
seed for reproducibility (-1 = no reproducibility).
string grammarString
the grammar to use
Dictionary< int, string > logitBias
the logit bias option allows to manually adjust the likelihood of specific tokens appearing in the ge...
virtual async Task< string > Complete(string prompt, Callback< string > callback=null, EmptyCallback completionCallback=null)
Pure completion functionality of the LLM. It calls the LLM completion based solely on the provided pr...
virtual void AddMessage(string role, string content)
Allows to add a message in the chat history.
virtual async Task Warmup(string query, EmptyCallback completionCallback=null)
Allow to warm-up a model by processing the provided prompt without adding it to history....
virtual void ClearChat()
Clear the chat of the LLMCharacter.
List< ChatMessage > chat
the chat history as list of chat messages
string save
file to save the chat history. The file will be saved within the persistentDataPath directory.
virtual void SetPrompt(string newPrompt, bool clearChat=true)
Set the system prompt for the LLMCharacter.
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
Definition LLM.cs:19