evanqjones ggerganov commited on
Commit
46f0c56
·
unverified ·
1 Parent(s): aa7c2e9

whisper : add grammar-based sampling (#1229)

Browse files

* whisper : add grammar-based sampling

* build : fix after master merge

* command : fix exception when recognizing the command

* whisper : fine-tuning grammar functionality

* command : grammar-related improvements

- option to read grammar from file
- add sample grammars for colors and chess moves
- fine-tune the performance further

* grammars : add assistant + update comments

* command : enable beam-search, add "no_timestamps", add "context", add p

* whisper : remove comment

---------

Co-authored-by: Georgi Gerganov <[email protected]>

Makefile CHANGED
@@ -362,8 +362,8 @@ quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
362
  stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
363
  $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
364
 
365
- command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
366
- $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
367
 
368
  lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
369
  $(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
 
362
  stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
363
  $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
364
 
365
+ command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
366
+ $(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
367
 
368
  lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
369
  $(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
examples/CMakeLists.txt CHANGED
@@ -23,6 +23,7 @@ add_library(${TARGET} STATIC
23
  common.cpp
24
  common-ggml.h
25
  common-ggml.cpp
 
26
  )
27
 
28
  include(DefaultTargetOptions)
 
23
  common.cpp
24
  common-ggml.h
25
  common-ggml.cpp
26
+ grammar-parser.cpp
27
  )
28
 
29
  include(DefaultTargetOptions)
examples/command/command.cpp CHANGED
@@ -9,6 +9,7 @@
9
  #include "common-sdl.h"
10
  #include "common.h"
11
  #include "whisper.h"
 
12
 
13
  #include <sstream>
14
  #include <cassert>
@@ -21,6 +22,11 @@
21
  #include <vector>
22
  #include <map>
23
 
 
 
 
 
 
24
  // command-line parameters
25
  struct whisper_params {
26
  int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
@@ -30,8 +36,12 @@ struct whisper_params {
30
  int32_t max_tokens = 32;
31
  int32_t audio_ctx = 0;
32
 
33
- float vad_thold = 0.6f;
34
- float freq_thold = 100.0f;
 
 
 
 
35
 
36
  bool speed_up = false;
37
  bool translate = false;
@@ -45,6 +55,8 @@ struct whisper_params {
45
  std::string fname_out;
46
  std::string commands;
47
  std::string prompt;
 
 
48
  };
49
 
50
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -75,6 +87,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
75
  else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
76
  else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
77
  else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
 
 
 
78
  else {
79
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
80
  whisper_print_usage(argc, argv, params);
@@ -109,16 +124,30 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
109
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
110
  fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
111
  fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
 
 
 
112
  fprintf(stderr, "\n");
113
  }
114
 
115
- std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
 
 
 
 
 
 
 
 
116
  const auto t_start = std::chrono::high_resolution_clock::now();
117
 
118
- prob = 0.0f;
 
 
119
  t_ms = 0;
120
 
121
- whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
122
 
123
  wparams.print_progress = false;
124
  wparams.print_special = params.print_special;
@@ -126,19 +155,41 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
126
  wparams.print_timestamps = !params.no_timestamps;
127
  wparams.translate = params.translate;
128
  wparams.no_context = true;
 
129
  wparams.single_segment = true;
130
  wparams.max_tokens = params.max_tokens;
131
  wparams.language = params.language.c_str();
132
  wparams.n_threads = params.n_threads;
133
 
134
- wparams.audio_ctx = params.audio_ctx;
135
- wparams.speed_up = params.speed_up;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
138
  return "";
139
  }
140
 
141
- int prob_n = 0;
142
  std::string result;
143
 
144
  const int n_segments = whisper_full_n_segments(ctx);
@@ -147,19 +198,17 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
147
 
148
  result += text;
149
 
150
- const int n_tokens = whisper_full_n_tokens(ctx, i);
151
- for (int j = 0; j < n_tokens; ++j) {
152
  const auto token = whisper_full_get_token_data(ctx, i, j);
153
 
154
- prob += token.p;
155
- ++prob_n;
 
 
156
  }
157
  }
158
 
159
- if (prob_n > 0) {
160
- prob /= prob_n;
161
- }
162
-
163
  const auto t_end = std::chrono::high_resolution_clock::now();
164
  t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
165
 
@@ -250,7 +299,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
250
  fprintf(stderr, " ]\n");
251
  }
252
 
253
- std::string k_prompt = "select one from the available words: ";
254
  for (int i = 0; i < (int) allowed_commands.size(); ++i) {
255
  if (i > 0) {
256
  k_prompt += ", ";
@@ -418,7 +467,9 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
418
  bool is_running = true;
419
  bool ask_prompt = true;
420
 
421
- float prob = 0.0f;
 
 
422
 
423
  std::vector<float> pcmf32_cur;
424
 
@@ -456,7 +507,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
456
  // detect the commands
457
  audio.get(params.command_ms, pcmf32_cur);
458
 
459
- const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
460
 
461
  const auto words = get_words(txt);
462
 
@@ -492,18 +543,27 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
492
 
493
  // general-purpose mode
494
  // freely transcribe the voice into text
495
- int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
496
  bool is_running = true;
497
  bool have_prompt = false;
498
  bool ask_prompt = true;
499
 
500
- float prob0 = 0.0f;
501
- float prob = 0.0f;
 
 
 
 
 
 
502
 
503
  std::vector<float> pcmf32_cur;
504
  std::vector<float> pcmf32_prompt;
505
 
506
- const std::string k_prompt = "Ok Whisper, start listening for commands.";
 
 
 
507
 
508
  fprintf(stderr, "\n");
509
  fprintf(stderr, "%s: general-purpose mode\n", __func__);
@@ -536,9 +596,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
536
  // wait for activation phrase
537
  audio.get(params.prompt_ms, pcmf32_cur);
538
 
539
- const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
540
 
541
- fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
 
 
542
 
543
  const float sim = similarity(txt, k_prompt);
544
 
@@ -559,19 +621,30 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
559
  // we have heard the activation phrase, now detect the commands
560
  audio.get(params.command_ms, pcmf32_cur);
561
 
 
 
 
 
 
 
562
  // prepend the prompt audio
563
  pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
564
 
565
- const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
566
 
567
- prob = 100.0f*(prob - prob0);
 
568
 
569
  //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
570
 
571
  // find the prompt in the text
572
  float best_sim = 0.0f;
573
  size_t best_len = 0;
574
- for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
 
 
 
 
575
  const auto prompt = txt.substr(0, n);
576
 
577
  const float sim = similarity(prompt, k_prompt);
@@ -584,9 +657,16 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
584
  }
585
  }
586
 
587
- const std::string command = ::trim(txt.substr(best_len));
 
 
 
 
 
 
 
 
588
 
589
- fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
590
  fprintf(stdout, "\n");
591
  }
592
 
@@ -654,12 +734,36 @@ int main(int argc, char ** argv) {
654
 
655
  int ret_val = 0;
656
 
657
- if (!params.commands.empty()) {
658
- ret_val = process_command_list(ctx, audio, params);
659
- } else if (!params.prompt.empty()) {
660
- ret_val = always_prompt_transcription(ctx, audio, params);
661
- } else {
662
- ret_val = process_general_transcription(ctx, audio, params);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  }
664
 
665
  audio.pause();
 
9
  #include "common-sdl.h"
10
  #include "common.h"
11
  #include "whisper.h"
12
+ #include "grammar-parser.h"
13
 
14
  #include <sstream>
15
  #include <cassert>
 
22
  #include <vector>
23
  #include <map>
24
 
25
+ bool file_exists(const std::string & fname) {
26
+ std::ifstream f(fname.c_str());
27
+ return f.good();
28
+ }
29
+
30
  // command-line parameters
31
  struct whisper_params {
32
  int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
 
36
  int32_t max_tokens = 32;
37
  int32_t audio_ctx = 0;
38
 
39
+ float vad_thold = 0.6f;
40
+ float freq_thold = 100.0f;
41
+
42
+ float grammar_penalty = 100.0f;
43
+
44
+ grammar_parser::parse_state grammar_parsed;
45
 
46
  bool speed_up = false;
47
  bool translate = false;
 
55
  std::string fname_out;
56
  std::string commands;
57
  std::string prompt;
58
+ std::string context;
59
+ std::string grammar;
60
  };
61
 
62
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 
87
  else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
88
  else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
89
  else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
90
+ else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
91
+ else if ( arg == "--grammar") { params.grammar = argv[++i]; }
92
+ else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
93
  else {
94
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
95
  whisper_print_usage(argc, argv, params);
 
124
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
125
  fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
126
  fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
127
+ fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
128
+ fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
129
+ fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
130
  fprintf(stderr, "\n");
131
  }
132
 
133
+ std::string transcribe(
134
+ whisper_context * ctx,
135
+ const whisper_params & params,
136
+ const std::vector<float> & pcmf32,
137
+ const std::string & grammar_rule,
138
+ float & logprob_min,
139
+ float & logprob_sum,
140
+ int & n_tokens,
141
+ int64_t & t_ms) {
142
  const auto t_start = std::chrono::high_resolution_clock::now();
143
 
144
+ logprob_min = 0.0f;
145
+ logprob_sum = 0.0f;
146
+ n_tokens = 0;
147
  t_ms = 0;
148
 
149
+ //whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
150
+ whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
151
 
152
  wparams.print_progress = false;
153
  wparams.print_special = params.print_special;
 
155
  wparams.print_timestamps = !params.no_timestamps;
156
  wparams.translate = params.translate;
157
  wparams.no_context = true;
158
+ wparams.no_timestamps = params.no_timestamps;
159
  wparams.single_segment = true;
160
  wparams.max_tokens = params.max_tokens;
161
  wparams.language = params.language.c_str();
162
  wparams.n_threads = params.n_threads;
163
 
164
+ wparams.audio_ctx = params.audio_ctx;
165
+ wparams.speed_up = params.speed_up;
166
+
167
+ wparams.temperature = 0.4f;
168
+ wparams.temperature_inc = 1.0f;
169
+ wparams.greedy.best_of = 5;
170
+
171
+ wparams.beam_search.beam_size = 5;
172
+
173
+ wparams.initial_prompt = params.context.data();
174
+
175
+ const auto & grammar_parsed = params.grammar_parsed;
176
+ auto grammar_rules = grammar_parsed.c_rules();
177
+
178
+ if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
179
+ if (grammar_parsed.symbol_ids.find(grammar_rule) == grammar_parsed.symbol_ids.end()) {
180
+ fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, grammar_rule.c_str());
181
+ } else {
182
+ wparams.grammar_rules = grammar_rules.data();
183
+ wparams.n_grammar_rules = grammar_rules.size();
184
+ wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
185
+ wparams.grammar_penalty = params.grammar_penalty;
186
+ }
187
+ }
188
 
189
  if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
190
  return "";
191
  }
