From 1631f3077277427a198d626159fcbce2cd2a958a Mon Sep 17 00:00:00 2001 From: Mario Graff Date: Tue, 12 Mar 2024 09:29:13 -0600 Subject: [PATCH] StackBoW parameters --- EvoMSA/__init__.py | 2 +- EvoMSA/back_prop.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/EvoMSA/__init__.py b/EvoMSA/__init__.py index 70019fd..bb2c9e6 100644 --- a/EvoMSA/__init__.py +++ b/EvoMSA/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '2.0.7' +__version__ = '2.0.8' try: from EvoMSA.text_repr import BoW, TextRepresentations, StackGeneralization, DenseBoW diff --git a/EvoMSA/back_prop.py b/EvoMSA/back_prop.py index 6bb2fcf..1a7254b 100644 --- a/EvoMSA/back_prop.py +++ b/EvoMSA/back_prop.py @@ -23,6 +23,7 @@ from IngeoML.optimizer import classifier, array from IngeoML.utils import soft_BER from EvoMSA.text_repr import BoW, DenseBoW, StackGeneralization +from EvoMSA.utils import b4msa_params @jax.jit @@ -318,15 +319,21 @@ class StackBoW(StackGeneralization): def __init__(self, decision_function_models: list=None, transform_models: list=[], + voc_size_exponent: int=15, deviation=None, optimizer_kwargs: dict=None, lang: str='es', **kwargs): if decision_function_models is None: estimator_kwargs = dict(dual='auto', class_weight='balanced') + b4msa_kwargs = b4msa_params(lang=lang) + if voc_size_exponent != 17: + b4msa_kwargs['token_max_filter'] = 2**voc_size_exponent bow_np = BoW(lang=lang, pretrain=False, + b4msa_kwargs=b4msa_kwargs, estimator_kwargs=estimator_kwargs) - bow = BoW(lang=lang, + bow = BoW(lang=lang, voc_size_exponent=voc_size_exponent, estimator_kwargs=estimator_kwargs) - dense = DenseBoW(lang=lang, voc_size_exponent=15, + dense = DenseBoW(lang=lang, + voc_size_exponent=voc_size_exponent, estimator_kwargs=estimator_kwargs) decision_function_models = [bow_np, bow, dense] assert len(decision_function_models) > 1