Skip to content

Commit

Permalink
StackBoW parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Mar 12, 2024
1 parent 478f352 commit 1631f30
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion EvoMSA/__init__.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = '2.0.7'
__version__ = '2.0.8'

try:
from EvoMSA.text_repr import BoW, TextRepresentations, StackGeneralization, DenseBoW
Expand Down
11 changes: 9 additions & 2 deletions EvoMSA/back_prop.py
Expand Up @@ -23,6 +23,7 @@
from IngeoML.optimizer import classifier, array
from IngeoML.utils import soft_BER
from EvoMSA.text_repr import BoW, DenseBoW, StackGeneralization
from EvoMSA.utils import b4msa_params


@jax.jit
Expand Down Expand Up @@ -318,15 +319,21 @@ class StackBoW(StackGeneralization):

def __init__(self, decision_function_models: list=None,
transform_models: list=[],
voc_size_exponent: int=15,
deviation=None, optimizer_kwargs: dict=None,
lang: str='es', **kwargs):
if decision_function_models is None:
estimator_kwargs = dict(dual='auto', class_weight='balanced')
b4msa_kwargs = b4msa_params(lang=lang)
if voc_size_exponent != 17:
b4msa_kwargs['token_max_filter'] = 2**voc_size_exponent
bow_np = BoW(lang=lang, pretrain=False,
b4msa_kwargs=b4msa_kwargs,
estimator_kwargs=estimator_kwargs)
bow = BoW(lang=lang,
bow = BoW(lang=lang, voc_size_exponent=voc_size_exponent,
estimator_kwargs=estimator_kwargs)
dense = DenseBoW(lang=lang, voc_size_exponent=15,
dense = DenseBoW(lang=lang,
voc_size_exponent=voc_size_exponent,
estimator_kwargs=estimator_kwargs)
decision_function_models = [bow_np, bow, dense]
assert len(decision_function_models) > 1
Expand Down

0 comments on commit 1631f30

Please sign in to comment.