Skip to content

Commit

Permalink
chore: release v4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Jun 27, 2022
1 parent 464c54c commit 1353379
Show file tree
Hide file tree
Showing 26 changed files with 2,064 additions and 1,391 deletions.
3 changes: 0 additions & 3 deletions CHANGELOG.md
@@ -1,7 +1,4 @@

## [3.1.1](https://github.com/dreamquark-ai/tabnet/compare/v3.1.0...v3.1.1) (2021-02-02)


### Bug Fixes

* add preds_mapper to pretraining ([76f2c85](https://github.com/dreamquark-ai/tabnet/commit/76f2c852f59c6ed2c5dc5f0766cb99310bae5f2c))
Expand Down
8 changes: 6 additions & 2 deletions docs/_modules/index.html
Expand Up @@ -82,14 +82,17 @@



<p class="caption"><span class="caption-text">Contents:</span></p>
<p><span class="caption-text">Contents:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/README.html">README</a></li>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/README.html#tabnet-attentive-interpretable-tabular-learning">TabNet : Attentive Interpretable Tabular Learning</a></li>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/README.html#installation">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/README.html#what-problems-does-pytorch-tabnet-handles">What problems does pytorch-tabnet handles?</a></li>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/README.html#contributing">Contributing</a></li>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/README.html#what-problems-does-pytorch-tabnet-handle">What problems does pytorch-tabnet handle?</a></li>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/README.html#how-to-use-it">How to use it?</a></li>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/README.html#semi-supervised-pre-training">Semi-supervised pre-training</a></li>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/README.html#data-augmentation-on-the-fly">Data augmentation on the fly</a></li>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/README.html#easy-saving-and-loading">Easy saving and loading</a></li>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/README.html#useful-links">Useful links</a></li>
<li class="toctree-l1"><a class="reference internal" href="../generated_docs/pytorch_tabnet.html">pytorch_tabnet package</a></li>
</ul>
Expand Down Expand Up @@ -155,6 +158,7 @@

<h1>All modules for which code is available</h1>
<ul><li><a href="pytorch_tabnet/abstract_model.html">pytorch_tabnet.abstract_model</a></li>
<li><a href="pytorch_tabnet/augmentations.html">pytorch_tabnet.augmentations</a></li>
<li><a href="pytorch_tabnet/callbacks.html">pytorch_tabnet.callbacks</a></li>
<li><a href="pytorch_tabnet/metrics.html">pytorch_tabnet.metrics</a></li>
<li><a href="pytorch_tabnet/multiclass_utils.html">pytorch_tabnet.multiclass_utils</a></li>
Expand Down
83 changes: 53 additions & 30 deletions docs/_modules/pytorch_tabnet/abstract_model.html

Large diffs are not rendered by default.

295 changes: 295 additions & 0 deletions docs/_modules/pytorch_tabnet/augmentations.html

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions docs/_modules/pytorch_tabnet/callbacks.html
Expand Up @@ -82,14 +82,17 @@



<p class="caption"><span class="caption-text">Contents:</span></p>
<p><span class="caption-text">Contents:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html">README</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#tabnet-attentive-interpretable-tabular-learning">TabNet : Attentive Interpretable Tabular Learning</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#installation">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#what-problems-does-pytorch-tabnet-handles">What problems does pytorch-tabnet handles?</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#contributing">Contributing</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#what-problems-does-pytorch-tabnet-handle">What problems does pytorch-tabnet handle?</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#how-to-use-it">How to use it?</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#semi-supervised-pre-training">Semi-supervised pre-training</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#data-augmentation-on-the-fly">Data augmentation on the fly</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#easy-saving-and-loading">Easy saving and loading</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#useful-links">Useful links</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/pytorch_tabnet.html">pytorch_tabnet package</a></li>
</ul>
Expand Down Expand Up @@ -162,6 +165,7 @@ <h1>Source code for pytorch_tabnet.callbacks</h1><div class="highlight"><pre>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span><span class="p">,</span> <span class="n">field</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Any</span>
<span class="kn">import</span> <span class="nn">warnings</span>


<div class="viewcode-block" id="Callback"><a class="viewcode-back" href="../../generated_docs/pytorch_tabnet.html#pytorch_tabnet.callbacks.Callback">[docs]</a><span class="k">class</span> <span class="nc">Callback</span><span class="p">:</span>
Expand Down Expand Up @@ -325,7 +329,8 @@ <h1>Source code for pytorch_tabnet.callbacks</h1><div class="highlight"><pre>
<span class="o">+</span> <span class="sa">f</span><span class="s2">&quot;best_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">early_stopping_metric</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">best_loss</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">msg</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Best weights from best epoch are automatically used!&quot;</span><span class="p">)</span></div></div>
<span class="n">wrn_msg</span> <span class="o">=</span> <span class="s2">&quot;Best weights from best epoch are automatically used!&quot;</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="n">wrn_msg</span><span class="p">)</span></div></div>


<div class="viewcode-block" id="History"><a class="viewcode-back" href="../../generated_docs/pytorch_tabnet.html#pytorch_tabnet.callbacks.History">[docs]</a><span class="nd">@dataclass</span>
Expand Down
66 changes: 63 additions & 3 deletions docs/_modules/pytorch_tabnet/metrics.html
Expand Up @@ -82,14 +82,17 @@



<p class="caption"><span class="caption-text">Contents:</span></p>
<p><span class="caption-text">Contents:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html">README</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#tabnet-attentive-interpretable-tabular-learning">TabNet : Attentive Interpretable Tabular Learning</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#installation">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#what-problems-does-pytorch-tabnet-handles">What problems does pytorch-tabnet handles?</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#contributing">Contributing</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#what-problems-does-pytorch-tabnet-handle">What problems does pytorch-tabnet handle?</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#how-to-use-it">How to use it?</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#semi-supervised-pre-training">Semi-supervised pre-training</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#data-augmentation-on-the-fly">Data augmentation on the fly</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#easy-saving-and-loading">Easy saving and loading</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/README.html#useful-links">Useful links</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../generated_docs/pytorch_tabnet.html">pytorch_tabnet package</a></li>
</ul>
Expand Down Expand Up @@ -197,7 +200,11 @@ <h1>Source code for pytorch_tabnet.metrics</h1><div class="highlight"><pre>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">errors</span> <span class="o">=</span> <span class="n">y_pred</span> <span class="o">-</span> <span class="n">embedded_x</span>
<span class="n">reconstruction_errors</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mul</span><span class="p">(</span><span class="n">errors</span><span class="p">,</span> <span class="n">obf_vars</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="n">batch_stds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">embedded_x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">eps</span>
<span class="n">batch_means</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">embedded_x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">batch_means</span><span class="p">[</span><span class="n">batch_means</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>

<span class="n">batch_stds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">embedded_x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="n">batch_stds</span><span class="p">[</span><span class="n">batch_stds</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch_means</span><span class="p">[</span><span class="n">batch_stds</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span>
<span class="n">features_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">reconstruction_errors</span><span class="p">,</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">batch_stds</span><span class="p">)</span>
<span class="c1"># compute the number of obfuscated variables to reconstruct</span>
<span class="n">nb_reconstructed_variables</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">obf_vars</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
Expand All @@ -208,6 +215,24 @@ <h1>Source code for pytorch_tabnet.metrics</h1><div class="highlight"><pre>
<span class="k">return</span> <span class="n">loss</span></div>


<div class="viewcode-block" id="UnsupervisedLossNumpy"><a class="viewcode-back" href="../../generated_docs/pytorch_tabnet.html#pytorch_tabnet.metrics.UnsupervisedLossNumpy">[docs]</a><span class="k">def</span> <span class="nf">UnsupervisedLossNumpy</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">embedded_x</span><span class="p">,</span> <span class="n">obf_vars</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-9</span><span class="p">):</span>
<span class="n">errors</span> <span class="o">=</span> <span class="n">y_pred</span> <span class="o">-</span> <span class="n">embedded_x</span>
<span class="n">reconstruction_errors</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">multiply</span><span class="p">(</span><span class="n">errors</span><span class="p">,</span> <span class="n">obf_vars</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="n">batch_means</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">embedded_x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">batch_means</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">batch_means</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">batch_means</span><span class="p">)</span>

<span class="n">batch_stds</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">embedded_x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">ddof</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
<span class="n">batch_stds</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">batch_stds</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="n">batch_means</span><span class="p">,</span> <span class="n">batch_stds</span><span class="p">)</span>
<span class="n">features_loss</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">reconstruction_errors</span><span class="p">,</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">batch_stds</span><span class="p">)</span>
<span class="c1"># compute the number of obfuscated variables to reconstruct</span>
<span class="n">nb_reconstructed_variables</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">obf_vars</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># take the mean of the reconstructed variable errors</span>
<span class="n">features_loss</span> <span class="o">=</span> <span class="n">features_loss</span> <span class="o">/</span> <span class="p">(</span><span class="n">nb_reconstructed_variables</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
<span class="c1"># here we take the mean per batch, contrary to the paper</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">features_loss</span><span class="p">)</span>
<span class="k">return</span> <span class="n">loss</span></div>


<div class="viewcode-block" id="UnsupMetricContainer"><a class="viewcode-back" href="../../generated_docs/pytorch_tabnet.html#pytorch_tabnet.metrics.UnsupMetricContainer">[docs]</a><span class="nd">@dataclass</span>
<span class="k">class</span> <span class="nc">UnsupMetricContainer</span><span class="p">:</span>
<span class="sd">&quot;&quot;&quot;Container holding a list of metrics.</span>
Expand Down Expand Up @@ -571,6 +596,41 @@ <h1>Source code for pytorch_tabnet.metrics</h1><div class="highlight"><pre>
<span class="k">return</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span></div>


<div class="viewcode-block" id="UnsupervisedNumpyMetric"><a class="viewcode-back" href="../../generated_docs/pytorch_tabnet.html#pytorch_tabnet.metrics.UnsupervisedNumpyMetric">[docs]</a><span class="k">class</span> <span class="nc">UnsupervisedNumpyMetric</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Unsupervised metric</span>
<span class="sd"> &quot;&quot;&quot;</span>

<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="s2">&quot;unsup_loss_numpy&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_maximize</span> <span class="o">=</span> <span class="kc">False</span>

<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">embedded_x</span><span class="p">,</span> <span class="n">obf_vars</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Compute MSE (Mean Squared Error) of predictions.</span>

<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> y_pred : torch.Tensor or np.array</span>
<span class="sd"> Reconstructed prediction (with embeddings)</span>
<span class="sd"> embedded_x : torch.Tensor</span>
<span class="sd"> Original input embedded by network</span>
<span class="sd"> obf_vars : torch.Tensor</span>
<span class="sd"> Binary mask for obfuscated variables.</span>
<span class="sd"> 1 means the variables was obfuscated so reconstruction is based on this.</span>

<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> float</span>
<span class="sd"> MSE of predictions vs targets.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">UnsupervisedLossNumpy</span><span class="p">(</span>
<span class="n">y_pred</span><span class="p">,</span>
<span class="n">embedded_x</span><span class="p">,</span>
<span class="n">obf_vars</span>
<span class="p">)</span></div>


<div class="viewcode-block" id="RMSE"><a class="viewcode-back" href="../../generated_docs/pytorch_tabnet.html#pytorch_tabnet.metrics.RMSE">[docs]</a><span class="k">class</span> <span class="nc">RMSE</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Root Mean Squared Error.</span>
Expand Down

0 comments on commit 1353379

Please sign in to comment.