Skip to content

Commit 4865085

Browse files
committed
Refactor model and trainer kwargs in SetFitClassification
Signed-off-by: Christopher Schröder <chschroeder@users.noreply.github.com>
1 parent 1d38af8 commit 4865085

File tree

4 files changed

+87
-32
lines changed

4 files changed

+87
-32
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ On the other hand, this also allowed us to deal with further issues that contain
2727
- TransformerBasedClassification:
2828
- Removed unnecessary `token_type_ids` keyword argument in model call.
2929
- Additional keyword args for config, tokenizer, and model can now be configured.
30+
- SetFitClassification:
31+
- Additional keyword args for trainer and model are now attached to `SetFitModelArguments` instead of `SetFitClassification`.
32+
3033
- Embeddings:
3134
- Prevented unnecessary gradient computations for some embedding types and unified code structure.
3235
- Pytorch:

MIGRATION_GUIDE.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ This is not an exhaustive list of changes, but here we try to collect changes th
1111
- [PoolBasedActiveLearner](https://small-text.readthedocs.io/en/latest/api/active_learner.html#activelearner-api):
1212
`initialize_data()` has been changed to `initialize()`. The method now takes a list of initial indices or an initialized first (proxy-)model.
1313

14+
- SetFitClassification: `model_kwargs` and `trainer_kwargs` are now attached to `SetFitModelArguments` instead of `SetFitClassification`.
15+
1416
### Renamed Classes
1517

16-
The following classes amd variables have been renamed for consistency:
18+
The following classes and variables have been renamed for consistency:
1719

1820
- KimCNNFactory -> KimCNNClassifierFactory
1921

small_text/integrations/transformers/classifiers/setfit.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
ModelLoadingStrategy
2020
)
2121

22+
from small_text.integrations.transformers.utils.setfit import (
23+
_check_model_kwargs,
24+
_check_trainer_kwargs,
25+
_check_train_kwargs,
26+
_truncate_texts
27+
)
28+
2229
try:
2330
import torch
2431

@@ -29,34 +36,54 @@
2936
from small_text.integrations.transformers.utils.classification import (
3037
_get_arguments_for_from_pretrained_model
3138
)
32-
from small_text.integrations.transformers.utils.setfit import (
33-
_check_model_kwargs,
34-
_check_trainer_kwargs,
35-
_check_train_kwargs,
36-
_truncate_texts
37-
)
38-
except ImportError:
39-
pass
39+
except ImportError as e:
40+
print(e)
41+
print(e)
4042

4143

4244
class SetFitModelArguments(object):
43-
"""
45+
"""Model arguments for :py:class:`SetFitClassification`.
46+
4447
.. versionadded:: 1.2.0
4548
"""
4649

4750
def __init__(self,
4851
sentence_transformer_model: str,
52+
model_kwargs={},
53+
trainer_kwargs={},
4954
model_loading_strategy: ModelLoadingStrategy = ModelLoadingStrategy.DEFAULT):
5055
"""
5156
Parameters
5257
----------
5358
sentence_transformer_model : str
5459
Name of a sentence transformer model.
60+
model_kwargs : dict, default={}
61+
Keyword arguments used for the SetFit model. The keyword `use_differentiable_head` is
62+
excluded and managed by this class. The other keywords are directly passed to
63+
`SetFitModel.from_pretrained()`. Additional kwargs that will be passed into
64+
`SetFitModel.from_pretrained()`. Arguments that are managed by small-text
65+
(such as the model name given by `model`) are excluded.
66+
67+
.. seealso::
68+
69+
`SetFitModel.from_pretrained()
70+
<https://huggingface.co/docs/setfit/en/reference/main#setfit.SetFitModel.from_pretrained>`_
71+
in the SetFit documentation.
72+
trainer_kwargs : dict
73+
Keyword arguments used for the SetFit model. The keyword `batch_size` is excluded and
74+
is instead controlled by the keyword `mini_batch_size` of this class. The other
75+
keywords are directly passed to `SetFitTrainer.__init__()`.
76+
77+
.. seealso:: `Trainer
78+
<https://huggingface.co/docs/setfit/en/reference/trainer>`_
79+
in the SetFit documentation.
5580
model_loading_strategy: ModelLoadingStrategy, default=ModelLoadingStrategy.DEFAULT
5681
Specifies if there should be attempts to download the model or if only local
5782
files should be used.
5883
"""
5984
self.sentence_transformer_model = sentence_transformer_model
85+
self.model_kwargs = _check_model_kwargs(model_kwargs)
86+
self.trainer_kwargs = _check_trainer_kwargs(trainer_kwargs)
6087
self.model_loading_strategy = model_loading_strategy
6188

6289

@@ -135,8 +162,7 @@ class SetFitClassification(SetFitClassificationEmbeddingMixin, Classifier):
135162
"""
136163

