ggerganov commited on
Commit
f9c5a09
·
unverified ·
0 Parent(s):

Initial release

Browse files
Files changed (11) hide show
  1. .gitattributes +12 -0
  2. .gitignore +3 -0
  3. Makefile +109 -0
  4. convert-pt-to-ggml.py +328 -0
  5. dr_wav.h +0 -0
  6. ggml.c +0 -0
  7. ggml.h +527 -0
  8. main.cpp +2116 -0
  9. models/.gitignore +1 -0
  10. samples/.gitignore +1 -0
  11. samples/jfk.wav +3 -0
.gitattributes ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bindings/go/samples/jfk.wav filter=lfs diff=lfs merge=lfs -text
2
+ models/for-tests-ggml-base.bin filter=lfs diff=lfs merge=lfs -text
3
+ models/for-tests-ggml-base.en.bin filter=lfs diff=lfs merge=lfs -text
4
+ models/for-tests-ggml-large.bin filter=lfs diff=lfs merge=lfs -text
5
+ models/for-tests-ggml-medium.bin filter=lfs diff=lfs merge=lfs -text
6
+ models/for-tests-ggml-medium.en.bin filter=lfs diff=lfs merge=lfs -text
7
+ models/for-tests-ggml-small.bin filter=lfs diff=lfs merge=lfs -text
8
+ models/for-tests-ggml-small.en.bin filter=lfs diff=lfs merge=lfs -text
9
+ models/for-tests-ggml-tiny.bin filter=lfs diff=lfs merge=lfs -text
10
+ models/for-tests-ggml-tiny.en.bin filter=lfs diff=lfs merge=lfs -text
11
+ models/for-tests-silero-v5.1.2-ggml.bin filter=lfs diff=lfs merge=lfs -text
12
+ *.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ sync.sh
2
+ main
3
+ *.o
Makefile ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ main: ggml.o main.o
2
+ g++ -o main ggml.o main.o
3
+
4
+ ggml.o: ggml.c ggml.h
5
+ gcc -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c
6
+
7
+ main.o: main.cpp ggml.h
8
+ g++ -O3 -std=c++11 -c main.cpp
9
+
10
+ # clean up the directory
11
+ clean:
12
+ rm -f *.o main
13
+
14
+ # run the program
15
+ run: main
16
+ ./main
17
+
18
+ # download the following audio samples into folder "./samples":
19
+ .PHONY: samples
20
+ samples:
21
+ @echo "Downloading samples..."
22
+ mkdir -p samples
23
+ @wget --quiet --show-progress -O samples/gb0.ogg https://upload.wikimedia.org/wikipedia/commons/2/22/George_W._Bush%27s_weekly_radio_address_%28November_1%2C_2008%29.oga
24
+ @wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
25
+ @wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg
26
+ @echo "Converting to 16-bit WAV ..."
27
+ @ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav
28
+ @ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav
29
+ @ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav
30
+
31
+ .PHONY: tiny.en
32
+ tiny.en: main
33
+ @echo "Downloading tiny.en (75 MB just once)"
34
+ mkdir -p models
35
+ @if [ ! -f models/ggml-tiny.en.bin ]; then \
36
+ wget --quiet --show-progress -O models/ggml-tiny.en.bin https://ggml.ggerganov.com/ggml-model-whisper-tiny.en.bin ; \
37
+ fi
38
+ @echo "==============================================="
39
+ @echo "Running tiny.en on all samples in ./samples ..."
40
+ @echo "==============================================="
41
+ @echo ""
42
+ @for f in samples/*.wav; do \
43
+ echo "----------------------------------------------" ; \
44
+ echo "[+] Running base.en on $$f ... (run 'ffplay $$f' to listen)" ; \
45
+ echo "----------------------------------------------" ; \
46
+ echo "" ; \
47
+ ./main -m models/ggml-tiny.en.bin -f $$f ; \
48
+ echo "" ; \
49
+ done
50
+
51
+ .PHONY: base.en
52
+ base.en: main
53
+ @echo "Downloading base.en (142 MB just once)"
54
+ mkdir -p models
55
+ @if [ ! -f models/ggml-base.en.bin ]; then \
56
+ wget --quiet --show-progress -O models/ggml-base.en.bin https://ggml.ggerganov.com/ggml-model-whisper-base.en.bin ; \
57
+ fi
58
+ @echo "==============================================="
59
+ @echo "Running base.en on all samples in ./samples ..."
60
+ @echo "==============================================="
61
+ @echo ""
62
+ @for f in samples/*.wav; do \
63
+ echo "----------------------------------------------" ; \
64
+ echo "[+] Running base.en on $$f ... (run 'ffplay $$f' to listen)" ; \
65
+ echo "----------------------------------------------" ; \
66
+ echo "" ; \
67
+ ./main -m models/ggml-base.en.bin -f $$f ; \
68
+ echo "" ; \
69
+ done
70
+
71
+ .PHONY: small.en
72
+ small.en: main
73
+ @echo "Downloading small.en (466 MB just once)"
74
+ mkdir -p models
75
+ @if [ ! -f models/ggml-small.en.bin ]; then \
76
+ wget --quiet --show-progress -O models/ggml-small.en.bin https://ggml.ggerganov.com/ggml-model-whisper-small.en.bin ; \
77
+ fi
78
+ @echo "==============================================="
79
+ @echo "Running small.en on all samples in ./samples ..."
80
+ @echo "==============================================="
81
+ @echo ""
82
+ @for f in samples/*.wav; do \
83
+ echo "----------------------------------------------" ; \
84
+ echo "[+] Running base.en on $$f ... (run 'ffplay $$f' to listen)" ; \
85
+ echo "----------------------------------------------" ; \
86
+ echo "" ; \
87
+ ./main -m models/ggml-small.en.bin -f $$f ; \
88
+ echo "" ; \
89
+ done
90
+
91
+ .PHONY: medium.en
92
+ medium.en: main
93
+ @echo "Downloading medium.en (1.5 GB just once)"
94
+ mkdir -p models
95
+ @if [ ! -f models/ggml-medium.en.bin ]; then \
96
+ wget --quiet --show-progress -O models/ggml-medium.en.bin https://ggml.ggerganov.com/ggml-model-whisper-medium.en.bin ; \
97
+ fi
98
+ @echo "==============================================="
99
+ @echo "Running medium.en on all samples in ./samples ..."
100
+ @echo "==============================================="
101
+ @echo ""
102
+ @for f in samples/*.wav; do \
103
+ echo "----------------------------------------------" ; \
104
+ echo "[+] Running base.en on $$f ... (run 'ffplay $$f' to listen)" ; \
105
+ echo "----------------------------------------------" ; \
106
+ echo "" ; \
107
+ ./main -m models/ggml-medium.en.bin -f $$f ; \
108
+ echo "" ; \
109
+ done
convert-pt-to-ggml.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert Whisper transformer model from PyTorch to ggml format
2
+ #
3
+ # Usage: python convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium
4
+ #
5
+ # You need to clone the original repo in ~/path/to/repo/whisper/
6
+ #
7
+ # git clone https://github.com/openai/whisper ~/path/to/repo/whisper/
8
+ #
9
+ # It is used to various assets needed by the algorithm:
10
+ #
11
+ # - tokenizer
12
+ # - mel filters
13
+ #
14
+ # Also, you need to have the original models in ~/.cache/whisper/
15
+ # See the original repo for more details.
16
+ #
17
+ # This script loads the specified model and whisper assets and saves them in ggml format.
18
+ # The output is a single binary file containing the following information:
19
+ #
20
+ # - hparams
21
+ # - mel filters
22
+ # - tokenizer vocab
23
+ # - model variables
24
+ #
25
+ # For each variable, write the following:
26
+ #
27
+ # - Number of dimensions (int)
28
+ # - Name length (int)
29
+ # - Dimensions (int[n_dims])
30
+ # - Name (char[name_length])
31
+ # - Data (float[n_dims])
32
+ #
33
+
34
+ import io
35
+ import os
36
+ import sys
37
+ import struct
38
+ import json
39
+ import code
40
+ import torch
41
+ import numpy as np
42
+
43
+ from transformers import GPTJForCausalLM
44
+ from transformers import GPT2TokenizerFast
45
+
46
+ # ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110
47
+ LANGUAGES = {
48
+ "en": "english",
49
+ "zh": "chinese",
50
+ "de": "german",
51
+ "es": "spanish",
52
+ "ru": "russian",
53
+ "ko": "korean",
54
+ "fr": "french",
55
+ "ja": "japanese",
56
+ "pt": "portuguese",
57
+ "tr": "turkish",
58
+ "pl": "polish",
59
+ "ca": "catalan",
60
+ "nl": "dutch",
61
+ "ar": "arabic",
62
+ "sv": "swedish",
63
+ "it": "italian",
64
+ "id": "indonesian",
65
+ "hi": "hindi",
66
+ "fi": "finnish",
67
+ "vi": "vietnamese",
68
+ "iw": "hebrew",
69
+ "uk": "ukrainian",
70
+ "el": "greek",
71
+ "ms": "malay",
72
+ "cs": "czech",
73
+ "ro": "romanian",
74
+ "da": "danish",
75
+ "hu": "hungarian",
76
+ "ta": "tamil",
77
+ "no": "norwegian",
78
+ "th": "thai",
79
+ "ur": "urdu",
80
+ "hr": "croatian",
81
+ "bg": "bulgarian",
82
+ "lt": "lithuanian",
83
+ "la": "latin",
84
+ "mi": "maori",
85
+ "ml": "malayalam",
86
+ "cy": "welsh",
87
+ "sk": "slovak",
88
+ "te": "telugu",
89
+ "fa": "persian",
90
+ "lv": "latvian",
91
+ "bn": "bengali",
92
+ "sr": "serbian",
93
+ "az": "azerbaijani",
94
+ "sl": "slovenian",
95
+ "kn": "kannada",
96
+ "et": "estonian",
97
+ "mk": "macedonian",
98
+ "br": "breton",
99
+ "eu": "basque",
100
+ "is": "icelandic",
101
+ "hy": "armenian",
102
+ "ne": "nepali",
103
+ "mn": "mongolian",
104
+ "bs": "bosnian",
105
+ "kk": "kazakh",
106
+ "sq": "albanian",
107
+ "sw": "swahili",
108
+ "gl": "galician",
109
+ "mr": "marathi",
110
+ "pa": "punjabi",
111
+ "si": "sinhala",
112
+ "km": "khmer",
113
+ "sn": "shona",
114
+ "yo": "yoruba",
115
+ "so": "somali",
116
+ "af": "afrikaans",
117
+ "oc": "occitan",
118
+ "ka": "georgian",
119
+ "be": "belarusian",
120
+ "tg": "tajik",
121
+ "sd": "sindhi",
122
+ "gu": "gujarati",
123
+ "am": "amharic",
124
+ "yi": "yiddish",
125
+ "lo": "lao",
126
+ "uz": "uzbek",
127
+ "fo": "faroese",
128
+ "ht": "haitian creole",
129
+ "ps": "pashto",
130
+ "tk": "turkmen",
131
+ "nn": "nynorsk",
132
+ "mt": "maltese",
133
+ "sa": "sanskrit",
134
+ "lb": "luxembourgish",
135
+ "my": "myanmar",
136
+ "bo": "tibetan",
137
+ "tl": "tagalog",
138
+ "mg": "malagasy",
139
+ "as": "assamese",
140
+ "tt": "tatar",
141
+ "haw": "hawaiian",
142
+ "ln": "lingala",
143
+ "ha": "hausa",
144
+ "ba": "bashkir",
145
+ "jw": "javanese",
146
+ "su": "sundanese",
147
+ }
148
+
149
+ # ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292
150
+ def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"):
151
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
152
+ path = os.path.join(path_to_whisper_repo, "whisper/assets", name)
153
+ tokenizer = GPT2TokenizerFast.from_pretrained(path)
154
+
155
+ specials = [
156
+ "<|startoftranscript|>",
157
+ *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
158
+ "<|translate|>",
159
+ "<|transcribe|>",
160
+ "<|startoflm|>",
161
+ "<|startofprev|>",
162
+ "<|nocaptions|>",
163
+ "<|notimestamps|>",
164
+ ]
165
+
166
+ tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
167
+ return tokenizer
168
+
169
+ # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
170
+ def bytes_to_unicode():
171
+ """
172
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
173
+ The reversible bpe codes work on unicode strings.
174
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
175
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
176
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
177
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
178
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
179
+ """
180
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
181
+ cs = bs[:]
182
+ n = 0
183
+ for b in range(2**8):
184
+ if b not in bs:
185
+ bs.append(b)
186
+ cs.append(2**8+n)
187
+ n += 1
188
+ cs = [chr(n) for n in cs]
189
+ return dict(zip(bs, cs))
190
+
191
+
192
+ if len(sys.argv) < 4:
193
+ print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n")
194
+ sys.exit(1)
195
+
196
+ fname_inp = sys.argv[1]
197
+ dir_whisper = sys.argv[2]
198
+ dir_out = sys.argv[3]
199
+
200
+ # try to load PyTorch binary data
201
+ try:
202
+ model_bytes = open(fname_inp, "rb").read()
203
+ with io.BytesIO(model_bytes) as fp:
204
+ checkpoint = torch.load(fp, map_location="cpu")
205
+ except:
206
+ print("Error: failed to load PyTorch model file: %s" % fname_inp)
207
+ sys.exit(1)
208
+
209
+ hparams = checkpoint["dims"]
210
+ print("hparams:", hparams)
211
+
212
+ list_vars = checkpoint["model_state_dict"]
213
+
214
+ #print(list_vars['encoder.positional_embedding'])
215
+ #print(list_vars['encoder.conv1.weight'])
216
+ #print(list_vars['encoder.conv1.weight'].shape)
217
+
218
+ # load mel filters
219
+ n_mels = hparams["n_mels"]
220
+ with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f:
221
+ filters = torch.from_numpy(f[f"mel_{n_mels}"])
222
+ #print (filters)
223
+
224
+ #code.interact(local=locals())
225
+
226
+ multilingual = hparams["n_vocab"] == 51865
227
+ tokenizer = build_tokenizer(dir_whisper, multilingual and "multilingual" or "gpt2")
228
+
229
+ #print(tokenizer)
230
+ #print(tokenizer.name_or_path)
231
+ #print(len(tokenizer.additional_special_tokens))
232
+ dir_tokenizer = tokenizer.name_or_path
233
+
234
+ # output in the same directory as the model
235
+ fname_out = dir_out + "/ggml-model.bin"
236
+
237
+ with open(dir_tokenizer + "/vocab.json", "r") as f:
238
+ tokens = json.load(f)
239
+
240
+ # use 16-bit or 32-bit floats
241
+ use_f16 = True
242
+ if len(sys.argv) > 4:
243
+ use_f16 = False
244
+ fname_out = dir_out + "/ggml-model-f32.bin"
245
+
246
+ fout = open(fname_out, "wb")
247
+
248
+ fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
249
+ fout.write(struct.pack("i", hparams["n_vocab"]))
250
+ fout.write(struct.pack("i", hparams["n_audio_ctx"]))
251
+ fout.write(struct.pack("i", hparams["n_audio_state"]))
252
+ fout.write(struct.pack("i", hparams["n_audio_head"]))
253
+ fout.write(struct.pack("i", hparams["n_audio_layer"]))
254
+ fout.write(struct.pack("i", hparams["n_text_ctx"]))
255
+ fout.write(struct.pack("i", hparams["n_text_state"]))
256
+ fout.write(struct.pack("i", hparams["n_text_head"]))
257
+ fout.write(struct.pack("i", hparams["n_text_layer"]))
258
+ fout.write(struct.pack("i", hparams["n_mels"]))
259
+ fout.write(struct.pack("i", use_f16))
260
+
261
+ # write mel filters
262
+ fout.write(struct.pack("i", filters.shape[0]))
263
+ fout.write(struct.pack("i", filters.shape[1]))
264
+ for i in range(filters.shape[0]):
265
+ for j in range(filters.shape[1]):
266
+ fout.write(struct.pack("f", filters[i][j]))
267
+
268
+ byte_encoder = bytes_to_unicode()
269
+ byte_decoder = {v:k for k, v in byte_encoder.items()}
270
+
271
+ fout.write(struct.pack("i", len(tokens)))
272
+
273
+ for key in tokens:
274
+ text = bytearray([byte_decoder[c] for c in key]).decode('utf-8', errors='replace').encode('utf-8')
275
+ fout.write(struct.pack("i", len(text)))
276
+ fout.write(text)
277
+
278
+ for name in list_vars.keys():
279
+ data = list_vars[name].squeeze().numpy()
280
+ print("Processing variable: " + name + " with shape: ", data.shape)
281
+
282
+ # reshape conv bias from [n] to [n, 1]
283
+ if name == "encoder.conv1.bias" or \
284
+ name == "encoder.conv2.bias":
285
+ data = data.reshape(data.shape[0], 1)
286
+ print(" Reshaped variable: " + name + " to shape: ", data.shape)
287
+
288
+ n_dims = len(data.shape);
289
+
290
+ # looks like the whisper models are in f16 by default
291
+ # so we need to convert the small tensors to f32 until we fully support f16 in ggml
292
+ # ftype == 0 -> float32, ftype == 1 -> float16
293
+ ftype = 1;
294
+ if use_f16:
295
+ if n_dims < 2 or \
296
+ name == "encoder.conv1.bias" or \
297
+ name == "encoder.conv2.bias" or \
298
+ name == "encoder.positional_embedding" or \
299
+ name == "decoder.positional_embedding":
300
+ ftype = 0
301
+ data = data.astype(np.float32)
302
+ print(" Converting to float32")
303
+ data = data.astype(np.float32)
304
+ ftype = 0
305
+ else:
306
+ data = data.astype(np.float32)
307
+ ftype = 0
308
+
309
+ #if name.startswith("encoder"):
310
+ # if name.endswith("mlp.0.weight") or \
311
+ # name.endswith("mlp.2.weight"):
312
+ # print(" Transposing")
313
+ # data = data.transpose()
314
+
315
+ # header
316
+ str = name.encode('utf-8')
317
+ fout.write(struct.pack("iii", n_dims, len(str), ftype))
318
+ for i in range(n_dims):
319
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
320
+ fout.write(str);
321
+
322
+ # data
323
+ data.tofile(fout)
324
+
325
+ fout.close()
326
+
327
+ print("Done. Output file: " + fname_out)
328
+ print("")
dr_wav.h ADDED
The diff for this file is too large to render. See raw diff
 
ggml.c ADDED
The diff for this file is too large to render. See raw diff
 
ggml.h ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef __cplusplus
4
+ extern "C" {
5
+ #endif
6
+
7
+ #include <stdint.h>
8
+ #include <stddef.h>
9
+ #include <stdbool.h>
10
+
11
+ #define GGML_MAX_DIMS 4
12
+ #define GGML_MAX_NODES 4096
13
+ #define GGML_MAX_PARAMS 16
14
+ #define GGML_MAX_CONTEXTS 16
15
+
16
+ #ifdef __ARM_NEON
17
+ // we use the built-in 16-bit float type
18
+ typedef __fp16 ggml_fp16_t;
19
+ #else
20
+ typedef uint16_t ggml_fp16_t;
21
+ #endif
22
+
23
+ float ggml_fp16_to_fp32(ggml_fp16_t x);
24
+ ggml_fp16_t ggml_fp32_to_fp16(float x);
25
+
26
+ struct ggml_object;
27
+ struct ggml_context;
28
+
29
+ enum ggml_type {
30
+ GGML_TYPE_I8,
31
+ GGML_TYPE_I16,
32
+ GGML_TYPE_I32,
33
+ GGML_TYPE_F16,
34
+ GGML_TYPE_F32,
35
+ GGML_TYPE_COUNT,
36
+ };
37
+
38
+ enum ggml_op {
39
+ GGML_OP_NONE = 0,
40
+
41
+ GGML_OP_DUP,
42
+ GGML_OP_ADD,
43
+ GGML_OP_SUB,
44
+ GGML_OP_MUL,
45
+ GGML_OP_DIV,
46
+ GGML_OP_SQR,
47
+ GGML_OP_SQRT,
48
+ GGML_OP_SUM,
49
+ GGML_OP_MEAN,
50
+ GGML_OP_REPEAT,
51
+ GGML_OP_ABS,
52
+ GGML_OP_SGN,
53
+ GGML_OP_NEG,
54
+ GGML_OP_STEP,
55
+ GGML_OP_RELU,
56
+ GGML_OP_GELU,
57
+ GGML_OP_NORM, // normalize
58
+
59
+ GGML_OP_MUL_MAT,
60
+
61
+ GGML_OP_SCALE,
62
+ GGML_OP_CPY,
63
+ GGML_OP_RESHAPE,
64
+ GGML_OP_VIEW,
65
+ GGML_OP_PERMUTE,
66
+ GGML_OP_TRANSPOSE,
67
+ GGML_OP_GET_ROWS,
68
+ GGML_OP_DIAG_MASK_INF,
69
+ GGML_OP_SOFT_MAX,
70
+ GGML_OP_ROPE,
71
+ GGML_OP_CONV_1D_1S,
72
+ GGML_OP_CONV_1D_2S,
73
+
74
+ GGML_OP_COUNT,
75
+ };
76
+
77
+ // n-dimensional tensor
78
+ struct ggml_tensor {
79
+ enum ggml_type type;
80
+
81
+ int n_dims;
82
+ int ne[GGML_MAX_DIMS]; // number of elements
83
+ size_t nb[GGML_MAX_DIMS]; // stride in bytes:
84
+ // nb[0] = sizeof(type)
85
+ // nb[1] = nb[0] * ne[0] + padding
86
+ // nb[i] = nb[i-1] * ne[i-1]
87
+
88
+ // compute data
89
+ enum ggml_op op;
90
+
91
+ bool is_param;
92
+
93
+ struct ggml_tensor * grad;
94
+ struct ggml_tensor * src0;
95
+ struct ggml_tensor * src1;
96
+
97
+ // thread scheduling
98
+ int n_tasks;
99
+
100
+ // performance
101
+ int perf_runs;
102
+ int64_t perf_cycles;
103
+ int64_t perf_time_us;
104
+
105
+ void * data;
106
+ char pad[8];
107
+ };
108
+
109
+ // computation graph
110
+ struct ggml_cgraph {
111
+ int n_nodes;
112
+ int n_leafs;
113
+ int n_threads;
114
+
115
+ size_t work_size;
116
+ struct ggml_tensor * work;
117
+
118
+ struct ggml_tensor * nodes[GGML_MAX_NODES];
119
+ struct ggml_tensor * grads[GGML_MAX_NODES];
120
+ struct ggml_tensor * leafs[GGML_MAX_NODES];
121
+
122
+ // performance
123
+ int perf_runs;
124
+ int64_t perf_cycles;
125
+ int64_t perf_time_us;
126
+ };
127
+
128
+ struct ggml_init_params {
129
+ // memory pool
130
+ size_t mem_size; // bytes
131
+ void * mem_buffer; // if NULL, memory will be allocated internally
132
+ };
133
+
134
+ int64_t ggml_time_ms(void);
135
+ int64_t ggml_time_us(void);
136
+ int64_t ggml_cycles(void);
137
+ int64_t ggml_cycles_per_ms(void);
138
+
139
+ void ggml_print_object (const struct ggml_object * obj);
140
+ void ggml_print_objects(const struct ggml_context * ctx);
141
+
142
+ int ggml_nelements(const struct ggml_tensor * tensor);
143
+ size_t ggml_nbytes (const struct ggml_tensor * tensor);
144
+
145
+ size_t ggml_type_size (enum ggml_type type);
146
+ size_t ggml_element_size(const struct ggml_tensor * tensor);
147
+
148
+ struct ggml_context * ggml_init(struct ggml_init_params params);
149
+ void ggml_free(struct ggml_context * ctx);
150
+
151
+ size_t ggml_used_mem(const struct ggml_context * ctx);
152
+
153
+ struct ggml_tensor * ggml_new_tensor(
154
+ struct ggml_context * ctx,
155
+ enum ggml_type type,
156
+ int n_dims,
157
+ const int *ne);
158
+
159
+ struct ggml_tensor * ggml_new_tensor_1d(
160
+ struct ggml_context * ctx,
161
+ enum ggml_type type,
162
+ int ne0);
163
+
164
+ struct ggml_tensor * ggml_new_tensor_2d(
165
+ struct ggml_context * ctx,
166
+ enum ggml_type type,
167
+ int ne0,
168
+ int ne1);
169
+
170
+ struct ggml_tensor * ggml_new_tensor_3d(
171
+ struct ggml_context * ctx,
172
+ enum ggml_type type,
173
+ int ne0,
174
+ int ne1,
175
+ int ne2);
176
+
177
+ struct ggml_tensor * ggml_new_tensor_4d(
178
+ struct ggml_context * ctx,
179
+ enum ggml_type type,
180
+ int ne0,
181
+ int ne1,
182
+ int ne2,
183
+ int ne3);
184
+
185
+ struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
186
+
187
+ struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
188
+ struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
189
+
190
+ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
191
+ struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
192
+
193
+ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
194
+ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
195
+
196
+ void * ggml_get_data (const struct ggml_tensor * tensor);
197
+ float * ggml_get_data_f32(const struct ggml_tensor * tensor);
198
+
199
+ //
200
+ // operations on tensors with backpropagation
201
+ //
202
+
203
+ struct ggml_tensor * ggml_dup(
204
+ struct ggml_context * ctx,
205
+ struct ggml_tensor * a);
206
+
207
+ struct ggml_tensor * ggml_add(
208
+ struct ggml_context * ctx,
209
+ struct ggml_tensor * a,
210
+ struct ggml_tensor * b);
211
+
212
+ struct ggml_tensor * ggml_sub(
213
+ struct ggml_context * ctx,
214
+ struct ggml_tensor * a,
215
+ struct ggml_tensor * b);
216
+
217
+ struct ggml_tensor * ggml_mul(
218
+ struct ggml_context * ctx,
219
+ struct ggml_tensor * a,
220
+ struct ggml_tensor * b);
221
+
222
+ struct ggml_tensor * ggml_div(
223
+ struct ggml_context * ctx,
224
+ struct ggml_tensor * a,
225
+ struct ggml_tensor * b);
226
+
227
+ struct ggml_tensor * ggml_sqr(
228
+ struct ggml_context * ctx,
229
+ struct ggml_tensor * a);
230
+
231
+ struct ggml_tensor * ggml_sqrt(
232
+ struct ggml_context * ctx,
233
+ struct ggml_tensor * a);
234
+
235
+ // return scalar
236
+ // TODO: compute sum along rows
237
+ struct ggml_tensor * ggml_sum(
238
+ struct ggml_context * ctx,
239
+ struct ggml_tensor * a);
240
+
241
+ // mean along rows
242
+ struct ggml_tensor * ggml_mean(
243
+ struct ggml_context * ctx,
244
+ struct ggml_tensor * a);
245
+
246
+ // if a is the same shape as b, and a is not parameter, return a
247
+ // otherwise, return a new tensor: repeat(a) to fit in b
248
+ struct ggml_tensor * ggml_repeat(
249
+ struct ggml_context * ctx,
250
+ struct ggml_tensor * a,
251
+ struct ggml_tensor * b);
252
+
253
+ struct ggml_tensor * ggml_abs(
254
+ struct ggml_context * ctx,
255
+ struct ggml_tensor * a);
256
+
257
+ struct ggml_tensor * ggml_sgn(
258
+ struct ggml_context * ctx,
259
+ struct ggml_tensor * a);
260
+
261
+ struct ggml_tensor * ggml_neg(
262
+ struct ggml_context * ctx,
263
+ struct ggml_tensor * a);
264
+
265
+ struct ggml_tensor * ggml_step(
266
+ struct ggml_context * ctx,
267
+ struct ggml_tensor * a);
268
+
269
+ struct ggml_tensor * ggml_relu(
270
+ struct ggml_context * ctx,
271
+ struct ggml_tensor * a);
272
+
273
+ // TODO: double-check this computation is correct
274
+ struct ggml_tensor * ggml_gelu(
275
+ struct ggml_context * ctx,
276
+ struct ggml_tensor * a);
277
+
278
+ // normalize along rows
279
+ // TODO: eps is hardcoded to 1e-5 for now
280
+ struct ggml_tensor * ggml_norm(
281
+ struct ggml_context * ctx,
282
+ struct ggml_tensor * a);
283
+
284
+ // A: m rows, n columns
285
+ // B: p rows, n columns (i.e. we transpose it internally)
286
+ // result is m columns, p rows
287
+ struct ggml_tensor * ggml_mul_mat(
288
+ struct ggml_context * ctx,
289
+ struct ggml_tensor * a,
290
+ struct ggml_tensor * b);
291
+
292
+ //
293
+ // operations on tensors without backpropagation
294
+ //
295
+
296
+ // in-place, returns view(a)
297
+ struct ggml_tensor * ggml_scale(
298
+ struct ggml_context * ctx,
299
+ struct ggml_tensor * a,
300
+ struct ggml_tensor * b);
301
+
302
+ // a -> b, return view(b)
303
+ struct ggml_tensor * ggml_cpy(
304
+ struct ggml_context * ctx,
305
+ struct ggml_tensor * a,
306
+ struct ggml_tensor * b);
307
+
308
+ // return view(a), b specifies the new shape
309
+ // TODO: when we start computing gradient, make a copy instead of view
310
+ struct ggml_tensor * ggml_reshape(
311
+ struct ggml_context * ctx,
312
+ struct ggml_tensor * a,
313
+ struct ggml_tensor * b);
314
+
315
+ // return view(a)
316
+ // TODO: when we start computing gradient, make a copy instead of view
317
+ struct ggml_tensor * ggml_reshape_2d(
318
+ struct ggml_context * ctx,
319
+ struct ggml_tensor * a,
320
+ int ne0,
321
+ int ne1);
322
+
323
+ // return view(a)
324
+ // TODO: when we start computing gradient, make a copy instead of view
325
+ struct ggml_tensor * ggml_reshape_3d(
326
+ struct ggml_context * ctx,
327
+ struct ggml_tensor * a,
328
+ int ne0,
329
+ int ne1,
330
+ int ne2);
331
+
332
+ // offset in bytes
333
+ struct ggml_tensor * ggml_view_1d(
334
+ struct ggml_context * ctx,
335
+ struct ggml_tensor * a,
336
+ int ne0,
337
+ size_t offset);
338
+
339
+ struct ggml_tensor * ggml_view_2d(
340
+ struct ggml_context * ctx,
341
+ struct ggml_tensor * a,
342
+ int ne0,
343
+ int ne1,
344
+ size_t nb1, // row stride in bytes
345
+ size_t offset);
346
+
347
+ struct ggml_tensor * ggml_permute(
348
+ struct ggml_context * ctx,
349
+ struct ggml_tensor * a,
350
+ int axis0,
351
+ int axis1,
352
+ int axis2,
353
+ int axis3);
354
+
355
+ // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
356
+ struct ggml_tensor * ggml_transpose(
357
+ struct ggml_context * ctx,
358
+ struct ggml_tensor * a);
359
+
360
+ struct ggml_tensor * ggml_get_rows(
361
+ struct ggml_context * ctx,
362
+ struct ggml_tensor * a,
363
+ struct ggml_tensor * b);
364
+
365
+ // set elements above the diagonal to -INF
366
+ // in-place, returns view(a)
367
+ struct ggml_tensor * ggml_diag_mask_inf(
368
+ struct ggml_context * ctx,
369
+ struct ggml_tensor * a,
370
+ int n_past);
371
+
372
+ // in-place, returns view(a)
373
+ struct ggml_tensor * ggml_soft_max(
374
+ struct ggml_context * ctx,
375
+ struct ggml_tensor * a);
376
+
377
+ // rotary position embedding
378
+ // in-place, returns view(a)
379
+ // if mode == 1, skip n_past elements
380
+ // TODO: avoid creating a new tensor every time
381
+ struct ggml_tensor * ggml_rope(
382
+ struct ggml_context * ctx,
383
+ struct ggml_tensor * a,
384
+ int n_past,
385
+ int n_dims,
386
+ int mode);
387
+
388
+ // padding = 1
389
+ // TODO: we don't support extra parameters for now
390
+ // that's why we are hard-coding the stride, padding, and dilation
391
+ // not great ..
392
+ struct ggml_tensor * ggml_conv_1d_1s(
393
+ struct ggml_context * ctx,
394
+ struct ggml_tensor * a,
395
+ struct ggml_tensor * b);
396
+
397
+ struct ggml_tensor * ggml_conv_1d_2s(
398
+ struct ggml_context * ctx,
399
+ struct ggml_tensor * a,
400
+ struct ggml_tensor * b);
401
+
402
+ //
403
+ // automatic differentiation
404
+ //
405
+
406
+ void ggml_set_param(
407
+ struct ggml_context * ctx,
408
+ struct ggml_tensor * tensor);
409
+
410
+ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
411
+
412
+ struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
413
+ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
414
+
415
+ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
416
+ void ggml_graph_reset (struct ggml_cgraph * cgraph);
417
+
418
+ // print info and performance information for the graph
419
+ void ggml_graph_print(const struct ggml_cgraph * cgraph);
420
+
421
+ // dump the graph into a file using the dot format
422
+ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
423
+
424
+ //
425
+ // optimization
426
+ //
427
+
428
+ // optimization methods
429
+ enum ggml_opt_type {
430
+ GGML_OPT_ADAM,
431
+ GGML_OPT_LBFGS,
432
+ };
433
+
434
+ // linesearch methods
435
+ enum ggml_linesearch {
436
+ GGML_LINESEARCH_DEFAULT = 1,
437
+
438
+ GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
439
+ GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
440
+ GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
441
+ };
442
+
443
+ // optimization return values
444
+ enum ggml_opt_result {
445
+ GGML_OPT_OK = 0,
446
+ GGML_OPT_DID_NOT_CONVERGE,
447
+ GGML_OPT_NO_CONTEXT,
448
+ GGML_OPT_INVALID_WOLFE,
449
+ GGML_OPT_FAIL,
450
+
451
+ GGML_LINESEARCH_FAIL = -128,
452
+ GGML_LINESEARCH_MINIMUM_STEP,
453
+ GGML_LINESEARCH_MAXIMUM_STEP,
454
+ GGML_LINESEARCH_MAXIMUM_ITERATIONS,
455
+ GGML_LINESEARCH_INVALID_PARAMETERS,
456
+ };
457
+
458
+ // optimization parameters
459
+ //
460
+ // see ggml.c (ggml_opt_default_params) for default values
461
+ //
462
+ struct ggml_opt_params {
463
+ enum ggml_opt_type type;
464
+
465
+ int n_threads;
466
+
467
+ // delta-based convergence test
468
+ //
469
+ // if past == 0 - disabled
470
+ // if past > 0:
471
+ // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
472
+ //
473
+ int past;
474
+ float delta;
475
+
476
+ // maximum number of iterations without improvement
477
+ //
478
+ // if 0 - disabled
479
+ // if > 0:
480
+ // assume convergence if no cost improvement in this number of iterations
481
+ //
482
+ int max_no_improvement;
483
+
484
+ bool print_forward_graph;
485
+ bool print_backward_graph;
486
+
487
+ union {
488
+ // ADAM parameters
489
+ struct {
490
+ int n_iter;
491
+
492
+ float alpha; // learning rate
493
+ float beta1;
494
+ float beta2;
495
+ float eps; // epsilon for numerical stability
496
+ float eps_f; // epsilon for convergence test
497
+ float eps_g; // epsilon for convergence test
498
+ } adam;
499
+
500
+ // LBFGS parameters
501
+ struct {
502
+ int m; // number of corrections to approximate the inv. Hessian
503
+ int n_iter;
504
+ int max_linesearch;
505
+
506
+ float eps; // convergence tolerance
507
+ float ftol; // line search tolerance
508
+ float wolfe;
509
+ float min_step;
510
+ float max_step;
511
+
512
+ enum ggml_linesearch linesearch;
513
+ } lbfgs;
514
+ };
515
+ };
516
+
517
+ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
518
+
519
+ // optimize the function defined by the tensor f
520
+ enum ggml_opt_result ggml_opt(
521
+ struct ggml_context * ctx,
522
+ struct ggml_opt_params params,
523
+ struct ggml_tensor * f);
524
+
525
+ #ifdef __cplusplus
526
+ }
527
+ #endif
main.cpp ADDED
@@ -0,0 +1,2116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml.h"
2
+
3
+ // third-party utilities
4
+ // use your favorite implementations
5
+ #define DR_WAV_IMPLEMENTATION
6
+ #include "dr_wav.h"
7
+
8
+ #include <algorithm>
9
+ #include <cassert>
10
+ #include <cmath>
11
+ #include <cstdio>
12
+ #include <cstring>
13
+ #include <fstream>
14
+ #include <map>
15
+ #include <string>
16
+ #include <thread>
17
+ #include <vector>
18
+
19
+ enum e_model {
20
+ MODEL_UNKNOWN,
21
+ MODEL_TINY,
22
+ MODEL_BASE,
23
+ MODEL_SMALL,
24
+ MODEL_MEDIUM,
25
+ MODEL_LARGE,
26
+ };
27
+
28
+ const size_t MB = 1024*1024;
29
+
30
+ const std::map<e_model, size_t> MEM_REQ_MODEL = {
31
+ { MODEL_TINY, 100ull*MB },
32
+ { MODEL_BASE, 190ull*MB },
33
+ { MODEL_SMALL, 610ull*MB },
34
+ { MODEL_MEDIUM, 1900ull*MB },
35
+ { MODEL_LARGE, 3600ull*MB },
36
+ };
37
+
38
+ const std::map<e_model, size_t> MEM_REQ_ENCODE = {
39
+ { MODEL_TINY, 80ull*MB },
40
+ { MODEL_BASE, 128ull*MB },
41
+ { MODEL_SMALL, 300ull*MB },
42
+ { MODEL_MEDIUM, 680ull*MB },
43
+ { MODEL_LARGE, 1100ull*MB },
44
+ };
45
+
46
+ const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
47
+ { MODEL_TINY, 170ull*MB },
48
+ { MODEL_BASE, 230ull*MB },
49
+ { MODEL_SMALL, 350ull*MB },
50
+ { MODEL_MEDIUM, 450ull*MB },
51
+ { MODEL_LARGE, 570ull*MB },
52
+ };
53
+
54
+ const std::map<e_model, size_t> MEM_REQ_DECODE = {
55
+ { MODEL_TINY, 190ull*MB },
56
+ { MODEL_BASE, 190ull*MB },
57
+ { MODEL_SMALL, 190ull*MB },
58
+ { MODEL_MEDIUM, 200ull*MB },
59
+ { MODEL_LARGE, 200ull*MB },
60
+ };
61
+
62
+ const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
63
+ { MODEL_TINY, 32ull*MB },
64
+ { MODEL_BASE, 44ull*MB },
65
+ { MODEL_SMALL, 64ull*MB },
66
+ { MODEL_MEDIUM, 84ull*MB },
67
+ { MODEL_LARGE, 110ull*MB },
68
+ };
69
+
70
+ const int SAMPLE_RATE = 16000;
71
+ const int N_FFT = 400;
72
+ const int N_MEL = 80;
73
+ const int HOP_LENGTH = 160;
74
+ const int CHUNK_SIZE = 30; // seconds
75
+
76
+ struct whisper_mel {
77
+ int n_len;
78
+ int n_mel;
79
+
80
+ std::vector<float> data;
81
+ };
82
+
83
+ struct whisper_filters {
84
+ int32_t n_mel;
85
+ int32_t n_fft;
86
+
87
+ std::vector<float> data;
88
+ };
89
+
90
+ struct whisper_vocab {
91
+ using id = int32_t;
92
+ using token = std::string;
93
+
94
+ int n_vocab = 51864;
95
+
96
+ std::map<token, id> token_to_id;
97
+ std::map<id, token> id_to_token;
98
+
99
+ id token_eot = 50256;
100
+ id token_sot = 50257;
101
+ id token_prev = 50360;
102
+ id token_solm = 50361; // ??
103
+ id token_beg = 50363;
104
+
105
+ bool is_multilingual() const {
106
+ return n_vocab == 51865;
107
+ }
108
+ };
109
+
110
+ // command-line parameters
111
+ struct whisper_params {
112
+ int32_t seed = -1; // RNG seed
113
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
114
+
115
+ int32_t max_tokens_per_iter = 64;
116
+
117
+ bool verbose = false;
118
+ bool print_special_tokens = false;
119
+
120
+ std::string model = "models/whisper-tiny.en/ggml-model.bin"; // model path
121
+
122
+ std::string fname_inp = "default.wav";
123
+ };
124
+
125
+ void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
126
+
127
+ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
128
+ for (int i = 1; i < argc; i++) {
129
+ std::string arg = argv[i];
130
+
131
+ if (arg == "-s" || arg == "--seed") {
132
+ params.seed = std::stoi(argv[++i]);
133
+ } else if (arg == "-t" || arg == "--threads") {
134
+ params.n_threads = std::stoi(argv[++i]);
135
+ } else if (arg == "-T" || arg == "--tokens") {
136
+ params.max_tokens_per_iter = std::stoi(argv[++i]);
137
+ } else if (arg == "-v" || arg == "--verbose") {
138
+ params.verbose = true;
139
+ } else if (arg == "-ps" || arg == "--print_special") {
140
+ params.print_special_tokens = true;
141
+ } else if (arg == "-m" || arg == "--model") {
142
+ params.model = argv[++i];
143
+ } else if (arg == "-f" || arg == "--file") {
144
+ params.fname_inp = argv[++i];
145
+ } else if (arg == "-h" || arg == "--help") {
146
+ whisper_print_usage(argc, argv, params);
147
+ exit(0);
148
+ } else {
149
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
150
+ whisper_print_usage(argc, argv, params);
151
+ exit(0);
152
+ }
153
+ }
154
+
155
+ return true;
156
+ }
157
+
158
+ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
159
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
160
+ fprintf(stderr, "\n");
161
+ fprintf(stderr, "options:\n");
162
+ fprintf(stderr, " -h, --help show this help message and exit\n");
163
+ fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
164
+ fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
165
+ fprintf(stderr, " -T N, --tokens N maximum number of tokens to generate per iteration (default: %d)\n", params.max_tokens_per_iter);
166
+ fprintf(stderr, " -v, --verbose verbose output\n");
167
+ fprintf(stderr, " -ps, --print_special print special tokens\n");
168
+ fprintf(stderr, " -m FNAME, --model FNAME\n");
169
+ fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
170
+ fprintf(stderr, " -f FNAME, --file FNAME\n");
171
+ fprintf(stderr, " input WAV file path (default: %s)\n", params.fname_inp.c_str());
172
+ fprintf(stderr, "\n");
173
+ }
174
+
175
+
176
+ // medium
177
+ // hparams: {
178
+ // 'n_mels': 80,
179
+ // 'n_vocab': 51864,
180
+ // 'n_audio_ctx': 1500,
181
+ // 'n_audio_state': 1024,
182
+ // 'n_audio_head': 16,
183
+ // 'n_audio_layer': 24,
184
+ // 'n_text_ctx': 448,
185
+ // 'n_text_state': 1024,
186
+ // 'n_text_head': 16,
187
+ // 'n_text_layer': 24
188
+ // }
189
+ //
190
+ // default hparams (Whisper tiny)
191
+ struct whisper_hparams {
192
+ int32_t n_vocab = 51864;
193
+ int32_t n_audio_ctx = 1500;
194
+ int32_t n_audio_state = 384;
195
+ int32_t n_audio_head = 6;
196
+ int32_t n_audio_layer = 4;
197
+ int32_t n_text_ctx = 448;
198
+ int32_t n_text_state = 384;
199
+ int32_t n_text_head = 6;
200
+ int32_t n_text_layer = 4;
201
+ int32_t n_mels = 80;
202
+ int32_t f16 = 1;
203
+ };
204
+
205
+ // audio encoding layer
206
+ struct whisper_layer_encoder {
207
+ // encoder.blocks.*.attn_ln
208
+ struct ggml_tensor * attn_ln_0_w;
209
+ struct ggml_tensor * attn_ln_0_b;
210
+
211
+ // encoder.blocks.*.attn.out
212
+ struct ggml_tensor * attn_ln_1_w;
213
+ struct ggml_tensor * attn_ln_1_b;
214
+
215
+ // encoder.blocks.*.attn.query
216
+ struct ggml_tensor * attn_q_w;
217
+ struct ggml_tensor * attn_q_b;
218
+
219
+ // encoder.blocks.*.attn.key
220
+ struct ggml_tensor * attn_k_w;
221
+
222
+ // encoder.blocks.*.attn.value
223
+ struct ggml_tensor * attn_v_w;
224
+ struct ggml_tensor * attn_v_b;
225
+
226
+ // encoder.blocks.*.mlp_ln
227
+ struct ggml_tensor * mlp_ln_w;
228
+ struct ggml_tensor * mlp_ln_b;
229
+
230
+ // encoder.blocks.*.mlp.0
231
+ struct ggml_tensor * mlp_0_w;
232
+ struct ggml_tensor * mlp_0_b;
233
+
234
+ // encoder.blocks.*.mlp.2
235
+ struct ggml_tensor * mlp_1_w;
236
+ struct ggml_tensor * mlp_1_b;
237
+ };
238
+
239
+ // token decoding layer
240
+ struct whisper_layer_decoder {
241
+ // decoder.blocks.*.attn_ln
242
+ struct ggml_tensor * attn_ln_0_w;
243
+ struct ggml_tensor * attn_ln_0_b;
244
+
245
+ // decoder.blocks.*.attn.out
246
+ struct ggml_tensor * attn_ln_1_w;
247
+ struct ggml_tensor * attn_ln_1_b;
248
+
249
+ // decoder.blocks.*.attn.query
250
+ struct ggml_tensor * attn_q_w;
251
+ struct ggml_tensor * attn_q_b;
252
+
253
+ // decoder.blocks.*.attn.key
254
+ struct ggml_tensor * attn_k_w;
255
+
256
+ // decoder.blocks.*.attn.value
257
+ struct ggml_tensor * attn_v_w;
258
+ struct ggml_tensor * attn_v_b;
259
+
260
+ // decoder.blocks.*.cross_attn_ln
261
+ struct ggml_tensor * cross_attn_ln_0_w;
262
+ struct ggml_tensor * cross_attn_ln_0_b;
263
+
264
+ // decoder.blocks.*.cross_attn.out
265
+ struct ggml_tensor * cross_attn_ln_1_w;
266
+ struct ggml_tensor * cross_attn_ln_1_b;
267
+
268
+ // decoder.blocks.*.cross_attn.query
269
+ struct ggml_tensor * cross_attn_q_w;
270
+ struct ggml_tensor * cross_attn_q_b;
271
+
272
+ // decoder.blocks.*.cross_attn.key
273
+ struct ggml_tensor * cross_attn_k_w;
274
+
275
+ // decoder.blocks.*.cross_attn.value
276
+ struct ggml_tensor * cross_attn_v_w;
277
+ struct ggml_tensor * cross_attn_v_b;
278
+
279
+ // decoder.blocks.*.mlp_ln
280
+ struct ggml_tensor * mlp_ln_w;
281
+ struct ggml_tensor * mlp_ln_b;
282
+
283
+ // decoder.blocks.*.mlp.0
284
+ struct ggml_tensor * mlp_0_w;
285
+ struct ggml_tensor * mlp_0_b;
286
+
287
+ // decoder.blocks.*.mlp.2
288
+ struct ggml_tensor * mlp_1_w;
289
+ struct ggml_tensor * mlp_1_b;
290
+ };
291
+
292
+ struct whisper_model {
293
+ e_model type = MODEL_UNKNOWN;
294
+
295
+ whisper_hparams hparams;
296
+ whisper_filters filters;
297
+
298
+ // encoder.positional_embedding
299
+ struct ggml_tensor * e_pe;
300
+
301
+ // encoder.conv1
302
+ struct ggml_tensor * e_conv_1_w;
303
+ struct ggml_tensor * e_conv_1_b;
304
+
305
+ // encoder.conv2
306
+ struct ggml_tensor * e_conv_2_w;
307
+ struct ggml_tensor * e_conv_2_b;
308
+
309
+ // encoder.ln_post
310
+ struct ggml_tensor * e_ln_w;
311
+ struct ggml_tensor * e_ln_b;
312
+
313
+ // decoder.positional_embedding
314
+ struct ggml_tensor * d_pe; // DD
315
+
316
+ // decoder.token_embedding
317
+ struct ggml_tensor * d_te; // DD
318
+
319
+ // decoder.ln
320
+ struct ggml_tensor * d_ln_w; // DD
321
+ struct ggml_tensor * d_ln_b; // DD
322
+
323
+ std::vector<whisper_layer_encoder> layers_encoder;
324
+ std::vector<whisper_layer_decoder> layers_decoder;
325
+
326
+ // key + value memory
327
+ struct ggml_tensor * memory_k;
328
+ struct ggml_tensor * memory_v;
329
+
330
+ struct ggml_tensor * memory_cross_k;
331
+ struct ggml_tensor * memory_cross_v;
332
+
333
+ //
334
+ struct ggml_context * ctx;
335
+ std::map<std::string, struct ggml_tensor *> tensors;
336
+ };
337
+
338
+ // load the model from a ggml file
339
+ //
340
+ // file format:
341
+ //
342
+ // - hparams
343
+ // - pre-computed mel filters
344
+ // - vocab
345
+ // - weights
346
+ //
347
+ // see the convert-pt-to-ggml.py script for details
348
+ //
349
+ bool whisper_model_load(const std::string & fname, whisper_model & model, whisper_vocab & vocab) {
350
+ printf("%s: loading model from '%s'\n", __func__, fname.c_str());
351
+
352
+ auto fin = std::ifstream(fname, std::ios::binary);
353
+ if (!fin) {
354
+ fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
355
+ return false;
356
+ }
357
+
358
+ // verify magic
359
+ {
360
+ uint32_t magic;
361
+ fin.read((char *) &magic, sizeof(magic));
362
+ if (magic != 0x67676d6c) {
363
+ fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
364
+ return false;
365
+ }
366
+ }
367
+
368
+ //load hparams
369
+ {
370
+ auto & hparams = model.hparams;
371
+
372
+ fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
373
+ fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
374
+ fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
375
+ fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
376
+ fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
377
+ fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
378
+ fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
379
+ fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
380
+ fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
381
+ fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
382
+ fin.read((char *) &hparams.f16, sizeof(hparams.f16));
383
+
384
+ assert(hparams.n_text_state == hparams.n_audio_state);
385
+
386
+ if (hparams.n_audio_layer == 4) {
387
+ model.type = e_model::MODEL_TINY;
388
+ }
389
+
390
+ if (hparams.n_audio_layer == 6) {
391
+ model.type = e_model::MODEL_BASE;
392
+ }
393
+
394
+ if (hparams.n_audio_layer == 12) {
395
+ model.type = e_model::MODEL_SMALL;
396
+ }
397
+
398
+ if (hparams.n_audio_layer == 24) {
399
+ model.type = e_model::MODEL_MEDIUM;
400
+ }
401
+
402
+ if (hparams.n_audio_layer == 32) {
403
+ model.type = e_model::MODEL_LARGE;
404
+ }
405
+
406
+ printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
407
+ printf("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
408
+ printf("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
409
+ printf("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
410
+ printf("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
411
+ printf("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
412
+ printf("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
413
+ printf("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
414
+ printf("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
415
+ printf("%s: n_mels = %d\n", __func__, hparams.n_mels);
416
+ printf("%s: f16 = %d\n", __func__, hparams.f16);
417
+ printf("%s: type = %d\n", __func__, model.type);
418
+
419
+ const size_t mem_required =
420
+ MEM_REQ_MODEL.at(model.type) +
421
+ MEM_REQ_ENCODE.at(model.type) +
422
+ MEM_REQ_ENCODE_LAYER.at(model.type) +
423
+ MEM_REQ_DECODE.at(model.type) +
424
+ MEM_REQ_DECODE_LAYER.at(model.type);
425
+
426
+ printf("%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
427
+ }
428
+
429
+ // load mel filters
430
+ {
431
+ auto & filters = model.filters;
432
+
433
+ fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
434
+ fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
435
+
436
+ filters.data.resize(filters.n_mel * filters.n_fft);
437
+ fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
438
+ }
439
+
440
+ // load vocab
441
+ {
442
+ int32_t n_vocab = 0;
443
+ fin.read((char *) &n_vocab, sizeof(n_vocab));
444
+
445
+ //if (n_vocab != model.hparams.n_vocab) {
446
+ // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
447
+ // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
448
+ // return false;
449
+ //}
450
+
451
+ std::string word;
452
+ for (int i = 0; i < n_vocab; i++) {
453
+ uint32_t len;
454
+ fin.read((char *) &len, sizeof(len));
455
+
456
+ word.resize(len);
457
+ fin.read((char *) word.data(), len);
458
+
459
+ vocab.token_to_id[word] = i;
460
+ vocab.id_to_token[i] = word;
461
+
462
+ //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
463
+ }
464
+
465
+ vocab.n_vocab = model.hparams.n_vocab;
466
+ if (vocab.is_multilingual()) {
467
+ vocab.token_eot++;
468
+ vocab.token_sot++;
469
+ vocab.token_prev++;
470
+ vocab.token_solm++;
471
+ vocab.token_beg++;
472
+ }
473
+
474
+ if (n_vocab < model.hparams.n_vocab) {
475
+ printf("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
476
+ for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
477
+ if (i > vocab.token_beg) {
478
+ word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
479
+ } else if (i == vocab.token_eot) {
480
+ word = "[_EOT_]";
481
+ } else if (i == vocab.token_sot) {
482
+ word = "[_SOT_]";
483
+ } else if (i == vocab.token_prev) {
484
+ word = "[_PREV_]";
485
+ } else if (i == vocab.token_beg) {
486
+ word = "[_BEG_]";
487
+ } else {
488
+ word = "[_extra_token_" + std::to_string(i) + "]";
489
+ }
490
+ vocab.token_to_id[word] = i;
491
+ vocab.id_to_token[i] = word;
492
+ }
493
+ }
494
+ }
495
+
496
+ // for the big tensors, we have the option to store the data in 16-bit floats
497
+ // in order to save memory and also to speed up the computation
498
+ const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
499
+
500
+ auto & ctx = model.ctx;
501
+
502
+ size_t ctx_size = 0;
503
+
504
+ {
505
+ const auto & hparams = model.hparams;
506
+
507
+ const int n_vocab = hparams.n_vocab;
508
+
509
+ const int n_audio_ctx = hparams.n_audio_ctx;
510
+ const int n_audio_state = hparams.n_audio_state;
511
+ const int n_audio_layer = hparams.n_audio_layer;
512
+
513
+ const int n_text_ctx = hparams.n_text_ctx;
514
+ const int n_text_state = hparams.n_text_state;
515
+ const int n_text_layer = hparams.n_text_layer;
516
+
517
+ const int n_mels = hparams.n_mels;
518
+
519
+ // encoder
520
+ {
521
+ // TODO: F16 .. maybe not?
522
+ ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
523
+
524
+ ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
525
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
526
+
527
+ ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
528
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
529
+
530
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
531
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
532
+ }
533
+
534
+ // decoder
535
+ {
536
+ // TODO: F16 .. maybe not?
537
+ ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
538
+
539
+ ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
540
+
541
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
542
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
543
+ }
544
+
545
+ // encoder layers
546
+ {
547
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
548
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
549
+
550
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
551
+ ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
552
+
553
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
554
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
555
+
556
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
557
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
558
+
559
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
560
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
561
+
562
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
563
+
564
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
565
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
566
+
567
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
568
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
569
+ }
570
+
571
+ // decoder layers
572
+ {
573
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
574
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
575
+
576
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
577
+ ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
578
+
579
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
580
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
581
+
582
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
583
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
584
+
585
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
586
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
587
+
588
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
589
+
590
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
591
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
592
+
593
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
594
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
595
+ //
596
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
597
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
598
+
599
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
600
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
601
+
602
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
603
+
604
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
605
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
606
+
607
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
608
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
609
+ }
610
+
611
+ ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_k
612
+ ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_v
613
+
614
+ ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_cross_k
615
+ ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_cross_v
616
+
617
+ ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
618
+
619
+ printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
620
+ }
621
+
622
+ // create the ggml context
623
+ {
624
+ struct ggml_init_params params = {
625
+ .mem_size = ctx_size,
626
+ .mem_buffer = NULL,
627
+ };
628
+
629
+ model.ctx = ggml_init(params);
630
+ if (!model.ctx) {
631
+ fprintf(stderr, "%s: ggml_init() failed\n", __func__);
632
+ return false;
633
+ }
634
+ }
635
+
636
+ // prepare memory for the weights
637
+ {
638
+ const auto & hparams = model.hparams;
639
+
640
+ const int n_vocab = hparams.n_vocab;
641
+
642
+ const int n_audio_ctx = hparams.n_audio_ctx;
643
+ const int n_audio_state = hparams.n_audio_state;
644
+ const int n_audio_layer = hparams.n_audio_layer;
645
+
646
+ const int n_text_ctx = hparams.n_text_ctx;
647
+ const int n_text_state = hparams.n_text_state;
648
+ const int n_text_layer = hparams.n_text_layer;
649
+
650
+ const int n_mels = hparams.n_mels;
651
+
652
+ model.layers_encoder.resize(n_audio_layer);
653
+ model.layers_decoder.resize(n_text_layer);
654
+
655
+ // encoder
656
+ {
657
+ model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
658
+
659
+ model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
660
+ model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
661
+
662
+ model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
663
+ model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
664
+
665
+ model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
666
+ model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
667
+
668
+ // map by name
669
+ model.tensors["encoder.positional_embedding"] = model.e_pe;
670
+
671
+ model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
672
+ model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
673
+
674
+ model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
675
+ model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
676
+
677
+ model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
678
+ model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
679
+
680
+ for (int i = 0; i < n_audio_layer; ++i) {
681
+ auto & layer = model.layers_encoder[i];
682
+
683
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
684
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
685
+
686
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
687
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
688
+
689
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
690
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
691
+
692
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
693
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
694
+
695
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
696
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
697
+
698
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
699
+
700
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
701
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
702
+
703
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
704
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
705
+
706
+ // map by name
707
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
708
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
709
+
710
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
711
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
712
+
713
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
714
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
715
+
716
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
717
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
718
+
719
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
720
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
721
+
722
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
723
+
724
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
725
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
726
+
727
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
728
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
729
+ }
730
+ }
731
+
732
+ // decoder
733
+ {
734
+ model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
735
+
736
+ model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
737
+
738
+ model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
739
+ model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
740
+
741
+ // map by name
742
+ model.tensors["decoder.positional_embedding"] = model.d_pe;
743
+
744
+ model.tensors["decoder.token_embedding.weight"] = model.d_te;
745
+
746
+ model.tensors["decoder.ln.weight"] = model.d_ln_w;
747
+ model.tensors["decoder.ln.bias"] = model.d_ln_b;
748
+
749
+ for (int i = 0; i < n_text_layer; ++i) {
750
+ auto & layer = model.layers_decoder[i];
751
+
752
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
753
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
754
+
755
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
756
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
757
+
758
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
759
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
760
+
761
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
762
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
763
+
764
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
765
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
766
+
767
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
768
+
769
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
770
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
771
+
772
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
773
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
774
+
775
+ layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
776
+ layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
777
+
778
+ layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
779
+ layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
780
+
781
+ layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
782
+
783
+ layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
784
+ layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
785
+
786
+ layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
787
+ layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
788
+
789
+ // map by name
790
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
791
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
792
+
793
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
794
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
795
+
796
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
797
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
798
+
799
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
800
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
801
+
802
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
803
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
804
+
805
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
806
+
807
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
808
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
809
+
810
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
811
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
812
+
813
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
814
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
815
+
816
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
817
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
818
+
819
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
820
+
821
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
822
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
823
+
824
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
825
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
826
+ }
827
+ }
828
+ }
829
+
830
+ // key + value memory
831
+ {
832
+ const auto & hparams = model.hparams;
833
+
834
+ const int n_text_state = hparams.n_text_state;
835
+ const int n_text_layer = hparams.n_text_layer;
836
+ const int n_text_ctx = hparams.n_text_ctx;
837
+
838
+ {
839
+ const int n_mem = n_text_layer*n_text_ctx;
840
+ const int n_elements = n_text_state*n_mem;
841
+
842
+ model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
843
+ model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
844
+ }
845
+
846
+ {
847
+ const int n_audio_ctx = hparams.n_audio_ctx;
848
+
849
+ const int n_mem = n_text_layer*n_audio_ctx;
850
+ const int n_elements = n_text_state*n_mem;
851
+
852
+ model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
853
+ model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
854
+ }
855
+
856
+ const size_t memory_size =
857
+ ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
858
+ ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
859
+
860
+ printf("%s: memory size = %8.2f MB \n", __func__, memory_size/1024.0/1024.0);
861
+ }
862
+
863
+ // load weights
864
+ {
865
+ size_t total_size = 0;
866
+
867
+ while (true) {
868
+ int32_t n_dims;
869
+ int32_t length;
870
+ int32_t ftype;
871
+
872
+ fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
873
+ fin.read(reinterpret_cast<char *>(&length), sizeof(length));
874
+ fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
875
+
876
+ if (fin.eof()) {
877
+ break;
878
+ }
879
+
880
+ int32_t nelements = 1;
881
+ int32_t ne[3] = { 1, 1, 1 };
882
+ for (int i = 0; i < n_dims; ++i) {
883
+ fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
884
+ nelements *= ne[i];
885
+ }
886
+
887
+ std::string name(length, 0);
888
+ fin.read(&name[0], length);
889
+
890
+ if (model.tensors.find(name.data()) == model.tensors.end()) {
891
+ fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
892
+ return false;
893
+ }
894
+
895
+ auto tensor = model.tensors[name.data()];
896
+ if (ggml_nelements(tensor) != nelements) {
897
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
898
+ return false;
899
+ }
900
+
901
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
902
+ fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
903
+ __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
904
+ return false;
905
+ }
906
+
907
+ const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
908
+
909
+ if (nelements*bpe != ggml_nbytes(tensor)) {
910
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
911
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
912
+ return false;
913
+ }
914
+
915
+ fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
916
+
917
+ //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
918
+ total_size += ggml_nbytes(tensor);
919
+ }
920
+
921
+ printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
922
+ }
923
+
924
+ fin.close();
925
+
926
+ return true;
927
+ }
928
+
929
+ // evaluate the encoder
930
+ //
931
+ // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
932
+ // part of the transformer model and returns the encoded features
933
+ //
934
+ // - model: the model
935
+ // - n_threads: number of threads to use
936
+ // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
937
+ // - mel_inp: input mel spectrogram
938
+ // - features: output encoded features
939
+ //
940
+ bool whisper_encode(
941
+ const whisper_model & model,
942
+ const int n_threads,
943
+ const int mel_offset,
944
+ const whisper_mel & mel_inp,
945
+ std::vector<float> & features) {
946
+ const auto & hparams = model.hparams;
947
+
948
+ const int n_vocab = hparams.n_vocab;
949
+
950
+ const int n_ctx = hparams.n_audio_ctx;
951
+ const int n_state = hparams.n_audio_state;
952
+ const int n_head = hparams.n_audio_head;
953
+ const int n_layer = hparams.n_audio_layer;
954
+
955
+ const int N = n_ctx;
956
+
957
+ const int n_mels = hparams.n_mels;
958
+ assert(mel_inp.n_mel == n_mels);
959
+
960
+ struct ggml_init_params params;
961
+
962
+ {
963
+ static size_t buf_size = MEM_REQ_ENCODE.at(model.type);
964
+ static void * buf = malloc(buf_size);
965
+
966
+ params = {
967
+ .mem_size = buf_size,
968
+ .mem_buffer = buf,
969
+ };
970
+ }
971
+
972
+ struct ggml_context * ctx0 = ggml_init(params);
973
+
974
+ struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
975
+ assert(mel->type == GGML_TYPE_F32);
976
+ {
977
+ float * dst = (float *) mel->data;
978
+ memset(dst, 0, ggml_nbytes(mel));
979
+
980
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
981
+ const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
982
+
983
+ for (int j = 0; j < mel_inp.n_mel; ++j) {
984
+ for (int i = i0; i < i1; ++i) {
985
+ dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
986
+ }
987
+ }
988
+ }
989
+
990
+ struct ggml_tensor * cur;
991
+
992
+ // convolution + gelu
993
+ {
994
+ cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
995
+ cur = ggml_add(ctx0,
996
+ ggml_repeat(ctx0,
997
+ model.e_conv_1_b,
998
+ cur),
999
+ cur);
1000
+
1001
+ cur = ggml_gelu(ctx0, cur);
1002
+
1003
+ cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1004
+ cur = ggml_add(ctx0,
1005
+ ggml_repeat(ctx0,
1006
+ model.e_conv_2_b,
1007
+ cur),
1008
+ cur);
1009
+
1010
+ cur = ggml_gelu(ctx0, cur);
1011
+ }
1012
+
1013
+ cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
1014
+
1015
+ struct ggml_tensor * inpL = cur;
1016
+
1017
+ for (int il = 0; il < n_layer; ++il) {
1018
+ const auto & layer = model.layers_encoder[il];
1019
+
1020
+ // create separate context for each layer to reduce memory usage
1021
+
1022
+ struct ggml_init_params paramsL;
1023
+ {
1024
+ static size_t buf_size = MEM_REQ_ENCODE_LAYER.at(model.type);
1025
+ static void * buf = malloc(buf_size);
1026
+
1027
+ paramsL = {
1028
+ .mem_size = buf_size,
1029
+ .mem_buffer = buf,
1030
+ };
1031
+ }
1032
+
1033
+ struct ggml_context * ctxL = ggml_init(paramsL);
1034
+
1035
+ // norm
1036
+ {
1037
+ cur = ggml_norm(ctxL, inpL);
1038
+
1039
+ // cur = ln_0_w*cur + ln_0_b
1040
+ cur = ggml_add(ctxL,
1041
+ ggml_mul(ctxL,
1042
+ ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1043
+ cur),
1044
+ ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1045
+ }
1046
+
1047
+ // self-attention
1048
+ {
1049
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1050
+ layer.attn_q_w,
1051
+ cur);
1052
+
1053
+ Qcur = ggml_add(ctxL,
1054
+ ggml_repeat(ctxL,
1055
+ layer.attn_q_b,
1056
+ Qcur),
1057
+ Qcur);
1058
+
1059
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1060
+
1061
+ // no bias for Key
1062
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1063
+ layer.attn_k_w,
1064
+ cur);
1065
+
1066
+ Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1067
+
1068
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1069
+ layer.attn_v_w,
1070
+ cur);
1071
+
1072
+ Vcur = ggml_add(ctxL,
1073
+ ggml_repeat(ctxL,
1074
+ layer.attn_v_b,
1075
+ Vcur),
1076
+ Vcur);
1077
+
1078
+ // ------
1079
+
1080
+ struct ggml_tensor * Q =
1081
+ ggml_permute(ctxL,
1082
+ ggml_cpy(ctxL,
1083
+ Qcur,
1084
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1085
+ 0, 2, 1, 3);
1086
+
1087
+ struct ggml_tensor * K =
1088
+ ggml_permute(ctxL,
1089
+ ggml_cpy(ctxL,
1090
+ Kcur,
1091
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), // F16 !
1092
+ 0, 2, 1, 3);
1093
+
1094
+ //// BLAS attempt
1095
+ //struct ggml_tensor * KQ =
1096
+ // ggml_mul_mat(ctxL,
1097
+ // ggml_cpy(ctxL, K, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, N, n_head)),
1098
+ // ggml_cpy(ctxL, Q, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, N, n_head)));
1099
+
1100
+ // K * Q
1101
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1102
+
1103
+ //struct ggml_tensor * K =
1104
+ // ggml_cpy(ctxL,
1105
+ // ggml_permute(ctxL,
1106
+ // ggml_reshape_3d(ctxL,
1107
+ // Kcur,
1108
+ // n_state/n_head, n_head, N),
1109
+ // 1, 2, 0, 3),
1110
+ // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
1111
+ // );
1112
+
1113
+ //// K * Q
1114
+ //struct ggml_tensor * KQ = ggml_mul_mat(ctxL, ggml_transpose(ctxL, K), Q);
1115
+
1116
+ //struct ggml_tensor * KQ_scaled =
1117
+ // ggml_scale(ctxL,
1118
+ // KQ,
1119
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1120
+ // );
1121
+
1122
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
1123
+
1124
+ //struct ggml_tensor * V_trans =
1125
+ // ggml_permute(ctxL,
1126
+ // ggml_cpy(ctxL,
1127
+ // Vcur,
1128
+ // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1129
+ // 1, 2, 0, 3);
1130
+
1131
+ //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1132
+
1133
+ struct ggml_tensor * V =
1134
+ ggml_cpy(ctxL,
1135
+ ggml_permute(ctxL,
1136
+ ggml_reshape_3d(ctxL,
1137
+ Vcur,
1138
+ n_state/n_head, n_head, N),
1139
+ 0, 2, 1, 3),
1140
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head) // F16 !
1141
+ );
1142
+
1143
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
1144
+
1145
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1146
+
1147
+ cur = ggml_cpy(ctxL,
1148
+ KQV_merged,
1149
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1150
+ }
1151
+
1152
+ // projection
1153
+ {
1154
+ cur = ggml_mul_mat(ctxL,
1155
+ layer.attn_ln_1_w,
1156
+ cur);
1157
+
1158
+ cur = ggml_add(ctxL,
1159
+ ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
1160
+ cur);
1161
+ }
1162
+
1163
+ // add the input
1164
+ cur = ggml_add(ctxL, cur, inpL);
1165
+
1166
+ struct ggml_tensor * inpFF = cur;
1167
+
1168
+ // feed-forward network
1169
+ {
1170
+ // norm
1171
+ {
1172
+ cur = ggml_norm(ctxL, inpFF);
1173
+
1174
+ // cur = mlp_ln_w*cur + mlp_ln_b
1175
+ cur = ggml_add(ctxL,
1176
+ ggml_mul(ctxL,
1177
+ ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1178
+ cur),
1179
+ ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1180
+ }
1181
+
1182
+ // fully connected
1183
+ cur = ggml_mul_mat(ctxL,
1184
+ layer.mlp_0_w,
1185
+ cur);
1186
+
1187
+ cur = ggml_add(ctxL,
1188
+ ggml_repeat(ctxL, layer.mlp_0_b, cur),
1189
+ cur);
1190
+
1191
+ // GELU activation
1192
+ cur = ggml_gelu(ctxL, cur);
1193
+
1194
+ // projection
1195
+ cur = ggml_mul_mat(ctxL,
1196
+ layer.mlp_1_w,
1197
+ cur);
1198
+
1199
+ cur = ggml_add(ctxL,
1200
+ ggml_repeat(ctxL, layer.mlp_1_b, cur),
1201
+ cur);
1202
+ }
1203
+
1204
+ // output from this layer
1205
+ struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
1206
+
1207
+ {
1208
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1209
+
1210
+ ggml_build_forward_expand(&gf, inpO);
1211
+ ggml_graph_compute (ctxL, &gf);
1212
+
1213
+ //ggml_graph_print(&gf);
1214
+ }
1215
+
1216
+ // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
1217
+ // input for next layer (inpO -> inpL)
1218
+ memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
1219
+ inpL->op = GGML_OP_NONE;
1220
+ inpL->src0 = NULL;
1221
+ inpL->src1 = NULL;
1222
+
1223
+ //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
1224
+
1225
+ ggml_free(ctxL);
1226
+ }
1227
+
1228
+ cur = inpL;
1229
+
1230
+ // norm
1231
+ {
1232
+ cur = ggml_norm(ctx0, cur);
1233
+
1234
+ // cur = ln_f_g*cur + ln_f_b
1235
+ cur = ggml_add(ctx0,
1236
+ ggml_mul(ctx0,
1237
+ ggml_repeat(ctx0, model.e_ln_w, cur),
1238
+ cur),
1239
+ ggml_repeat(ctx0, model.e_ln_b, cur));
1240
+ }
1241
+
1242
+ // run the computation
1243
+ {
1244
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1245
+
1246
+ ggml_build_forward_expand(&gf, cur);
1247
+ ggml_graph_compute (ctx0, &gf);
1248
+
1249
+ //ggml_graph_print(&gf);
1250
+ }
1251
+
1252
+ // cur
1253
+ //{
1254
+ // printf("ne0 = %d\n", cur->ne[0]);
1255
+ // printf("ne1 = %d\n", cur->ne[1]);
1256
+ // for (int i = 0; i < 10; ++i) {
1257
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1258
+ // }
1259
+ // printf("... ");
1260
+ // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1261
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1262
+ // }
1263
+ // printf("\n");
1264
+ //}
1265
+
1266
+ // pre-compute cross-attention memory
1267
+ {
1268
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1269
+
1270
+ // TODO: hack to disconnect the encoded features from the previous graph
1271
+ cur->op = GGML_OP_NONE;
1272
+ cur->src0 = NULL;
1273
+ cur->src1 = NULL;
1274
+
1275
+ for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1276
+ auto & layer = model.layers_decoder[il];
1277
+
1278
+ struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1279
+ layer.cross_attn_k_w,
1280
+ cur);
1281
+
1282
+ Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1283
+
1284
+ struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1285
+ layer.cross_attn_v_w,
1286
+ cur);
1287
+
1288
+ Vcross = ggml_add(ctx0,
1289
+ ggml_repeat(ctx0,
1290
+ layer.cross_attn_v_b,
1291
+ Vcross),
1292
+ Vcross);
1293
+
1294
+ struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
1295
+ struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
1296
+
1297
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1298
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
1299
+ }
1300
+
1301
+ ggml_graph_compute(ctx0, &gf);
1302
+ }
1303
+
1304
+ ////////////////////////////////////////////////////////////////////////////
1305
+
1306
+ // output the features
1307
+ assert(cur->type == GGML_TYPE_F32);
1308
+ features.resize(cur->ne[0]*cur->ne[1]);
1309
+ memcpy(features.data(), cur->data, features.size()*sizeof(float));
1310
+
1311
+ //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0);
1312
+
1313
+ ggml_free(ctx0);
1314
+
1315
+ return true;
1316
+ }
1317
+
1318
+ // evaluate the decoder
1319
+ //
1320
+ // given text prompt + audio features -> predicts the probabilities for the next token
1321
+ //
1322
+ // - model: the model
1323
+ // - n_threads: number of threads to use
1324
+ // - n_past: prompt length
1325
+ // - prompt: text prompt
1326
+ // - logits_out: output logits
1327
+ // - probs_out: output probabilities
1328
+ //
1329
+ bool whisper_decode(
1330
+ const whisper_model & model,
1331
+ const int n_threads,
1332
+ const int n_past,
1333
+ const std::vector<whisper_vocab::id> & prompt,
1334
+ std::vector<float> & logits_out,
1335
+ std::vector<float> & probs_out) {
1336
+ const auto & hparams = model.hparams;
1337
+
1338
+ const int n_vocab = hparams.n_vocab;
1339
+
1340
+ const int n_ctx = hparams.n_text_ctx;
1341
+ const int n_state = hparams.n_text_state;
1342
+ const int n_head = hparams.n_text_head;
1343
+ const int n_layer = hparams.n_text_layer;
1344
+
1345
+ const int N = prompt.size();
1346
+ const int M = hparams.n_audio_ctx;
1347
+
1348
+ struct ggml_init_params params;
1349
+
1350
+ {
1351
+ static size_t buf_size = MEM_REQ_DECODE.at(model.type);
1352
+ static void * buf = malloc(buf_size);
1353
+
1354
+ params = {
1355
+ .mem_size = buf_size,
1356
+ .mem_buffer = buf,
1357
+ };
1358
+ }
1359
+
1360
+ struct ggml_context * ctx0 = ggml_init(params);
1361
+
1362
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1363
+ memcpy(embd->data, prompt.data(), N*ggml_element_size(embd));
1364
+
1365
+ struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1366
+ for (int i = 0; i < N; ++i) {
1367
+ ((int32_t *) position->data)[i] = n_past + i;
1368
+ }
1369
+
1370
+ // wte + wpe
1371
+ struct ggml_tensor * cur =
1372
+ ggml_add(ctx0,
1373
+ ggml_get_rows(ctx0, model.d_te, embd),
1374
+ ggml_get_rows(ctx0, model.d_pe, position));
1375
+
1376
+ struct ggml_tensor * inpL = cur;
1377
+
1378
+ for (int il = 0; il < n_layer; ++il) {
1379
+ const auto & layer = model.layers_decoder[il];
1380
+
1381
+ struct ggml_init_params paramsL;
1382
+
1383
+ {
1384
+ static size_t buf_size = MEM_REQ_DECODE_LAYER.at(model.type);
1385
+ static void * buf = malloc(buf_size);
1386
+
1387
+ paramsL = {
1388
+ .mem_size = buf_size,
1389
+ .mem_buffer = buf,
1390
+ };
1391
+ }
1392
+
1393
+ struct ggml_context * ctxL = ggml_init(paramsL);
1394
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1395
+
1396
+ // norm
1397
+ {
1398
+ cur = ggml_norm(ctxL, inpL);
1399
+
1400
+ // cur = ln_0_w*cur + ln_0_b
1401
+ cur = ggml_add(ctxL,
1402
+ ggml_mul(ctxL,
1403
+ ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1404
+ cur),
1405
+ ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1406
+ }
1407
+
1408
+ // self-attention
1409
+ {
1410
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1411
+ layer.attn_q_w,
1412
+ cur);
1413
+
1414
+ Qcur = ggml_add(ctxL,
1415
+ ggml_repeat(ctxL,
1416
+ layer.attn_q_b,
1417
+ Qcur),
1418
+ Qcur);
1419
+
1420
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1421
+
1422
+ // no bias for Key
1423
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1424
+ layer.attn_k_w,
1425
+ cur);
1426
+
1427
+ Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1428
+
1429
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1430
+ layer.attn_v_w,
1431
+ cur);
1432
+
1433
+ Vcur = ggml_add(ctxL,
1434
+ ggml_repeat(ctxL,
1435
+ layer.attn_v_b,
1436
+ Vcur),
1437
+ Vcur);
1438
+
1439
+ // store key and value to memory
1440
+ {
1441
+ struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
1442
+ struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
1443
+
1444
+ ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
1445
+ ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
1446
+ }
1447
+
1448
+ // ------
1449
+
1450
+ struct ggml_tensor * Q =
1451
+ ggml_permute(ctxL,
1452
+ ggml_cpy(ctxL,
1453
+ Qcur,
1454
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1455
+ 0, 2, 1, 3);
1456
+
1457
+ struct ggml_tensor * K =
1458
+ ggml_permute(ctxL,
1459
+ ggml_reshape_3d(ctxL,
1460
+ ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
1461
+ n_state/n_head, n_head, n_past + N),
1462
+ 0, 2, 1, 3);
1463
+
1464
+ // K * Q
1465
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1466
+
1467
+ //struct ggml_tensor * KQ_scaled =
1468
+ // ggml_scale(ctxL,
1469
+ // KQ,
1470
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1471
+ // );
1472
+
1473
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
1474
+
1475
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
1476
+
1477
+ struct ggml_tensor * V_trans =
1478
+ ggml_permute(ctxL,
1479
+ ggml_reshape_3d(ctxL,
1480
+ ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
1481
+ n_state/n_head, n_head, n_past + N),
1482
+ 1, 2, 0, 3);
1483
+
1484
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1485
+
1486
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1487
+
1488
+ cur = ggml_cpy(ctxL,
1489
+ KQV_merged,
1490
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1491
+ }
1492
+
1493
+ {
1494
+ cur = ggml_mul_mat(ctxL,
1495
+ layer.attn_ln_1_w,
1496
+ cur);
1497
+
1498
+ cur = ggml_add(ctxL,
1499
+ ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
1500
+ cur);
1501
+ }
1502
+
1503
+ // add the input
1504
+ struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL);
1505
+
1506
+ // norm
1507
+ {
1508
+ cur = ggml_norm(ctxL, inpCA); // Note we use inpCA here
1509
+
1510
+ // cur = ln_0_w*cur + ln_0_b
1511
+ cur = ggml_add(ctxL,
1512
+ ggml_mul(ctxL,
1513
+ ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur),
1514
+ cur),
1515
+ ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur));
1516
+ }
1517
+
1518
+ // cross-attention
1519
+ {
1520
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1521
+ layer.cross_attn_q_w,
1522
+ cur);
1523
+
1524
+ Qcur = ggml_add(ctxL,
1525
+ ggml_repeat(ctxL,
1526
+ layer.cross_attn_q_b,
1527
+ Qcur),
1528
+ Qcur);
1529
+
1530
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1531
+
1532
+ // Kcross is already scaled
1533
+ struct ggml_tensor * Kcross =
1534
+ ggml_reshape_3d(ctxL,
1535
+ ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
1536
+ n_state/n_head, n_head, M);
1537
+
1538
+ struct ggml_tensor * Vcross =
1539
+ ggml_reshape_3d(ctxL,
1540
+ ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
1541
+ n_state/n_head, n_head, M);
1542
+
1543
+ // ------
1544
+
1545
+ struct ggml_tensor * Q =
1546
+ ggml_permute(ctxL,
1547
+ ggml_cpy(ctxL,
1548
+ Qcur,
1549
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1550
+ 0, 2, 1, 3);
1551
+
1552
+ struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
1553
+
1554
+ // K * Q
1555
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1556
+
1557
+ //struct ggml_tensor * KQ_scaled =
1558
+ // ggml_scale(ctxL,
1559
+ // KQ,
1560
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1561
+ // );
1562
+
1563
+ // no masking for cross-attention
1564
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
1565
+
1566
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
1567
+
1568
+ struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
1569
+
1570
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1571
+
1572
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1573
+
1574
+ // cur = KQV_merged.contiguous().view(n_state, N)
1575
+ cur = ggml_cpy(ctxL,
1576
+ KQV_merged,
1577
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1578
+ }
1579
+
1580
+ // projection
1581
+ {
1582
+ cur = ggml_mul_mat(ctxL,
1583
+ layer.cross_attn_ln_1_w,
1584
+ cur);
1585
+
1586
+ cur = ggml_add(ctxL,
1587
+ ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur),
1588
+ cur);
1589
+ }
1590
+
1591
+
1592
+ // add the input
1593
+ cur = ggml_add(ctxL, cur, inpCA);
1594
+
1595
+ struct ggml_tensor * inpFF = cur;
1596
+
1597
+ // feed-forward network
1598
+ {
1599
+ // norm
1600
+ {
1601
+ cur = ggml_norm(ctxL, inpFF);
1602
+
1603
+ // cur = ln_2_g*cur + ln_2_b
1604
+ // [ 768, N]
1605
+ cur = ggml_add(ctxL,
1606
+ ggml_mul(ctxL,
1607
+ ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1608
+ cur),
1609
+ ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1610
+ }
1611
+
1612
+ // fully connected
1613
+ cur = ggml_mul_mat(ctxL,
1614
+ layer.mlp_0_w,
1615
+ cur);
1616
+
1617
+ cur = ggml_add(ctxL,
1618
+ ggml_repeat(ctxL, layer.mlp_0_b, cur),
1619
+ cur);
1620
+
1621
+ // GELU activation
1622
+ cur = ggml_gelu(ctxL, cur);
1623
+
1624
+ // projection
1625
+ cur = ggml_mul_mat(ctxL,
1626
+ layer.mlp_1_w,
1627
+ cur);
1628
+
1629
+ cur = ggml_add(ctxL,
1630
+ ggml_repeat(ctxL, layer.mlp_1_b, cur),
1631
+ cur);
1632
+ }
1633
+
1634
+ // output from this layer
1635
+ struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
1636
+
1637
+ {
1638
+ ggml_build_forward_expand(&gf, inpO);
1639
+ ggml_graph_compute (ctxL, &gf);
1640
+
1641
+ //ggml_graph_print(&gf);
1642
+ }
1643
+
1644
+ // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
1645
+ // input for next layer (inpO -> inpL)
1646
+ memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
1647
+ inpL->op = GGML_OP_NONE;
1648
+ inpL->src0 = NULL;
1649
+ inpL->src1 = NULL;
1650
+
1651
+ if (N > 1) {
1652
+ //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
1653
+ }
1654
+
1655
+ ggml_free(ctxL);
1656
+ }
1657
+
1658
+ cur = inpL;
1659
+
1660
+ // norm
1661
+ {
1662
+ cur = ggml_norm(ctx0, cur);
1663
+
1664
+ cur = ggml_add(ctx0,
1665
+ ggml_mul(ctx0,
1666
+ ggml_repeat(ctx0, model.d_ln_w, cur),
1667
+ cur),
1668
+ ggml_repeat(ctx0, model.d_ln_b, cur));
1669
+ }
1670
+
1671
+ struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
1672
+
1673
+ // logits -> probs
1674
+ cur = ggml_dup(ctx0, logits);
1675
+ cur = ggml_soft_max(ctx0, cur); // in-place
1676
+
1677
+ // run the computation
1678
+ {
1679
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1680
+
1681
+ ggml_build_forward_expand(&gf, cur);
1682
+ ggml_graph_compute (ctx0, &gf);
1683
+ }
1684
+
1685
+ logits_out.resize(N*n_vocab);
1686
+ memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
1687
+
1688
+ probs_out.resize(N*n_vocab);
1689
+ memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
1690
+
1691
+ //if (N > 1) {
1692
+ // const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
1693
+ // printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
1694
+ // printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
1695
+ //}
1696
+
1697
+ ggml_free(ctx0);
1698
+
1699
+ return true;
1700
+ }
1701
+
1702
+ // the most basic sampling scheme - select the top token
1703
+ // TODO: beam search
1704
+ // TODO: temperature
1705
+ whisper_vocab::id whisper_sample_best(
1706
+ const whisper_vocab & vocab,
1707
+ const float * probs,
1708
+ double temp,
1709
+ int offset = 0) {
1710
+ int n_logits = vocab.id_to_token.size();
1711
+
1712
+ std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1713
+ probs_id.reserve(n_logits);
1714
+
1715
+ for (int i = offset; i < n_logits; i++) {
1716
+ probs_id.push_back(std::make_pair(probs[i], i));
1717
+ }
1718
+
1719
+ const int top_k = 10;
1720
+
1721
+ // find the top K tokens
1722
+ std::partial_sort(
1723
+ probs_id.begin(),
1724
+ probs_id.begin() + top_k, probs_id.end(),
1725
+ [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
1726
+ return a.first > b.first;
1727
+ });
1728
+
1729
+ probs_id.resize(top_k);
1730
+
1731
+ //printf("\n");
1732
+ //for (int i = 0; i < (int) probs_id.size(); i++) {
1733
+ // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1734
+ //}
1735
+
1736
+ int res = 0;
1737
+ while (probs_id[res].second == vocab.token_solm && res < (int) probs_id.size() - 1) {
1738
+ res++;
1739
+ }
1740
+
1741
+ return probs_id[res].second;
1742
+ }
1743
+
1744
+ // Cooley-Tukey FFT
1745
+ // poor man's implmentation - use something better
1746
+ // input is real-valued
1747
+ // output is complex-valued
1748
+ void fft(const std::vector<float> & in, std::vector<float> & out) {
1749
+ out.resize(in.size()*2);
1750
+
1751
+ int N = in.size();
1752
+
1753
+ if (N == 1) {
1754
+ out[0] = in[0];
1755
+ out[1] = 0;
1756
+ return;
1757
+ }
1758
+
1759
+ std::vector<float> even;
1760
+ std::vector<float> odd;
1761
+
1762
+ for (int i = 0; i < N; i++) {
1763
+ if (i % 2 == 0) {
1764
+ even.push_back(in[i]);
1765
+ } else {
1766
+ odd.push_back(in[i]);
1767
+ }
1768
+ }
1769
+
1770
+ std::vector<float> even_fft;
1771
+ std::vector<float> odd_fft;
1772
+
1773
+ fft(even, even_fft);
1774
+ fft(odd, odd_fft);
1775
+
1776
+ for (int k = 0; k < N/2; k++) {
1777
+ float theta = 2*M_PI*k/N;
1778
+
1779
+ float re = cos(theta);
1780
+ float im = -sin(theta);
1781
+
1782
+ float re_odd = odd_fft[2*k + 0];
1783
+ float im_odd = odd_fft[2*k + 1];
1784
+
1785
+ out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
1786
+ out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
1787
+
1788
+ out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
1789
+ out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
1790
+ }
1791
+ }
1792
+
1793
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
1794
+ bool log_mel_spectrogram(
1795
+ const std::vector<float> sf32,
1796
+ const int sample_rate,
1797
+ const int fft_size,
1798
+ const int fft_step,
1799
+ const int n_mel,
1800
+ const int n_threads,
1801
+ const whisper_filters & filters,
1802
+ whisper_mel & mel) {
1803
+ const int n_sample = sf32.size();
1804
+ const float * samples = sf32.data();
1805
+
1806
+ // Hanning window
1807
+ std::vector<float> hann;
1808
+ hann.resize(fft_size);
1809
+ for (int i = 0; i < fft_size; i++) {
1810
+ hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
1811
+ }
1812
+
1813
+ mel.n_mel = n_mel;
1814
+ mel.n_len = (n_sample)/fft_step;
1815
+ mel.data.resize(mel.n_mel*mel.n_len);
1816
+
1817
+ const int n_fft = 1 + fft_size/2;
1818
+
1819
+ printf("%s: n_sample = %d, n_len = %d\n", __func__, n_sample, mel.n_len);
1820
+ printf("%s: recording length: %f s\n", __func__, (float) n_sample/sample_rate);
1821
+
1822
+ std::vector<std::thread> workers(n_threads);
1823
+ for (int iw = 0; iw < n_threads; ++iw) {
1824
+ workers[iw] = std::thread([&](int ith) {
1825
+ std::vector<float> fft_in;
1826
+ fft_in.resize(fft_size);
1827
+ for (int i = 0; i < fft_size; i++) {
1828
+ fft_in[i] = 0.0;
1829
+ }
1830
+
1831
+ std::vector<float> fft_out;
1832
+ fft_out.resize(2*fft_size);
1833
+
1834
+ for (int i = ith; i < mel.n_len; i += n_threads) {
1835
+ const int offset = i*fft_step;
1836
+
1837
+ // apply Hanning window
1838
+ for (int j = 0; j < fft_size; j++) {
1839
+ if (offset + j < n_sample) {
1840
+ fft_in[j] = hann[j]*samples[offset + j];
1841
+ } else {
1842
+ fft_in[j] = 0.0;
1843
+ }
1844
+ }
1845
+
1846
+ // FFT -> mag^2
1847
+ fft(fft_in, fft_out);
1848
+
1849
+ for (int j = 0; j < n_fft; j++) {
1850
+ fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
1851
+ }
1852
+
1853
+ // mel spectrogram
1854
+ for (int j = 0; j < mel.n_mel; j++) {
1855
+ double sum = 0.0;
1856
+
1857
+ for (int k = 0; k < n_fft; k++) {
1858
+ sum += fft_out[k]*filters.data[j*n_fft + k];
1859
+ }
1860
+ if (sum < 1e-10) {
1861
+ sum = 1e-10;
1862
+ }
1863
+
1864
+ sum = log10(sum);
1865
+
1866
+ mel.data[j*mel.n_len + i] = sum;
1867
+ }
1868
+ }
1869
+ }, iw);
1870
+ }
1871
+
1872
+ for (int iw = 0; iw < n_threads; ++iw) {
1873
+ workers[iw].join();
1874
+ }
1875
+
1876
+ // clamping and normalization
1877
+ double mmax = -1e20;
1878
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
1879
+ if (mel.data[i] > mmax) {
1880
+ mmax = mel.data[i];
1881
+ }
1882
+ }
1883
+
1884
+ mmax -= 8.0;
1885
+
1886
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
1887
+ if (mel.data[i] < mmax) {
1888
+ mel.data[i] = mmax;
1889
+ }
1890
+
1891
+ mel.data[i] = (mel.data[i] + 4.0)/4.0;
1892
+ }
1893
+
1894
+ return true;
1895
+ }
1896
+
1897
+ int main(int argc, char ** argv) {
1898
+ const int64_t t_main_start_us = ggml_time_us();
1899
+
1900
+ whisper_params params;
1901
+ params.model = "models/whisper-tiny.en/ggml-model.bin";
1902
+
1903
+ if (whisper_params_parse(argc, argv, params) == false) {
1904
+ return 1;
1905
+ }
1906
+
1907
+ if (params.seed < 0) {
1908
+ params.seed = time(NULL);
1909
+ }
1910
+
1911
+ // Model loading
1912
+
1913
+ //printf("%s: seed = %d\n", __func__, params.seed);
1914
+
1915
+ int64_t t_load_us = 0;
1916
+ int64_t t_mel_us = 0;
1917
+ int64_t t_sample_us = 0;
1918
+ int64_t t_encode_us = 0;
1919
+ int64_t t_decode_us = 0;
1920
+
1921
+ whisper_vocab vocab;
1922
+ whisper_model model;
1923
+
1924
+ // load the model
1925
+ {
1926
+ const int64_t t_start_us = ggml_time_us();
1927
+
1928
+ if (!whisper_model_load(params.model, model, vocab)) {
1929
+ fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
1930
+ return 1;
1931
+ }
1932
+
1933
+ t_load_us = ggml_time_us() - t_start_us;
1934
+ }
1935
+
1936
+ // WAV input
1937
+ std::vector<float> pcmf32;
1938
+ {
1939
+ drwav wav;
1940
+ if (!drwav_init_file(&wav, params.fname_inp.c_str(), NULL)) {
1941
+ fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], params.fname_inp.c_str());
1942
+ return 2;
1943
+ }
1944
+
1945
+ if (wav.channels != 1) {
1946
+ fprintf(stderr, "%s: WAV file '%s' must be mono\n", argv[0], params.fname_inp.c_str());
1947
+ return 3;
1948
+ }
1949
+
1950
+ if (wav.sampleRate != SAMPLE_RATE) {
1951
+ fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], params.fname_inp.c_str());
1952
+ return 4;
1953
+ }
1954
+
1955
+ if (wav.bitsPerSample != 16) {
1956
+ fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], params.fname_inp.c_str());
1957
+ return 5;
1958
+ }
1959
+
1960
+ std::vector<int16_t> pcm16;
1961
+ pcm16.resize(wav.totalPCMFrameCount);
1962
+ drwav_read_pcm_frames_s16(&wav, wav.totalPCMFrameCount, pcm16.data());
1963
+ drwav_uninit(&wav);
1964
+
1965
+ // convert to float
1966
+ pcmf32.resize(pcm16.size());
1967
+ for (size_t i = 0; i < pcm16.size(); i++) {
1968
+ pcmf32[i] = float(pcm16[i])/32768.0f;
1969
+ }
1970
+ }
1971
+
1972
+ // compute log mel spectrogram
1973
+ whisper_mel mel_inp;
1974
+ {
1975
+ const int64_t t_start_us = ggml_time_us();
1976
+
1977
+ log_mel_spectrogram(pcmf32, SAMPLE_RATE, N_FFT, HOP_LENGTH, N_MEL, params.n_threads, model.filters, mel_inp);
1978
+
1979
+ t_mel_us = ggml_time_us() - t_start_us;
1980
+ }
1981
+
1982
+ std::vector<whisper_vocab::id> prompt_past = { };
1983
+
1984
+ // main loop
1985
+ int seek = 0;
1986
+ while (true) {
1987
+ if (seek >= mel_inp.n_len) {
1988
+ break;
1989
+ }
1990
+
1991
+ // encode audio features starting at offset seek
1992
+ std::vector<float> features;
1993
+ {
1994
+ const int64_t t_start_us = ggml_time_us();
1995
+
1996
+ if (!whisper_encode(model, params.n_threads, seek, mel_inp, features)) {
1997
+ fprintf(stderr, "%s: failed to eval\n", __func__);
1998
+ return 1;
1999
+ }
2000
+
2001
+ t_encode_us = ggml_time_us() - t_start_us;
2002
+ }
2003
+
2004
+ std::vector<float> probs;
2005
+ std::vector<float> logits;
2006
+
2007
+ // SOT
2008
+ // ref: https://github.com/openai/whisper/blob/15ab54826343c27cfaf44ce31e9c8fb63d0aa775/whisper/decoding.py#L506-L526
2009
+ // TODO: use different initial tokens for different tasks
2010
+ std::vector<whisper_vocab::id> prompt = { vocab.token_sot };
2011
+
2012
+ int n_past = 0;
2013
+
2014
+ if (prompt_past.size() > 0) {
2015
+ int n_take = std::min(model.hparams.n_text_ctx/2, int(prompt_past.size()));
2016
+
2017
+ prompt = { vocab.token_prev };
2018
+ prompt.insert(prompt.end(), prompt_past.end() - n_take, prompt_past.end());
2019
+ prompt.push_back(vocab.token_sot);
2020
+
2021
+ prompt_past.clear();
2022
+ prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - 1);
2023
+ }
2024
+
2025
+ bool done = false;
2026
+ int seek_delta = 100*CHUNK_SIZE;
2027
+ whisper_vocab::id last_id = 0;
2028
+
2029
+ //for (int i = 0; i < prompt.size(); i++) {
2030
+ // printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
2031
+ //}
2032
+
2033
+ printf("\n");
2034
+ for (int i = 0; i < model.hparams.n_text_ctx/2; ++i) {
2035
+ // decode
2036
+ if (prompt.size() > 0) {
2037
+ const int64_t t_start_us = ggml_time_us();
2038
+
2039
+ if (!whisper_decode(model, params.n_threads, n_past, prompt, logits, probs)) {
2040
+ fprintf(stderr, "%s: failed to eval\n", __func__);
2041
+ return 1;
2042
+ }
2043
+
2044
+ t_decode_us += ggml_time_us() - t_start_us;
2045
+ }
2046
+
2047
+ n_past += prompt.size();
2048
+ prompt.clear();
2049
+
2050
+ {
2051
+ // sample next token
2052
+ const float temp = 1.0; // TODO
2053
+
2054
+ const int n_vocab = model.hparams.n_vocab;
2055
+
2056
+ whisper_vocab::id id = 0;
2057
+
2058
+ {
2059
+ const int64_t t_start_sample_us = ggml_time_us();
2060
+
2061
+ id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab), temp, i > params.max_tokens_per_iter ? vocab.token_beg : 0);
2062
+
2063
+ t_sample_us += ggml_time_us() - t_start_sample_us;
2064
+ }
2065
+
2066
+ // end of text token
2067
+ if (id == vocab.token_eot) {
2068
+ break;
2069
+ }
2070
+
2071
+ // 2 consecutive time tokens
2072
+ if (id > vocab.token_beg && last_id > vocab.token_beg) {
2073
+ seek_delta = 2*(id - vocab.token_beg);
2074
+ done = true;
2075
+ }
2076
+ last_id = id;
2077
+
2078
+ // add it to the context
2079
+ prompt.push_back(id);
2080
+ prompt_past.push_back(id);
2081
+ }
2082
+
2083
+ // display text
2084
+ for (auto id : prompt) {
2085
+ if (params.print_special_tokens == false && id >= vocab.token_eot) {
2086
+ continue;
2087
+ }
2088
+ printf("%s", vocab.id_to_token[id].c_str());
2089
+ }
2090
+ fflush(stdout);
2091
+
2092
+ if (done) {
2093
+ break;
2094
+ }
2095
+ }
2096
+
2097
+ seek += seek_delta;
2098
+ }
2099
+
2100
+ // report timing
2101
+ {
2102
+ const int64_t t_main_end_us = ggml_time_us();
2103
+
2104
+ printf("\n\n");
2105
+ printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
2106
+ printf("%s: mel time = %8.2f ms\n", __func__, t_mel_us/1000.0f);
2107
+ printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
2108
+ printf("%s: encode time = %8.2f ms / %.2f ms per layer\n", __func__, t_encode_us/1000.0f, t_encode_us/1000.0f/model.hparams.n_audio_layer);
2109
+ printf("%s: decode time = %8.2f ms\n", __func__, t_decode_us/1000.0f);
2110
+ printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
2111
+ }
2112
+
2113
+ ggml_free(model.ctx);
2114
+
2115
+ return 0;
2116
+ }
models/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.bin
samples/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *
samples/jfk.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59dfb9a4acb36fe2a2affc14bacbee2920ff435cb13cc314a08c13f66ba7860e
3
+ size 352078