danbev ggerganov commited on
Commit
a28f35e
·
unverified ·
1 Parent(s): cac1a97

vad : add initial Voice Activity Detection (VAD) support (#3065)

Browse files

* vad : add initial Voice Activity Detection (VAD) support

This commit add support for Voice Activity Detection (VAD). When enabled
this feature will process the audio input and detect speech segments.
This information is then used to reduce the number of samples that need
to be processed by whisper_full.

Resolves: https://github.com/ggml-org/whisper.cpp/issues/3003

---------

Co-authored-by: Georgi Gerganov <[email protected]>

.github/workflows/build.yml CHANGED
@@ -1253,3 +1253,23 @@ jobs:
1253
  source venv/bin/activate
1254
  pip install ane_transformers openai-whisper coremltools
1255
  ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1253
  source venv/bin/activate
1254
  pip install ane_transformers openai-whisper coremltools
1255
  ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }}
1256
+
1257
+ vad:
1258
+ if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
1259
+ github.event.inputs.run_type == 'full-ci' }}
1260
+ runs-on: ubuntu-latest
1261
+
1262
+ steps:
1263
+ - name: Checkout
1264
+ uses: actions/checkout@v4
1265
+
1266
+ - name: Build
1267
+ shell: bash
1268
+ run: |
1269
+ cmake -B build
1270
+ cmake --build build --config Release
1271
+
1272
+ - name: Test
1273
+ shell: bash
1274
+ run: |
1275
+ ctest -R ^test-vad$ --test-dir build --output-on-failure -VV
README.md CHANGED
@@ -25,6 +25,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
25
  - [Ascend NPU Support](#ascend-npu-support)
26
  - [Moore Threads GPU Support](#moore-threads-gpu-support)
27
  - [C-style API](https://github.com/ggml-org/whisper.cpp/blob/master/include/whisper.h)
 
28
 
29
  Supported platforms:
30
 
@@ -732,6 +733,64 @@ let package = Package(
732
  )
733
  ```
734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
  ## Examples
736
 
737
  There are various examples of using the library for different projects in the [examples](examples) folder.
 
25
  - [Ascend NPU Support](#ascend-npu-support)
26
  - [Moore Threads GPU Support](#moore-threads-gpu-support)
27
  - [C-style API](https://github.com/ggml-org/whisper.cpp/blob/master/include/whisper.h)
28
+ - [Voice Activity Detection (VAD)](#voice-activity-detection-vad)
29
 
30
  Supported platforms:
31
 
 
733
  )
734
  ```
735
 
736
+ ### Voice Activity Detection (VAD)
737
+ Support for Voice Activity Detection (VAD) can be enabled using the `--vad`
738
+ argument to `whisper-cli`. In addition to this option a VAD model is also
739
+ required.
740
+
741
+ The way this works is that first the audio samples are passed through
742
+ the VAD model which will detect speech segments. Using this information the
743
+ only the speech segments that are detected are extracted from the original audio
744
+ input and passed to whisper for processing. This reduces the amount of audio
745
+ data that needs to be processed by whisper and can significantly speed up the
746
+ transcription process.
747
+
748
+ The following VAD models are currently supported:
749
+
750
+ #### Silero-VAD
751
+ [Silero-vad](https://github.com/snakers4/silero-vad) is a lightweight VAD model
752
+ written in Python that is fast and accurate.
753
+
754
+ This model can be converted to ggml using the following command:
755
+ ```console
756
+ $ python3 -m venv venv && source venv/bin/activate
757
+ $ (venv) pip install silero-vad
758
+ $ (venv) $ python models/convert-silero-vad-to-ggml.py --output models/silero.bin
759
+ Saving GGML Silero-VAD model to models/silero-v5.1.2-ggml.bin
760
+ ```
761
+ And it can then be used with whisper as follows:
762
+ ```console
763
+ $ ./build/bin/whisper-cli \
764
+ --file ./samples/jfk.wav \
765
+ --model ./models/ggml-base.en.bin \
766
+ --vad \
767
+ --vad-model ./models/silero-v5.1.2-ggml.bin
768
+ ```
769
+
770
+ #### VAD Options
771
+
772
+ * --vad-threshold: Threshold probability for speech detection. A probability
773
+ for a speech segment/frame above this threshold will be considered as speech.
774
+
775
+ * --vad-min-speech-duration-ms: Minimum speech duration in milliseconds. Speech
776
+ segments shorter than this value will be discarded to filter out brief noise or
777
+ false positives.
778
+
779
+ * --vad-min-silence-duration-ms: Minimum silence duration in milliseconds. Silence
780
+ periods must be at least this long to end a speech segment. Shorter silence
781
+ periods will be ignored and included as part of the speech.
782
+
783
+ * --vad-max-speech-duration-s: Maximum speech duration in seconds. Speech segments
784
+ longer than this will be automatically split into multiple segments at silence
785
+ points exceeding 98ms to prevent excessively long segments.
786
+
787
+ * --vad-speech-pad-ms: Speech padding in milliseconds. Adds this amount of padding
788
+ before and after each detected speech segment to avoid cutting off speech edges.
789
+
790
+ * --vad-samples-overlap: Amount of audio to extend from each speech segment into
791
+ the next one, in seconds (e.g., 0.10 = 100ms overlap). This ensures speech isn't
792
+ cut off abruptly between segments when they're concatenated together.
793
+
794
  ## Examples
795
 
796
  There are various examples of using the library for different projects in the [examples](examples) folder.
examples/cli/cli.cpp CHANGED
@@ -11,6 +11,7 @@
11
  #include <thread>
12
  #include <vector>
13
  #include <cstring>
 
14
 
15
  #if defined(_WIN32)
16
  #ifndef NOMINMAX
@@ -97,6 +98,16 @@ struct whisper_params {
97
  std::vector<std::string> fname_out = {};
98
 
99
  grammar_parser::parse_state grammar_parsed;
 
 
 
 
 
 
 
 
 
 
100
  };
101
 
102
  static void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -185,6 +196,15 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
185
  else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; }
186
  else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; }
187
  else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
 
 
 
 
 
 
 
 
 
188
  else {
189
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
190
  whisper_print_usage(argc, argv, params);
@@ -254,6 +274,18 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
254
  fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
255
  fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
256
  fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
 
 
 
 
 
 
 
 
 
 
 
 
257
  fprintf(stderr, "\n");
258
  }
259
 
@@ -1134,6 +1166,16 @@ int main(int argc, char ** argv) {
1134
 
1135
  wparams.suppress_nst = params.suppress_nst;
1136
 
 
 
 
 
 
 
 
 
 
 
1137
  whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
1138
 
1139
  const auto & grammar_parsed = params.grammar_parsed;
 
11
  #include <thread>
12
  #include <vector>
13
  #include <cstring>
14
+ #include <cfloat>
15
 
16
  #if defined(_WIN32)
17
  #ifndef NOMINMAX
 
98
  std::vector<std::string> fname_out = {};
99
 
100
  grammar_parser::parse_state grammar_parsed;
101
+
102
+ // Voice Activity Detection (VAD) parameters
103
+ bool vad = false;
104
+ std::string vad_model = "";
105
+ float vad_threshold = 0.5f;
106
+ int vad_min_speech_duration_ms = 250;
107
+ int vad_min_silence_duration_ms = 100;
108
+ float vad_max_speech_duration_s = FLT_MAX;
109
+ int vad_speech_pad_ms = 30;
110
+ float vad_samples_overlap = 0.1f;
111
  };
112
 
113
  static void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 
196
  else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; }
197
  else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; }
198
  else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
199
+ // Voice Activity Detection (VAD)
200
+ else if (arg == "-v" || arg == "--vad") { params.vad = true; }
201
+ else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; }
202
+ else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(ARGV_NEXT); }
203
+ else if (arg == "-vsd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
204
+ else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
205
+ else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(ARGV_NEXT); }
206
+ else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(ARGV_NEXT); }
207
+ else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(ARGV_NEXT); }
208
  else {
209
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
210
  whisper_print_usage(argc, argv, params);
 
274
  fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
275
  fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
276
  fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
277
+ // Voice Activity Detection (VAD) parameters
278
+ fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n");
279
+ fprintf(stderr, " -v, --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false");
280
+ fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str());
281
+ fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold);
282
+ fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms);
283
+ fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms);
284
+ fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ?
285
+ std::string("FLT_MAX").c_str() :
286
+ std::to_string(params.vad_max_speech_duration_s).c_str());
287
+ fprintf(stderr, " -vp N, --vad-speech-pad-ms N [%-7d] VAD speech padding (extend segments)\n", params.vad_speech_pad_ms);
288
+ fprintf(stderr, " -vo N, --vad-samples-overlap N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap);
289
  fprintf(stderr, "\n");
290
  }
291
 
 
1166
 
1167
  wparams.suppress_nst = params.suppress_nst;
1168
 
1169
+ wparams.vad = params.vad;
1170
+ wparams.vad_model_path = params.vad_model.c_str();
1171
+
1172
+ wparams.vad_params.threshold = params.vad_threshold;
1173
+ wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms;
1174
+ wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms;
1175
+ wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s;
1176
+ wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms;
1177
+ wparams.vad_params.samples_overlap = params.vad_samples_overlap;
1178
+
1179
  whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
1180
 
1181
  const auto & grammar_parsed = params.grammar_parsed;
include/whisper.h CHANGED
@@ -189,6 +189,15 @@ extern "C" {
189
  uint32_t value; // Unicode code point or rule ID
190
  } whisper_grammar_element;
191
 
 
 
 
 
 
 
 
 
 
192
  // Various functions for loading a ggml whisper model.
193
  // Allocate (almost) all memory needed for the model.
194
  // Return NULL on failure
@@ -570,11 +579,18 @@ extern "C" {
570
  size_t n_grammar_rules;
571
  size_t i_start_rule;
572
  float grammar_penalty;
 
 
 
 
 
 
573
  };
574
 
575
  // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
576
  WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(void);
577
  WHISPER_API struct whisper_context_params whisper_context_default_params (void);
 
578
  WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
579
  WHISPER_API struct whisper_full_params whisper_full_default_params (enum whisper_sampling_strategy strategy);
580
 
@@ -652,6 +668,53 @@ extern "C" {
652
  WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
653
  WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  ////////////////////////////////////////////////////////////////////////////
656
 
657
  // Temporary helpers needed for exposing ggml interface
 
189
  uint32_t value; // Unicode code point or rule ID
190
  } whisper_grammar_element;
191
 
192
+ typedef struct whisper_vad_params {
193
+ float threshold; // Probability threshold to consider as speech.
194
+ int min_speech_duration_ms; // Min duration for a valid speech segment.
195
+ int min_silence_duration_ms; // Min silence duration to consider speech as ended.
196
+ float max_speech_duration_s; // Max duration of a speech segment before forcing a new segment.
197
+ int speech_pad_ms; // Padding added before and after speech segments.
198
+ float samples_overlap; // Overlap in seconds when copying audio samples from speech segment.
199
+ } whisper_vad_params;
200
+
201
  // Various functions for loading a ggml whisper model.
202
  // Allocate (almost) all memory needed for the model.
203
  // Return NULL on failure
 
579
  size_t n_grammar_rules;
580
  size_t i_start_rule;
581
  float grammar_penalty;
582
+
583
+ // Voice Activity Detection (VAD) params
584
+ bool vad; // Enable VAD
585
+ const char * vad_model_path; // Path to VAD model
586
+
587
+ whisper_vad_params vad_params;
588
  };
589
 
590
  // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
591
  WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(void);
592
  WHISPER_API struct whisper_context_params whisper_context_default_params (void);
593
+
594
  WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
595
  WHISPER_API struct whisper_full_params whisper_full_default_params (enum whisper_sampling_strategy strategy);
596
 
 
668
  WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
669
  WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
670
 
671
+ //
672
+ // Voice Activity Detection (VAD)
673
+ //
674
+
675
+ struct whisper_vad_context;
676
+
677
+ WHISPER_API struct whisper_vad_params whisper_vad_default_params(void);
678
+
679
+ struct whisper_vad_context_params {
680
+ int n_threads; // The number of threads to use for processing.
681
+ bool use_gpu;
682
+ int gpu_device; // CUDA device
683
+ };
684
+
685
+ WHISPER_API struct whisper_vad_context_params whisper_vad_default_context_params(void);
686
+
687
+ WHISPER_API struct whisper_vad_context * whisper_vad_init_from_file_with_params(const char * path_model, struct whisper_vad_context_params params);
688
+ WHISPER_API struct whisper_vad_context * whisper_vad_init_with_params (struct whisper_model_loader * loader, struct whisper_vad_context_params params);
689
+
690
+ WHISPER_API bool whisper_vad_detect_speech(
691
+ struct whisper_vad_context * vctx,
692
+ const float * samples,
693
+ int n_samples);
694
+
695
+ WHISPER_API int whisper_vad_n_probs(struct whisper_vad_context * vctx);
696
+ WHISPER_API float * whisper_vad_probs (struct whisper_vad_context * vctx);
697
+
698
+ struct whisper_vad_segments;
699
+
700
+ WHISPER_API struct whisper_vad_segments * whisper_vad_segments_from_probs(
701
+ struct whisper_vad_context * vctx,
702
+ struct whisper_vad_params params);
703
+
704
+ WHISPER_API struct whisper_vad_segments * whisper_vad_segments_from_samples(
705
+ struct whisper_vad_context * vctx,
706
+ struct whisper_vad_params params,
707
+ const float * samples,
708
+ int n_samples);
709
+
710
+ WHISPER_API int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments);
711
+
712
+ WHISPER_API float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment);
713
+ WHISPER_API float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment);
714
+
715
+ WHISPER_API void whisper_vad_free_segments(struct whisper_vad_segments * segments);
716
+ WHISPER_API void whisper_vad_free (struct whisper_vad_context * ctx);
717
+
718
  ////////////////////////////////////////////////////////////////////////////