192
 
 
193
  std::string result;
194
 
195
  const int n_segments = whisper_full_n_segments(ctx);
 
198
 
199
  result += text;
200
 
201
+ const int n = whisper_full_n_tokens(ctx, i);
202
+ for (int j = 0; j < n; ++j) {
203
  const auto token = whisper_full_get_token_data(ctx, i, j);
204
 
205
+ if(token.plog > 0.0f) exit(0);
206
+ logprob_min = std::min(logprob_min, token.plog);
207
+ logprob_sum += token.plog;
208
+ ++n_tokens;
209
  }
210
  }
211
 
 
 
 
 
212
  const auto t_end = std::chrono::high_resolution_clock::now();
213
  t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
214
 
 
299
  fprintf(stderr, " ]\n");
300
  }
301
 
302
+ std::string k_prompt = "select one from the available words: ";
303
  for (int i = 0; i < (int) allowed_commands.size(); ++i) {
304
  if (i > 0) {
305
  k_prompt += ", ";
 
467
  bool is_running = true;
468
  bool ask_prompt = true;
469
 
470
+ float logprob_min = 0.0f;
471
+ float logprob_sum = 0.0f;
472
+ int n_tokens = 0;
473
 
474
  std::vector<float> pcmf32_cur;
475
 
 
507
  // detect the commands
508
  audio.get(params.command_ms, pcmf32_cur);
509
 
510
+ const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
511
 
512
  const auto words = get_words(txt);
513
 
 
543
 
544
  // general-purpose mode
545
  // freely transcribe the voice into text
546
+ int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
547
  bool is_running = true;
548
  bool have_prompt = false;
549
  bool ask_prompt = true;
550
 
551
+ float logprob_min0 = 0.0f;
552
+ float logprob_min = 0.0f;
553
+
554
+ float logprob_sum0 = 0.0f;
555
+ float logprob_sum = 0.0f;
556
+
557
+ int n_tokens0 = 0;
558
+ int n_tokens = 0;
559
 
560
  std::vector<float> pcmf32_cur;
561
  std::vector<float> pcmf32_prompt;
562
 
563
+ std::string k_prompt = "Ok Whisper, start listening for commands.";
564
+ if (!params.prompt.empty()) {
565
+ k_prompt = params.prompt;
566
+ }
567
 
568
  fprintf(stderr, "\n");
569
  fprintf(stderr, "%s: general-purpose mode\n", __func__);
 
596
  // wait for activation phrase
597
  audio.get(params.prompt_ms, pcmf32_cur);
598
 
599
+ const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
600
 
601
+ const float p = 100.0f * std::exp(logprob_min0);
602
+
603
+ fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
604
 
605
  const float sim = similarity(txt, k_prompt);
606
 
 
621
  // we have heard the activation phrase, now detect the commands
622
  audio.get(params.command_ms, pcmf32_cur);
623
 
624
+ //printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
625
+ //printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
626
+
627
+ // prepend 3 second of silence
628
+ pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
629
+
630
  // prepend the prompt audio
631
  pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
632
 
633
+ const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
634
 
635
+ //const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
636
+ const float p = 100.0f * std::exp(logprob_min);
637
 
638
  //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
639
 
640
  // find the prompt in the text
641
  float best_sim = 0.0f;
642
  size_t best_len = 0;
643
+ for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
644
+ if (n >= txt.size()) {
645
+ break;
646
+ }
647
+
648
  const auto prompt = txt.substr(0, n);
649
 
650
  const float sim = similarity(prompt, k_prompt);
 
657
  }
658
  }
659
 
660
+ fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
661
+ if (best_len == 0) {
662
+ fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
663
+ } else {
664
+ // cut the prompt from the decoded text
665
+ const std::string command = ::trim(txt.substr(best_len));
666
+
667
+ fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
668
+ }
669
 
 
670
  fprintf(stdout, "\n");
671
  }
672
 
 
734
 
735
  int ret_val = 0;
736
 
737
+ if (!params.grammar.empty()) {
738
+ auto & grammar = params.grammar_parsed;
739
+ if (file_exists(params.grammar.c_str())) {
740
+ // read grammar from file
741
+ std::ifstream ifs(params.grammar.c_str());
742
+ const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
743
+ grammar = grammar_parser::parse(txt.c_str());
744
+ } else {
745
+ // read grammar from string
746
+ grammar = grammar_parser::parse(params.grammar.c_str());
747
+ }
748
+
749
+ // will be empty (default) if there are parse errors
750
+ if (grammar.rules.empty()) {
751
+ ret_val = 1;
752
+ } else {
753
+ fprintf(stderr, "%s: grammar:\n", __func__);
754
+ grammar_parser::print_grammar(stderr, grammar);
755
+ fprintf(stderr, "\n");
756
+ }
757
+ }
758
+
759
+ if (ret_val == 0) {
760
+ if (!params.commands.empty()) {
761
+ ret_val = process_command_list(ctx, audio, params);
762
+ } else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
763
+ ret_val = always_prompt_transcription(ctx, audio, params);
764
+ } else {
765
+ ret_val = process_general_transcription(ctx, audio, params);
766
+ }
767
  }
768
 
769
  audio.pause();
