danbev commited on
Commit
bf862e4
·
unverified ·
1 Parent(s): 26aba7a

coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci] (#2979)

Browse files

* coreml: fix Whisper to CoreML conversion by disabling SDPA

This commit disables the use of PyTorch's
`scaled_dot_product_attention` in the Whisper model to avoid
compatibility issues during CoreML conversion.
The issue occurs because coremltools requires PyTorch 2.5.0, but the
Whisper implementation may expect behavior from newer PyTorch versions.

By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to
use its fallback manual attention implementation, which works correctly
with PyTorch 2.5.0 during the tracing process.

Refs: https://github.com/ggerganov/whisper.cpp/issues/2783

* coreml: fix audio shape in whisper decoder conversion

This commit fixes the audio shape in the whisper decoder conversion
script.

The motivation for this is that the audio shape was incorrect and
was causing the conversion to fail.

* coreml : set -e in generate-coreml-interface.sh

The commit sets the -e flag in the generate-coreml-interface.sh script
to make sure the script fails if any command fails.

* coreml : update generated encoder/decoder interfaces

This commit updates the generated encoder/decoder interfaces for the
whisper model which is the result of running the
generate-coreml-interface.sh script.

models/convert-whisper-to-coreml.py CHANGED
@@ -12,6 +12,15 @@ from coremltools.models.neural_network.quantization_utils import quantize_weight
12
  from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
13
  from whisper import load_model
14
 
 
 
 
 
 
 
 
 
 
15
  # Use for changing dim of input in encoder and decoder embeddings
16
  def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
17
  missing_keys, unexpected_keys, error_msgs):
@@ -260,10 +269,11 @@ def convert_decoder(hparams, model, quantize=False):
260
  model.eval()
261
 
262
  tokens_shape = (1, 1)
263
- audio_shape = (1, hparams.n_audio_state, 1, 1500)
264
 
265
  audio_data = torch.randn(audio_shape)
266
- token_data = torch.randint(50257, tokens_shape).long()
 
267
  traced_model = torch.jit.trace(model, (token_data, audio_data))
268
 
269
  model = ct.convert(
 
12
  from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
13
  from whisper import load_model
14
 
15
+ # Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
16
+ # The Whisper implementation expects a specific behavior from
17
+ # torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
18
+ # versions. Setting use_sdpa=False forces Whisper to use its manual attention
19
+ # implementation instead, which is more stable across different PyTorch versions
20
+ # (2.5.0 required by coremltools vs newer versions).
21
+ import whisper.model
22
+ whisper.model.MultiHeadAttention.use_sdpa = False
23
+
24
  # Use for changing dim of input in encoder and decoder embeddings
25
  def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
26
  missing_keys, unexpected_keys, error_msgs):
 
269
  model.eval()
270
 
271
  tokens_shape = (1, 1)
272
+ audio_shape = (1, hparams.n_audio_ctx, hparams.n_audio_state)
273
 
274
  audio_data = torch.randn(audio_shape)
275
+ token_data = torch.randint(hparams.n_vocab, tokens_shape).long()
276
+
277
  traced_model = torch.jit.trace(model, (token_data, audio_data))
278
 
