19
19
ModelLoadingStrategy
20
20
)
21
21
22
+ from small_text .integrations .transformers .utils .setfit import (
23
+ _check_model_kwargs ,
24
+ _check_trainer_kwargs ,
25
+ _check_train_kwargs ,
26
+ _truncate_texts
27
+ )
28
+
22
29
try :
23
30
import torch
24
31
29
36
from small_text .integrations .transformers .utils .classification import (
30
37
_get_arguments_for_from_pretrained_model
31
38
)
32
- from small_text .integrations .transformers .utils .setfit import (
33
- _check_model_kwargs ,
34
- _check_trainer_kwargs ,
35
- _check_train_kwargs ,
36
- _truncate_texts
37
- )
38
- except ImportError :
39
- pass
39
+ except ImportError as e :
40
+ print (e )
41
+ print (e )
40
42
41
43
42
44
class SetFitModelArguments (object ):
43
- """
45
+ """Model arguments for :py:class:`SetFitClassification`.
46
+
44
47
.. versionadded:: 1.2.0
45
48
"""
46
49
47
50
def __init__ (self ,
48
51
sentence_transformer_model : str ,
52
+ model_kwargs = {},
53
+ trainer_kwargs = {},
49
54
model_loading_strategy : ModelLoadingStrategy = ModelLoadingStrategy .DEFAULT ):
50
55
"""
51
56
Parameters
52
57
----------
53
58
sentence_transformer_model : str
54
59
Name of a sentence transformer model.
60
+ model_kwargs : dict, default={}
61
+ Keyword arguments used for the SetFit model. The keyword `use_differentiable_head` is
62
+ excluded and managed by this class. The other keywords are directly passed to
63
+ `SetFitModel.from_pretrained()`. Additional kwargs that will be passed into
64
+ `SetFitModel.from_pretrained()`. Arguments that are managed by small-text
65
+ (such as the model name given by `model`) are excluded.
66
+
67
+ .. seealso::
68
+
69
+ `SetFitModel.from_pretrained()
70
+ <https://huggingface.co/docs/setfit/en/reference/main#setfit.SetFitModel.from_pretrained>`_
71
+ in the SetFit documentation.
72
+ trainer_kwargs : dict
73
+ Keyword arguments used for the SetFit model. The keyword `batch_size` is excluded and
74
+ is instead controlled by the keyword `mini_batch_size` of this class. The other
75
+ keywords are directly passed to `SetFitTrainer.__init__()`.
76
+
77
+ .. seealso:: `Trainer
78
+ <https://huggingface.co/docs/setfit/en/reference/trainer>`_
79
+ in the SetFit documentation.
55
80
model_loading_strategy: ModelLoadingStrategy, default=ModelLoadingStrategy.DEFAULT
56
81
Specifies if there should be attempts to download the model or if only local
57
82
files should be used.
58
83
"""
59
84
self .sentence_transformer_model = sentence_transformer_model
85
+ self .model_kwargs = _check_model_kwargs (model_kwargs )
86
+ self .trainer_kwargs = _check_trainer_kwargs (trainer_kwargs )
60
87
self .model_loading_strategy = model_loading_strategy
61
88
62
89
@@ -135,8 +162,7 @@ class SetFitClassification(SetFitClassificationEmbeddingMixin, Classifier):
135
162
"""
136
163
137
164
def __init__ (self , setfit_model_args , num_classes , multi_label = False , max_seq_len = 512 ,
138
- use_differentiable_head = False , mini_batch_size = 32 , model_kwargs = dict (),
139
- trainer_kwargs = dict (), device = None ):
165
+ use_differentiable_head = False , mini_batch_size = 32 , device = None ):
140
166
"""
141
167
sentence_transformer_model : SetFitModelArguments
142
168
Settings for the sentence transformer model to be used.
@@ -149,21 +175,6 @@ def __init__(self, setfit_model_args, num_classes, multi_label=False, max_seq_le
149
175
Uses a differentiable head instead of a logistic regression for the classification head.
150
176
Corresponds to the keyword argument with the same name in
151
177
`SetFitModel.from_pretrained()`.
152
- model_kwargs : dict
153
- Keyword arguments used for the SetFit model. The keyword `use_differentiable_head` is
154
- excluded and managed by this class. The other keywords are directly passed to
155
- `SetFitModel.from_pretrained()`.
156
-
157
- .. seealso:: `SetFit: src/setfit/modeling.py
158
- <https://github.com/huggingface/setfit/blob/main/src/setfit/modeling.py>`_
159
-
160
- trainer_kwargs : dict
161
- Keyword arguments used for the SetFit model. The keyword `batch_size` is excluded and
162
- is instead controlled by the keyword `mini_batch_size` of this class. The other
163
- keywords are directly passed to `SetFitTrainer.__init__()`.
164
-
165
- .. seealso:: `SetFit: src/setfit/trainer.py
166
- <https://github.com/huggingface/setfit/blob/main/src/setfit/trainer.py>`_
167
178
device : str or torch.device, default=None
168
179
Torch device on which the computation will be performed.
169
180
"""
@@ -173,10 +184,7 @@ def __init__(self, setfit_model_args, num_classes, multi_label=False, max_seq_le
173
184
self .num_classes = num_classes
174
185
self .multi_label = multi_label
175
186
176
- self .model_kwargs = _check_model_kwargs (model_kwargs )
177
- self .trainer_kwargs = _check_trainer_kwargs (trainer_kwargs )
178
-
179
- model_kwargs = self .model_kwargs .copy ()
187
+ model_kwargs = self .setfit_model_args .model_kwargs .copy ()
180
188
if self .multi_label and 'multi_target_strategy' not in model_kwargs :
181
189
model_kwargs ['multi_target_strategy' ] = 'one-vs-rest'
182
190
@@ -264,7 +272,7 @@ def _fit(self, sub_train, sub_valid, setfit_train_kwargs):
264
272
eval_dataset = sub_valid ,
265
273
batch_size = self .mini_batch_size ,
266
274
seed = seed ,
267
- ** self .trainer_kwargs
275
+ ** self .setfit_model_args . trainer_kwargs
268
276
)
269
277
trainer .train (max_length = self .max_seq_len , ** setfit_train_kwargs )
270
278
return self
0 commit comments