ggerganov commited on
Commit
cafe46d
·
1 Parent(s): 0225795

parallel : adding tool for parallel transformer inference

Browse files
examples/parallel/CMakeLists.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ set(TARGET parallel)
2
+ add_executable(${TARGET} parallel.cpp)
3
+ target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
examples/parallel/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # parallel
2
+
3
+ TODO
examples/parallel/parallel.cpp ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "whisper.h"
2
+
3
+ // third-party utilities
4
+ // use your favorite implementations
5
+ #define DR_WAV_IMPLEMENTATION
6
+ #include "dr_wav.h"
7
+
8
+ #include <cmath>
9
+ #include <fstream>
10
+ #include <cstdio>
11
+ #include <string>
12
+ #include <thread>
13
+ #include <vector>
14
+
15
+ // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
16
+ // Lowest is red, middle is yellow, highest is green.
17
+ const std::vector<std::string> k_colors = {
18
+ "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
19
+ "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
20
+ };
21
+
22
+ // 500 -> 00:05.000
23
+ // 6000 -> 01:00.000
24
+ std::string to_timestamp(int64_t t, bool comma = false) {
25
+ int64_t msec = t * 10;
26
+ int64_t hr = msec / (1000 * 60 * 60);
27
+ msec = msec - hr * (1000 * 60 * 60);
28
+ int64_t min = msec / (1000 * 60);
29
+ msec = msec - min * (1000 * 60);
30
+ int64_t sec = msec / 1000;
31
+ msec = msec - sec * 1000;
32
+
33
+ char buf[32];
34
+ snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
35
+
36
+ return std::string(buf);
37
+ }
38
+
39
+ // command-line parameters
40
+ struct whisper_params {
41
+ int32_t seed = -1; // RNG seed, not used currently
42
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
43
+ int32_t offset_t_ms = 0;
44
+ int32_t offset_n = 0;
45
+
46
+ bool verbose = false;
47
+ bool translate = false;
48
+ bool output_txt = false;
49
+ bool output_vtt = false;
50
+ bool output_srt = false;
51
+ bool print_special_tokens = false;
52
+ bool print_colors = false;
53
+ bool no_timestamps = false;
54
+
55
+ std::string language = "en";
56
+ std::string model = "models/ggml-base.en.bin";
57
+
58
+ std::vector<std::string> fname_inp = {};
59
+ };
60
+
61
+ void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
62
+
63
+ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
64
+ for (int i = 1; i < argc; i++) {
65
+ std::string arg = argv[i];
66
+
67
+ if (arg[0] != '-') {
68
+ params.fname_inp.push_back(arg);
69
+ continue;
70
+ }
71
+
72
+ if (arg == "-s" || arg == "--seed") {
73
+ params.seed = std::stoi(argv[++i]);
74
+ } else if (arg == "-t" || arg == "--threads") {
75
+ params.n_threads = std::stoi(argv[++i]);
76
+ } else if (arg == "-ot" || arg == "--offset-t") {
77
+ params.offset_t_ms = std::stoi(argv[++i]);
78
+ } else if (arg == "-on" || arg == "--offset-n") {
79
+ params.offset_n = std::stoi(argv[++i]);
80
+ } else if (arg == "-v" || arg == "--verbose") {
81
+ params.verbose = true;
82
+ } else if (arg == "--translate") {
83
+ params.translate = true;
84
+ } else if (arg == "-l" || arg == "--language") {
85
+ params.language = argv[++i];
86
+ if (whisper_lang_id(params.language.c_str()) == -1) {
87
+ fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
88
+ whisper_print_usage(argc, argv, params);
89
+ exit(0);
90
+ }
91
+ } else if (arg == "-otxt" || arg == "--output-txt") {
92
+ params.output_txt = true;
93
+ } else if (arg == "-ovtt" || arg == "--output-vtt") {
94
+ params.output_vtt = true;
95
+ } else if (arg == "-osrt" || arg == "--output-srt") {
96
+ params.output_srt = true;
97
+ } else if (arg == "-ps" || arg == "--print_special") {
98
+ params.print_special_tokens = true;
99
+ } else if (arg == "-pc" || arg == "--print_colors") {
100
+ params.print_colors = true;
101
+ } else if (arg == "-nt" || arg == "--no_timestamps") {
102
+ params.no_timestamps = true;
103
+ } else if (arg == "-m" || arg == "--model") {
104
+ params.model = argv[++i];
105
+ } else if (arg == "-f" || arg == "--file") {
106
+ params.fname_inp.push_back(argv[++i]);
107
+ } else if (arg == "-h" || arg == "--help") {
108
+ whisper_print_usage(argc, argv, params);
109
+ exit(0);
110
+ } else {
111
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
112
+ whisper_print_usage(argc, argv, params);
113
+ exit(0);
114
+ }
115
+ }
116
+
117
+ return true;
118
+ }
119
+
120
+ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
121
+ fprintf(stderr, "\n");
122
+ fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
123
+ fprintf(stderr, "\n");
124
+ fprintf(stderr, "options:\n");
125
+ fprintf(stderr, " -h, --help show this help message and exit\n");
126
+ fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
127
+ fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
128
+ fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
129
+ fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
130
+ fprintf(stderr, " -v, --verbose verbose output\n");
131
+ fprintf(stderr, " --translate translate from source language to english\n");
132
+ fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
133
+ fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
134
+ fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
135
+ fprintf(stderr, " -ps, --print_special print special tokens\n");
136
+ fprintf(stderr, " -pc, --print_colors print colors\n");
137
+ fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
138
+ fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
139
+ fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
140
+ fprintf(stderr, " -f FNAME, --file FNAME input WAV file path\n");
141
+ fprintf(stderr, "\n");
142
+ }
143
+
144
+ void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
145
+ const whisper_params & params = *(whisper_params *) user_data;
146
+
147
+ const int n_segments = whisper_full_n_segments(ctx);
148
+
149
+ // print the last segment
150
+ const int i = n_segments - 1;
151
+ if (i == 0) {
152
+ printf("\n");
153
+ }
154
+
155
+ if (params.no_timestamps) {
156
+ if (params.print_colors) {
157
+ for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
158
+ if (params.print_special_tokens == false) {
159
+ const whisper_token id = whisper_full_get_token_id(ctx, i, j);
160
+ if (id >= whisper_token_eot(ctx)) {
161
+ continue;
162
+ }
163
+ }
164
+
165
+ const char * text = whisper_full_get_token_text(ctx, i, j);
166
+ const float p = whisper_full_get_token_p (ctx, i, j);
167
+
168
+ const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
169
+
170
+ printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
171
+ }
172
+ } else {
173
+ const char * text = whisper_full_get_segment_text(ctx, i);
174
+ printf("%s", text);
175
+ }
176
+ fflush(stdout);
177
+ } else {
178
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
179
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
180
+
181
+ if (params.print_colors) {
182
+ printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
183
+ for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
184
+ if (params.print_special_tokens == false) {
185
+ const whisper_token id = whisper_full_get_token_id(ctx, i, j);
186
+ if (id >= whisper_token_eot(ctx)) {
187
+ continue;
188
+ }
189
+ }
190
+
191
+ const char * text = whisper_full_get_token_text(ctx, i, j);
192
+ const float p = whisper_full_get_token_p (ctx, i, j);
193
+
194
+ const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
195
+
196
+ printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
197
+ }
198
+ printf("\n");
199
+ } else {
200
+ const char * text = whisper_full_get_segment_text(ctx, i);
201
+
202
+ printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
203
+ }
204
+ }
205
+ }
206
+
207
+ bool output_txt(struct whisper_context * ctx, const char * fname) {
208
+ std::ofstream fout(fname);
209
+ if (!fout.is_open()) {
210
+ fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
211
+ return false;
212
+ }
213
+
214
+ fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
215
+
216
+ const int n_segments = whisper_full_n_segments(ctx);
217
+ for (int i = 0; i < n_segments; ++i) {
218
+ const char * text = whisper_full_get_segment_text(ctx, i);
219
+ fout << text;
220
+ }
221
+
222
+ return true;
223
+ }
224
+
225
+ bool output_vtt(struct whisper_context * ctx, const char * fname) {
226
+ std::ofstream fout(fname);
227
+ if (!fout.is_open()) {
228
+ fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
229
+ return 9;
230
+ }
231
+
232
+ fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
233
+
234
+ fout << "WEBVTT\n\n";
235
+
236
+ const int n_segments = whisper_full_n_segments(ctx);
237
+ for (int i = 0; i < n_segments; ++i) {
238
+ const char * text = whisper_full_get_segment_text(ctx, i);
239
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
240
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
241
+
242
+ fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
243
+ fout << text << "\n\n";
244
+ }
245
+
246
+ return true;
247
+ }
248
+
249
+ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
250
+ std::ofstream fout(fname);
251
+ if (!fout.is_open()) {
252
+ fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
253
+ return false;
254
+ }
255
+
256
+ fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
257
+
258
+ const int n_segments = whisper_full_n_segments(ctx);
259
+ for (int i = 0; i < n_segments; ++i) {
260
+ const char * text = whisper_full_get_segment_text(ctx, i);
261
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
262
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
263
+
264
+ fout << i + 1 + params.offset_n << "\n";
265
+ fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
266
+ fout << text << "\n\n";
267
+ }
268
+
269
+ return true;
270
+ }
271
+
272
+ int main(int argc, char ** argv) {
273
+ whisper_params params;
274
+
275
+ if (whisper_params_parse(argc, argv, params) == false) {
276
+ return 1;
277
+ }
278
+
279
+ if (params.seed < 0) {
280
+ params.seed = time(NULL);
281
+ }
282
+
283
+ if (params.fname_inp.empty()) {
284
+ fprintf(stderr, "error: no input files specified\n");
285
+ whisper_print_usage(argc, argv, params);
286
+ return 2;
287
+ }
288
+
289
+ // whisper init
290
+
291
+ struct whisper_context * ctx = whisper_init(params.model.c_str());
292
+
293
+ if (ctx == nullptr) {
294
+ fprintf(stderr, "error: failed to initialize whisper context\n");
295
+ return 3;
296
+ }
297
+
298
+ for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
299
+ const auto fname_inp = params.fname_inp[f];
300
+
301
+ // WAV input
302
+ std::vector<float> pcmf32;
303
+ {
304
+ drwav wav;
305
+ if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) {
306
+ fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str());
307
+ whisper_print_usage(argc, argv, {});
308
+ return 4;
309
+ }
310
+
311
+ if (wav.channels != 1 && wav.channels != 2) {
312
+ fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
313
+ return 5;
314
+ }
315
+
316
+ if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
317
+ fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
318
+ return 6;
319
+ }
320
+
321
+ if (wav.bitsPerSample != 16) {
322
+ fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
323
+ return 7;
324
+ }
325
+
326
+ int n = wav.totalPCMFrameCount;
327
+
328
+ std::vector<int16_t> pcm16;
329
+ pcm16.resize(n*wav.channels);
330
+ drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
331
+ drwav_uninit(&wav);
332
+
333
+ // convert to mono, float
334
+ pcmf32.resize(n);
335
+ if (wav.channels == 1) {
336
+ for (int i = 0; i < n; i++) {
337
+ pcmf32[i] = float(pcm16[i])/32768.0f;
338
+ }
339
+ } else {
340
+ for (int i = 0; i < n; i++) {
341
+ pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
342
+ }
343
+ }
344
+ }
345
+
346
+ // print system information
347
+ {
348
+ fprintf(stderr, "\n");
349
+ fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
350
+ }
351
+
352
+ // print some info about the processing
353
+ {
354
+ fprintf(stderr, "\n");
355
+ if (!whisper_is_multilingual(ctx)) {
356
+ if (params.language != "en" || params.translate) {
357
+ params.language = "en";
358
+ params.translate = false;
359
+ fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
360
+ }
361
+ }
362
+ fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
363
+ __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads,
364
+ params.language.c_str(),
365
+ params.translate ? "translate" : "transcribe",
366
+ params.no_timestamps ? 0 : 1);
367
+
368
+ fprintf(stderr, "\n");
369
+ }
370
+
371
+
372
+ // run the inference
373
+ {
374
+ whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
375
+
376
+ wparams.print_realtime = false;
377
+ wparams.print_progress = false;
378
+ wparams.print_timestamps = !params.no_timestamps;
379
+ wparams.print_special_tokens = params.print_special_tokens;
380
+ wparams.translate = params.translate;
381
+ wparams.language = params.language.c_str();
382
+ wparams.n_threads = params.n_threads;
383
+ wparams.offset_ms = params.offset_t_ms;
384
+
385
+ // this callback is called on each new segment
386
+ if (!wparams.print_realtime) {
387
+ wparams.new_segment_callback = whisper_print_segment_callback;
388
+ wparams.new_segment_callback_user_data = &params;
389
+ }
390
+
391
+ if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
392
+ fprintf(stderr, "%s: failed to process audio\n", argv[0]);
393
+ return 8;
394
+ }
395
+
396
+ printf("\n");
397
+
398
+ // output to text file
399
+ if (params.output_txt) {
400
+ const auto fname_txt = fname_inp + ".txt";
401
+ output_txt(ctx, fname_txt.c_str());
402
+ }
403
+
404
+ // output to VTT file
405
+ if (params.output_vtt) {
406
+ const auto fname_vtt = fname_inp + ".vtt";
407
+ output_vtt(ctx, fname_vtt.c_str());
408
+ }
409
+
410
+ // output to SRT file
411
+ if (params.output_srt) {
412
+ const auto fname_srt = fname_inp + ".srt";
413
+ output_srt(ctx, fname_srt.c_str(), params);
414
+ }
415
+ }
416
+ }
417
+
418
+ whisper_print_timings(ctx);
419
+ whisper_free(ctx);
420
+
421
+ return 0;
422
+ }
whisper.cpp CHANGED
@@ -413,7 +413,6 @@ struct whisper_context {
413
  std::vector<float> probs;
414
  std::vector<float> logits;
415
 
416
- std::vector<whisper_token_data> tokens_cur;
417
  std::vector<whisper_segment> result_all;
418
 
419
  std::vector<whisper_token> prompt_past;
@@ -430,7 +429,7 @@ struct whisper_context {
430
  //
431
  // see the convert-pt-to-ggml.py script for details
432
  //
433
- bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
434
  fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
435
 
436
  auto & model = wctx.model;
@@ -700,11 +699,11 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
700
  ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
701
  }
702
 
703
- ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
704
- ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
705
 
706
- ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
707
- ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
708
 
709
  ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
710
 
@@ -934,7 +933,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
934
  // key/value memory for the self-attention layer
935
  {
936
  const int n_mem = n_text_layer*n_text_ctx;
937
- const int n_elements = n_text_state*n_mem;
938
 
939
  model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
940
  model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
@@ -945,7 +944,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
945
  const int n_audio_ctx = hparams.n_audio_ctx;
946
 
947
  const int n_mem = n_text_layer*n_audio_ctx;
948
- const int n_elements = n_text_state*n_mem;
949
 
950
  model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
951
  model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
@@ -955,7 +954,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
955
  ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
956
  ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
957
 
958
- fprintf(stderr, "%s: memory size = %8.2f MB \n", __func__, memory_size/1024.0/1024.0);
959
  }
960
 
961
  // load weights
@@ -1046,7 +1045,8 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
1046
  bool whisper_encode(
1047
  whisper_context & wctx,
1048
  const int n_threads,
1049
- const int mel_offset) {
 
1050
  const auto & model = wctx.model;
1051
  const auto & mel_inp = wctx.mel;
1052
  const auto & hparams = model.hparams;
@@ -1400,8 +1400,11 @@ bool whisper_encode(
1400
  Vcross),
1401
  Vcross);
1402
 
1403
- struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
1404
- struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
 
 
 
1405
 
1406
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1407
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
@@ -1434,7 +1437,8 @@ bool whisper_decode(
1434
  const int n_threads,
1435
  const whisper_token * tokens,
1436
  const int n_tokens,
1437
- const int n_past) {
 
1438
  const auto & model = wctx.model;
1439
  const auto & hparams = model.hparams;
1440
 
@@ -1529,10 +1533,13 @@ bool whisper_decode(
1529
  Vcur),
1530
  Vcur);
1531
 
 
 
 
1532
  // store key and value to memory
1533
  {
1534
- struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
1535
- struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
1536
 
1537
  ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
1538
  ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
@@ -1550,7 +1557,7 @@ bool whisper_decode(
1550
  struct ggml_tensor * K =
1551
  ggml_permute(ctxL,
1552
  ggml_reshape_3d(ctxL,
1553
- ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
1554
  n_state/n_head, n_head, n_past + N),
1555
  0, 2, 1, 3);
1556
 
@@ -1570,7 +1577,7 @@ bool whisper_decode(
1570
  struct ggml_tensor * V_trans =
1571
  ggml_permute(ctxL,
1572
  ggml_reshape_3d(ctxL,
1573
- ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
1574
  n_state/n_head, n_head, n_past + N),
1575
  1, 2, 0, 3);
1576
 
@@ -1622,15 +1629,18 @@ bool whisper_decode(
1622
 
1623
  Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1624
 
 
 
 
1625
  // Kcross is already scaled
1626
  struct ggml_tensor * Kcross =
1627
  ggml_reshape_3d(ctxL,
1628
- ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
1629
  n_state/n_head, n_head, M);
1630
 
1631
  struct ggml_tensor * Vcross =
1632
  ggml_reshape_3d(ctxL,
1633
- ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
1634
  n_state/n_head, n_head, M);
1635
 
1636
  // ------
@@ -2116,7 +2126,26 @@ struct whisper_context * whisper_init(const char * path_model) {
2116
 
2117
  ctx->t_start_us = t_start_us;
2118
 
2119
- if (!whisper_model_load(path_model, *ctx)) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2120
  fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
2121
  return NULL;
2122
  }
@@ -2167,7 +2196,7 @@ int whisper_set_mel(
2167
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2168
  const int64_t t_start_us = ggml_time_us();
2169
 
2170
- if (!whisper_encode(*ctx, n_threads, offset)) {
2171
  fprintf(stderr, "%s: failed to eval\n", __func__);
2172
  return -1;
2173
  }
@@ -2180,7 +2209,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2180
  int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2181
  const int64_t t_start_us = ggml_time_us();
2182
 
2183
- if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) {
2184
  fprintf(stderr, "%s: failed to eval\n", __func__);
2185
  return 1;
2186
  }
@@ -2302,6 +2331,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2302
 
2303
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
2304
  /*.offset_ms =*/ 0,
 
2305
 
2306
  /*.translate =*/ false,
2307
  /*.no_context =*/ false,
@@ -2333,6 +2363,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2333
 
2334
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
2335
  /*.offset_ms =*/ 0,
 
2336
 
2337
  /*.translate =*/ false,
2338
  /*.no_context =*/ false,
@@ -2369,7 +2400,6 @@ int whisper_full(
2369
  int n_samples) {
2370
  // clear old results
2371
  auto & result_all = ctx->result_all;
2372
- auto & tokens_cur = ctx->tokens_cur;
2373
 
2374
  result_all.clear();
2375
 
@@ -2379,10 +2409,12 @@ int whisper_full(
2379
  return -1;
2380
  }
2381
 
 
 
2382
  // if length of spectrogram is less than 1s (100 samples), then return
2383
  // basically don't process anything that is less than 1s
2384
  // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
2385
- if (whisper_n_len(ctx) < 100) {
2386
  return 0;
2387
  }
2388
 
@@ -2406,8 +2438,14 @@ int whisper_full(
2406
  int progress_prev = 0;
2407
  int progress_step = 5;
2408
 
 
 
 
 
 
 
2409
  // main loop
2410
- int seek = params.offset_ms/10;
2411
  while (true) {
2412
  int progress_cur = (100*seek)/whisper_n_len(ctx);
2413
  while (progress_cur >= progress_prev + progress_step) {
@@ -2427,9 +2465,8 @@ int whisper_full(
2427
  return 7;
2428
  }
2429
 
2430
- std::vector<whisper_token> prompt;
2431
-
2432
  int n_past = 0;
 
2433
 
2434
  // if we have already generated some text, use it as a prompt to condition the next generation
2435
  if (prompt_past.size() > 0) {
 
413
  std::vector<float> probs;
414
  std::vector<float> logits;
415
 
 
416
  std::vector<whisper_segment> result_all;
417
 
418
  std::vector<whisper_token> prompt_past;
 
429
  //
430
  // see the convert-pt-to-ggml.py script for details
431
  //
432
+ bool whisper_model_load(const std::string & fname, const int n_processors, whisper_context & wctx) {
433
  fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
434
 
435
  auto & model = wctx.model;
 
699
  ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
700
  }
701
 
702
+ ctx_size += n_processors*n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
703
+ ctx_size += n_processors*n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
704
 
705
+ ctx_size += n_processors*n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
706
+ ctx_size += n_processors*n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
707
 
708
  ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
709
 
 
933
  // key/value memory for the self-attention layer
934
  {
935
  const int n_mem = n_text_layer*n_text_ctx;
936
+ const int n_elements = n_text_state*n_mem*n_processors;
937
 
938
  model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
939
  model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
 
944
  const int n_audio_ctx = hparams.n_audio_ctx;
945
 
946
  const int n_mem = n_text_layer*n_audio_ctx;
947
+ const int n_elements = n_text_state*n_mem*n_processors;
948
 
949
  model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
950
  model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
 
954
  ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
955
  ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
956
 
957
+ fprintf(stderr, "%s: memory size = %8.2f MB (%d processors)\n", __func__, memory_size/1024.0/1024.0, n_processors);
958
  }
959
 
960
  // load weights
 
1045
  bool whisper_encode(
1046
  whisper_context & wctx,
1047
  const int n_threads,
1048
+ const int mel_offset,
1049
+ const int processor_id) {
1050
  const auto & model = wctx.model;
1051
  const auto & mel_inp = wctx.mel;
1052
  const auto & hparams = model.hparams;
 
1400
  Vcross),
1401
  Vcross);
1402
 
1403
+ const size_t offset_k = processor_id*(ggml_element_size(model.memory_cross_k)*n_state)*(model.hparams.n_text_layer*n_ctx);
1404
+ const size_t offset_v = processor_id*(ggml_element_size(model.memory_cross_v)*n_state)*(model.hparams.n_text_layer*n_ctx);
1405
+
1406
+ struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, offset_k + (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
1407
+ struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, offset_v + (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
1408
 
1409
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1410
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
 
1437
  const int n_threads,
1438
  const whisper_token * tokens,
1439
  const int n_tokens,
1440
+ const int n_past,
1441
+ const int processor_id) {
1442
  const auto & model = wctx.model;
1443
  const auto & hparams = model.hparams;
1444
 
 
1533
  Vcur),
1534
  Vcur);
1535
 
1536
+ const size_t offset_k = processor_id*(ggml_element_size(model.memory_k)*n_state)*(n_layer*n_ctx);
1537
+ const size_t offset_v = processor_id*(ggml_element_size(model.memory_v)*n_state)*(n_layer*n_ctx);
1538
+
1539
  // store key and value to memory
1540
  {
1541
+ struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, offset_k + (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
1542
+ struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, offset_v + (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
1543
 
1544
  ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
1545
  ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
 
1557
  struct ggml_tensor * K =
1558
  ggml_permute(ctxL,
1559
  ggml_reshape_3d(ctxL,
1560
+ ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, offset_k + il*n_ctx*ggml_element_size(model.memory_k)*n_state),
1561
  n_state/n_head, n_head, n_past + N),
1562
  0, 2, 1, 3);
1563
 
 
1577
  struct ggml_tensor * V_trans =
1578
  ggml_permute(ctxL,
1579
  ggml_reshape_3d(ctxL,
1580
+ ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, offset_v + il*n_ctx*ggml_element_size(model.memory_v)*n_state),
1581
  n_state/n_head, n_head, n_past + N),
1582
  1, 2, 0, 3);
1583
 
 
1629
 
1630
  Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1631
 
1632
+ const size_t offset_k = processor_id*(ggml_element_size(model.memory_cross_k)*n_state)*(n_layer*M);
1633
+ const size_t offset_v = processor_id*(ggml_element_size(model.memory_cross_v)*n_state)*(n_layer*M);
1634
+
1635
  // Kcross is already scaled
1636
  struct ggml_tensor * Kcross =
1637
  ggml_reshape_3d(ctxL,
1638
+ ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, offset_k + il*M*ggml_element_size(model.memory_cross_k)*n_state),
1639
  n_state/n_head, n_head, M);
1640
 
1641
  struct ggml_tensor * Vcross =
1642
  ggml_reshape_3d(ctxL,
1643
+ ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, offset_v + il*M*ggml_element_size(model.memory_cross_v)*n_state),
1644
  n_state/n_head, n_head, M);
1645
 
1646
  // ------
 
2126
 
2127
  ctx->t_start_us = t_start_us;
2128
 
2129
+ if (!whisper_model_load(path_model, 1, *ctx)) {
2130
+ fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
2131
+ return NULL;
2132
+ }
2133
+
2134
+ ctx->t_load_us = ggml_time_us() - t_start_us;
2135
+
2136
+ return ctx;
2137
+ }
2138
+
2139
+ struct whisper_context * whisper_init_parallel(const char * path_model, int n_processors) {
2140
+ ggml_time_init();
2141
+
2142
+ whisper_context * ctx = new whisper_context;
2143
+
2144
+ const int64_t t_start_us = ggml_time_us();
2145
+
2146
+ ctx->t_start_us = t_start_us;
2147
+
2148
+ if (!whisper_model_load(path_model, n_processors, *ctx)) {
2149
  fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
2150
  return NULL;
2151
  }
 
2196
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2197
  const int64_t t_start_us = ggml_time_us();
2198
 
2199
+ if (!whisper_encode(*ctx, n_threads, offset, 0)) {
2200
  fprintf(stderr, "%s: failed to eval\n", __func__);
2201
  return -1;
2202
  }
 
2209
  int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2210
  const int64_t t_start_us = ggml_time_us();
2211
 
2212
+ if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past, 0)) {
2213
  fprintf(stderr, "%s: failed to eval\n", __func__);
2214
  return 1;
2215
  }
 
2331
 
2332
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
2333
  /*.offset_ms =*/ 0,
2334
+ /*.n_processors =*/ 1,
2335
 
2336
  /*.translate =*/ false,
2337
  /*.no_context =*/ false,
 
2363
 
2364
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
2365
  /*.offset_ms =*/ 0,
2366
+ /*.n_processors =*/ 1,
2367
 
2368
  /*.translate =*/ false,
2369
  /*.no_context =*/ false,
 
2400
  int n_samples) {
2401
  // clear old results
2402
  auto & result_all = ctx->result_all;
 
2403
 
2404
  result_all.clear();
2405
 
 
2409
  return -1;
2410
  }
2411
 
2412
+ const int seek_start = params.offset_ms/10;
2413
+
2414
  // if length of spectrogram is less than 1s (100 samples), then return
2415
  // basically don't process anything that is less than 1s
2416
  // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
2417
+ if (whisper_n_len(ctx) < 100 + seek_start) {
2418
  return 0;
2419
  }
2420
 
 
2438
  int progress_prev = 0;
2439
  int progress_step = 5;
2440
 
2441
+ std::vector<whisper_token_data> tokens_cur;
2442
+ tokens_cur.reserve(whisper_n_text_ctx(ctx));
2443
+
2444
+ std::vector<whisper_token> prompt;
2445
+ prompt.reserve(whisper_n_text_ctx(ctx));
2446
+
2447
  // main loop
2448
+ int seek = seek_start;
2449
  while (true) {
2450
  int progress_cur = (100*seek)/whisper_n_len(ctx);
2451
  while (progress_cur >= progress_prev + progress_step) {
 
2465
  return 7;
2466
  }
2467
 
 
 
2468
  int n_past = 0;
2469
+ prompt.clear();
2470
 
2471
  // if we have already generated some text, use it as a prompt to condition the next generation
2472
  if (prompt_past.size() > 0) {
whisper.h CHANGED
@@ -72,6 +72,8 @@ extern "C" {
72
  // Returns NULL on failure.
73
  WHISPER_API struct whisper_context * whisper_init(const char * path_model);
74
 
 
 
75
  // Frees all memory allocated by the model.
76
  WHISPER_API void whisper_free(struct whisper_context * ctx);
77
 
@@ -170,6 +172,7 @@ extern "C" {
170
 
171
  int n_threads;
172
  int offset_ms;
 
173
 
174
  bool translate;
175
  bool no_context;
 
72
  // Returns NULL on failure.
73
  WHISPER_API struct whisper_context * whisper_init(const char * path_model);
74
 
75
+ WHISPER_API struct whisper_context * whisper_init_parallel(const char * path_model, int n_processors);
76
+
77
  // Frees all memory allocated by the model.
78
  WHISPER_API void whisper_free(struct whisper_context * ctx);
79
 
 
172
 
173
  int n_threads;
174
  int offset_ms;
175
+ int n_processors;
176
 
177
  bool translate;
178
  bool no_context;