719
 
720
  // Temporary helpers needed for exposing ggml interface
models/convert-silero-vad-to-ggml.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import struct
3
+ import argparse
4
+ import torch
5
+ import numpy as np
6
+ from silero_vad import load_silero_vad, __version__ as silero_version
7
+
8
+ def convert_silero_vad(output_path, print_tensors=True):
9
+ model = load_silero_vad()
10
+ state_dict = model.state_dict()
11
+
12
+ # Clean up state dict keys - filter out 8k model
13
+ cleaned_dict = {}
14
+ for key, value in state_dict.items():
15
+ # Skip 8k model
16
+ if "_8k" not in key:
17
+ clean_key = key
18
+ if not key.startswith("_model."):
19
+ clean_key = "_model." + key
20
+ cleaned_dict[clean_key] = value
21
+
22
+ base, ext = os.path.splitext(output_path)
23
+ output_file = f"{base}-v{silero_version}-ggml{ext}"
24
+ print(f"Saving GGML Silero-VAD model to {output_file}")
25
+
26
+ print("\nTensor info for debugging:")
27
+ for key, tensor in cleaned_dict.items():
28
+ print(f" - {key}: {tensor.shape} ({tensor.dtype})")
29
+ print()
30
+
31
+ with open(output_file, "wb") as fout:
32
+ # Write magic and version
33
+ fout.write(struct.pack("i", 0x67676d6c))
34
+
35
+ model_type = "silero-16k"
36
+ str_len = len(model_type)
37
+ fout.write(struct.pack("i", str_len))
38
+ fout.write(model_type.encode('utf-8'))
39
+
40
+ version_parts = silero_version.split('.')
41
+ major, minor, patch = map(int, version_parts)
42
+ print(f"Version: {major}.{minor}.{patch}")
43
+ fout.write(struct.pack("i", major))
44
+ fout.write(struct.pack("i", minor))
45
+ fout.write(struct.pack("i", patch))
46
+
47
+ # Write model architecture parameters
48
+ window_size = 512
49
+ fout.write(struct.pack("i", window_size))
50
+ context_size = 64
51
+ fout.write(struct.pack("i", context_size))
52
+
53
+ n_encoder_layers = 4
54
+ fout.write(struct.pack("i", n_encoder_layers))
55
+
56
+ # Write encoder dimensions
57
+ input_channels = 129
58
+ encoder_in_channels = [input_channels, 128, 64, 64]
59
+ encoder_out_channels = [128, 64, 64, 128]
60
+ kernel_size = 3
61
+
62
+ for i in range(n_encoder_layers):
63
+ fout.write(struct.pack("i", encoder_in_channels[i]))
64
+ fout.write(struct.pack("i", encoder_out_channels[i]))
65
+ fout.write(struct.pack("i", kernel_size))
66
+
67
+ # Write LSTM dimensions
68
+ lstm_input_size = 128
69
+ lstm_hidden_size = 128
70
+ fout.write(struct.pack("i", lstm_input_size))
71
+ fout.write(struct.pack("i", lstm_hidden_size))
72
+
73
+ # Write final conv dimensions
74
+ final_conv_in = 128
75
+ final_conv_out = 1
76
+ fout.write(struct.pack("i", final_conv_in))
77
+ fout.write(struct.pack("i", final_conv_out))
78
+
79
+ # Define tensor keys to write
80
+ tensor_keys = []
81
+
82
+ # Encoder weights
83
+ for i in range(n_encoder_layers):
84
+ weight_key = f"_model.encoder.{i}.reparam_conv.weight"
85
+ bias_key = f"_model.encoder.{i}.reparam_conv.bias"
86
+ if weight_key in cleaned_dict and bias_key in cleaned_dict:
87
+ tensor_keys.append(weight_key)
88
+ tensor_keys.append(bias_key)
89
+
90
+ # LSTM weights
91
+ lstm_keys = [
92
+ "_model.decoder.rnn.weight_ih",
93
+ "_model.decoder.rnn.weight_hh",
94
+ "_model.decoder.rnn.bias_ih",
95
+ "_model.decoder.rnn.bias_hh"
96
+ ]
97
+ tensor_keys.extend([k for k in lstm_keys if k in cleaned_dict])
98
+
99
+ # Final conv weights
100
+ final_keys = [
101
+ "_model.decoder.decoder.2.weight",
102
+ "_model.decoder.decoder.2.bias"
103
+ ]
104
+ tensor_keys.extend([k for k in final_keys if k in cleaned_dict])
105
+
106
+ # STFT basis - add this last
107
+ stft_tensor = "_model.stft.forward_basis_buffer"
108
+ tensor_keys.append(stft_tensor)
109
+
110
+ print(f"Writing {len(tensor_keys)} tensors:")
111
+ for key in tensor_keys:
112
+ if key in cleaned_dict:
113
+ print(f" - {key}: {cleaned_dict[key].shape}")
114
+ else:
115
+ print(f" - {key}: MISSING")
116
+
117
+ # Process each tensor
118
+ for key in tensor_keys:
119
+ if key not in cleaned_dict:
120
+ print(f"Warning: Missing tensor {key}, skipping")
121
+ continue
122
+
123
+ tensor = cleaned_dict[key]
124
+
125
+ # Special handling for STFT tensor
126
+ if key == "_model.stft.forward_basis_buffer":
127
+ # Get the original numpy array without squeezing
128
+ data = tensor.detach().cpu().numpy()
129
+ # Ensure it has the expected shape
130
+ print(f"STFT tensor original shape: {data.shape}")
131
+ n_dims = 3
132
+ tensor_shape = [data.shape[2], data.shape[1], data.shape[0]]
133
+ is_conv_weight = True
134
+ else:
135
+ # For other tensors, we can use standard processing
136
+ data = tensor.detach().cpu().squeeze().numpy()
137
+ tensor_shape = list(data.shape)
138
+
139
+ # Ensure we have at most 4 dimensions for GGML
140
+ n_dims = min(len(tensor_shape), 4)
141
+
142
+ # Reverse dimensions for GGML
143
+ tensor_shape = tensor_shape[:n_dims]
144
+ tensor_shape.reverse()
145
+
146
+ # Check if this is a convolution weight tensor
147
+ is_conv_weight = "weight" in key and ("encoder" in key or "_model.decoder.decoder.2" in key)
148
+
149
+ # Convert to float16 for convolution weights
150
+ if is_conv_weight:
151
+ data = data.astype(np.float16)
152
+ ftype = 1 # float16
153
+ else:
154
+ ftype = 0 # float32
155
+
156
+ # Debug printing of tensor info
157
+ print(f"\nWriting tensor: {key}")
158
+ print(f" Original shape: {tensor.shape}")
159
+ print(f" Processed shape: {data.shape}")
160
+ print(f" GGML dimensions: {n_dims}")
161
+ print(f" GGML shape: {tensor_shape}")
162
+ print(f" Type: {'float16' if ftype == 1 else 'float32'}")
163
+
164
+ # Convert tensor name to bytes
165
+ name_bytes = key.encode('utf-8')
166
+ name_length = len(name_bytes)
167
+
168
+ # Write tensor header
169
+ fout.write(struct.pack("i", n_dims))
170
+ fout.write(struct.pack("i", name_length))
171
+ fout.write(struct.pack("i", ftype))
172
+
173
+ # Write tensor dimensions
174
+ for i in range(n_dims):
175
+ size = tensor_shape[i] if i < len(tensor_shape) else 1
176
+ fout.write(struct.pack("i", size))
177
+ print(f" Writing dimension {i}: {size}")
178
+
179
+ # Write tensor name
180
+ fout.write(name_bytes)
181
+
182
+ # Write tensor data
183
+ data.tofile(fout)
184
+
185
+ print(f" Wrote {data.size * (2 if ftype==1 else 4)} bytes")
186
+
187
+ print(f"\nDone! Model has been converted to GGML format: {output_file}")
188
+ print(f"File size: {os.path.getsize(output_file)} bytes")
189
+
190
+ if __name__ == "__main__":
191
+ parser = argparse.ArgumentParser(description="Convert Silero-VAD PyTorch model to GGML format")
192
+ parser.add_argument("--output", type=str, required=True, help="Path to output GGML model file")
193
+ parser.add_argument("--print-tensors", action="store_true", help="Print tensor values", default=True)
194
+ args = parser.parse_args()
195
+
196
+ convert_silero_vad(args.output, args.print_tensors)
models/for-tests-silero-v5.1.2-ggml.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29940d98d42b91fbd05ce489f3ecf7c72f0a42f027e4875919a28fb4c04ea2cf
3
+ size 885098
src/whisper-arch.h CHANGED
@@ -139,3 +139,59 @@ static const std::map<asr_tensor, ggml_op> ASR_TENSOR_INFO = {
139
  {ASR_TENSOR_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
140
  {ASR_TENSOR_ATTN_OUT_BIAS, GGML_OP_ADD},
141
  };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  {ASR_TENSOR_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
140
  {ASR_TENSOR_ATTN_OUT_BIAS, GGML_OP_ADD},
141
  };
142
+
143
+ enum vad_tensor {
144
+ VAD_TENSOR_STFT_BASIS,
145
+ VAD_TENSOR_ENC_0_WEIGHT,
146
+ VAD_TENSOR_ENC_0_BIAS,
147
+ VAD_TENSOR_ENC_1_WEIGHT,
148
+ VAD_TENSOR_ENC_1_BIAS,
149
+ VAD_TENSOR_ENC_2_WEIGHT,
150
+ VAD_TENSOR_ENC_2_BIAS,
151
+ VAD_TENSOR_ENC_3_WEIGHT,
152
+ VAD_TENSOR_ENC_3_BIAS,
153
+ VAD_TENSOR_LSTM_WEIGHT_IH,
154
+ VAD_TENSOR_LSTM_WEIGHT_HH,
155
+ VAD_TENSOR_LSTM_BIAS_IH,
156
+ VAD_TENSOR_LSTM_BIAS_HH,
157
+ VAD_TENSOR_FINAL_CONV_WEIGHT,
158
+ VAD_TENSOR_FINAL_CONV_BIAS,
159
+ };
160
+
161
+ static const std::map<vad_tensor, ggml_op> VAD_TENSOR_OPS = {
162
+ {VAD_TENSOR_STFT_BASIS, GGML_OP_IM2COL},
163
+ {VAD_TENSOR_ENC_0_WEIGHT, GGML_OP_IM2COL},
164
+ {VAD_TENSOR_ENC_0_BIAS, GGML_OP_ADD},
165
+ {VAD_TENSOR_ENC_1_WEIGHT, GGML_OP_IM2COL},
166
+ {VAD_TENSOR_ENC_1_BIAS, GGML_OP_ADD},
167
+ {VAD_TENSOR_ENC_2_WEIGHT, GGML_OP_IM2COL},
168
+ {VAD_TENSOR_ENC_2_BIAS, GGML_OP_ADD},
169
+ {VAD_TENSOR_ENC_3_WEIGHT, GGML_OP_IM2COL},
170
+ {VAD_TENSOR_ENC_3_BIAS, GGML_OP_ADD},
171
+
172
+ {VAD_TENSOR_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT},
173
+ {VAD_TENSOR_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT},
174
+ {VAD_TENSOR_LSTM_BIAS_IH, GGML_OP_ADD},
175
+ {VAD_TENSOR_LSTM_BIAS_HH, GGML_OP_ADD},
176
+
177
+ {VAD_TENSOR_FINAL_CONV_WEIGHT, GGML_OP_IM2COL},
178
+ {VAD_TENSOR_FINAL_CONV_BIAS, GGML_OP_ADD}
179
+ };
180
+
181
+ static const std::map<vad_tensor, const char *> VAD_TENSOR_NAMES = {
182
+ {VAD_TENSOR_STFT_BASIS, "_model.stft.forward_basis_buffer"},
183
+ {VAD_TENSOR_ENC_0_WEIGHT, "_model.encoder.0.reparam_conv.weight"},
184
+ {VAD_TENSOR_ENC_0_BIAS, "_model.encoder.0.reparam_conv.bias"},
185
+ {VAD_TENSOR_ENC_1_WEIGHT, "_model.encoder.1.reparam_conv.weight"},
186
+ {VAD_TENSOR_ENC_1_BIAS, "_model.encoder.1.reparam_conv.bias"},
187
+ {VAD_TENSOR_ENC_2_WEIGHT, "_model.encoder.2.reparam_conv.weight"},
188
+ {VAD_TENSOR_ENC_2_BIAS, "_model.encoder.2.reparam_conv.bias"},
189
+ {VAD_TENSOR_ENC_3_WEIGHT, "_model.encoder.3.reparam_conv.weight"},
190
+ {VAD_TENSOR_ENC_3_BIAS, "_model.encoder.3.reparam_conv.bias"},
191
+ {VAD_TENSOR_LSTM_WEIGHT_IH, "_model.decoder.rnn.weight_ih"},
192
+ {VAD_TENSOR_LSTM_WEIGHT_HH, "_model.decoder.rnn.weight_hh"},
193
+ {VAD_TENSOR_LSTM_BIAS_IH, "_model.decoder.rnn.bias_ih"},
194
+ {VAD_TENSOR_LSTM_BIAS_HH, "_model.decoder.rnn.bias_hh"},
195
+ {VAD_TENSOR_FINAL_CONV_WEIGHT, "_model.decoder.decoder.2.weight"},
196
+ {VAD_TENSOR_FINAL_CONV_BIAS, "_model.decoder.decoder.2.bias"}
197
+ };
src/whisper.cpp CHANGED
@@ -17,6 +17,7 @@
17
  #include <atomic>
18
  #include <algorithm>
19
  #include <cassert>
 
20
  #define _USE_MATH_DEFINES
21
  #include <cmath>
22
  #include <climits>
@@ -163,7 +164,6 @@ static bool ggml_graph_compute_helper(
163
  int n_threads,
164
  ggml_abort_callback abort_callback,
165
  void * abort_callback_data) {
166
-
167
  ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
168
 
169
  auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
@@ -184,8 +184,8 @@ static bool ggml_graph_compute_helper(
184
  static bool ggml_graph_compute_helper(
185
  ggml_backend_sched_t sched,
186
  struct ggml_cgraph * graph,
187
- int n_threads) {
188
-
189
  for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
190
  ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
191
  ggml_backend_dev_t dev = ggml_backend_get_device(backend);
@@ -197,8 +197,12 @@ static bool ggml_graph_compute_helper(
197
  }
198
  }
199
 
200
- bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
201
- ggml_backend_sched_reset(sched);
 
 
 
 
202
  return t;
203
  }
204
 
@@ -949,6 +953,15 @@ struct whisper_state {
949
 
950
  // [EXPERIMENTAL] speed-up techniques
951
  int32_t exp_n_audio_ctx = 0; // 0 - use default
 
 
 
 
 
 
 
 
 
952
  };
953
 
954
  struct whisper_context {
@@ -4341,225 +4354,1337 @@ const char * whisper_print_system_info(void) {
4341
  }
4342
 
4343
  //////////////////////////////////
4344
- // Grammar - ported from llama.cpp
4345
  //////////////////////////////////
4346
 
4347
- // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
4348
- // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
4349
- static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
4350
- const char * src,
4351
- whisper_partial_utf8 partial_start) {
4352
- static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
4353
- const char * pos = src;
4354
- std::vector<uint32_t> code_points;
4355
- uint32_t value = partial_start.value;
4356
- int n_remain = partial_start.n_remain;
4357
 
4358
- // continue previous decode, if applicable
4359
- while (*pos != 0 && n_remain > 0) {
4360
- uint8_t next_byte = static_cast<uint8_t>(*pos);
4361
- if ((next_byte >> 6) != 2) {
4362
- // invalid sequence, abort
4363
- code_points.push_back(0);
4364
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
4365
- }
4366
- value = (value << 6) + (next_byte & 0x3F);
4367
- ++pos;
4368
- --n_remain;
4369
- }
4370
 
4371
- if (partial_start.n_remain > 0 && n_remain == 0) {
4372
- code_points.push_back(value);
4373
- }
4374
 
4375
- // decode any subsequent utf-8 sequences, which may end in an incomplete one
4376
- while (*pos != 0) {
4377
- uint8_t first_byte = static_cast<uint8_t>(*pos);
4378
- uint8_t highbits = first_byte >> 4;
4379
- n_remain = lookup[highbits] - 1;
4380
 
4381
- if (n_remain < 0) {
4382
- // invalid sequence, abort
4383
- code_points.clear();
4384
- code_points.push_back(0);
4385
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
4386
- }
4387
 
4388
- uint8_t mask = (1 << (7 - n_remain)) - 1;
4389
- value = first_byte & mask;
4390
- ++pos;
4391
- while (*pos != 0 && n_remain > 0) {
4392
- value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
4393
- ++pos;
4394
- --n_remain;
4395
- }
4396
- if (n_remain == 0) {
4397
- code_points.push_back(value);
4398
- }
4399
- }
4400
- code_points.push_back(0);
4401
 
4402
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
4403
- }
 
4404
 
4405
- // returns true iff pos points to the end of one of the definitions of a rule
4406
- static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
4407
- switch (pos->type) {
4408
- case WHISPER_GRETYPE_END: return true; // NOLINT
4409
- case WHISPER_GRETYPE_ALT: return true; // NOLINT
4410
- default: return false;
4411
- }
4412
- }
4413
 
4414
- // returns true iff chr satisfies the char range at pos (regular or inverse range)
4415
- // asserts that pos is pointing to a char range element
4416
- static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
4417
- const whisper_grammar_element * pos,
4418
- const uint32_t chr) {
4419
 
4420
- bool found = false;
4421
- bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
4422
 
4423
- WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
 
4424
 
4425
- do {
4426
- if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
4427
- // inclusive range, e.g. [a-z]
4428
- found = found || (pos->value <= chr && chr <= pos[1].value);
4429
- pos += 2;
4430
- } else {
4431
- // exact char match, e.g. [a] or "a"
4432
- found = found || pos->value == chr;
4433
- pos += 1;
4434
- }
4435
- } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
4436
 