279
  model = ct.convert(
models/generate-coreml-interface.sh CHANGED
@@ -5,6 +5,8 @@
5
  # - src/coreml/whisper-decoder-impl.h and src/coreml/whisper-decoder-impl.m
6
  #
7
 
 
 
8
  wd=$(dirname "$0")
9
  cd "$wd/../" || exit
10
 
 
5
  # - src/coreml/whisper-decoder-impl.h and src/coreml/whisper-decoder-impl.m
6
  #
7
 
8
+ set -e
9
+
10
  wd=$(dirname "$0")
11
  cd "$wd/../" || exit
12
 
src/coreml/whisper-decoder-impl.h CHANGED
@@ -11,36 +11,33 @@
11
 
12
  NS_ASSUME_NONNULL_BEGIN
13
 
14
-
15
  /// Model Prediction Input Type
16
- API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
17
  @interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
18
 
19
- /// token_data as 1 by 1 matrix of 32-bit integers
20
  @property (readwrite, nonatomic, strong) MLMultiArray * token_data;
21
 
22
- /// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
23
  @property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
24
  - (instancetype)init NS_UNAVAILABLE;
25
  - (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
26
 
27
  @end
28
 
29
-
30
  /// Model Prediction Output Type
31
- API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
32
  @interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
33
 
34
- /// var_1346 as multidimensional array of floats
35
- @property (readwrite, nonatomic, strong) MLMultiArray * var_1346;
36
  - (instancetype)init NS_UNAVAILABLE;
37
- - (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER;
38
 
39
  @end
40
 
41
-
42
  /// Class for model loading and prediction
43
- API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
44
  @interface whisper_decoder_impl : NSObject
45
  @property (readonly, nonatomic, nullable) MLModel * model;
46
 
@@ -94,7 +91,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
94
  @param configuration The model configuration
95
  @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
96
  */
97
- + (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
98
 
99
  /**
100
  Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
@@ -105,7 +102,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
105
  @param configuration The model configuration
106
  @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
107
  */
108
- + (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
109
 
110
  /**
111
  Make a prediction using the standard interface
@@ -124,10 +121,25 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
124
  */
125
  - (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  /**
128
  Make a prediction using the convenience interface
129
- @param token_data as 1 by 1 matrix of 32-bit integers:
130
- @param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
131
  @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
132
  @return the prediction as whisper_decoder_implOutput
133
  */
 
11
 
12
  NS_ASSUME_NONNULL_BEGIN
13
 
 
14
  /// Model Prediction Input Type
15
+ API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
16
  @interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
17
 
18
+ /// token_data as 1 by 1 matrix of floats
19
  @property (readwrite, nonatomic, strong) MLMultiArray * token_data;
20
 
21
+ /// audio_data as 1 × 1500 × 384 3-dimensional array of floats
22
  @property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
23
  - (instancetype)init NS_UNAVAILABLE;
24
  - (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
25
 
26
  @end
27
 
 
28
  /// Model Prediction Output Type
29
+ API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
30
  @interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
31
 
32
+ /// cast_76 as multidimensional array of floats
33
+ @property (readwrite, nonatomic, strong) MLMultiArray * cast_76;
34
  - (instancetype)init NS_UNAVAILABLE;
35
+ - (instancetype)initWithCast_76:(MLMultiArray *)cast_76 NS_DESIGNATED_INITIALIZER;
36
 
37
  @end
38
 
 
39
  /// Class for model loading and prediction
40
+ API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
41
  @interface whisper_decoder_impl : NSObject
42
  @property (readonly, nonatomic, nullable) MLModel * model;
43
 
 
91
  @param configuration The model configuration
92
  @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
93
  */
94
+ + (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
95
 
96
  /**
97
  Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
 
102
  @param configuration The model configuration
103
  @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
104
  */
105
+ + (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
106
 
107
  /**
108
  Make a prediction using the standard interface
 
121
  */
122
  - (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
123
 
124
+ /**
125
+ Make an asynchronous prediction using the standard interface
126
+ @param input an instance of whisper_decoder_implInput to predict from
127
+ @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
128
+ */
129
+ - (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
130
+
131
+ /**
132
+ Make an asynchronous prediction using the standard interface
133
+ @param input an instance of whisper_decoder_implInput to predict from
134
+ @param options prediction options
135
+ @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
136
+ */
137
+ - (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
138
+
139
  /**
140
  Make a prediction using the convenience interface
141
+ @param token_data 1 by 1 matrix of floats
142
+ @param audio_data 1 × 1500 × 384 3-dimensional array of floats
143
  @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
144
  @return the prediction as whisper_decoder_implOutput
145
  */
src/coreml/whisper-decoder-impl.m CHANGED
@@ -39,21 +39,21 @@
39
 
40
  @implementation whisper_decoder_implOutput
41
 
42
- - (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 {
43
  self = [super init];
44
  if (self) {
45
- _var_1346 = var_1346;
46
  }
47
  return self;
48
  }
49
 
50
  - (NSSet<NSString *> *)featureNames {
51
- return [NSSet setWithArray:@[@"var_1346"]];
52
  }
53
 
54
  - (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
55
- if ([featureName isEqualToString:@"var_1346"]) {
56
- return [MLFeatureValue featureValueWithMultiArray:self.var_1346];
57
  }
58
  return nil;
59
  }
@@ -80,10 +80,13 @@
80
  Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
81
  */
82
  - (instancetype)initWithMLModel:(MLModel *)model {
 
 
 
83
  self = [super init];
84
- if (!self) { return nil; }
85
- _model = model;
86
- if (_model == nil) { return nil; }
87
  return self;
88
  }
89
 
@@ -177,7 +180,29 @@
177
  - (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
178
  id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
179
  if (!outFeatures) { return nil; }
180
- return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  }
182
 
183
  - (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
@@ -192,7 +217,7 @@
192
  NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
193
  for (NSInteger i = 0; i < outBatch.count; i++) {
194
  id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
195
- whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue];
196
  [results addObject:result];
197
  }
198
  return results;
 
39
 
40
  @implementation whisper_decoder_implOutput
41
 
42
+ - (instancetype)initWithCast_76:(MLMultiArray *)cast_76 {
43
  self = [super init];
44
  if (self) {
45
+ _cast_76 = cast_76;
46
  }
47
  return self;
48
  }
49
 
50
  - (NSSet<NSString *> *)featureNames {
51
+ return [NSSet setWithArray:@[@"cast_76"]];
52
  }
53
 
54
  - (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
55
+ if ([featureName isEqualToString:@"cast_76"]) {
56
+ return [MLFeatureValue featureValueWithMultiArray:self.cast_76];
57
  }
58
  return nil;
59
  }
 
80
  Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
81
  */
82
  - (instancetype)initWithMLModel:(MLModel *)model {
83
+ if (model == nil) {
84
+ return nil;
85
+ }
86
  self = [super init];
87
+ if (self != nil) {
88
+ _model = model;
89
+ }
90
  return self;
91
  }
92
 
 
180
  - (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
181
  id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
182
  if (!outFeatures) { return nil; }
183
+ return [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[outFeatures featureValueForName:@"cast_76"].multiArrayValue];
184
+ }
185
+
186
+ - (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
187
+ [self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
188
+ if (prediction != nil) {
189
+ whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
190
+ completionHandler(output, predictionError);
191
+ } else {
192
+ completionHandler(nil, predictionError);
193
+ }
194
+ }];
195
+ }
196
+
197
+ - (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
198
+ [self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
199
+ if (prediction != nil) {
200
+ whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
201
+ completionHandler(output, predictionError);
202
+ } else {
203
+ completionHandler(nil, predictionError);
204
+ }
205
+ }];
206
  }
207
 
208
  - (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
 
217
  NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
218
  for (NSInteger i = 0; i < outBatch.count; i++) {
219
  id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
220
+ whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[resultProvider featureValueForName:@"cast_76"].multiArrayValue];
221
  [results addObject:result];
222
  }
223
  return results;
src/coreml/whisper-encoder-impl.h CHANGED
@@ -11,9 +11,8 @@
11
 
12
  NS_ASSUME_NONNULL_BEGIN
13
 
14
-
15
  /// Model Prediction Input Type
16
- API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
17
  @interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
18
 
19
  /// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
@@ -23,9 +22,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
23
 
24
  @end
25
 
26
-
27
  /// Model Prediction Output Type
28
- API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
29
  @interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
30
 
31
  /// output as multidimensional array of floats
@@ -35,9 +33,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
35
 
36
  @end
37
 
38
-
39
  /// Class for model loading and prediction
40
- API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
41
  @interface whisper_encoder_impl : NSObject
42
  @property (readonly, nonatomic, nullable) MLModel * model;
43
 
@@ -91,7 +88,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
91
  @param configuration The model configuration
92
  @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
93
  */
94
- + (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
95
 
96
  /**
97
  Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
@@ -102,7 +99,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
102
  @param configuration The model configuration
103
  @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
104
  */
105
- + (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
106
 
107
  /**
108
  Make a prediction using the standard interface
@@ -121,9 +118,24 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
121
  */
122
  - (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  /**
125
  Make a prediction using the convenience interface
126
- @param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
127
  @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
128
  @return the prediction as whisper_encoder_implOutput
129
  */
 
11
 
12
  NS_ASSUME_NONNULL_BEGIN
13
 
 
14
  /// Model Prediction Input Type
15
+ API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
16
  @interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
17
 
18
  /// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
 
22
 
23
  @end
24
 
 
25
  /// Model Prediction Output Type
26
+ API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
27
  @interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
28
 
29
  /// output as multidimensional array of floats
 
33
 
34
  @end
35
 
 
36
  /// Class for model loading and prediction
37
+ API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
38
  @interface whisper_encoder_impl : NSObject
39
  @property (readonly, nonatomic, nullable) MLModel * model;
40
 
 
88
  @param configuration The model configuration
89
  @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
90
  */
91
+ + (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
92
 
93
  /**
94
  Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
 
99
  @param configuration The model configuration
100
  @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
101
  */
102
+ + (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
103
 
104
  /**
105
  Make a prediction using the standard interface
 
118
  */
119
  - (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
120
 
121
+ /**
122
+ Make an asynchronous prediction using the standard interface
123
+ @param input an instance of whisper_encoder_implInput to predict from
124
+ @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
125
+ */
126
+ - (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
127
+
128
+ /**
129
+ Make an asynchronous prediction using the standard interface
130
+ @param input an instance of whisper_encoder_implInput to predict from
131
+ @param options prediction options
132
+ @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
133
+ */
134
+ - (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
135
+
136
  /**
137
  Make a prediction using the convenience interface
138
+ @param logmel_data 1 × 80 × 3000 3-dimensional array of floats
139
  @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
140
  @return the prediction as whisper_encoder_implOutput
141
  */
src/coreml/whisper-encoder-impl.m CHANGED
@@ -76,10 +76,13 @@
76
  Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
77
  */
78
  - (instancetype)initWithMLModel:(MLModel *)model {
 
 
 
79
  self = [super init];
80
- if (!self) { return nil; }
81
- _model = model;
82
- if (_model == nil) { return nil; }
83
  return self;
84
  }
85
 
@@ -176,6 +179,28 @@
176
  return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
177
  }
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  - (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
180
  whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
181
  return [self predictionFromFeatures:input_ error:error];
 
76
  Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
77
  */
78
  - (instancetype)initWithMLModel:(MLModel *)model {
79
+ if (model == nil) {
80
+ return nil;
81
+ }
82
  self = [super init];
83
+ if (self != nil) {
84
+ _model = model;
85
+ }
86
  return self;
87
  }
88
 
 
179
  return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
180
  }
181
 
182
+ - (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
183
+ [self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
184
+ if (prediction != nil) {
185
+ whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
186
+ completionHandler(output, predictionError);
187
+ } else {
188
+ completionHandler(nil, predictionError);
189
+ }
190
+ }];
191
+ }
192
+
193
+ - (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
194
+ [self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
195
+ if (prediction != nil) {
196
+ whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
197
+ completionHandler(output, predictionError);
198
+ } else {
199
+ completionHandler(nil, predictionError);
200
+ }
201
+ }];
202
+ }
203
+
204
  - (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
205
  whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
206
  return [self predictionFromFeatures:input_ error:error];