137164
def __init__(self, setfit_model_args, num_classes, multi_label=False, max_seq_len=512,
138-
use_differentiable_head=False, mini_batch_size=32, model_kwargs=dict(),
139-
trainer_kwargs=dict(), device=None):
165+
use_differentiable_head=False, mini_batch_size=32, device=None):
140166
"""
141167
sentence_transformer_model : SetFitModelArguments
142168
Settings for the sentence transformer model to be used.
@@ -149,21 +175,6 @@ def __init__(self, setfit_model_args, num_classes, multi_label=False, max_seq_le
149175
Uses a differentiable head instead of a logistic regression for the classification head.
150176
Corresponds to the keyword argument with the same name in
151177
`SetFitModel.from_pretrained()`.
152-
model_kwargs : dict
153-
Keyword arguments used for the SetFit model. The keyword `use_differentiable_head` is
154-
excluded and managed by this class. The other keywords are directly passed to
155-
`SetFitModel.from_pretrained()`.
156-
157-
.. seealso:: `SetFit: src/setfit/modeling.py
158-
<https://github.com/huggingface/setfit/blob/main/src/setfit/modeling.py>`_
159-
160-
trainer_kwargs : dict
161-
Keyword arguments used for the SetFit model. The keyword `batch_size` is excluded and
162-
is instead controlled by the keyword `mini_batch_size` of this class. The other
163-
keywords are directly passed to `SetFitTrainer.__init__()`.
164-
165-
.. seealso:: `SetFit: src/setfit/trainer.py
166-
<https://github.com/huggingface/setfit/blob/main/src/setfit/trainer.py>`_
167178
device : str or torch.device, default=None
168179
Torch device on which the computation will be performed.
169180
"""
@@ -173,10 +184,7 @@ def __init__(self, setfit_model_args, num_classes, multi_label=False, max_seq_le
173184
self.num_classes = num_classes
174185
self.multi_label = multi_label
175186

176-
self.model_kwargs = _check_model_kwargs(model_kwargs)
177-
self.trainer_kwargs = _check_trainer_kwargs(trainer_kwargs)
178-
179-
model_kwargs = self.model_kwargs.copy()
187+
model_kwargs = self.setfit_model_args.model_kwargs.copy()
180188
if self.multi_label and 'multi_target_strategy' not in model_kwargs:
181189
model_kwargs['multi_target_strategy'] = 'one-vs-rest'
182190

@@ -264,7 +272,7 @@ def _fit(self, sub_train, sub_valid, setfit_train_kwargs):
264272
eval_dataset=sub_valid,
265273
batch_size=self.mini_batch_size,
266274
seed=seed,
267-
**self.trainer_kwargs
275+
**self.setfit_model_args.trainer_kwargs
268276
)
269277
trainer.train(max_length=self.max_seq_len, **setfit_train_kwargs)
270278
return self