4437
- return std::make_pair(found == is_positive_char, pos);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4438
  }
4439
 
4440
- // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
4441
- // range at pos (regular or inverse range)
4442
- // asserts that pos is pointing to a char range element
4443
- static bool whisper_grammar_match_partial_char(
4444
- const whisper_grammar_element * pos,
4445
- const whisper_partial_utf8 partial_utf8) {
 
 
 
 
 
4446
 
4447
- bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
4448
- WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
4449
 
4450
- uint32_t partial_value = partial_utf8.value;
4451
- int n_remain = partial_utf8.n_remain;
 
 
 
 
 
 
 
 
 
 
 
4452
 
4453
- // invalid sequence or 7-bit char split across 2 bytes (overlong)
4454
- if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
4455
- return false;
4456
- }
 
4457
 
4458
- // range of possible code points this partial UTF-8 sequence could complete to
4459
- uint32_t low = partial_value << (n_remain * 6);
4460
- uint32_t high = low | ((1 << (n_remain * 6)) - 1);
4461
 
4462
- if (low == 0) {
4463
- if (n_remain == 2) {
4464
- low = 1 << 11;
4465
- } else if (n_remain == 3) {
4466
- low = 1 << 16;
4467
- }
4468
- }
4469
 
4470
- do {
4471
- if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
4472
- // inclusive range, e.g. [a-z]
4473
- if (pos->value <= high && low <= pos[1].value) {
4474
- return is_positive_char;
 
 
4475
  }
4476
- pos += 2;
4477
- } else {
4478
- // exact char match, e.g. [a] or "a"
4479
- if (low <= pos->value && pos->value <= high) {
4480
- return is_positive_char;
4481
  }
4482
- pos += 1;
 
 
 
 
 
 
 
 
 
 
 
4483
  }
4484
- } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
4485
 
4486
- return !is_positive_char;
4487
  }
4488
 
 
 
 
 
4489
 
4490
- // transforms a grammar pushdown stack into N possible stacks, all ending
4491
- // at a character range (terminal element)
4492
- static void whisper_grammar_advance_stack(
4493
- const std::vector<std::vector<whisper_grammar_element>> & rules,
4494
- const std::vector<const whisper_grammar_element *> & stack,
4495
- std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
4496
 
4497
- if (stack.empty()) {
4498
- new_stacks.emplace_back();
4499
- return;
4500
- }
4501
 
4502
- const whisper_grammar_element * pos = stack.back();
 
 
 
4503
 
4504
- switch (pos->type) {
4505
- case WHISPER_GRETYPE_RULE_REF: {
4506
- const size_t rule_id = static_cast<size_t>(pos->value);
4507
- const whisper_grammar_element * subpos = rules[rule_id].data();
4508
- do {
4509
- // init new stack without the top (pos)
4510
- std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
4511
- if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
4512
- // if this rule ref is followed by another element, add that to stack
4513
- new_stack.push_back(pos + 1);
4514
- }
4515
- if (!whisper_grammar_is_end_of_sequence(subpos)) {
4516
- // if alternate is nonempty, add to stack
4517
- new_stack.push_back(subpos);
4518
- }
4519
- whisper_grammar_advance_stack(rules, new_stack, new_stacks);
4520
- while (!whisper_grammar_is_end_of_sequence(subpos)) {
4521
- // scan to end of alternate def
4522
- subpos++;
4523
- }
4524
- if (subpos->type == WHISPER_GRETYPE_ALT) {
4525
- // there's another alternate def of this rule to process
4526
- subpos++;
4527
- } else {
4528
- break;
4529
- }
4530
- } while (true);
4531
- break;
4532
- }
4533
- case WHISPER_GRETYPE_CHAR:
4534
- case WHISPER_GRETYPE_CHAR_NOT:
4535
- new_stacks.push_back(stack);
4536
- break;
4537
- default:
4538
- // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
4539
- // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
4540
- // those
4541
- WHISPER_ASSERT(false);
4542
- }
4543
  }
4544
 
4545
- // takes a set of possible pushdown stacks on a grammar, which are required to
4546
- // be positioned at a character range (see `whisper_grammar_advance_stack`), and
4547
- // produces the N possible stacks if the given char is accepted at those
4548
- // positions
4549
- static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
4550
- const std::vector<std::vector<whisper_grammar_element>> & rules,
4551
- const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
4552
- const uint32_t chr) {
4553
 
4554
- std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
 
 
 
4555
 
4556
- for (const auto & stack : stacks) {
4557
- if (stack.empty()) {
4558
- continue;
4559
- }
4560
 
4561
- auto match = whisper_grammar_match_char(stack.back(), chr);
4562
- if (match.first) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4563
  const whisper_grammar_element * pos = match.second;
4564
 
4565
  // update top of stack to next element, if any
@@ -4856,6 +5981,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4856
  /*.n_grammar_rules =*/ 0,
4857
  /*.i_start_rule =*/ 0,
4858
  /*.grammar_penalty =*/ 100.0f,
 
 
 
 
 
4859
  };
4860
 
4861
  switch (strategy) {
@@ -5472,6 +6602,117 @@ static void whisper_sequence_score(
5472
  }
5473
  }
5474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5475
  int whisper_full_with_state(
5476
  struct whisper_context * ctx,
5477
  struct whisper_state * state,
@@ -5483,9 +6724,24 @@ int whisper_full_with_state(
5483
 
5484
  result_all.clear();
5485
 
5486
- if (n_samples > 0) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5487
  // compute log mel spectrogram
5488
- if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
5489
  WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
5490
  return -2;
5491
  }
@@ -6530,19 +7786,133 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
6530
  }
6531
 
6532
  int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
6533
- return state->result_all[i_segment].t0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6534
  }
6535
 
6536
  int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
6537
- return ctx->state->result_all[i_segment].t0;
6538
  }
6539
 
6540
  int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
6541
- return state->result_all[i_segment].t1;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6542
  }
6543
 
6544
  int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
6545
- return ctx->state->result_all[i_segment].t1;
6546
  }
6547
 
6548
  bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
 
17
  #include <atomic>
18
  #include <algorithm>
19
  #include <cassert>
20
+ #include <cfloat>
21
  #define _USE_MATH_DEFINES
22
  #include <cmath>
23
  #include <climits>
 
164
  int n_threads,
165
  ggml_abort_callback abort_callback,
166
  void * abort_callback_data) {
 
167
  ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
168
 
169
  auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
 
184
  static bool ggml_graph_compute_helper(
185
  ggml_backend_sched_t sched,
186
  struct ggml_cgraph * graph,
187
+ int n_threads,
188
+ bool sched_reset = true) {
189
  for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
190
  ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
191
  ggml_backend_dev_t dev = ggml_backend_get_device(backend);
 
197
  }
198
  }
199
 
200
+ const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
201
+
202
+ if (!t || sched_reset) {
203
+ ggml_backend_sched_reset(sched);
204
+ }
205
+
206
  return t;
207
  }
208
 
 
953
 
954
  // [EXPERIMENTAL] speed-up techniques
955
  int32_t exp_n_audio_ctx = 0; // 0 - use default
956
+
957
+ struct vad_segment_info {
958
+ float orig_start;
959
+ float orig_end;
960
+ float vad_start;
961
+ float vad_end;
962
+ };
963
+ std::vector<vad_segment_info> vad_segments;
964
+ bool has_vad_segments = false;
965
  };