examples/grammar-parser.cpp ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "grammar-parser.h"
2
+ #include <cstdint>
3
+ #include <cwchar>
4
+ #include <string>
5
+ #include <utility>
6
+ #include <stdexcept>
7
+ #include <exception>
8
+
9
+ namespace grammar_parser {
10
+ // NOTE: assumes valid utf8 (but checks for overrun)
11
+ // copied from whisper.cpp
12
+ std::pair<uint32_t, const char *> decode_utf8(const char * src) {
13
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
14
+ uint8_t first_byte = static_cast<uint8_t>(*src);
15
+ uint8_t highbits = first_byte >> 4;
16
+ int len = lookup[highbits];
17
+ uint8_t mask = (1 << (8 - len)) - 1;
18
+ uint32_t value = first_byte & mask;
19
+ const char * end = src + len; // may overrun!
20
+ const char * pos = src + 1;
21
+ for ( ; pos < end && *pos; pos++) {
22
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
23
+ }
24
+ return std::make_pair(value, pos);
25
+ }
26
+
27
+ uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
28
+ uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
29
+ auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
30
+ return result.first->second;
31
+ }
32
+
33
+ uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
34
+ uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
35
+ state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
36
+ return next_id;
37
+ }
38
+
39
+ void add_rule(
40
+ parse_state & state,
41
+ uint32_t rule_id,
42
+ const std::vector<whisper_grammar_element> & rule) {
43
+ if (state.rules.size() <= rule_id) {
44
+ state.rules.resize(rule_id + 1);
45
+ }
46
+ state.rules[rule_id] = rule;
47
+ }
48
+
49
+ bool is_word_char(char c) {
50
+ return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
51
+ }
52
+
53
+ std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
54
+ const char * pos = src;
55
+ const char * end = src + size;
56
+ uint32_t value = 0;
57
+ for ( ; pos < end && *pos; pos++) {
58
+ value <<= 4;
59
+ char c = *pos;
60
+ if ('a' <= c && c <= 'f') {
61
+ value += c - 'a' + 10;
62
+ } else if ('A' <= c && c <= 'F') {
63
+ value += c - 'A' + 10;
64
+ } else if ('0' <= c && c <= '9') {
65
+ value += c - '0';
66
+ } else {
67
+ break;
68
+ }
69
+ }
70
+ if (pos != end) {
71
+ throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
72
+ }
73
+ return std::make_pair(value, pos);
74
+ }
75
+
76
+ const char * parse_space(const char * src, bool newline_ok) {
77
+ const char * pos = src;
78
+ while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
79
+ (newline_ok && (*pos == '\r' || *pos == '\n'))) {
80
+ if (*pos == '#') {
81
+ while (*pos && *pos != '\r' && *pos != '\n') {
82
+ pos++;
83
+ }
84
+ } else {
85
+ pos++;
86
+ }
87
+ }
88
+ return pos;
89
+ }
90
+
91
+ const char * parse_name(const char * src) {
92
+ const char * pos = src;
93
+ while (is_word_char(*pos)) {
94
+ pos++;
95
+ }
96
+ if (pos == src) {
97
+ throw std::runtime_error(std::string("expecting name at ") + src);
98
+ }
99
+ return pos;
100
+ }
101
+
102
+ std::pair<uint32_t, const char *> parse_char(const char * src) {
103
+ if (*src == '\\') {
104
+ switch (src[1]) {
105
+ case 'x': return parse_hex(src + 2, 2);
106
+ case 'u': return parse_hex(src + 2, 4);
107
+ case 'U': return parse_hex(src + 2, 8);
108
+ case 't': return std::make_pair('\t', src + 2);
109
+ case 'r': return std::make_pair('\r', src + 2);
110
+ case 'n': return std::make_pair('\n', src + 2);
111
+ case '\\':
112
+ case '"':
113
+ case '[':
114
+ case ']':
115
+ return std::make_pair(src[1], src + 2);
116
+ default:
117
+ throw std::runtime_error(std::string("unknown escape at ") + src);
118
+ }
119
+ } else if (*src) {
120
+ return decode_utf8(src);
121
+ }
122
+ throw std::runtime_error("unexpected end of input");
123
+ }
124
+
125
+ const char * parse_alternates(
126
+ parse_state & state,
127
+ const char * src,
128
+ const std::string & rule_name,
129
+ uint32_t rule_id,
130
+ bool is_nested);
131
+
132
+ const char * parse_sequence(
133
+ parse_state & state,
134
+ const char * src,
135
+ const std::string & rule_name,
136
+ std::vector<whisper_grammar_element> & out_elements,
137
+ bool is_nested) {
138
+ size_t last_sym_start = out_elements.size();
139
+ const char * pos = src;
140
+ while (*pos) {
141
+ if (*pos == '"') { // literal string
142
+ pos++;
143
+ last_sym_start = out_elements.size();
144
+ while (*pos != '"') {
145
+ auto char_pair = parse_char(pos);
146
+ pos = char_pair.second;
147
+ out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
148
+ }
149
+ pos = parse_space(pos + 1, is_nested);
150
+ } else if (*pos == '[') { // char range(s)
151
+ pos++;
152
+ enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
153
+ if (*pos == '^') {
154
+ pos++;
155
+ start_type = WHISPER_GRETYPE_CHAR_NOT;
156
+ }
157
+ last_sym_start = out_elements.size();
158
+ while (*pos != ']') {
159
+ auto char_pair = parse_char(pos);
160
+ pos = char_pair.second;
161
+ enum whisper_gretype type = last_sym_start < out_elements.size()
162
+ ? WHISPER_GRETYPE_CHAR_ALT
163
+ : start_type;
164
+
165
+ out_elements.push_back({type, char_pair.first});
166
+ if (pos[0] == '-' && pos[1] != ']') {
167
+ auto endchar_pair = parse_char(pos + 1);
168
+ pos = endchar_pair.second;
169
+ out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
170
+ }
171
+ }
172
+ pos = parse_space(pos + 1, is_nested);
173
+ } else if (is_word_char(*pos)) { // rule reference
174
+ const char * name_end = parse_name(pos);
175
+ uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
176
+ pos = parse_space(name_end, is_nested);
177
+ last_sym_start = out_elements.size();
178
+ out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
179
+ } else if (*pos == '(') { // grouping
180
+ // parse nested alternates into synthesized rule
181
+ pos = parse_space(pos + 1, true);
182
+ uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
183
+ pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
184
+ last_sym_start = out_elements.size();
185
+ // output reference to synthesized rule
186
+ out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
187
+ if (*pos != ')') {
188
+ throw std::runtime_error(std::string("expecting ')' at ") + pos);
189
+ }
190
+ pos = parse_space(pos + 1, is_nested);
191
+ } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
192
+ if (last_sym_start == out_elements.size()) {
193
+ throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
194
+ }
195
+
196
+ // apply transformation to previous symbol (last_sym_start to end) according to
197
+ // rewrite rules:
198
+ // S* --> S' ::= S S' |
199
+ // S+ --> S' ::= S S' | S
200
+ // S? --> S' ::= S |
201
+ uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
202
+ std::vector<whisper_grammar_element> sub_rule;
203
+ // add preceding symbol to generated rule
204
+ sub_rule.insert(
205
+ sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
206
+ if (*pos == '*' || *pos == '+') {
207
+ // cause generated rule to recurse
208
+ sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
209
+ }
210
+ // mark start of alternate def
211
+ sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
212
+ if (*pos == '+') {
213
+ // add preceding symbol as alternate only for '+' (otherwise empty)
214
+ sub_rule.insert(
215
+ sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
216
+ }
217
+ sub_rule.push_back({WHISPER_GRETYPE_END, 0});
218
+ add_rule(state, sub_rule_id, sub_rule);
219
+
220
+ // in original rule, replace previous symbol with reference to generated rule
221
+ out_elements.resize(last_sym_start);
222
+ out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
223
+
224
+ pos = parse_space(pos + 1, is_nested);
225
+ } else {
226
+ break;
227
+ }
228
+ }
229
+ return pos;
230
+ }
231
+
232
+ const char * parse_alternates(
233
+ parse_state & state,
234
+ const char * src,
235
+ const std::string & rule_name,
236
+ uint32_t rule_id,
237
+ bool is_nested) {
238
+ std::vector<whisper_grammar_element> rule;
239
+ const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
240
+ while (*pos == '|') {
241
+ rule.push_back({WHISPER_GRETYPE_ALT, 0});
242
+ pos = parse_space(pos + 1, true);
243
+ pos = parse_sequence(state, pos, rule_name, rule, is_nested);
244
+ }
245
+ rule.push_back({WHISPER_GRETYPE_END, 0});
246
+ add_rule(state, rule_id, rule);
247
+ return pos;
248
+ }
249
+
250
+ const char * parse_rule(parse_state & state, const char * src) {
251
+ const char * name_end = parse_name(src);
252
+ const char * pos = parse_space(name_end, false);
253
+ size_t name_len = name_end - src;
254
+ uint32_t rule_id = get_symbol_id(state, src, name_len);
255
+ const std::string name(src, name_len);
256
+
257
+ if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
258
+ throw std::runtime_error(std::string("expecting ::= at ") + pos);
259
+ }
260
+ pos = parse_space(pos + 3, true);
261
+
262
+ pos = parse_alternates(state, pos, name, rule_id, false);
263
+
264
+ if (*pos == '\r') {
265
+ pos += pos[1] == '\n' ? 2 : 1;
266
+ } else if (*pos == '\n') {
267
+ pos++;
268
+ } else if (*pos) {
269
+ throw std::runtime_error(std::string("expecting newline or end at ") + pos);
270
+ }
271
+ return parse_space(pos, true);
272
+ }
273
+
274
+ parse_state parse(const char * src) {
275
+ try {
276
+ parse_state state;
277
+ const char * pos = parse_space(src, true);
278
+ while (*pos) {
279
+ pos = parse_rule(state, pos);
280
+ }
281
+ return state;
282
+ } catch (const std::exception & err) {
283
+ fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
284
+ return parse_state();
285
+ }
286
+ }
287
+
288
+ void print_grammar_char(FILE * file, uint32_t c) {
289
+ if (0x20 <= c && c <= 0x7f) {
290
+ fprintf(file, "%c", static_cast<char>(c));
291
+ } else {
292
+ // cop out of encoding UTF-8
293
+ fprintf(file, "<U+%04X>", c);
294
+ }
295
+ }
296
+
297
+ bool is_char_element(whisper_grammar_element elem) {
298
+ switch (elem.type) {
299
+ case WHISPER_GRETYPE_CHAR: return true;
300
+ case WHISPER_GRETYPE_CHAR_NOT: return true;
301
+ case WHISPER_GRETYPE_CHAR_ALT: return true;
302
+ case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
303
+ default: return false;
304
+ }
305
+ }
306
+
307
+ void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
308
+ for (auto elem : rule) {
309
+ switch (elem.type) {
310
+ case WHISPER_GRETYPE_END: fprintf(file, "END"); break;
311
+ case WHISPER_GRETYPE_ALT: fprintf(file, "ALT"); break;
312
+ case WHISPER_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
313
+ case WHISPER_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
314
+ case WHISPER_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
315
+ case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
316
+ case WHISPER_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
317
+ }
318
+ switch (elem.type) {
319
+ case WHISPER_GRETYPE_END:
320
+ case WHISPER_GRETYPE_ALT:
321
+ case WHISPER_GRETYPE_RULE_REF:
322
+ fprintf(file, "(%u) ", elem.value);
323
+ break;
324
+ case WHISPER_GRETYPE_CHAR:
325
+ case WHISPER_GRETYPE_CHAR_NOT:
326
+ case WHISPER_GRETYPE_CHAR_RNG_UPPER:
327
+ case WHISPER_GRETYPE_CHAR_ALT:
328
+ fprintf(file, "(\"");
329
+ print_grammar_char(file, elem.value);
330
+ fprintf(file, "\") ");
331
+ break;
332
+ }
333
+ }
334
+ fprintf(file, "\n");
335
+ }
336
+
337
+ void print_rule(
338
+ FILE * file,
339
+ uint32_t rule_id,
340
+ const std::vector<whisper_grammar_element> & rule,
341
+ const std::map<uint32_t, std::string> & symbol_id_names) {
342
+ if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
343
+ throw std::runtime_error(
344
+ "malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
345
+ }
346
+ fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
347
+ for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
348
+ whisper_grammar_element elem = rule[i];
349
+ switch (elem.type) {
350
+ case WHISPER_GRETYPE_END:
351
+ throw std::runtime_error(
352
+ "unexpected end of rule: " + std::to_string(rule_id) + "," +
353
+ std::to_string(i));
354
+ case WHISPER_GRETYPE_ALT:
355
+ fprintf(file, "| ");
356
+ break;
357
+ case WHISPER_GRETYPE_RULE_REF:
358
+ fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
359
+ break;
360
+ case WHISPER_GRETYPE_CHAR:
361
+ fprintf(file, "[");
362
+ print_grammar_char(file, elem.value);
363
+ break;
364
+ case WHISPER_GRETYPE_CHAR_NOT:
365
+ fprintf(file, "[^");
366
+ print_grammar_char(file, elem.value);
367
+ break;
368
+ case WHISPER_GRETYPE_CHAR_RNG_UPPER:
369
+ if (i == 0 || !is_char_element(rule[i - 1])) {
370
+ throw std::runtime_error(
371
+ "WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
372
+ std::to_string(rule_id) + "," + std::to_string(i));
373
+ }
374
+ fprintf(file, "-");
375
+ print_grammar_char(file, elem.value);
376
+ break;
377
+ case WHISPER_GRETYPE_CHAR_ALT:
378
+ if (i == 0 || !is_char_element(rule[i - 1])) {
379
+ throw std::runtime_error(
380
+ "WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
381
+ std::to_string(rule_id) + "," + std::to_string(i));
382
+ }
383
+ print_grammar_char(file, elem.value);
384
+ break;
385
+ }
386
+ if (is_char_element(elem)) {
387
+ switch (rule[i + 1].type) {
388
+ case WHISPER_GRETYPE_CHAR_ALT:
389
+ case WHISPER_GRETYPE_CHAR_RNG_UPPER:
390
+ break;
391
+ default:
392
+ fprintf(file, "] ");
393
+ }
394
+ }
395
+ }
396
+ fprintf(file, "\n");
397
+ }
398
+
399
+ void print_grammar(FILE * file, const parse_state & state) {
400
+ try {
401
+ std::map<uint32_t, std::string> symbol_id_names;
402
+ for (auto kv : state.symbol_ids) {
403
+ symbol_id_names[kv.second] = kv.first;
404
+ }
405
+ for (size_t i = 0, end = state.rules.size(); i < end; i++) {
406
+ // fprintf(file, "%zu: ", i);
407
+ // print_rule_binary(file, state.rules[i]);
408
+ print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
409
+ // fprintf(file, "\n");
410
+ }
411
+ } catch (const std::exception & err) {
412
+ fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
413
+ }
414
+ }
415
+
416
+ std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
417
+ std::vector<const whisper_grammar_element *> ret;
418
+ for (const auto & rule : rules) {
419
+ ret.push_back(rule.data());
420
+ }
421
+ return ret;
422
+ }
423
+ }
examples/grammar-parser.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Implements a parser for an extended Backus-Naur form (BNF), producing the
2
+ // binary context-free grammar format specified by whisper.h. Supports character
3
+ // ranges, grouping, and repetition operators. As an example, a grammar for
4
+ // arithmetic might look like:
5
+ //
6
+ // root ::= expr
7
+ // expr ::= term ([-+*/] term)*
8
+ // term ::= num | "(" space expr ")" space
9
+ // num ::= [0-9]+ space
10
+ // space ::= [ \t\n]*
11
+
12
+ #pragma once
13
+ #include "whisper.h"
14
+ #include <vector>
15
+ #include <map>
16
+ #include <cstdint>
17
+ #include <string>
18
+
19
+ namespace grammar_parser {
20
+ struct parse_state {
21
+ std::map<std::string, uint32_t> symbol_ids;
22
+ std::vector<std::vector<whisper_grammar_element>> rules;
23
+
24
+ std::vector<const whisper_grammar_element *> c_rules() const;
25
+ };
26
+
27
+ parse_state parse(const char * src);
28
+ void print_grammar(FILE * file, const parse_state & state);
29
+ }
grammars/assistant.gbnf ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - "turn on lights."
2
+ # - "set thermostat to 22."
3
+ # - "increase TV by 10."
4
+ # - "decrease oven by 50."
5
+ # - "play music."
6
+ # - "stop podcast."
7
+ # - "schedule cleaning at 3pm."
8
+ # - "cancel cleaning."
9
+ # - "remind me to buy milk at 5pm."
10
+ # - "show me security system."
11
+ # - "hide washing machine."
12
+ # - "what is the lights status?"
13
+ # - "what is the current thermostat value?"
14
+ # - "what is the security system status?"
15
+ # - "what is the door lock status?"
16
+ # - "what is the camera battery level?"
17
+ # - "what is the weather like today?"
18
+ # - "what is the forecast for tomorrow?"
19
+ # - "what is the time?"
20
+ # - "what is my schedule for today?"
21
+ # - "what tasks do I have?"
22
+ # - "what reminders do I have?"
23
+ #
24
+ # example:
25
+ #
26
+ # ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10
27
+ #
28
+
29
+ root ::= init " " (command | question) "."
30
+ prompt ::= init
31
+
32
+ # leading space is very important!
33
+ init ::= " Ok Whisper, start listening for commands."
34
+
35
+ command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
36
+ "Increase " device " by " value | "Decrease " device " by " value |
37
+ "Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task |
38
+ "Remind me to " task " at " time | "Show me " device | "Hide " device
39
+
40
+ question ::= "What is the " device " status?" | "What is the current " device " value?" |
41
+ "What is the " device " temperature?" | "What is the " device " humidity?" |
42
+ "What is the " device " power consumption?" | "What is the " device " battery level?" |
43
+ "What is the weather like today?" | "What is the forecast for tomorrow?" |
44
+ "What is the time?" | "What is my schedule for today?" | "What tasks do I have?" |
45
+ "What reminders do I have?"
46
+
47
+ device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" |
48
+ "music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" |
49
+ "vacuum cleaner"
50
+
51
+ value ::= [0-9]+
52
+
53
+ media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie"
54
+
55
+ task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?
56
+
57
+ time ::= [0-9] [0-9]? ("am" | "pm")?
grammars/chess.gbnf ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - bishop to c3
2
+ # - rook to d4
3
+ # - knight to e5
4
+ # - d4 d5 knight to c3
5
+ # - c3 queen to d4 king b1
6
+ # - pawn to a1 bishop to b2 knight to c3
7
+ #
8
+ # The prompt (--prompt) is the initial phrase that the user has to say.
9
+ # This is used to prime Whisper with how the user is expected to speak.
10
+ #
11
+ # Provide long context (--context) with sample moves to help Whisper decode the correct sequence.
12
+ # Longer context is better, but it slightly increases the processing time.
13
+ #
14
+ # example:
15
+ #
16
+ # ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100
17
+ #
18
+
19
+ root ::= init move move? move? "."
20
+ prompt ::= init "."
21
+
22
+ # leading space is very important!
23
+ init ::= " rook to b4, f3"
24
+
25
+ move ::= ", " ((piece | pawn | king) " " "to "?)? [a-h] [1-8]
26
+
27
+ piece ::= "bishop" | "rook" | "knight" | "queen"
28
+ king ::= "king"
29
+ pawn ::= "pawn"
grammars/colors.gbnf ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - red
2
+ # - green
3
+ # - blue
4
+ #
5
+ # example:
6
+ #
7
+ # ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red, green, blue," --context "green, red, blue,"
8
+ #
9
+
10
+ root ::= init color "."
11
+ prompt ::= init "."
12
+
13
+ # leading space is very important!
14
+ init ::= " red, green, blue"
15
+
16
+ color ::= ", " ("red" | "green" | "blue")
whisper.cpp CHANGED
@@ -579,6 +579,25 @@ struct whisper_model {
579
  std::map<std::string, struct ggml_tensor *> tensors;
580
  };
