UnityPaul commited on
Commit
e05e85c
1 Parent(s): f60dfaf

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. RunTinyStories.cs +31 -42
  3. info.json +1 -1
  4. tinystories.sentis +2 -2
README.md CHANGED
@@ -4,8 +4,8 @@ library_name: unity-sentis
4
  pipeline_tag: text-generation
5
  ---
6
 
7
- # Tiny Stories Model in Unity Sentis (Version 1.3.0-pre.3*)
8
- *Version 1.3.0 Sentis files are not compatible with Version 1.4.0 and above and would need to be recreated
9
 
10
  This is the [Tiny Stories model](https://huggingface.co/roneneldan/TinyStories-33M) checked to run on Unity 2023. Tiny Stories is a Large Language Model that was trained on children's stories and can create stories based on the first couple of sentences.
11
 
 
4
  pipeline_tag: text-generation
5
  ---
6
 
7
+ # Tiny Stories Model in Unity Sentis Format (Sentis 1.4.0-pre.2*)
8
+ *Version 1.3.0 Sentis files are not compatible with Sentis 1.4.0 and would need to be recreated/downloaded
9
 
10
  This is the [Tiny Stories model](https://huggingface.co/roneneldan/TinyStories-33M) checked to run on Unity 2023. Tiny Stories is a Large Language Model that was trained on children's stories and can create stories based on the first couple of sentences.
11
 
RunTinyStories.cs CHANGED
@@ -3,8 +3,8 @@ using System.Collections.Generic;
3
  using UnityEngine;
4
  using Unity.Sentis;
5
  using System.IO;
6
- using Newtonsoft.Json;
7
  using System.Text;
 
8
 
9
  /*
10
  * Tiny Stories Inference Code
@@ -14,7 +14,7 @@ using System.Text;
14
  *
15
  * In Assets/StreamingAssets put:
16
  *
17
- * tinystories.sentis
18
  * vocab.json
19
  * merges.txt
20
  *
@@ -26,6 +26,8 @@ using System.Text;
26
 
27
  public class RunTinyStories : MonoBehaviour
28
  {
 
 
29
  const BackendType backend = BackendType.GPUCompute;
30
 
31
  //string outputString = "Once upon a time, there were three bears";
@@ -40,9 +42,6 @@ public class RunTinyStories : MonoBehaviour
40
  //Special tokens
41
  const int END_OF_TEXT = 50256;
42
 
43
- Ops ops;
44
- ITensorAllocator allocator;
45
-
46
  //Store the vocabulary
47
  string[] tokens;
48
 
@@ -68,16 +67,23 @@ public class RunTinyStories : MonoBehaviour
68
 
69
  void Start()
70
  {
71
- allocator = new TensorCachingAllocator();
72
- ops = WorkerFactory.CreateOps(backend, allocator);
73
-
74
  SetupWhiteSpaceShifts();
75
 
76
  LoadVocabulary();
77
 
78
- Model model = ModelLoader.Load(Application.streamingAssetsPath + "/tinystories.sentis");
 
 
 
 
 
 
 
 
 
 
79
 
80
- engine = WorkerFactory.CreateWorker(backend, model);
81
 
82
  DecodePrompt(outputString);
83
 
@@ -96,17 +102,18 @@ public class RunTinyStories : MonoBehaviour
96
  void RunInference()
97
  {
98
  using var tokensSoFar = new TensorInt(new TensorShape(1, maxTokens), outputTokens);
99
- engine.Execute(tokensSoFar);
 
 
100
 
101
- var tokensOut = engine.PeekOutput() as TensorFloat;
 
102
 
103
- using var row = ops.Slice(tokensOut, new[] { currentToken }, new[] { currentToken + 1 }, new[] { 1 }, new[] { 1 });
104
- using var rowB = ops.Mul(predictability, row);
105
- using var probs = ops.Softmax(rowB, 2);
106
- probs.MakeReadable();
107
 
108
- int ID = SelectRandomToken(probs.ToReadOnlyArray());
109
 
 
110
  if (currentToken >= maxTokens - 1)
111
  {
112
  for (int i = 0; i < maxTokens - 1; i++) outputTokens[i] = outputTokens[i + 1];
@@ -126,7 +133,6 @@ public class RunTinyStories : MonoBehaviour
126
 
127
  }
128
 
129
-
130
  void DecodePrompt(string text)
131
  {
132
  var inputTokens = GetTokens(text);
@@ -138,10 +144,9 @@ public class RunTinyStories : MonoBehaviour
138
  currentToken = inputTokens.Count - 1;
139
  }
140
 
141
-
142
  void LoadVocabulary()
143
  {
144
- var jsonText = File.ReadAllText(Application.streamingAssetsPath + "/vocab.json");
145
  vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonText);
146
  tokens = new string[vocab.Count];
147
  foreach (var item in vocab)
@@ -149,23 +154,7 @@ public class RunTinyStories : MonoBehaviour
149
  tokens[item.Value] = item.Key;
150
  }
151
 
152
- merges = File.ReadAllLines(Application.streamingAssetsPath + "/merges.txt");
153
- }
154
-
155
-
156
- int SelectRandomToken(float[] probs)
157
- {
158
- float p = UnityEngine.Random.Range(0, 1f);
159
- float t = 0;
160
- for (int i = 0; i < probs.Length; i++)
161
- {
162
- t += probs[i];
163
- if (p < t)
164
- {
165
- return i;
166
- }
167
- }
168
- return probs.Length - 1;
169
  }
170
 
171
  // Translates encoded special characters to Unicode
@@ -206,7 +195,7 @@ public class RunTinyStories : MonoBehaviour
206
  for (int i = 0, n = 0; i < 256; i++)
207
  {
208
  encodedCharacters[i] = i;
209
- if (IsWhiteSpace((char)i))
210
  {
211
  encodedCharacters[i] = n + 256;
212
  whiteSpaceCharacters[n++] = i;
@@ -214,9 +203,10 @@ public class RunTinyStories : MonoBehaviour
214
  }
215
  }
216
 
217
- bool IsWhiteSpace(char c)
218
  {
219
- return !(('!' <= c && c <= '~') || ('�' <= c && c <= '�') || ('�' <= c && c <= '�'));
 
220
  }
221
 
222
  List<int> GetTokens(string text)
@@ -267,7 +257,6 @@ public class RunTinyStories : MonoBehaviour
267
  private void OnDestroy()
268
  {
269
  engine?.Dispose();
270
- ops?.Dispose();
271
- allocator?.Dispose();
272
  }
 
273
  }
 
3
  using UnityEngine;
4
  using Unity.Sentis;
5
  using System.IO;
 
6
  using System.Text;
7
+ using FF = Unity.Sentis.Functional;
8
 
9
  /*
10
  * Tiny Stories Inference Code
 
14
  *
15
  * In Assets/StreamingAssets put:
16
  *
17
+ * tinystories.sentis (or put in asset folder and drag onto field)
18
  * vocab.json
19
  * merges.txt
20
  *
 
26
 
27
  public class RunTinyStories : MonoBehaviour
28
  {
29
+ //Drop the tinystories.sentis or onnx file on here if using an asset:
30
+ //public ModelAsset asset;
31
  const BackendType backend = BackendType.GPUCompute;
32
 
33
  //string outputString = "Once upon a time, there were three bears";
 
42
  //Special tokens
43
  const int END_OF_TEXT = 50256;
44
 
 
 
 
45
  //Store the vocabulary
46
  string[] tokens;
47
 
 
67
 
68
  void Start()
69
  {
 
 
 
70
  SetupWhiteSpaceShifts();
71
 
72
  LoadVocabulary();
73
 
74
+ var model1 = ModelLoader.Load(Path.Join(Application.streamingAssetsPath , "tinystories.sentis"));
75
+ //var model1 = ModelLoader.Load(asset);
76
+ //Create a new model to select the random token:
77
+ var model2 = FF.Compile(
78
+ (input, currentToken) =>
79
+ {
80
+ var row = FF.Select(model1.Forward(input)[8], 1, currentToken);
81
+ return FF.Multinomial(predictability * row, 1);
82
+ },
83
+ (InputDef.FromModel(model1)[0], new InputDef(DataType.Int, new TensorShape()))
84
+ );
85
 
86
+ engine = WorkerFactory.CreateWorker(backend, model2);
87
 
88
  DecodePrompt(outputString);
89
 
 
102
  void RunInference()
103
  {
104
  using var tokensSoFar = new TensorInt(new TensorShape(1, maxTokens), outputTokens);
105
+ using var index = new TensorInt(currentToken);
106
+
107
+ engine.Execute(new Dictionary<string, Tensor> { {"input_0", tokensSoFar }, { "input_1", index }});
108
 
109
+ var probs = engine.PeekOutput() as TensorInt;
110
+ Debug.Log(probs.shape);
111
 
112
+ probs.CompleteOperationsAndDownload();
 
 
 
113
 
114
+ int ID = probs[0];
115
 
116
+ //shift window down if got to the end
117
  if (currentToken >= maxTokens - 1)
118
  {
119
  for (int i = 0; i < maxTokens - 1; i++) outputTokens[i] = outputTokens[i + 1];
 
133
 
134
  }
135
 
 
136
  void DecodePrompt(string text)
137
  {
138
  var inputTokens = GetTokens(text);
 
144
  currentToken = inputTokens.Count - 1;
145
  }
146
 
 
147
  void LoadVocabulary()
148
  {
149
+ var jsonText = File.ReadAllText(Path.Join(Application.streamingAssetsPath , "vocab.json"));
150
  vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonText);
151
  tokens = new string[vocab.Count];
152
  foreach (var item in vocab)
 
154
  tokens[item.Value] = item.Key;
155
  }
156
 
157
+ merges = File.ReadAllLines(Path.Join(Application.streamingAssetsPath , "merges.txt"));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  }
159
 
160
  // Translates encoded special characters to Unicode
 
195
  for (int i = 0, n = 0; i < 256; i++)
196
  {
197
  encodedCharacters[i] = i;
198
+ if (IsWhiteSpace(i))
199
  {
200
  encodedCharacters[i] = n + 256;
201
  whiteSpaceCharacters[n++] = i;
 
203
  }
204
  }
205
 
206
+ bool IsWhiteSpace(int i)
207
  {
208
+ //returns true if it is a whitespace character
209
+ return i <= 32 || (i >= 127 && i <= 160) || i == 173;
210
  }
211
 
212
  List<int> GetTokens(string text)
 
257
  private void OnDestroy()
258
  {
259
  engine?.Dispose();
 
 
260
  }
261
+
262
  }
info.json CHANGED
@@ -10,6 +10,6 @@
10
  "merges.txt"
11
  ],
12
  "version": [
13
- "1.3.0-pre.3"
14
  ]
15
  }
 
10
  "merges.txt"
11
  ],
12
  "version": [
13
+ "1.4.0"
14
  ]
15
  }
tinystories.sentis CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8ed28a03db24da6fa58cc2bde739ecfb83b731ca47c263d17cdfec22e4b1698
3
- size 478881707
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7962eb7db56b241cc19cd3f0cffcf5d76d3c35639917f07effa6b3c242c91e9
3
+ size 478818076