966
 
967
  struct whisper_context {
 
4354
  }
4355
 
4356
  //////////////////////////////////
4357
+ // Voice Activity Detection (VAD)
4358
  //////////////////////////////////
4359
 
4360
+ struct whisper_vad_hparams {
4361
+ int32_t n_encoder_layers;
4362
+ int32_t * encoder_in_channels;
4363
+ int32_t * encoder_out_channels;
4364
+ int32_t * kernel_sizes;
4365
+ int32_t lstm_input_size;
4366
+ int32_t lstm_hidden_size;
4367
+ int32_t final_conv_in;
4368
+ int32_t final_conv_out;
4369
+ };
4370
 
4371
+ struct whisper_vad_model {
4372
+ std::string type;
4373
+ std::string version;
4374
+ whisper_vad_hparams hparams;
 
 
 
 
 
 
 
 
4375
 
4376
+ struct ggml_tensor * stft_forward_basis; // [256, 1, 258]
 
 
4377
 
4378
+ // Encoder tensors - 4 convolutional layers
4379
+ struct ggml_tensor * encoder_0_weight; // [3, 129, 128]
4380
+ struct ggml_tensor * encoder_0_bias; // [128]
 
 
4381
 
4382
+ // Second encoder layer
4383
+ struct ggml_tensor * encoder_1_weight; // [3, 128, 64]
4384
+ struct ggml_tensor * encoder_1_bias; // [64]
 
 
 
4385
 
4386
+ // Third encoder layer
4387
+ struct ggml_tensor * encoder_2_weight; // [3, 64, 64]
4388
+ struct ggml_tensor * encoder_2_bias; // [64]
 
 
 
 
 
 
 
 
 
 
4389
 
4390
+ // Fourth encoder layer
4391
+ struct ggml_tensor * encoder_3_weight; // [3, 64, 128]
4392
+ struct ggml_tensor * encoder_3_bias; // [128]
4393
 
4394
+ // LSTM decoder tensors
4395
+ struct ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden
4396
+ struct ggml_tensor * lstm_ih_bias; // [512]
4397
+ struct ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden
4398
+ struct ggml_tensor * lstm_hh_bias; // [512]
 
 
 
4399
 
4400
+ // Final conv layer
4401
+ struct ggml_tensor * final_conv_weight; // [128]
4402
+ struct ggml_tensor * final_conv_bias; // [1]
 
 
4403
 
4404
+ // ggml contexts
4405
+ std::vector<ggml_context *> ctxs;
4406
 
4407
+ // buffer for the model tensors
4408
+ std::vector<ggml_backend_buffer_t> buffers;
4409
 
4410
+ // tensors
4411
+ int n_loaded;
4412
+ std::map<std::string, struct ggml_tensor *> tensors;
4413
+ };
 
 
 
 
 
 
 
4414
 
4415
+ struct whisper_vad_segment {
4416
+ float start; // Start time in seconds
4417
+ float end; // End time in seconds
4418
+ };
4419
+
4420
+ struct whisper_vad_segments {
4421
+ std::vector<whisper_vad_segment> data;
4422
+ };
4423
+
4424
+ struct whisper_vad_context {
4425
+ int64_t t_vad_us = 0;
4426
+
4427
+ int n_window;
4428
+ int n_context;
4429
+ int n_threads;
4430
+
4431
+ std::vector<ggml_backend_t> backends;
4432
+ ggml_backend_buffer_t buffer = nullptr;
4433
+ whisper_context_params params;
4434
+ std::vector<uint8_t> ctx_buf;
4435
+ whisper_sched sched;
4436
+
4437
+ whisper_vad_model model;
4438
+ std::string path_model;
4439
+ struct ggml_tensor * h_state;
4440
+ struct ggml_tensor * c_state;
4441
+ std::vector<float> probs;
4442
+ };
4443
+
4444
+ struct whisper_vad_context_params whisper_vad_default_context_params(void) {
4445
+ whisper_vad_context_params result = {
4446
+ /*.n_thread = */ 4,
4447
+ /*.use_gpu = */ false,
4448
+ /*.gpu_device = */ 0,
4449
+ };
4450
+ return result;
4451
  }
4452
 
4453
+ struct whisper_vad_params whisper_vad_default_params(void) {
4454
+ whisper_vad_params result = {
4455
+ /* threshold = */ 0.5f,
4456
+ /* min_speech_duration_ms = */ 250,
4457
+ /* min_silence_duration_ms = */ 100,
4458
+ /* max_speech_duration_s = */ FLT_MAX,
4459
+ /* speech_pad_ms = */ 30,
4460
+ /* samples_overlap = */ 0.1,
4461
+ };
4462
+ return result;
4463
+ }
4464
 
4465
+ static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
4466
+ bool op_supported = true;
4467
 
4468
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
4469
+ (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
4470
+ // GPU and default CPU backend support all operators
4471
+ op_supported = true;
4472
+ } else {
4473
+ switch (op) {
4474
+ // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
4475
+ case GGML_OP_MUL_MAT: {
4476
+ ggml_init_params params = {
4477
+ /*.mem_size =*/ 2 * ggml_tensor_overhead(),
4478
+ /*.mem_buffer =*/ nullptr,
4479
+ /*.no_alloc =*/ true,
4480
+ };
4481
 
4482
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
4483
+ if (!ctx_ptr) {
4484
+ throw std::runtime_error("failed to create ggml context");
4485
+ }
4486
+ ggml_context * ctx = ctx_ptr.get();
4487
 
4488
+ ggml_tensor * op_tensor = nullptr;
 
 
4489
 
4490
+ int64_t n_ctx = hparams.lstm_hidden_size;
4491
+ ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
4492
+ op_tensor = ggml_mul_mat(ctx, w, b);
 
 
 
 
4493
 
4494
+ // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
4495
+ GGML_ASSERT(w->buffer == nullptr);
4496
+ w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
4497
+ op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
4498
+ ggml_backend_buffer_free(w->buffer);
4499
+ w->buffer = nullptr;
4500
+ break;
4501
  }
4502
+ default: {
4503
+ op_supported = false;
4504
+ break;
 
 
4505
  }
4506
+ };
4507
+ }
4508
+ return op_supported;
4509
+ }
4510
+
4511
+ static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
4512
+ GGML_ASSERT(!buft_list.empty());
4513
+ for (const auto & p : buft_list) {
4514
+ ggml_backend_dev_t dev = p.first;
4515
+ ggml_backend_buffer_type_t buft = p.second;
4516
+ if (weight_buft_supported(hparams, w, op, buft, dev)) {
4517
+ return buft;
4518
  }
4519
+ }
4520
 
4521
+ return nullptr;
4522
  }
4523
 
4524
+ static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0,
4525
+ const whisper_vad_model & model, ggml_tensor * cur) {
4526
+ // Apply reflective padding to the input tensor
4527
+ ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64);
4528
 
4529
+ struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
 
 
 
 
 
4530
 
4531
+ // Calculate cutoff for real/imaginary parts
4532
+ int cutoff = model.stft_forward_basis->ne[2] / 2;
 
 
4533
 
4534
+ // Extract real part (first half of the STFT output).
4535
+ struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
4536
+ // Extract imaginary part (second half of the STFT output).
4537
+ struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
4538
 
4539
+ // Calculate magnitude: sqrt(real^2 + imag^2)
4540
+ struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part);
4541
+ struct ggml_tensor * img_squared = ggml_mul(ctx0, img_part, img_part);
4542
+ struct ggml_tensor * sum_squares = ggml_add(ctx0, real_squared, img_squared);
4543
+ struct ggml_tensor * magnitude = ggml_sqrt(ctx0, sum_squares);
4544
+ return magnitude;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4545
  }
4546
 
