TheJCDenton commited on
Commit
aa7c2e9
·
unverified ·
1 Parent(s): bb48f57

talk-llama : add n_gpu_layers parameter (#1475)

Browse files
Files changed (1) hide show
  1. examples/talk-llama/talk-llama.cpp +5 -0
examples/talk-llama/talk-llama.cpp CHANGED
@@ -53,6 +53,7 @@ struct whisper_params {
53
  int32_t capture_id = -1;
54
  int32_t max_tokens = 32;
55
  int32_t audio_ctx = 0;
 
56
 
57
  float vad_thold = 0.6f;
58
  float freq_thold = 100.0f;
@@ -90,6 +91,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
90
  else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
91
  else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
92
  else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
 
93
  else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
94
  else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
95
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
@@ -134,6 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
134
  fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
135
  fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
136
  fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
 
137
  fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
138
  fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
139
  fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
@@ -268,6 +271,8 @@ int main(int argc, char ** argv) {
268
  auto lmparams = llama_model_default_params();
269
  if (!params.use_gpu) {
270
  lmparams.n_gpu_layers = 0;
 
 
271
  }
272
 
273
  struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);
 
53
  int32_t capture_id = -1;
54
  int32_t max_tokens = 32;
55
  int32_t audio_ctx = 0;
56
+ int32_t n_gpu_layers = 0;
57
 
58
  float vad_thold = 0.6f;
59
  float freq_thold = 100.0f;
 
91
  else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
92
  else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
93
  else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
94
+ else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
95
  else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
96
  else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
97
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
 
136
  fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
137
  fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
138
  fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
139
+ fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7s] number of layers to store in VRAM\n", params.n_gpu_layers);
140
  fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
141
  fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
142
  fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
 
271
  auto lmparams = llama_model_default_params();
272
  if (!params.use_gpu) {
273
  lmparams.n_gpu_layers = 0;
274
+ } else {
275
+ lmparams.n_gpu_layers = params.n_gpu_layers;
276
  }
277
 
278
  struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);