tests/unit/small_text/integrations/transformers/classifiers/test_setfit.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,36 @@ def test_setfit_model_arguments_init(self):
3636
sentence_transformer_model = 'sentence-transformers/all-MiniLM-L6-v2'
3737
args = SetFitModelArguments(sentence_transformer_model)
3838
self.assertEqual(sentence_transformer_model, args.sentence_transformer_model)
39+
self.assertIsNotNone(args.model_kwargs)
40+
self.assertEqual(0, len(args.model_kwargs))
41+
self.assertIsNotNone(args.trainer_kwargs)
42+
self.assertEqual(0, len(args.trainer_kwargs))
43+
self.assertIsNotNone(args.model_loading_strategy)
44+
self.assertEqual(ModelLoadingStrategy.DEFAULT, args.model_loading_strategy)
45+
self.assertFalse(args.compile_model)
46+
47+
def test_setfit_model_arguments_init_with_model_kwargs(self):
48+
sentence_transformer_model = 'sentence-transformers/all-MiniLM-L6-v2'
49+
model_kwargs = {'cache_dir': '/tmp/cache'}
50+
args = SetFitModelArguments(sentence_transformer_model, model_kwargs=model_kwargs)
51+
self.assertEqual(sentence_transformer_model, args.sentence_transformer_model)
52+
self.assertIsNotNone(args.model_kwargs)
53+
self.assertEqual(1, len(args.model_kwargs))
54+
self.assertIsNotNone(args.trainer_kwargs)
55+
self.assertEqual(0, len(args.trainer_kwargs))
56+
self.assertIsNotNone(args.model_loading_strategy)
57+
self.assertEqual(ModelLoadingStrategy.DEFAULT, args.model_loading_strategy)
58+
self.assertFalse(args.compile_model)
59+
60+
def test_setfit_model_arguments_init_with_trainer_kwargs(self):
61+
sentence_transformer_model = 'sentence-transformers/all-MiniLM-L6-v2'
62+
trainer_kwargs = {'batch_size': 32}
63+
args = SetFitModelArguments(sentence_transformer_model, trainer_kwargs=trainer_kwargs)
64+
self.assertEqual(sentence_transformer_model, args.sentence_transformer_model)
65+
self.assertIsNotNone(args.model_kwargs)
66+
self.assertEqual(0, len(args.model_kwargs))
67+
self.assertIsNotNone(args.trainer_kwargs)
68+
self.assertEqual(1, len(args.trainer_kwargs))
3969
self.assertIsNotNone(args.model_loading_strategy)
4070
self.assertEqual(ModelLoadingStrategy.DEFAULT, args.model_loading_strategy)
4171
self.assertFalse(args.compile_model)
@@ -46,6 +76,10 @@ def test_setfit_model_arguments_init_with_model_loading_strategy(self):
4676
args = SetFitModelArguments(sentence_transformer_model,
4777
model_loading_strategy=model_loading_strategy)
4878
self.assertEqual(sentence_transformer_model, args.sentence_transformer_model)
79+
self.assertIsNotNone(args.model_kwargs)
80+
self.assertEqual(0, len(args.model_kwargs))
81+
self.assertIsNotNone(args.trainer_kwargs)
82+
self.assertEqual(0, len(args.trainer_kwargs))
4983
self.assertIsNotNone(args.model_loading_strategy)
5084
self.assertEqual(model_loading_strategy, args.model_loading_strategy)
5185
self.assertFalse(args.compile_model)
@@ -61,6 +95,10 @@ def test_transformer_model_arguments_init_with_env_override(self):
6195
args = SetFitModelArguments(sentence_transformer_model)
6296

6397
self.assertEqual(sentence_transformer_model, args.sentence_transformer_model)
98+
self.assertIsNotNone(args.model_kwargs)
99+
self.assertEqual(0, len(args.model_kwargs))
100+
self.assertIsNotNone(args.trainer_kwargs)
101+
self.assertEqual(0, len(args.trainer_kwargs))
64102
self.assertIsNotNone(args.model_loading_strategy)
65103
self.assertEqual(ModelLoadingStrategy.ALWAYS_LOCAL, args.model_loading_strategy)
66104
self.assertFalse(args.compile_model)
@@ -70,6 +108,10 @@ def test_setfit_model_arguments_init_with_compile(self):
70108
args = SetFitModelArguments(sentence_transformer_model,
71109
compile_model=True)
72110
self.assertEqual(sentence_transformer_model, args.sentence_transformer_model)
111+
self.assertIsNotNone(args.model_kwargs)
112+
self.assertEqual(0, len(args.model_kwargs))
113+
self.assertIsNotNone(args.trainer_kwargs)
114+
self.assertEqual(0, len(args.trainer_kwargs))
73115
self.assertIsNotNone(args.model_loading_strategy)
74116
self.assertEqual(ModelLoadingStrategy.DEFAULT, args.model_loading_strategy)
75117
self.assertTrue(args.compile_model)

0 commit comments

Comments
 (0)