4547
+ static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0,
4548
+ const whisper_vad_model & model, ggml_tensor * cur) {
4549
+ // First Conv1D: expands to 128 channels.
4550
+ cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
4551
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
4552
+ cur = ggml_relu(ctx0, cur);
 
 
4553
 
4554
+ // Second Conv1D: reduces to 64 channels.
4555
+ cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
4556
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
4557
+ cur = ggml_relu(ctx0, cur);
4558
 
4559
+ // Third Conv1D: maintains 64 channels
4560
+ cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
4561
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
4562
+ cur = ggml_relu(ctx0, cur);
4563
 
4564
+ // Fourth Conv1D: expands to 128 channels
4565
+ cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
4566
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
4567
+ cur = ggml_relu(ctx0, cur);
4568
+
4569
+ return cur;
4570
+ }
4571
+
4572
+ static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
4573
+ const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) {
4574
+ const whisper_vad_model & model = vctx.model;
4575
+ const int hdim = model.hparams.lstm_hidden_size;
4576
+
4577
+ struct ggml_tensor * x_t = ggml_transpose(ctx0, cur);
4578
+
4579
+ // Create operations using the input-to-hidden weights.
4580
+ struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
4581
+ inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
4582
+
4583
+ // Create operations using the hidden-to-hidden weights.
4584
+ struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
4585
+ hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
4586
+
4587
+ // Create add operation to get preactivations for all gates.
4588
+ struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate);
4589
+
4590
+ const size_t hdim_size = ggml_row_size(out_gate->type, hdim);
4591
+
4592
+ // Create sigmoid for input gate (using the first 128 bytes from the preactivations).
4593
+ struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
4594
+
4595
+ // Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
4596
+ struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
4597
+
4598
+ // Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
4599
+ struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
4600
+
4601
+ // Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
4602
+ struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
4603
+
4604
+ // Update cell state
4605
+ struct ggml_tensor * c_out = ggml_add(ctx0,
4606
+ ggml_mul(ctx0, f_t, vctx.c_state),
4607
+ ggml_mul(ctx0, i_t, g_t));
4608
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state));
4609
+
4610
+ // Update hidden state
4611
+ struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out));
4612
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state));
4613
+
4614
+ return out;
4615
+ }
4616
+
4617
+ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
4618
+ const auto & model = vctx.model;
4619
+
4620
+ struct ggml_init_params params = {
4621
+ /*.mem_size =*/ vctx.sched.meta.size(),
4622
+ /*.mem_buffer =*/ vctx.sched.meta.data(),
4623
+ /*.no_alloc =*/ true,
4624
+ };
4625
+
4626
+ struct ggml_context * ctx0 = ggml_init(params);
4627
+
4628
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
4629
+
4630
+ struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1);
4631
+ ggml_set_name(frame, "frame");
4632
+ ggml_set_input(frame);
4633
+
4634
+ struct ggml_tensor * cur = nullptr;
4635
+ {
4636
+ cur = whisper_vad_build_stft_layer(ctx0, model, frame);
4637
+
4638
+ cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
4639
+
4640
+ // Extract the first element of the first dimension
4641
+ // (equivalent to pytorch's [:, :, 0])
4642
+ cur = ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
4643
+
4644
+ cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
4645
+ cur = ggml_relu(ctx0, cur);
4646
+ cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
4647
+ cur = ggml_add(ctx0, cur, model.final_conv_bias);
4648
+ cur = ggml_sigmoid(ctx0, cur);
4649
+ ggml_set_name(cur, "prob");
4650
+ ggml_set_output(cur);
4651
+ }
4652
+
4653
+ ggml_build_forward_expand(gf, cur);
4654
+
4655
+ ggml_free(ctx0);
4656
+
4657
+ return gf;
4658
+ }
4659
+
4660
+ static bool whisper_vad_init_context(whisper_vad_context * vctx) {
4661
+
4662
+ auto whisper_context_params = whisper_context_default_params();
4663
+ // TODO: GPU VAD is forced disabled until the performance is improved
4664
+ //whisper_context_params.use_gpu = vctx->params.use_gpu;
4665
+ whisper_context_params.use_gpu = false;
4666
+ whisper_context_params.gpu_device = vctx->params.gpu_device;
4667
+
4668
+ vctx->backends = whisper_backend_init(whisper_context_params);
4669
+ if (vctx->backends.empty()) {
4670
+ WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
4671
+ return false;
4672
+ }
4673
+
4674
+ const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
4675
+
4676
+ vctx->ctx_buf.resize(2u*ggml_tensor_overhead());
4677
+
4678
+ struct ggml_init_params params = {
4679
+ /*.mem_size =*/ vctx->ctx_buf.size(),
4680
+ /*.mem_buffer =*/ vctx->ctx_buf.data(),
4681
+ /*.no_alloc =*/ true,
4682
+ };
4683
+
4684
+ ggml_context * ctx = ggml_init(params);
4685
+ if (!ctx) {
4686
+ WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
4687
+ return false;
4688
+ }
4689
+
4690
+ // LSTM Hidden state
4691
+ vctx->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
4692
+ ggml_set_name(vctx->h_state, "h_state");
4693
+
4694
+ // LSTM Cell state
4695
+ vctx->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
4696
+ ggml_set_name(vctx->c_state, "c_state");
4697
+
4698
+ vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
4699
+ if (!vctx->buffer) {
4700
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
4701
+ return false;
4702
+ }
4703
+
4704
+ {
4705
+ bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
4706
+ [&]() {
4707
+ return whisper_vad_build_graph(*vctx);
4708
+ });
4709
+
4710
+ if (!ok) {
4711
+ WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
4712
+ return false;
4713
+ }
4714
+
4715
+ WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
4716
+ }
4717
+
4718
+ return true;
4719
+ }
4720
+
4721
+ struct whisper_vad_context * whisper_vad_init_from_file_with_params(
4722
+ const char * path_model,
4723
+ struct whisper_vad_context_params params) {
4724
+ WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
4725
+ #ifdef _MSC_VER
4726
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
4727
+ std::wstring path_model_wide = converter.from_bytes(path_model);
4728
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
4729
+ #else
4730
+ auto fin = std::ifstream(path_model, std::ios::binary);
4731
+ #endif
4732
+ if (!fin) {
4733
+ WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
4734
+ return nullptr;
4735
+ }
4736
+
4737
+ whisper_model_loader loader = {};
4738
+ loader.context = &fin;
4739
+
4740
+ loader.read = [](void * ctx, void * output, size_t read_size) {
4741
+ std::ifstream * fin = (std::ifstream*)ctx;
4742
+ fin->read((char *)output, read_size);
4743
+ return read_size;
4744
+ };
4745
+
4746
+ loader.eof = [](void * ctx) {
4747
+ std::ifstream * fin = (std::ifstream*)ctx;
4748
+ return fin->eof();
4749
+ };
4750
+
4751
+ loader.close = [](void * ctx) {
4752
+ std::ifstream * fin = (std::ifstream*)ctx;
4753
+ fin->close();
4754
+ };
4755
+
4756
+ auto ctx = whisper_vad_init_with_params(&loader, params);
4757
+ if (!ctx) {
4758
+ whisper_vad_free(ctx);
4759
+ return nullptr;
4760
+ }
4761
+ ctx->path_model = path_model;
4762
+ return ctx;
4763
+ }
4764
+
4765
+ struct whisper_vad_context * whisper_vad_init_with_params(
4766
+ struct whisper_model_loader * loader,
4767
+ struct whisper_vad_context_params params) {
4768
+ // Read the VAD model
4769
+ {
4770
+ uint32_t magic;
4771
+ read_safe(loader, magic);
4772
+ if (magic != GGML_FILE_MAGIC) {
4773
+ WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
4774
+ return nullptr;
4775
+ }
4776
+ }
4777
+
4778
+ whisper_vad_context * vctx = new whisper_vad_context;
4779
+ vctx->n_threads = params.n_threads;
4780
+ vctx->params.use_gpu = params.use_gpu;
4781
+ vctx->params.gpu_device = params.gpu_device;
4782
+
4783
+ auto & model = vctx->model;
4784
+ auto & hparams = model.hparams;
4785
+
4786
+ // load model context params.
4787
+ {
4788
+ int32_t str_len;
4789
+ read_safe(loader, str_len);
4790
+ std::vector<char> buffer(str_len + 1, 0);
4791
+ loader->read(loader->context, buffer.data(), str_len);
4792
+ std::string model_type(buffer.data(), str_len);
4793
+ model.type = model_type;
4794
+ WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
4795
+
4796
+ int32_t major, minor, patch;
4797
+ read_safe(loader, major);
4798
+ read_safe(loader, minor);
4799
+ read_safe(loader, patch);
4800
+ std::string version_str = std::to_string(major) + "." +
4801
+ std::to_string(minor) + "." +
4802
+ std::to_string(patch);
4803
+ model.version = version_str;
4804
+ WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
4805
+
4806
+ read_safe(loader, vctx->n_window);
4807
+ read_safe(loader, vctx->n_context);
4808
+ }
4809
+
4810
+ // load model hyper params (hparams).
4811
+ {
4812
+ read_safe(loader, hparams.n_encoder_layers);
4813
+
4814
+ hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
4815
+ hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
4816
+ hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
4817
+
4818
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4819
+ read_safe(loader, hparams.encoder_in_channels[i]);
4820
+ read_safe(loader, hparams.encoder_out_channels[i]);
4821
+ read_safe(loader, hparams.kernel_sizes[i]);
4822
+ }
4823
+
4824
+ read_safe(loader, hparams.lstm_input_size);
4825
+ read_safe(loader, hparams.lstm_hidden_size);
4826
+ read_safe(loader, hparams.final_conv_in);
4827
+ read_safe(loader, hparams.final_conv_out);
4828
+
4829
+ WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
4830
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4831
+ WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
4832
+ }
4833
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4834
+ WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
4835
+ }
4836
+ WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
4837
+ WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
4838
+ WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
4839
+ WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
4840
+ }
4841
+
4842
+ // 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
4843
+ const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
4844
+
4845
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
4846
+ auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
4847
+ auto it = ctx_map.find(buft);
4848
+ if (it == ctx_map.end()) {
4849
+ ggml_init_params params = {
4850
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
4851
+ /*.mem_buffer =*/ nullptr,
4852
+ /*.no_alloc =*/ true,
4853
+ };
4854
+
4855
+ ggml_context * ctx = ggml_init(params);
4856
+ if (!ctx) {
4857
+ throw std::runtime_error("failed to create ggml context");
4858
+ }
4859
+
4860
+ ctx_map[buft] = ctx;
4861
+ model.ctxs.emplace_back(ctx);
4862
+
4863
+ return ctx;
4864
+ }
4865
+
4866
+ return it->second;
4867
+ };
4868
+
4869
+ whisper_context_params wparams = whisper_context_default_params();
4870
+ wparams.use_gpu = params.use_gpu;
4871
+ wparams.gpu_device = params.gpu_device;
4872
+ buft_list_t buft_list = make_buft_list(wparams);
4873
+
4874
+ auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * {
4875
+ ggml_op op = VAD_TENSOR_OPS.at(type);
4876
+ ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
4877
+ if (!buft) {
4878
+ throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
4879
+ }
4880
+ ggml_context * ctx = get_ctx(buft);
4881
+ ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
4882
+ model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
4883
+
4884
+ return tensor;
4885
+ };
4886
+
4887
+ // create tensors
4888
+ {
4889
+ ggml_init_params params = {
4890
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
4891
+ /*.mem_buffer =*/ nullptr,
4892
+ /*.no_alloc =*/ true,
4893
+ };
4894
+
4895
+ ggml_context * ctx = ggml_init(params);
4896
+ const auto & hparams = model.hparams;
4897
+
4898
+ // SFTF precomputed basis matrix
4899
+ model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
4900
+ ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258));
4901
+
4902
+ model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
4903
+ ggml_new_tensor_3d(
4904
+ ctx,
4905
+ GGML_TYPE_F16,
4906
+ hparams.kernel_sizes[0],
4907
+ hparams.encoder_in_channels[0],
4908
+ hparams.encoder_out_channels[0]
4909
+ ));
4910
+ model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
4911
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0]));
4912
+
4913
+ model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
4914
+ ggml_new_tensor_3d(
4915
+ ctx,
4916
+ GGML_TYPE_F16,
4917
+ hparams.kernel_sizes[1],
4918
+ hparams.encoder_in_channels[1],
4919
+ hparams.encoder_out_channels[1]
4920
+ ));
4921
+ model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
4922
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1]));
4923
+
4924
+ model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
4925
+ ggml_new_tensor_3d(
4926
+ ctx,
4927
+ GGML_TYPE_F16,
4928
+ hparams.kernel_sizes[2],
4929
+ hparams.encoder_in_channels[2],
4930
+ hparams.encoder_out_channels[2]
4931
+ ));
4932
+ model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
4933
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2]));
4934
+
4935
+ model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
4936
+ ggml_new_tensor_3d(
4937
+ ctx,
4938
+ GGML_TYPE_F16,
4939
+ hparams.kernel_sizes[3],
4940
+ hparams.encoder_in_channels[3],
4941
+ hparams.encoder_out_channels[3]
4942
+ ));
4943
+ model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
4944
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3]));
4945
+
4946
+ // Hidden State dimension (input gate, forget gate, cell gate, output gate)
4947
+ const int hstate_dim = hparams.lstm_hidden_size * 4;
4948
+
4949
+ // LSTM weights - input to hidden
4950
+ model.lstm_ih_weight = create_tensor(
4951
+ VAD_TENSOR_LSTM_WEIGHT_IH,
4952
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
4953
+ );
4954
+ model.lstm_ih_bias = create_tensor(
4955
+ VAD_TENSOR_LSTM_BIAS_IH,
4956
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
4957
+ );
4958
+
4959
+ // LSTM weights - hidden to hidden
4960
+ model.lstm_hh_weight = create_tensor(
4961
+ VAD_TENSOR_LSTM_WEIGHT_HH,
4962
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
4963
+ );
4964
+ model.lstm_hh_bias = create_tensor(
4965
+ VAD_TENSOR_LSTM_BIAS_HH,
4966
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
4967
+ );
4968
+
4969
+ // Final conv layer weight
4970
+ model.final_conv_weight = create_tensor(
4971
+ VAD_TENSOR_FINAL_CONV_WEIGHT,
4972
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1)
4973
+ );
4974
+ model.final_conv_bias = create_tensor(
4975
+ VAD_TENSOR_FINAL_CONV_BIAS,
4976
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)
4977
+ );
4978
+
4979
+ ggml_free(ctx);
4980
+ }
4981
+
4982
+ // allocate tensors in the backend buffers
4983
+ for (auto & p : ctx_map) {
4984
+ ggml_backend_buffer_type_t buft = p.first;
4985
+ ggml_context * ctx = p.second;
4986
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
4987
+ if (buf) {
4988
+ model.buffers.emplace_back(buf);
4989
+
4990
+ size_t size_main = ggml_backend_buffer_get_size(buf);
4991
+ WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
4992
+ }
4993
+ }
4994
+
4995
+ // load weights
4996
+ {
4997
+ size_t total_size = 0;
4998
+ model.n_loaded = 0;
4999
+ std::vector<char> read_buf;
5000
+
5001
+ while (true) {
5002
+ int32_t n_dims;
5003
+ int32_t length;
5004
+ int32_t ttype;
5005
+
5006
+ read_safe(loader, n_dims);
5007
+ read_safe(loader, length);
5008
+ read_safe(loader, ttype);
5009
+
5010
+ if (loader->eof(loader->context)) {
5011
+ break;
5012
+ }
5013
+
5014
+ int32_t nelements = 1;
5015
+ int32_t ne[4] = { 1, 1, 1, 1 };
5016
+ for (int i = 0; i < n_dims; ++i) {
5017
+ read_safe(loader, ne[i]);
5018
+ nelements *= ne[i];
5019
+ }
5020
+
5021
+ std::string name;
5022
+ std::vector<char> tmp(length);
5023
+ loader->read(loader->context, &tmp[0], tmp.size());
5024
+ name.assign(&tmp[0], tmp.size());
5025
+
5026
+ if (model.tensors.find(name) == model.tensors.end()) {
5027
+ WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
5028
+ return nullptr;
5029
+ }
5030
+
5031
+ auto tensor = model.tensors[name.data()];
5032
+
5033
+ if (ggml_nelements(tensor) != nelements) {
5034
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
5035
+ WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
5036
+ __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
5037
+ return nullptr;
5038
+ }
5039
+
5040
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
5041
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
5042
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
5043
+ return nullptr;
5044
+ }
5045
+
5046
+ const size_t bpe = ggml_type_size(ggml_type(ttype));
5047
+
5048
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
5049
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
5050
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
5051
+ return nullptr;
5052
+ }
5053
+
5054
+ if (ggml_backend_buffer_is_host(tensor->buffer)) {
5055
+ // for the CPU and Metal backend, we can read directly into the tensor
5056
+ loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
5057
+ BYTESWAP_TENSOR(tensor);
5058
+ } else {
5059
+ // read into a temporary buffer first, then copy to device memory
5060
+ read_buf.resize(ggml_nbytes(tensor));
5061
+
5062
+ loader->read(loader->context, read_buf.data(), read_buf.size());
5063
+
5064
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
5065
+ }
5066
+
5067
+ total_size += ggml_nbytes(tensor);
5068
+ model.n_loaded++;
5069
+ }
5070
+
5071
+ WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
5072
+
5073
+ if (model.n_loaded == 0) {
5074
+ WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
5075
+ } else if (model.n_loaded != (int) model.tensors.size()) {
5076
+ WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
5077
+ return nullptr;
5078
+ }
5079
+
5080
+ }
5081
+
5082
+ if (!whisper_vad_init_context(vctx)) {
5083
+ whisper_vad_free(vctx);
5084
+ return nullptr;
5085
+ }
5086
+
5087
+ return vctx;
5088
+ }
5089
+
5090
+ bool whisper_vad_detect_speech(
5091
+ struct whisper_vad_context * vctx,
5092
+ const float * samples,
5093
+ int n_samples) {
5094
+ int n_chunks = n_samples / vctx->n_window;
5095
+ if (n_samples % vctx->n_window != 0) {
5096
+ n_chunks += 1; // Add one more chunk for remaining samples.
5097
+ }
5098
+
5099
+ WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
5100
+ WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
5101
+
5102
+ // Reset LSTM hidden/cell states
5103
+ ggml_backend_buffer_clear(vctx->buffer, 0);
5104
+
5105
+ vctx->probs.resize(n_chunks);
5106
+ WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
5107
+
5108
+ std::vector<float> window(vctx->n_window, 0.0f);
5109
+
5110
+ auto & sched = vctx->sched.sched;
5111
+
5112
+ ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
5113
+
5114
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
5115
+ WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
5116
+ return false;
5117
+ }
5118
+
5119
+ struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
5120
+ struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob");
5121
+
5122
+ // we are going to reuse the graph multiple times for each chunk
5123
+ const int64_t t_start_vad_us = ggml_time_us();
5124
+
5125
+ for (int i = 0; i < n_chunks; i++) {
5126
+ const int idx_start = i * vctx->n_window;
5127
+ const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
5128
+
5129
+ const int chunk_len = idx_end - idx_start;
5130
+
5131
+ if (chunk_len < vctx->n_window) {
5132
+ WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
5133
+ std::vector<float> partial_chunk(vctx->n_window, 0.0f);
5134
+ std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
5135
+
5136
+ // Copy the zero-padded chunk to the window.
5137
+ const int samples_to_copy_max = vctx->n_window;
5138
+ const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
5139
+ std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
5140
+ if (samples_to_copy_cur < samples_to_copy_max) {
5141
+ std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
5142
+ }
5143
+ } else {
5144
+ // Copy current frame samples to the window.
5145
+ const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
5146
+ std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
5147
+ }
5148
+
5149
+ // Set the frame tensor data with the samples.
5150
+ ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float));
5151
+
5152
+ // do not reset the scheduler - we will reuse the graph in the next chunk
5153
+ if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
5154
+ WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
5155
+ break;
5156
+ }
5157
+
5158
+ // Get the probability for this chunk.
5159
+ ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
5160
+
5161
+ //WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
5162
+ }
5163
+
5164
+ vctx->t_vad_us += ggml_time_us() - t_start_vad_us;
5165
+ WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
5166
+
5167
+ ggml_backend_sched_reset(sched);
5168
+
5169
+ return true;
5170
+ }
5171
+
5172
+ int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
5173
+ return segments->data.size();
5174
+ }
5175
+
5176
+ float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
5177
+ return segments->data[i_segment].start;
5178
+ }
5179
+
5180
+ float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
5181
+ return segments->data[i_segment].end;
5182
+ }
5183
+
5184
+ int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
5185
+ return vctx->probs.size();
5186
+ }
5187
+
5188
+ float * whisper_vad_probs(struct whisper_vad_context * vctx) {
5189
+ return vctx->probs.data();
5190
+ }
5191
+
5192
+ struct whisper_vad_segments * whisper_vad_segments_from_probs(
5193
+ struct whisper_vad_context * vctx,
5194
+ whisper_vad_params params) {
5195
+ WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
5196
+
5197
+ int n_probs = whisper_vad_n_probs(vctx);
5198
+ float * probs = whisper_vad_probs(vctx);
5199
+ float threshold = params.threshold;
5200
+ int min_speech_duration_ms = params.min_speech_duration_ms;
5201
+ int min_silence_duration_ms = params.min_silence_duration_ms;
5202
+ float max_speech_duration_s = params.max_speech_duration_s;
5203
+ int speech_pad_ms = params.speech_pad_ms;
5204
+ int n_window = vctx->n_window;
5205
+ int sample_rate = WHISPER_SAMPLE_RATE;
5206
+ int min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
5207
+ int audio_length_samples = n_probs * n_window;
5208
+
5209
+ // Min number of samples to be considered valid speech.
5210
+ int min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
5211
+ int speech_pad_samples = sample_rate * speech_pad_ms / 1000;
5212
+
5213
+ // Max number of samples that a speech segment can contain before it is
5214
+ // split into multiple segments.
5215
+ int max_speech_samples;
5216
+ if (max_speech_duration_s > 100000.0f) {
5217
+ max_speech_samples = INT_MAX / 2;
5218
+ } else {
5219
+ int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
5220
+ max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
5221
+ if (max_speech_samples < 0) {
5222
+ max_speech_samples = INT_MAX / 2;
5223
+ }
5224
+ }
5225
+ // Detect silence period that exceeds this value, then that location (sample)
5226
+ // is marked as a potential place where the segment could be split if
5227
+ // max_speech_samples is reached. The value 98 was taken from the original
5228
+ // silaro-vad python implementation:
5229
+ //https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
5230
+ int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
5231
+
5232
+ // Calculate lower threshold for detecting end of speech segments.
5233
+ float neg_threshold = threshold - 0.15f;
5234
+ if (neg_threshold < 0.01f) {
5235
+ neg_threshold = 0.01f;
5236
+ }
5237
+
5238
+ struct speech_segment_t {
5239
+ int start;
5240
+ int end;
5241
+ };
5242
+
5243
+ std::vector<speech_segment_t> speeches;
5244
+ speeches.reserve(256);
5245
+
5246
+ bool is_speech_segment = false;
5247
+ int temp_end = 0;
5248
+ int prev_end = 0;
5249
+ int next_start = 0;
5250
+ int curr_speech_start = 0;
5251
+ bool has_curr_speech = false;
5252
+
5253
+ for (int i = 0; i < n_probs; i++) {
5254
+ float curr_prob = probs[i];
5255
+ int curr_sample = n_window * i;
5256
+
5257
+ // Reset temp_end when we get back to speech
5258
+ if ((curr_prob >= threshold) && temp_end) {
5259
+ temp_end = 0;
5260
+ if (next_start < prev_end) {
5261
+ next_start = curr_sample;
5262
+ }
5263
+ }
5264
+
5265
+ // Start a new speech segment when probability exceeds threshold and not already in speech
5266
+ if ((curr_prob >= threshold) && !is_speech_segment) {
5267
+ is_speech_segment = true;
5268
+ curr_speech_start = curr_sample;
5269
+ has_curr_speech = true;
5270
+ continue;
5271
+ }
5272
+
5273
+ // Handle maximum speech duration
5274
+ if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
5275
+ if (prev_end) {
5276
+ speeches.push_back({ curr_speech_start, prev_end });
5277
+ has_curr_speech = true;
5278
+
5279
+ if (next_start < prev_end) { // Previously reached silence and is still not speech
5280
+ is_speech_segment = false;
5281
+ has_curr_speech = false;
5282
+ } else {
5283
+ curr_speech_start = next_start;
5284
+ }
5285
+ prev_end = next_start = temp_end = 0;
5286
+ } else {
5287
+ speeches.push_back({ curr_speech_start, curr_sample });
5288
+
5289
+ prev_end = next_start = temp_end = 0;
5290
+ is_speech_segment = false;
5291
+ has_curr_speech = false;
5292
+ continue;
5293
+ }
5294
+ }
5295
+
5296
+ // Handle silence after speech
5297
+ if ((curr_prob < neg_threshold) && is_speech_segment) {
5298
+ if (!temp_end) {
5299
+ temp_end = curr_sample;
5300
+ }
5301
+
5302
+ // Track potential segment ends for max_speech handling
5303
+ if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
5304
+ prev_end = temp_end;
5305
+ }
5306
+
5307
+ // Check if silence is long enough to end the segment
5308
+ if ((curr_sample - temp_end) < min_silence_samples) {
5309
+ continue;
5310
+ } else {
5311
+ // End the segment if it's long enough
5312
+ if ((temp_end - curr_speech_start) > min_speech_samples) {
5313
+ speeches.push_back({ curr_speech_start, temp_end });
5314
+ }
5315
+
5316
+ prev_end = next_start = temp_end = 0;
5317
+ is_speech_segment = false;
5318
+ has_curr_speech = false;
5319
+ continue;
5320
+ }
5321
+ }
5322
+ }
5323
+
5324
+ // Handle the case if we're still in a speech segment at the end
5325
+ if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
5326
+ speeches.push_back({ curr_speech_start, audio_length_samples });
5327
+ }
5328
+
5329
+ // Merge adjacent segments with small gaps in between (post-processing)
5330
+ if (speeches.size() > 1) {
5331
+ int merged_count = 0;
5332
+ for (int i = 0; i < (int) speeches.size() - 1; i++) {
5333
+ // Define maximum gap allowed for merging (e.g., 200ms converted to samples)
5334
+ int max_merge_gap_samples = sample_rate * 200 / 1000;
5335
+
5336
+ // If the gap between this segment and the next is small enough
5337
+ if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
5338
+ // Merge by extending current segment to the end of next segment
5339
+ speeches[i].end = speeches[i+1].end;
5340
+ speeches.erase(speeches.begin() + i + 1);
5341
+
5342
+ i--;
5343
+ merged_count++;
5344
+ }
5345
+ }
5346
+ WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
5347
+ __func__, merged_count, (int) speeches.size());
5348
+ }
5349
+
5350
+ // Double-check for minimum speech duration
5351
+ for (int i = 0; i < (int) speeches.size(); i++) {
5352
+ if (speeches[i].end - speeches[i].start < min_speech_samples) {
5353
+ WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
5354
+ __func__, i, speeches[i].end - speeches[i].start);
5355
+
5356
+ speeches.erase(speeches.begin() + i);
5357
+ i--;
5358
+ }
5359
+ }
5360
+
5361
+ WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
5362
+
5363
+ // Allocate final segments
5364
+ std::vector<whisper_vad_segment> segments;
5365
+ if (speeches.size() > 0) {
5366
+ try {
5367
+ segments.resize(speeches.size());
5368
+ } catch (const std::bad_alloc &) {
5369
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
5370
+ return nullptr;
5371
+ }
5372
+ }
5373
+
5374
+ // Apply padding to segments and copy to final segments
5375
+ for (int i = 0; i < (int) speeches.size(); i++) {
5376
+ // Apply padding to the start of the first segment
5377
+ if (i == 0) {
5378
+ speeches[i].start =
5379
+ (speeches[i].start > speech_pad_samples) ?
5380
+ (speeches[i].start - speech_pad_samples) : 0;
5381
+ }
5382
+
5383
+ // Handle spacing between segments
5384
+ if (i < (int) speeches.size() - 1) {
5385
+ int silence_duration = speeches[i+1].start - speeches[i].end;
5386
+
5387
+ if (silence_duration < 2 * speech_pad_samples) {
5388
+ // If segments are close, split the difference
5389
+ speeches[i].end += silence_duration / 2;
5390
+ speeches[i+1].start =
5391
+ (speeches[i+1].start > silence_duration / 2) ?
5392
+ (speeches[i+1].start - silence_duration / 2) : 0;
5393
+ } else {
5394
+ // Otherwise, apply full padding to both
5395
+ speeches[i].end =
5396
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
5397
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
5398
+ speeches[i+1].start =
5399
+ (speeches[i+1].start > speech_pad_samples) ?
5400
+ (speeches[i+1].start - speech_pad_samples) : 0;
5401
+ }
5402
+ } else {
5403
+ // Apply padding to the end of the last segment
5404
+ speeches[i].end =
5405
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
5406
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
5407
+ }
5408
+
5409
+ // Convert from samples to seconds and copy to final segments
5410
+ segments[i].start = (float)speeches[i].start / sample_rate;
5411
+ segments[i].end = (float)speeches[i].end / sample_rate;
5412
+
5413
+ WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
5414
+ __func__, i, segments[i].start, segments[i].end, segments[i].end - segments[i].start);
5415
+ }
5416
+
5417
+ whisper_vad_segments * vad_segments = new whisper_vad_segments;
5418
+ if (vad_segments == NULL) {
5419
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
5420
+ return nullptr;
5421
+ }
5422
+
5423
+ vad_segments->data = std::move(segments);
5424
+
5425
+ return vad_segments;
5426
+ }
5427
+
5428
+ struct whisper_vad_segments * whisper_vad_segments_from_samples(
5429
+ whisper_vad_context * vctx,
5430
+ whisper_vad_params params,
5431
+ const float * samples,
5432
+ int n_samples) {
5433
+ WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
5434
+ if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
5435
+ WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
5436
+ return nullptr;
5437
+ }
5438
+ return whisper_vad_segments_from_probs(vctx, params);
5439
+ }
5440
+
5441
+ void whisper_vad_free(whisper_vad_context * ctx) {
5442
+ if (ctx) {
5443
+ for (ggml_context * context : ctx->model.ctxs) {
5444
+ ggml_free(context);
5445
+ }
5446
+
5447
+ for (ggml_backend_buffer_t buf : ctx->model.buffers) {
5448
+ ggml_backend_buffer_free(buf);
5449
+ }
5450
+
5451
+ ggml_backend_sched_free(ctx->sched.sched);
5452
+
5453
+ for (auto & backend : ctx->backends) {
5454
+ ggml_backend_free(backend);
5455
+ }
5456
+
5457
+
5458
+ delete ctx;
5459
+ }
5460
+ }
5461
+
5462
+ void whisper_vad_free_segments(whisper_vad_segments * segments) {
5463
+ if (segments) {
5464
+ delete segments;
5465
+ }
5466
+ }
5467
+
5468
+ //////////////////////////////////
5469
+ // Grammar - ported from llama.cpp
5470
+ //////////////////////////////////
5471
+
5472
+ // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
5473
+ // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
5474
+ static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
5475
+ const char * src,
5476
+ whisper_partial_utf8 partial_start) {
5477
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
5478
+ const char * pos = src;
5479
+ std::vector<uint32_t> code_points;
5480
+ uint32_t value = partial_start.value;
5481
+ int n_remain = partial_start.n_remain;
5482
+
5483
+ // continue previous decode, if applicable
5484
+ while (*pos != 0 && n_remain > 0) {
5485
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
5486
+ if ((next_byte >> 6) != 2) {
5487
+ // invalid sequence, abort
5488
+ code_points.push_back(0);
5489
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
5490
+ }
5491
+ value = (value << 6) + (next_byte & 0x3F);
5492
+ ++pos;
5493
+ --n_remain;
5494
+ }
5495
+
5496
+ if (partial_start.n_remain > 0 && n_remain == 0) {
5497
+ code_points.push_back(value);
5498
+ }
5499
+
5500
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
5501
+ while (*pos != 0) {
5502
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
5503
+ uint8_t highbits = first_byte >> 4;
5504
+ n_remain = lookup[highbits] - 1;
5505
+
5506
+ if (n_remain < 0) {
5507
+ // invalid sequence, abort
5508
+ code_points.clear();
5509
+ code_points.push_back(0);
5510
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
5511
+ }
5512
+
5513
+ uint8_t mask = (1 << (7 - n_remain)) - 1;
5514
+ value = first_byte & mask;
5515
+ ++pos;
5516
+ while (*pos != 0 && n_remain > 0) {
5517
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
5518
+ ++pos;
5519
+ --n_remain;
5520
+ }
5521
+ if (n_remain == 0) {
5522
+ code_points.push_back(value);
5523
+ }
5524
+ }
5525
+ code_points.push_back(0);
5526
+
5527
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
5528
+ }
5529
+
5530
+ // returns true iff pos points to the end of one of the definitions of a rule
5531
+ static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
5532
+ switch (pos->type) {
5533
+ case WHISPER_GRETYPE_END: return true; // NOLINT
5534
+ case WHISPER_GRETYPE_ALT: return true; // NOLINT
5535
+ default: return false;
5536
+ }
5537
+ }
5538
+
5539
+ // returns true iff chr satisfies the char range at pos (regular or inverse range)
5540
+ // asserts that pos is pointing to a char range element
5541
+ static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
5542
+ const whisper_grammar_element * pos,
5543
+ const uint32_t chr) {
5544
+
5545
+ bool found = false;
5546
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
5547
+
5548
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
5549
+
5550
+ do {
5551
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
5552
+ // inclusive range, e.g. [a-z]
5553
+ found = found || (pos->value <= chr && chr <= pos[1].value);
5554
+ pos += 2;
5555
+ } else {
5556
+ // exact char match, e.g. [a] or "a"
5557
+ found = found || pos->value == chr;
5558
+ pos += 1;
5559
+ }
5560
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
5561
+
5562
+ return std::make_pair(found == is_positive_char, pos);
5563
+ }
5564
+
5565
+ // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
5566
+ // range at pos (regular or inverse range)
5567
+ // asserts that pos is pointing to a char range element
5568
+ static bool whisper_grammar_match_partial_char(
5569
+ const whisper_grammar_element * pos,
5570
+ const whisper_partial_utf8 partial_utf8) {
5571
+
5572
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
5573
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
5574
+
5575
+ uint32_t partial_value = partial_utf8.value;
5576
+ int n_remain = partial_utf8.n_remain;
5577
+
5578
+ // invalid sequence or 7-bit char split across 2 bytes (overlong)
5579
+ if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
5580
+ return false;
5581
+ }
5582
+
5583
+ // range of possible code points this partial UTF-8 sequence could complete to
5584
+ uint32_t low = partial_value << (n_remain * 6);
5585
+ uint32_t high = low | ((1 << (n_remain * 6)) - 1);
5586
+
5587
+ if (low == 0) {
5588
+ if (n_remain == 2) {
5589
+ low = 1 << 11;
5590
+ } else if (n_remain == 3) {
5591
+ low = 1 << 16;
5592
+ }
5593
+ }
5594
+
5595
+ do {
5596
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
5597
+ // inclusive range, e.g. [a-z]
5598
+ if (pos->value <= high && low <= pos[1].value) {
5599
+ return is_positive_char;
5600
+ }
5601
+ pos += 2;
5602
+ } else {
5603
+ // exact char match, e.g. [a] or "a"
5604
+ if (low <= pos->value && pos->value <= high) {
5605
+ return is_positive_char;
5606
+ }
5607
+ pos += 1;
5608
+ }
5609
+ } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
5610
+
5611
+ return !is_positive_char;
5612
+ }
5613
+
5614
+
5615
+ // transforms a grammar pushdown stack into N possible stacks, all ending
5616
+ // at a character range (terminal element)
5617
+ static void whisper_grammar_advance_stack(
5618
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
5619
+ const std::vector<const whisper_grammar_element *> & stack,
5620
+ std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
5621
+
5622
+ if (stack.empty()) {
5623
+ new_stacks.emplace_back();
5624
+ return;
5625
+ }
5626
+
5627
+ const whisper_grammar_element * pos = stack.back();
5628
+
5629
+ switch (pos->type) {
5630
+ case WHISPER_GRETYPE_RULE_REF: {
5631
+ const size_t rule_id = static_cast<size_t>(pos->value);
5632
+ const whisper_grammar_element * subpos = rules[rule_id].data();
5633
+ do {
5634
+ // init new stack without the top (pos)
5635
+ std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
5636
+ if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
5637
+ // if this rule ref is followed by another element, add that to stack
5638
+ new_stack.push_back(pos + 1);
5639
+ }
5640
+ if (!whisper_grammar_is_end_of_sequence(subpos)) {
5641
+ // if alternate is nonempty, add to stack
5642
+ new_stack.push_back(subpos);
5643
+ }
5644
+ whisper_grammar_advance_stack(rules, new_stack, new_stacks);
5645
+ while (!whisper_grammar_is_end_of_sequence(subpos)) {
5646
+ // scan to end of alternate def
5647
+ subpos++;
5648
+ }
5649
+ if (subpos->type == WHISPER_GRETYPE_ALT) {
5650
+ // there's another alternate def of this rule to process
5651
+ subpos++;
5652
+ } else {
5653
+ break;
5654
+ }
5655
+ } while (true);
5656
+ break;
5657
+ }
5658
+ case WHISPER_GRETYPE_CHAR:
5659
+ case WHISPER_GRETYPE_CHAR_NOT:
5660
+ new_stacks.push_back(stack);
5661
+ break;
5662
+ default:
5663
+ // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
5664
+ // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
5665
+ // those
5666
+ WHISPER_ASSERT(false);
5667
+ }
5668
+ }
5669
+
5670
+ // takes a set of possible pushdown stacks on a grammar, which are required to
5671
+ // be positioned at a character range (see `whisper_grammar_advance_stack`), and
5672
+ // produces the N possible stacks if the given char is accepted at those
5673
+ // positions
5674
+ static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
5675
+ const std::vector<std::vector<whisper_grammar_element>> & rules,
5676
+ const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
5677
+ const uint32_t chr) {
5678
+
5679
+ std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
5680
+
5681
+ for (const auto & stack : stacks) {
5682
+ if (stack.empty()) {
5683
+ continue;
5684
+ }
5685
+
5686
+ auto match = whisper_grammar_match_char(stack.back(), chr);
5687
+ if (match.first) {
5688
  const whisper_grammar_element * pos = match.second;
5689
 
5690
  // update top of stack to next element, if any
 
5981
  /*.n_grammar_rules =*/ 0,
5982
  /*.i_start_rule =*/ 0,
5983
  /*.grammar_penalty =*/ 100.0f,
5984
+
5985
+ /*.vad =*/ false,
5986
+ /*.vad_model_path =*/ nullptr,
5987
+
5988
+ /* vad_params =*/ whisper_vad_default_params(),
5989
  };
5990
 
5991
  switch (strategy) {
 
6602
  }
6603
  }
6604
 
6605
+ static bool whisper_vad(
6606
+ struct whisper_context * ctx,
6607
+ struct whisper_state * state,
6608
+ struct whisper_full_params params,
6609
+ const float * samples,
6610
+ int n_samples,
6611
+ std::vector<float> & filtered_samples,
6612
+ int & filtered_n_samples) {
6613
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speach segments only\n", __func__);
6614
+ filtered_n_samples = 0;
6615
+
6616
+ struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
6617
+ struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
6618
+ if (vctx == nullptr) {
6619
+ WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
6620
+ return false;
6621
+ }
6622
+
6623
+ const whisper_vad_params & vad_params = params.vad_params;
6624
+
6625
+ whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
6626
+
6627
+ if (vad_segments->data.size() > 0) {
6628
+ state->has_vad_segments = true;
6629
+ ctx->state->vad_segments.clear();
6630
+ ctx->state->vad_segments.reserve(vad_segments->data.size());
6631
+
6632
+ WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
6633
+ float overlap_seconds = vad_params.samples_overlap;
6634
+ int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
6635
+
6636
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
6637
+ int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
6638
+ int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
6639
+
6640
+ if (i < (int)vad_segments->data.size() - 1) {
6641
+ segment_end_samples += overlap_samples;
6642
+ }
6643
+ segment_end_samples = std::min(segment_end_samples, n_samples - 1);
6644
+ filtered_n_samples += (segment_end_samples - segment_start_samples);
6645
+
6646
+ WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
6647
+ __func__, i, vad_segments->data[i].start,
6648
+ vad_segments->data[i].end + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0),
6649
+ (vad_segments->data[i].end - vad_segments->data[i].start) +
6650
+ (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
6651
+ }
6652
+
6653
+ int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
6654
+ int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
6655
+ int total_samples_needed = filtered_n_samples + total_silence_samples;
6656
+
6657
+ WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
6658
+ __func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
6659
+
6660
+ try {
6661
+ filtered_samples.resize(total_samples_needed);
6662
+ } catch (const std::bad_alloc & /* e */) {
6663
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
6664
+ whisper_vad_free_segments(vad_segments);
6665
+ whisper_vad_free(vctx);
6666
+ return false;
6667
+ }
6668
+
6669
+ int offset = 0;
6670
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
6671
+ int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
6672
+ int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
6673
+
6674
+ if (i < (int)vad_segments->data.size() - 1) {
6675
+ segment_end_samples += overlap_samples;
6676
+ }
6677
+
6678
+ segment_start_samples = std::min(segment_start_samples, n_samples - 1);
6679
+ segment_end_samples = std::min(segment_end_samples, n_samples);
6680
+ int segment_length = segment_end_samples - segment_start_samples;
6681
+
6682
+ if (segment_length > 0) {
6683
+ whisper_state::vad_segment_info segment;
6684
+
6685
+ segment.orig_start = vad_segments->data[i].start;
6686
+ segment.orig_end = vad_segments->data[i].end;
6687
+
6688
+ segment.vad_start = offset / (float)WHISPER_SAMPLE_RATE;
6689
+ segment.vad_end = (offset + segment_length) / (float)WHISPER_SAMPLE_RATE;
6690
+
6691
+ WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
6692
+ __func__, segment.orig_start, segment.orig_end, segment.vad_start, segment.vad_end);
6693
+ ctx->state->vad_segments.push_back(segment);
6694
+
6695
+ // Copy this speech segment
6696
+ memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
6697
+ offset += segment_length;
6698
+
6699
+ // Add silence after this segment (except after the last segment)
6700
+ if (i < (int)vad_segments->data.size() - 1) {
6701
+ // Fill with zeros (silence)
6702
+ memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
6703
+ offset += silence_samples;
6704
+ }
6705
+ }
6706
+ }
6707
+
6708
+ filtered_n_samples = offset;
6709
+ WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
6710
+ __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
6711
+ }
6712
+
6713
+ return true;
6714
+ }
6715
+
6716
  int whisper_full_with_state(
6717
  struct whisper_context * ctx,
6718
  struct whisper_state * state,
 
6724
 
6725
  result_all.clear();
6726
 
6727
+ const float * process_samples = samples;
6728
+ int n_process_samples = n_samples;
6729
+ std::vector<float> vad_samples;
6730
+
6731
+ if (params.vad) {
6732
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
6733
+ int vad_n_samples;
6734
+ if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples, vad_n_samples)) {
6735
+ WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
6736
+ return -1;
6737
+ }
6738
+ process_samples = vad_samples.data();
6739
+ n_process_samples = vad_n_samples;
6740
+ }
6741
+
6742
+ if (n_process_samples > 0) {
6743
  // compute log mel spectrogram
6744
+ if (whisper_pcm_to_mel_with_state(ctx, state, process_samples, n_process_samples, params.n_threads) != 0) {
6745
  WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
6746
  return -2;
6747
  }
 
7786
  }