581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
  struct whisper_sequence {
583
  std::vector<whisper_token_data> tokens;
584
 
@@ -600,6 +619,9 @@ struct whisper_decoder {
600
  // the currently generated sequence of tokens
601
  whisper_sequence sequence;
602
 
 
 
 
603
  int seek_delta; // the window shift found so far based on the decoded timestamp tokens
604
 
605
  bool failed; // has the current segment failed to decode?
@@ -3685,6 +3707,425 @@ const char * whisper_print_system_info(void) {
3685
  return s.c_str();
3686
  }
3687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3688
  ////////////////////////////////////////////////////////////////////////////
3689
 
3690
  struct whisper_context_params * whisper_context_default_params_by_ref() {
@@ -3714,6 +4155,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3714
 
3715
  /*.translate =*/ false,
3716
  /*.no_context =*/ true,
 
3717
  /*.single_segment =*/ false,
3718
  /*.print_special =*/ false,
3719
  /*.print_progress =*/ true,
@@ -3776,6 +4218,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3776
 
3777
  /*.logits_filter_callback =*/ nullptr,
3778
  /*.logits_filter_callback_user_data =*/ nullptr,
 
 
 
 
 
3779
  };
3780
 
3781
  switch (strategy) {
@@ -3927,6 +4374,11 @@ static void whisper_process_logits(
3927
  // suppress <|notimestamps|> token
3928
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
3929
  logits[vocab.token_not] = -INFINITY;
 
 
 
 
 
3930
 
3931
  // suppress sot and nosp tokens
3932
  logits[vocab.token_sot] = -INFINITY;
@@ -3942,6 +4394,14 @@ static void whisper_process_logits(
3942
  logits[vocab.token_transcribe] = -INFINITY;
3943
  logits[vocab.token_prev] = -INFINITY;
3944
 
 
 
 
 
 
 
 
 
3945
  if (params.logits_filter_callback) {
3946
  params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
3947
  }
@@ -4052,10 +4512,33 @@ static void whisper_process_logits(
4052
  //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
4053
 
4054
  if (timestamp_logprob > max_text_token_logprob) {
 
4055
  for (int i = 0; i < vocab.token_beg; ++i) {
4056
  logits[i] = -INFINITY;
4057
  logprobs[i] = -INFINITY;
4058
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4059
  }
4060
  }
4061
  }
@@ -4073,32 +4556,55 @@ static void whisper_process_logits(
4073
 
4074
  #if 0
4075
  // print first 100 logits - token string : logit
4076
- for (int i = 0; i < 100; i++) {
4077
- const auto token = vocab.id_to_token.at(i);
4078
- const auto prob = probs[i];
4079
- const auto logit = logits[i];
4080
- const auto logprob = logprobs[i];
4081
- printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4082
  }
4083
 
4084
  // "And", "and", " And", " and"
4085
- printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
4086
- printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
4087
- printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
4088
- printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
4089
- printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
4090
-
4091
- printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
4092
- printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
4093
- printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
4094
- printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
4095
- printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
4096
-
4097
- printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
4098
- printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
4099
- printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
4100
- printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
4101
- printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
4102
  #endif
4103
  }
4104
 
@@ -4223,8 +4729,11 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
4223
  ptsum = sum_ts;
4224
  }
4225
 
 
 
4226
  for (int i = 0; i < k; ++i) {
4227
- const auto id = logits_id[i].second;
 
4228
 
4229
  result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
4230
 
@@ -4553,7 +5062,7 @@ int whisper_full_with_state(
4553
  state->exp_n_audio_ctx = params.audio_ctx;
4554
 
4555
  // these tokens determine the task that will be performed
4556
- std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
4557
 
4558
  if (whisper_is_multilingual(ctx)) {
4559
  const int lang_id = whisper_lang_id(params.language);
@@ -4566,17 +5075,19 @@ int whisper_full_with_state(
4566
  }
4567
  }
4568
 
 
4569
  {
4570
  const bool is_distil = ctx->model.hparams.n_text_layer == 2;
4571
-
4572
- // distilled models require the "no_timestamps" token
4573
- // TODO: add input parameter (#1229)
4574
- if (is_distil) {
4575
  WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
4576
- prompt_init.push_back(whisper_token_not(ctx));
4577
  }
4578
  }
4579
 
 
 
 
 
4580
  int seek = seek_start;
4581
 
4582
  std::vector<whisper_token> prompt;
@@ -4652,7 +5163,7 @@ int whisper_full_with_state(
4652
 
4653
  n_decoders_cur = std::max(1, n_decoders_cur);
4654
 
4655
- WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
4656
 
4657
  // TAGS: WHISPER_DECODER_INIT
4658
  for (int j = 0; j < n_decoders_cur; ++j) {
@@ -4673,6 +5184,13 @@ int whisper_full_with_state(
4673
  decoder.failed = false;
4674
  decoder.completed = false;
4675
  decoder.has_ts = false;
 
 
 
 
 
 
 
4676
  }
4677
 
4678
  // init prompt and kv cache for the current iteration
@@ -4790,6 +5308,10 @@ int whisper_full_with_state(
4790
  continue;
4791
  }
4792
 
 
 
 
 
4793
  auto & cur = beam_candidates[cur_c++];
4794
 
4795
  while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
@@ -4844,6 +5366,8 @@ int whisper_full_with_state(
4844
  has_ts = true;
4845
  }
4846
 
 
 
4847
  #ifdef WHISPER_DEBUG
4848
  {
4849
  const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
 
579
  std::map<std::string, struct ggml_tensor *> tensors;
580
  };
581
 
582
+ struct whisper_partial_utf8 {
583
+ uint32_t value; // bit value so far (unshifted)
584
+ int n_remain; // num bytes remaining; -1 indicates invalid sequence
585
+ };
586
+
587
+ struct whisper_grammar {
588
+ /*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
589
+ std::vector<std::vector<const whisper_grammar_element *>> stacks;
590
+
591
+ // buffer for partially generated UTF-8 sequence from accepted tokens
592
+ whisper_partial_utf8 partial_utf8;
593
+ };
594
+
595
+ struct whisper_grammar_candidate {
596
+ whisper_token id;
597
+ const uint32_t * code_points;
598
+ whisper_partial_utf8 partial_utf8;
599
+ };
600
+
601
  struct whisper_sequence {
602
  std::vector<whisper_token_data> tokens;
603
 
 
619
  // the currently generated sequence of tokens
620
  whisper_sequence sequence;
621
 
622
+ // grammar parse state of generated sequence of tokens
623
+ whisper_grammar grammar;
624
+
625
  int seek_delta; // the window shift found so far based on the decoded timestamp tokens
626
 
627
  bool failed; // has the current segment failed to decode?
 
3707
  return s.c_str();
3708
  }
3709
 
3710
+ //////////////////////////////////
3711
+ // Grammar - ported from llama.cpp
3712
+ //////////////////////////////////
3713
+
3714
+ // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
3715
+ // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
3716
+ std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
3717
+ const char * src,
3718
+ whisper_partial_utf8 partial_start) {
3719
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
3720
+ const char * pos = src;
3721
+ std::vector<uint32_t> code_points;
3722
+ uint32_t value = partial_start.value;
3723
+ int n_remain = partial_start.n_remain;
3724
+
3725
+ // continue previous decode, if applicable
3726
+ while (*pos != 0 && n_remain > 0) {
3727
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
3728
+ if ((next_byte >> 6) != 2) {
3729
+ // invalid sequence, abort
3730
+ code_points.push_back(0);
3731
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
3732
+ }
3733
+ value = (value << 6) + (next_byte & 0x3F);
3734
+ ++pos;
3735
+ --n_remain;
3736
+ }
3737
+
3738
+ if (partial_start.n_remain > 0 && n_remain == 0) {
3739
+ code_points.push_back(value);
3740
+ }
3741
+
3742
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
3743
+ while (*pos != 0) {
3744
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
3745
+ uint8_t highbits = first_byte >> 4;
3746
+ n_remain = lookup[highbits] - 1;
3747
+
3748
+ if (n_remain < 0) {
3749
+ // invalid sequence, abort
3750
+ code_points.clear();
3751
+ code_points.push_back(0);
3752
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
3753
+ }
3754
+
3755
+ uint8_t mask = (1 << (7 - n_remain)) - 1;
3756
+ value = first_byte & mask;
3757
+ ++pos;
3758
+ while (*pos != 0 && n_remain > 0) {
3759
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
3760
+ ++pos;
3761
+ --n_remain;
3762
+ }
3763
+ if (n_remain == 0) {
3764
+ code_points.push_back(value);
3765
+ }
3766
+ }
3767
+ code_points.push_back(0);
3768
+
3769
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
3770
+ }
3771
+
3772
+ // returns true iff pos points to the end of one of the definitions of a rule
3773
+ static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
3774
+ switch (pos->type) {
3775
+ case WHISPER_GRETYPE_END: return true; // NOLINT
3776
+ case WHISPER_GRETYPE_ALT: return true; // NOLINT
3777
+ default: return false;
3778
+ }
3779
+ }
3780
+
3781
+ // returns true iff chr satisfies the char range at pos (regular or inverse range)
3782
+ // asserts that pos is pointing to a char range element
3783
+ static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
3784
+ const whisper_grammar_element * pos,
3785
+ const uint32_t chr) {
3786
+
3787
+ bool found = false;
3788
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
3789
+
3790
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
3791
+
3792
+ do {
3793
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
3794
+ // inclusive range, e.g. [a-z]
3795
+ found = found || (pos->value <= chr && chr <= pos[1].value);
3796
+ pos += 2;
3797
+ } else {
3798
+ // exact char match, e.g. [a] or "a"
3799
+ found = found || pos->value == chr;
3800
+ pos += 1;
3801
+ }
3802
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
3803
+
3804
+ return std::make_pair(found == is_positive_char, pos);
3805
+ }
3806
+
3807
+ // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
3808
+ // range at pos (regular or inverse range)
3809
+ // asserts that pos is pointing to a char range element
3810
+ static bool whisper_grammar_match_partial_char(
3811
+ const whisper_grammar_element * pos,
3812
+ const whisper_partial_utf8 partial_utf8) {
3813
+
3814
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
3815
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
3816
+
3817
+ uint32_t partial_value = partial_utf8.value;
3818
+ int n_remain = partial_utf8.n_remain;
3819
+
3820
+ // invalid sequence or 7-bit char split across 2 bytes (overlong)
3821
+ if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
3822
+ return false;
3823
+ }
3824
+
3825
+ // range of possible code points this partial UTF-8 sequence could complete to
3826
+ uint32_t low = partial_value << (n_remain * 6);
3827
+ uint32_t high = low | ((1 << (n_remain * 6)) - 1);
3828
+
3829
+ if (low == 0) {
3830
+ if (n_remain == 2) {
3831
+ low = 1 << 11;
3832
+ } else if (n_remain == 3) {
3833
+ low = 1 << 16;
3834
+ }
3835
+ }
3836
+
3837
+ do {
3838
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
3839
+ // inclusive range, e.g. [a-z]
3840
+ if (pos->value <= high && low <= pos[1].value) {
3841
+ return is_positive_char;
3842
+ }
3843
+ pos += 2;
3844
+ } else {
3845
+ // exact char match, e.g. [a] or "a"
3846
+ if (low <= pos->value && pos->value <= high) {
3847
+ return is_positive_char;
3848
+ }
3849
+ pos += 1;
3850
+ }
3851
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
3852
+
3853
+ return !is_positive_char;
3854
+ }
3855
+
3856
+
3857
+ // transforms a grammar pushdown stack into N possible stacks, all ending
3858
+ // at a character range (terminal element)
3859
+ static void whisper_grammar_advance_stack(
3860
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
3861
+ const std::vector<const whisper_grammar_element *> & stack,
3862
+ std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
3863
+
3864
+ if (stack.empty()) {
3865
+ new_stacks.push_back(stack);
3866
+ return;
3867
+ }
3868
+
3869
+ const whisper_grammar_element * pos = stack.back();
3870
+
3871
+ switch (pos->type) {
3872
+ case WHISPER_GRETYPE_RULE_REF: {
3873
+ const size_t rule_id = static_cast<size_t>(pos->value);
3874
+ const whisper_grammar_element * subpos = rules[rule_id].data();
3875
+ do {
3876
+ // init new stack without the top (pos)
3877
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
3878
+ if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
3879
+ // if this rule ref is followed by another element, add that to stack
3880
+ new_stack.push_back(pos + 1);
3881
+ }
3882
+ if (!whisper_grammar_is_end_of_sequence(subpos)) {
3883
+ // if alternate is nonempty, add to stack
3884
+ new_stack.push_back(subpos);
3885
+ }
3886
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
3887
+ while (!whisper_grammar_is_end_of_sequence(subpos)) {
3888
+ // scan to end of alternate def
3889
+ subpos++;
3890
+ }
3891
+ if (subpos->type == WHISPER_GRETYPE_ALT) {
3892
+ // there's another alternate def of this rule to process
3893
+ subpos++;
3894
+ } else {
3895
+ break;
3896
+ }
3897
+ } while (true);
3898
+ break;
3899
+ }
3900
+ case WHISPER_GRETYPE_CHAR:
3901
+ case WHISPER_GRETYPE_CHAR_NOT:
3902
+ new_stacks.push_back(stack);
3903
+ break;
3904
+ default:
3905
+ // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
3906
+ // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
3907
+ // those
3908
+ WHISPER_ASSERT(false);
3909
+ }
3910
+ }
3911
+
3912
+ // takes a set of possible pushdown stacks on a grammar, which are required to
3913
+ // be positioned at a character range (see `whisper_grammar_advance_stack`), and
3914
+ // produces the N possible stacks if the given char is accepted at those
3915
+ // positions
3916
+ static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
3917
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
3918
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
3919
+ const uint32_t chr) {
3920
+
3921
+ std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
3922
+
3923
+ for (const auto & stack : stacks) {
3924
+ if (stack.empty()) {
3925
+ continue;
3926
+ }
3927
+
3928
+ auto match = whisper_grammar_match_char(stack.back(), chr);
3929
+ if (match.first) {
3930
+ const whisper_grammar_element * pos = match.second;
3931
+
3932
+ // update top of stack to next element, if any
3933
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
3934
+ if (!whisper_grammar_is_end_of_sequence(pos)) {
3935
+ new_stack.push_back(pos);
3936
+ }
3937
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
3938
+ }
3939
+ }
3940
+
3941
+ return new_stacks;
3942
+ }
3943
+
3944
+ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
3945
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
3946
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
3947
+ const std::vector<whisper_grammar_candidate> & candidates);
3948
+
3949
+ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_for_stack(
3950
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
3951
+ const std::vector<const whisper_grammar_element *> & stack,
3952
+ const std::vector<whisper_grammar_candidate> & candidates) {
3953
+
3954
+ std::vector<whisper_grammar_candidate> rejects;
3955
+
3956
+ if (stack.empty()) {
3957
+ for (auto tok : candidates) {
3958
+ if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
3959
+ rejects.push_back(tok);
3960
+ }
3961
+ }
3962
+ return rejects;
3963
+ }
3964
+
3965
+ const whisper_grammar_element * stack_pos = stack.back();
3966
+
3967
+ std::vector<whisper_grammar_candidate> next_candidates;
3968
+ for (auto tok : candidates) {
3969
+ if (*tok.code_points == 0) {
3970
+ // reached end of full codepoints in token, reject iff it ended in a partial sequence
3971
+ // that cannot satisfy this position in grammar
3972
+ if (tok.partial_utf8.n_remain != 0 &&
3973
+ !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
3974
+ rejects.push_back(tok);
3975
+ }
3976
+ } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
3977
+ next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 });
3978
+ } else {
3979
+ rejects.push_back(tok);
3980
+ }
3981
+ }
3982
+
3983
+ const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second;
3984
+
3985
+ // update top of stack to next element, if any
3986
+ std::vector<const whisper_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
3987
+ if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) {
3988
+ stack_after.push_back(stack_pos_after);
3989
+ }
3990
+ std::vector<std::vector<const whisper_grammar_element *>> next_stacks;
3991
+ whisper_grammar_advance_stack(rules, stack_after, next_stacks);
3992
+
3993
+ auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates);
3994
+ for (auto tok : next_rejects) {
3995
+ rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 });
3996
+ }
3997
+
3998
+ return rejects;
3999
+ }
4000
+
4001
+ static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
4002
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
4003
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
4004
+ const std::vector<whisper_grammar_candidate> & candidates) {
4005
+ if (candidates.empty() || stacks.empty()) {
4006
+ return std::vector<whisper_grammar_candidate>();
4007
+ }
4008
+
4009
+ auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
4010
+
4011
+ for (size_t i = 1, size = stacks.size(); i < size; ++i) {
4012
+ rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
4013
+ }
4014
+ return rejects;
4015
+ }
4016
+
4017
+ static struct whisper_grammar whisper_grammar_init(
4018
+ const whisper_grammar_element ** rules,
4019
+ size_t n_rules,
4020
+ size_t i_start_rule) {
4021
+ const whisper_grammar_element * pos;
4022
+
4023
+ // copy rule definitions into vectors
4024
+ std::vector<std::vector<whisper_grammar_element>> vec_rules(n_rules);
4025
+ for (size_t i = 0; i < n_rules; i++) {
4026
+ for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) {
4027
+ vec_rules[i].push_back(*pos);
4028
+ }
4029
+ vec_rules[i].push_back({WHISPER_GRETYPE_END, 0});
4030
+ }
4031
+
4032
+ // loop over alternates of start rule to build initial stacks
4033
+ std::vector<std::vector<const whisper_grammar_element *>> stacks;
4034
+ pos = rules[i_start_rule];
4035
+ do {
4036
+ std::vector<const whisper_grammar_element *> stack;
4037
+ if (!whisper_grammar_is_end_of_sequence(pos)) {
4038
+ // if alternate is nonempty, add to stack
4039
+ stack.push_back(pos);
4040
+ }
4041
+ whisper_grammar_advance_stack(vec_rules, stack, stacks);
4042
+ while (!whisper_grammar_is_end_of_sequence(pos)) {
4043
+ // scan to end of alternate def
4044
+ pos++;
4045
+ }
4046
+ if (pos->type == WHISPER_GRETYPE_ALT) {
4047
+ // there's another alternate def of this rule to process
4048
+ pos++;
4049
+ } else {
4050
+ break;
4051
+ }
4052
+ } while (true);
4053
+
4054
+ return { std::move(vec_rules), std::move(stacks), {} };
4055
+ }
4056
+
4057
+ static void whisper_suppress_invalid_grammar(
4058
+ whisper_context & ctx,
4059
+ const whisper_full_params & params,
4060
+ std::vector<float> & logits,
4061
+ const whisper_grammar & grammar) {
4062
+
4063
+ if (grammar.rules.empty() || grammar.stacks.empty()) {
4064
+ return;
4065
+ }
4066
+
4067
+ //bool allow_eot = false;
4068
+ //for (const auto & stack : grammar.stacks) {
4069
+ // if (stack.empty()) {
4070
+ // allow_eot = true;
4071
+ // break;
4072
+ // }
4073
+ //}
4074
+
4075
+ const whisper_token eot = whisper_token_eot(&ctx);
4076
+
4077
+ std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
4078
+ std::vector<whisper_grammar_candidate> candidates_grammar;
4079
+
4080
+ for (whisper_token id = 0; id < eot; ++id) {
4081
+ const std::string & text = ctx.vocab.id_to_token[id];
4082
+ if (!text.empty()) {
4083
+ candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
4084
+ candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second });
4085
+ }
4086
+ }
4087
+
4088
+ const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
4089
+
4090
+ for (const auto & reject : rejects) {
4091
+ logits[reject.id] -= params.grammar_penalty;
4092
+ }
4093
+
4094
+ // when the grammar allows a continuation, we penalize the end-of-text token
4095
+ //if (!allow_eot) {
4096
+ // logits[eot] -= params.grammar_penalty;
4097
+ //}
4098
+ //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
4099
+ }
4100
+
4101
+ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
4102
+ if (grammar.rules.empty() || grammar.stacks.empty()) {
4103
+ return;
4104
+ }
4105
+
4106
+ //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
4107
+
4108
+ const std::string & text = ctx.vocab.id_to_token[token];
4109
+
4110
+ if (text.rfind("[_", 0) == 0) {
4111
+ // fprintf(stderr, " (skipped)\n");
4112
+ return;
4113
+ }
4114
+ // fprintf(stderr, "\n");
4115
+
4116
+ // Note terminating 0 in decoded string
4117
+ const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8);
4118
+ const auto & code_points = decoded.first;
4119
+ for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
4120
+ grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it);
4121
+ }
4122
+ grammar.partial_utf8 = decoded.second;
4123
+ }
4124
+
4125
+ //////////////
4126
+ // END grammar
4127
+ //////////////
4128
+
4129
  ////////////////////////////////////////////////////////////////////////////
4130
 
4131
  struct whisper_context_params * whisper_context_default_params_by_ref() {
 
4155
 
4156
  /*.translate =*/ false,
4157
  /*.no_context =*/ true,
4158
+ /*.no_timestamps =*/ false,
4159
  /*.single_segment =*/ false,
4160
  /*.print_special =*/ false,
4161
  /*.print_progress =*/ true,
 
4218
 
4219
  /*.logits_filter_callback =*/ nullptr,
4220
  /*.logits_filter_callback_user_data =*/ nullptr,
4221
+
4222
+ /*.grammar_rules =*/ nullptr,
4223
+ /*.n_grammar_rules =*/ 0,
4224
+ /*.i_start_rule =*/ 0,
4225
+ /*.grammar_penalty =*/ 100.0f,
4226
  };
4227
 
4228
  switch (strategy) {
 
4374
  // suppress <|notimestamps|> token
4375
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
4376
  logits[vocab.token_not] = -INFINITY;
4377
+ if (params.no_timestamps) {
4378
+ for (int i = vocab.token_beg; i < n_logits; ++i) {
4379
+ logits[i] = -INFINITY;
4380
+ }
4381
+ }
4382
 
4383
  // suppress sot and nosp tokens
4384
  logits[vocab.token_sot] = -INFINITY;
 
4394
  logits[vocab.token_transcribe] = -INFINITY;
4395
  logits[vocab.token_prev] = -INFINITY;
4396
 
4397
+ // suppress lang tokens
4398
+ for (size_t i = 0; i < g_lang.size(); ++i) {
4399
+ logits[whisper_token_lang(&ctx, i)] = -INFINITY;
4400
+ }
4401
+
4402
+ // suppress prev token
4403
+ logits[vocab.token_prev] = -INFINITY;
4404
+
4405
  if (params.logits_filter_callback) {
4406
  params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
4407
  }
 
4512
  //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
4513
 
4514
  if (timestamp_logprob > max_text_token_logprob) {
4515
+ //printf("sampling timestamp\n");
4516
  for (int i = 0; i < vocab.token_beg; ++i) {
4517
  logits[i] = -INFINITY;
4518
  logprobs[i] = -INFINITY;
4519
  }
4520
+ } else if (params.n_grammar_rules > 0) {
4521
+ whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
4522
+
4523
+ // populate the logprobs array (log_softmax)
4524
+ {
4525
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
4526
+ float logsumexp = 0.0f;
4527
+ for (int i = 0; i < n_logits; ++i) {
4528
+ if (logits[i] > -INFINITY) {
4529
+ logsumexp += expf(logits[i] - logit_max);
4530
+ }
4531
+ }
4532
+ logsumexp = logf(logsumexp) + logit_max;
4533
+
4534
+ for (int i = 0; i < n_logits; ++i) {
4535
+ if (logits[i] > -INFINITY) {
4536
+ logprobs[i] = logits[i] - logsumexp;
4537
+ } else {
4538
+ logprobs[i] = -INFINITY;
4539
+ }
4540
+ }
4541
+ }
4542
  }
4543
  }
4544
  }
 
4556
 
4557
  #if 0
4558
  // print first 100 logits - token string : logit
4559
+ //for (int i = 0; i < 10; i++) {
4560
+ // const auto token = vocab.id_to_token.at(i);
4561
+ // const auto prob = probs[i];
4562
+ // const auto logit = logits[i];
4563
+ // const auto logprob = logprobs[i];
4564
+ // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
4565
+ //}
4566
+
4567
+ // print sorted
4568
+ {
4569
+ std::vector<std::pair<float, int>> pairs;
4570
+
4571
+ for (int i = 0; i < n_logits; ++i) {
4572
+ pairs.push_back(std::make_pair(probs[i], i));
4573
+ }
4574
+
4575
+ std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
4576
+ return a.first > b.first;
4577
+ });
4578
+
4579
+ for (int i = 0; i < 10; i++) {
4580
+ const auto token = vocab.id_to_token.at(pairs[i].second);
4581
+ const auto prob = pairs[i].first;
4582
+ const auto logit = logits[pairs[i].second];
4583
+ const auto logprob = logprobs[pairs[i].second];
4584
+ printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
4585
+ }
4586
+
4587
+ printf("----------------\n");
4588
  }
4589
 
4590
  // "And", "and", " And", " and"
4591
+ //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
4592
+ //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
4593
+ //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
4594
+ //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
4595
+ //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
4596
+
4597
+ //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
4598
+ //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
4599
+ //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
4600
+ //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
4601
+ //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
4602
+
4603
+ //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
4604
+ //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
4605
+ //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
4606
+ //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
4607
+ //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
4608
  #endif
4609
  }
4610
 
 
4729
  ptsum = sum_ts;
4730
  }
4731
 
4732
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
4733
+
4734
  for (int i = 0; i < k; ++i) {
4735
+ const auto id = dist(state.rng);
4736
+ //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
4737
 
4738
  result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
4739
 
 
5062
  state->exp_n_audio_ctx = params.audio_ctx;
5063
 
5064
  // these tokens determine the task that will be performed
5065
+ std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
5066
 
5067
  if (whisper_is_multilingual(ctx)) {
5068
  const int lang_id = whisper_lang_id(params.language);
 
5075
  }
5076
  }
5077
 
5078
+ // distilled models require the "no_timestamps" token
5079
  {
5080
  const bool is_distil = ctx->model.hparams.n_text_layer == 2;
5081
+ if (is_distil && !params.no_timestamps) {
 
 
 
5082
  WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
5083
+ params.no_timestamps = true;
5084
  }
5085
  }
5086
 
5087
+ if (params.no_timestamps) {
5088
+ prompt_init.push_back(whisper_token_not(ctx));
5089
+ }
5090
+
5091
  int seek = seek_start;
5092
 
5093
  std::vector<whisper_token> prompt;
 
5163
 
5164
  n_decoders_cur = std::max(1, n_decoders_cur);
5165
 
5166
+ WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
5167
 
5168
  // TAGS: WHISPER_DECODER_INIT
5169
  for (int j = 0; j < n_decoders_cur; ++j) {
 
5184
  decoder.failed = false;
5185
  decoder.completed = false;
5186
  decoder.has_ts = false;
5187
+
5188
+ if (params.grammar_rules != nullptr) {
5189
+ decoder.grammar = whisper_grammar_init(
5190
+ params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
5191
+ } else {
5192
+ decoder.grammar = {};
5193
+ }
5194
  }
5195
 
5196
  // init prompt and kv cache for the current iteration
 
5308
  continue;
5309
  }
5310
 
5311
+ if (cur_c >= beam_candidates.size()) {
5312
+ cur_c = 0;
5313
+ }
5314
+
5315
  auto & cur = beam_candidates[cur_c++];
5316
 
5317
  while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
 
5366
  has_ts = true;
5367
  }
5368
 
5369
+ whisper_grammar_accept_token(*ctx, decoder.grammar, token.id);
5370
+
5371
  #ifdef WHISPER_DEBUG
5372
  {
5373
  const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
whisper.h CHANGED
@@ -109,6 +109,37 @@ extern "C" {
109
  void (*close)(void * ctx);
110
  } whisper_model_loader;
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  // Various functions for loading a ggml whisper model.
113
  // Allocate (almost) all memory needed for the model.
114
  // Return NULL on failure
@@ -402,6 +433,7 @@ extern "C" {
402
 
403
  bool translate;
404
  bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
 
405
  bool single_segment; // force single segment output (useful for streaming)
406
  bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
407
  bool print_progress; // print progress information
@@ -479,6 +511,11 @@ extern "C" {
479
  // called by each decoder to filter obtained logits
480
  whisper_logits_filter_callback logits_filter_callback;
481
  void * logits_filter_callback_user_data;
 
 
 
 
 
482
  };
483
 
484
  // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
 
109
  void (*close)(void * ctx);
110
  } whisper_model_loader;
111
 
112
+ // grammar element type
113
+ enum whisper_gretype {
114
+ // end of rule definition
115
+ WHISPER_GRETYPE_END = 0,
116
+
117
+ // start of alternate definition for rule
118
+ WHISPER_GRETYPE_ALT = 1,
119
+
120
+ // non-terminal element: reference to rule
121
+ WHISPER_GRETYPE_RULE_REF = 2,
122
+
123
+ // terminal element: character (code point)
124
+ WHISPER_GRETYPE_CHAR = 3,
125
+
126
+ // inverse char(s) ([^a], [^a-b] [^abc])
127
+ WHISPER_GRETYPE_CHAR_NOT = 4,
128
+
129
+ // modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
130
+ // be an inclusive range ([a-z])
131
+ WHISPER_GRETYPE_CHAR_RNG_UPPER = 5,
132
+
133
+ // modifies a preceding WHISPER_GRETYPE_CHAR or
134
+ // WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
135
+ WHISPER_GRETYPE_CHAR_ALT = 6,
136
+ };
137
+
138
+ typedef struct whisper_grammar_element {
139
+ enum whisper_gretype type;
140
+ uint32_t value; // Unicode code point or rule ID
141
+ } whisper_grammar_element;
142
+
143
  // Various functions for loading a ggml whisper model.
144
  // Allocate (almost) all memory needed for the model.
145
  // Return NULL on failure
 
433
 
434
  bool translate;
435
  bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
436
+ bool no_timestamps; // do not generate timestamps
437
  bool single_segment; // force single segment output (useful for streaming)
438
  bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
439
  bool print_progress; // print progress information
 
511
  // called by each decoder to filter obtained logits
512
  whisper_logits_filter_callback logits_filter_callback;
513
  void * logits_filter_callback_user_data;
514
+
515
+ const whisper_grammar_element ** grammar_rules;
516
+ size_t n_grammar_rules;
517
+ size_t i_start_rule;
518
+ float grammar_penalty;
519
  };
520
 
521
  // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()