shibukazu commited on
Commit
a325deb
·
unverified ·
1 Parent(s): 661a3a2

whisper : suppress non-speech-related token outputs (#473)

Browse files

* add non-speech-token suppression

* add suppress non-speech_tokens param

Files changed (2) hide show
  1. whisper.cpp +36 -0
  2. whisper.h +1 -0
whisper.cpp CHANGED
@@ -2936,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2936
  /*.language =*/ "en",
2937
 
2938
  /*.suppress_blank =*/ true,
 
2939
 
2940
  /*.temperature =*/ 0.0f,
2941
  /*.max_initial_ts =*/ 1.0f,
@@ -3077,6 +3078,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
3077
  return res;
3078
  }
3079
 
 
 
 
 
 
 
 
 
3080
  // process the logits for the selected decoder
3081
  // - applies logit filters
3082
  // - computes logprobs and probs
@@ -3137,6 +3146,33 @@ static void whisper_process_logits(
3137
  logits[vocab.token_translate] = -INFINITY;
3138
  logits[vocab.token_transcribe] = -INFINITY;
3139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3140
  // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
3141
  // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
3142
  {
 
2936
  /*.language =*/ "en",
2937
 
2938
  /*.suppress_blank =*/ true,
2939
+ /*.suppress_non_speech_tokens =*/true,
2940
 
2941
  /*.temperature =*/ 0.0f,
2942
  /*.max_initial_ts =*/ 1.0f,
 
3078
  return res;
3079
  }
3080
 
3081
+ static const std::vector<std::string> non_speech_tokens
3082
+ {
3083
+ "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
3084
+ "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
3085
+ "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
3086
+ "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
3087
+ };
3088
+
3089
  // process the logits for the selected decoder
3090
  // - applies logit filters
3091
  // - computes logprobs and probs
 
3146
  logits[vocab.token_translate] = -INFINITY;
3147
  logits[vocab.token_transcribe] = -INFINITY;
3148
 
3149
+
3150
+ // suppress non-speech tokens
3151
+ // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
3152
+ if (params.suppress_non_speech_tokens)
3153
+ {
3154
+ for (const std::string &token : non_speech_tokens)
3155
+ {
3156
+ std::string suppress_tokens[] = {token, " " + token};
3157
+ for (const std::string &suppress_token : suppress_tokens)
3158
+ {
3159
+ if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
3160
+ {
3161
+ logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
3162
+ }
3163
+ }
3164
+ }
3165
+ // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
3166
+ if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
3167
+ {
3168
+ logits[vocab.token_to_id.at(" -")] = -INFINITY;
3169
+ }
3170
+ if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
3171
+ {
3172
+ logits[vocab.token_to_id.at(" '")] = -INFINITY;
3173
+ }
3174
+ }
3175
+
3176
  // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
3177
  // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
3178
  {
whisper.h CHANGED
@@ -285,6 +285,7 @@ extern "C" {
285
 
286
  // common decoding parameters:
287
  bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
 
288
 
289
  float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
290
  float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
 
285
 
286
  // common decoding parameters:
287
  bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
288
+ bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
289
 
290
  float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
291
  float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97