7787
 
7788
  int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
7789
+ // If VAD wasn't used, return the original timestamp
7790
+ if (!state->has_vad_segments || state->vad_segments.empty()) {
7791
+ return state->result_all[i_segment].t0;
7792
+ }
7793
+
7794
+ // Get the start timestamp produced by whisper_full. whisper_full processes
7795
+ // only the speech segments in this case so we need to map these timestamps
7796
+ // back to the original audio.
7797
+ float t0 = state->result_all[i_segment].t0 / 100.0f;
7798
+
7799
+ // Find which VAD segment this timestamp belongs.
7800
+ // TODO(danbev) This could be optimized by using a binary search if the number
7801
+ // of segments exceed a certain limit. Also we might be able to assume that
7802
+ // the access pattern is sequential and optimized for that too.
7803
+ for (size_t i = 0; i < state->vad_segments.size(); i++) {
7804
+ const auto & segment = state->vad_segments[i];
7805
+
7806
+ // Check if the timestamp falls within this segment.
7807
+ if (t0 >= segment.vad_start && t0 <= segment.vad_end) {
7808
+ float proportion = 0.0f;
7809
+ if (segment.vad_end > segment.vad_start) {
7810
+ proportion = (t0 - segment.vad_start) / (segment.vad_end - segment.vad_start);
7811
+ }
7812
+ float orig_t0 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
7813
+ return (int64_t)(orig_t0 * 100);
7814
+ }
7815
+ }
7816
+
7817
+ // Check if the timestamp falls between two segments.
7818
+ for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
7819
+ const auto & curr = state->vad_segments[i];
7820
+ const auto & next = state->vad_segments[i + 1];
7821
+
7822
+ if (t0 > curr.vad_end && t0 < next.vad_start) {
7823
+ // Calculate how far we are through the gap as a proportion
7824
+ float gap_proportion = 0.0f;
7825
+ if (next.vad_start > curr.vad_end) {
7826
+ gap_proportion = (t0 - curr.vad_end) / (next.vad_start - curr.vad_end);
7827
+ }
7828
+ float orig_t0 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
7829
+ return (int64_t)(orig_t0 * 100);
7830
+ }
7831
+ }
7832
+
7833
+ // Handle the case where the timestamp is after the last segment.
7834
+ if (t0 > state->vad_segments.back().vad_end) {
7835
+ // For timestamps after the last segment, add the extra time to the end of the last segment
7836
+ const auto& last = state->vad_segments.back();
7837
+ // Calculate how far beyond the last segment
7838
+ float extra_time = t0 - last.vad_end;
7839
+ // Add this extra time to the original end time
7840
+ float orig_t0 = last.orig_end + extra_time;
7841
+ return (int64_t)(orig_t0 * 100);
7842
+ }
7843
+
7844
+ WHISPER_LOG_WARN("%s: Could not map t0 = %f to a VAD segment\n", __func__, t0);
7845
+ return t0;
7846
  }
7847
 
7848
  int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
7849
+ return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
7850
  }
7851
 
7852
  int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
7853
+ // If VAD wasn't used, return the original timestamp
7854
+ if (!state->has_vad_segments || state->vad_segments.empty()) {
7855
+ return state->result_all[i_segment].t1;
7856
+ }
7857
+
7858
+ // Get the end timestamp produced by whisper_full. whisper_full processes
7859
+ // only the speech segments in this case so we need to map these timestamps
7860
+ // back to the original audio.
7861
+ float t1 = state->result_all[i_segment].t1 / 100.0f;
7862
+
7863
+ // Find which VAD segment this timestamp belongs.
7864
+ // TODO(danbev) This could be optimized by using a binary search if the number
7865
+ // of segments exceed a certain limit. Also we might be able to assume that
7866
+ // the access pattern is sequential and optimized for that too.
7867
+ for (size_t i = 0; i < state->vad_segments.size(); i++) {
7868
+ const auto& segment = state->vad_segments[i];
7869
+
7870
+ // Check if the timestamp falls within this segment.
7871
+ if (t1 >= segment.vad_start && t1 <= segment.vad_end) {
7872
+ // Calculate the proportion through the filtered segment.
7873
+ float proportion = 0.0f;
7874
+ if (segment.vad_end > segment.vad_start) {
7875
+ proportion = (t1 - segment.vad_start) / (segment.vad_end - segment.vad_start);
7876
+ }
7877
+ float orig_t1 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
7878
+ return (int64_t)(orig_t1 * 100);
7879
+ }
7880
+ }
7881
+
7882
+ // Check if the timestamp falls between two segments.
7883
+ for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
7884
+ const auto & curr = state->vad_segments[i];
7885
+ const auto & next = state->vad_segments[i + 1];
7886
+
7887
+ if (t1 > curr.vad_end && t1 < next.vad_start) {
7888
+ // Calculate how far we are through the gap as a proportion
7889
+ float gap_proportion = 0.0f;
7890
+ if (next.vad_start > curr.vad_end) {
7891
+ gap_proportion = (t1 - curr.vad_end) / (next.vad_start - curr.vad_end);
7892
+ }
7893
+ // Map to the corresponding position in the original gap
7894
+ float orig_t1 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
7895
+ return (int64_t)(orig_t1 * 100);
7896
+ }
7897
+ }
7898
+
7899
+ // Handle the case where the timestamp is after the last segment
7900
+ if (t1 > state->vad_segments.back().vad_end) {
7901
+ // For the last segment, use the end of the last VAD segment
7902
+ const auto& last = state->vad_segments.back();
7903
+ // Calculate how far beyond the last segment
7904
+ float extra_time = t1 - last.vad_end;
7905
+ // Add this extra time to the original end time
7906
+ float orig_t1 = last.orig_end + extra_time;
7907
+ return (int64_t)(orig_t1 * 100);
7908
+ }
7909
+
7910
+ WHISPER_LOG_WARN("%s: Could not map t1 = %f to a VAD segment\n", __func__, t1);
7911
+ return t1;
7912
  }
7913
 
7914
  int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
7915
+ return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
7916
  }
7917
 
7918
  bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
tests/CMakeLists.txt CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  if (EMSCRIPTEN)
2
  #
3
  # test-whisper-js
@@ -85,3 +88,18 @@ if (WHISPER_FFMPEG)
85
  set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3")
86
  endif()
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set(CMAKE_CXX_STANDARD 17)
2
+ set(CMAKE_CXX_STANDARD_REQUIRED ON)
3
+
4
  if (EMSCRIPTEN)
5
  #
6
  # test-whisper-js
 
88
  set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3")
89
  endif()
90
 
91
+ # VAD test tests VAD in isolation
92
+ set(VAD_TEST test-vad)
93
+ add_executable(${VAD_TEST} ${VAD_TEST}.cpp)
94
+ target_include_directories(${VAD_TEST} PRIVATE ../include ../ggml/include ../examples)
95
+ target_link_libraries(${VAD_TEST} PRIVATE common)
96
+ add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST})
97
+ set_tests_properties(${VAD_TEST} PROPERTIES LABELS "unit")
98
+
99
+ # VAD test full uses whisper_full with VAD enabled
100
+ set(VAD_TEST test-vad-full)
101
+ add_executable(${VAD_TEST} ${VAD_TEST}.cpp)
102
+ target_include_directories(${VAD_TEST} PRIVATE ../include ../ggml/include ../examples)
103
+ target_link_libraries(${VAD_TEST} PRIVATE common)
104
+ add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST})
105
+ set_tests_properties(${VAD_TARGET} PROPERTIES LABELS "base;en")
tests/test-vad-full.cpp ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "whisper.h"
2
+ #include "common-whisper.h"
3
+
4
+ #include <cstdio>
5
+ #include <cfloat>
6
+ #include <string>
7
+ #include <cstring>
8
+
9
+ #ifdef NDEBUG
10
+ #undef NDEBUG
11
+ #endif
12
+
13
+ #include <cassert>
14
+
15
+ int main() {
16
+ std::string whisper_model_path = "../../models/ggml-base.en.bin";
17
+ std::string vad_model_path = "../../models/for-tests-silero-v5.1.2-ggml.bin";
18
+ std::string sample_path = "../../samples/jfk.wav";
19
+
20
+ // Load the sample audio file
21
+ std::vector<float> pcmf32;
22
+ std::vector<std::vector<float>> pcmf32s;
23
+ assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
24
+
25
+ struct whisper_context_params cparams = whisper_context_default_params();
26
+ struct whisper_context * wctx = whisper_init_from_file_with_params(
27
+ whisper_model_path.c_str(),
28
+ cparams);
29
+
30
+ struct whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
31
+ wparams.vad = true;
32
+ wparams.vad_model_path = vad_model_path.c_str();
33
+
34
+ wparams.vad_params.threshold = 0.5f;
35
+ wparams.vad_params.min_speech_duration_ms = 250;
36
+ wparams.vad_params.min_silence_duration_ms = 100;
37
+ wparams.vad_params.max_speech_duration_s = FLT_MAX;
38
+ wparams.vad_params.speech_pad_ms = 30;
39
+
40
+ assert(whisper_full_parallel(wctx, wparams, pcmf32.data(), pcmf32.size(), 1) == 0);
41
+
42
+ const int n_segments = whisper_full_n_segments(wctx);
43
+ assert(n_segments == 1);
44
+
45
+ assert(strcmp(" And so my fellow Americans, ask not what your country can do for you,"
46
+ " ask what you can do for your country.",
47
+ whisper_full_get_segment_text(wctx, 0)) == 0);
48
+ assert(whisper_full_get_segment_t0(wctx, 0) == 29);
49
+ assert(whisper_full_get_segment_t1(wctx, 0) == 1049);
50
+
51
+ whisper_free(wctx);
52
+
53
+ return 0;
54
+ }
tests/test-vad.cpp ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "whisper.h"
2
+ #include "common-whisper.h"
3
+
4
+ #include <cstdio>
5
+ #include <string>
6
+
7
+ #ifdef NDEBUG
8
+ #undef NDEBUG
9
+ #endif
10
+ #include <cassert>
11
+
12
+ void assert_default_params(const struct whisper_vad_params & params) {
13
+ assert(params.threshold == 0.5);
14
+ assert(params.min_speech_duration_ms == 250);
15
+ assert(params.min_silence_duration_ms == 100);
16
+ assert(params.samples_overlap == 0.1f);
17
+ }
18
+
19
+ void assert_default_context_params(const struct whisper_vad_context_params & params) {
20
+ assert(params.n_threads == 4);
21
+ assert(params.use_gpu == false);
22
+ assert(params.gpu_device == 0);
23
+ }
24
+
25
+ void test_detect_speech(
26
+ struct whisper_vad_context * vctx,
27
+ struct whisper_vad_params params,
28
+ const float * pcmf32,
29
+ int n_samples) {
30
+ assert(whisper_vad_detect_speech(vctx, pcmf32, n_samples));
31
+ assert(whisper_vad_n_probs(vctx) == 344);
32
+ assert(whisper_vad_probs(vctx) != nullptr);
33
+ }
34
+
35
+ struct whisper_vad_segments * test_detect_timestamps(
36
+ struct whisper_vad_context * vctx,
37
+ struct whisper_vad_params params) {
38
+ struct whisper_vad_segments * timestamps = whisper_vad_segments_from_probs(vctx, params);
39
+ assert(whisper_vad_segments_n_segments(timestamps) == 5);
40
+
41
+ for (int i = 0; i < whisper_vad_segments_n_segments(timestamps); ++i) {
42
+ printf("VAD segment %d: start = %.2f, end = %.2f\n", i,
43
+ whisper_vad_segments_get_segment_t0(timestamps, i),
44
+ whisper_vad_segments_get_segment_t1(timestamps, i));
45
+ }
46
+
47
+ return timestamps;
48
+ }
49
+
50
+ int main() {
51
+ std::string vad_model_path = "../../models/for-tests-silero-v5.1.2-ggml.bin";
52
+ std::string sample_path = "../../samples/jfk.wav";
53
+
54
+ // Load the sample audio file
55
+ std::vector<float> pcmf32;
56
+ std::vector<std::vector<float>> pcmf32s;
57
+ assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
58
+ assert(pcmf32.size() > 0);
59
+ assert(pcmf32s.size() == 0); // no stereo vector
60
+
61
+ // Load the VAD model
62
+ struct whisper_vad_context_params ctx_params = whisper_vad_default_context_params();
63
+ assert_default_context_params(ctx_params);
64
+
65
+ struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(
66
+ vad_model_path.c_str(),
67
+ ctx_params);
68
+ assert(vctx != nullptr);
69
+
70
+ struct whisper_vad_params params = whisper_vad_default_params();
71
+ assert_default_params(params);
72
+
73
+ // Test speech probabilites
74
+ test_detect_speech(vctx, params, pcmf32.data(), pcmf32.size());
75
+
76
+ // Test speech timestamps (uses speech probabilities from above)
77
+ struct whisper_vad_segments * timestamps = test_detect_timestamps(vctx, params);
78
+
79
+ whisper_vad_free_segments(timestamps);
80
+ whisper_vad_free(vctx);
81
+
82
+ return 0;
83
+ }