\n",
+ "
\n",
"\n",
"\n",
"\n",
"\n",
"\n",
- "
\n",
+ "
\n",
"
\n",
""
],
"text/plain": [
- ":Scatter [feature_importance_mean] (nrd_mean,nrd_std,fwe_std,fwe_mean,feature_importance_std)"
+ ":Scatter [feature_importance_mean] (mProbes_mean,nFDR_mean,feature_importance_std,nFDR_std,mProbes_std)"
]
},
"execution_count": 22,
"metadata": {
"application/vnd.holoviews_exec.v0+json": {
- "id": "1003"
+ "id": "1111"
}
},
"output_type": "execute_result"
}
],
"source": [
- "df = ds[[\"feature_importance_mean\", \"nrd_mean\"]].to_dataframe()\n",
- "hv.Scatter(df).opts(padding=0.1, width=500, height=500)"
+ "df = ds[[\"feature_importance_mean\", \"mProbes_mean\"]].to_dataframe()\n",
+ "hv.Scatter(df).opts(padding=0.1, width=400, height=400)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Receiver Operating Characteristic (ROC)\n",
+ "\n",
+ "Another way to evaluate a classification model is the [ROC](https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html#sphx-glr-auto-examples-model-selection-plot-roc-py).\n",
+ "\n",
+ "Viewing ROC curves for multi-label models is a bit indirect as we have to use a binary classifier for each unique target label, so we provide this walkthrough.\n",
+ "Also examine the [sklearn demo](https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.multiclass import OneVsRestClassifier\n",
+ "from sklearn.metrics import roc_curve, auc, roc_auc_score, plot_roc_curve\n",
+ "from sklearn import preprocessing, model_selection"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Get the count and annotation data from GSForge."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 97,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array(['CONTROL', 'HEAT', 'RECOV_HEAT', 'DROUGHT', 'RECOV_DROUGHT'],\n",
+ " dtype=object)"
+ ]
+ },
+ "execution_count": 97,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "counts, treatment = gsf.get_data(gsc, annotation_variables=\"Treatment\",selected_gene_sets=[\"Boruta_Treatment\"])\n",
+ "classes = treatment.to_series().unique()\n",
+ "classes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Encode the annotation labels with a [one hot encoder](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html#sklearn.preprocessing.OneHotEncoder)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 98,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[1., 0., 0., 0., 0.],\n",
+ " [1., 0., 0., 0., 0.],\n",
+ " [1., 0., 0., 0., 0.],\n",
+ " ...,\n",
+ " [0., 0., 0., 1., 0.],\n",
+ " [0., 0., 0., 1., 0.],\n",
+ " [0., 0., 0., 1., 0.]])"
+ ]
+ },
+ "execution_count": 98,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "enc = preprocessing.OneHotEncoder().fit(treatment.values.reshape(-1, 1))\n",
+ "treatment_onehot = enc.transform(treatment.values.reshape(-1, 1)).toarray()\n",
+ "treatment_onehot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Split the data and encoded annotations into a train and test set."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 99,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x_train, x_test, y_train, y_test = model_selection.train_test_split(counts, treatment_onehot)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Fit the model with the training data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 100,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 5 s, sys: 322 ms, total: 5.32 s\n",
+ "Wall time: 7.05 s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "roc_rf_model = OneVsRestClassifier(RandomForestClassifier(class_weight='balanced', max_depth=3,\n",
+ " n_estimators=1000, n_jobs=-1))\n",
+ "roc_rf_model = roc_rf_model.fit(x_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now predict class probabilities for the test count data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 101,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[0.73333022, 0.66598266, 0.00726503, 0.15976397, 0.01979039],\n",
+ " [0.05407875, 0.00253519, 0.87441341, 0.01685384, 0.1158844 ],\n",
+ " [0.12352609, 0.02328239, 0.52621856, 0.02173646, 0.57647534],\n",
+ " [0.26545577, 0.00429616, 0.0633575 , 0.03606417, 0.83556347],\n",
+ " [0.59150757, 0.1084059 , 0.0406514 , 0.4259592 , 0.03584087]])"
+ ]
+ },
+ "execution_count": 101,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_score = roc_rf_model.predict_proba(x_test)\n",
+ "y_score[:5]"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 102,
"metadata": {},
"outputs": [],
- "source": []
+ "source": [
+ "fpr = dict()\n",
+ "tpr = dict()\n",
+ "roc_auc = dict()\n",
+ "for i, class_ in enumerate(classes):\n",
+ " fpr[class_], tpr[class_], _ = roc_curve(y_test[:, i], y_score[:, i])\n",
+ " roc_auc[class_] = auc(fpr[class_], tpr[class_])"
+ ]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 103,
"metadata": {},
"outputs": [],
- "source": []
+ "source": [
+ "roc_curves = {class_: hv.Curve((fpr[class_], tpr[class_]))\n",
+ " for class_ in classes}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 104,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.holoviews_exec.v0+json": "",
+ "text/html": [
+ "
\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ ""
+ ],
+ "text/plain": [
+ ":NdOverlay [Element]\n",
+ " :Curve [x] (y)"
+ ]
+ },
+ "execution_count": 104,
+ "metadata": {
+ "application/vnd.holoviews_exec.v0+json": {
+ "id": "4195"
+ }
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "hv.NdOverlay(roc_curves).opts(padding=0.05, legend_position=\"top\", width=400, height=450)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## View a Decision Tree"
+ ]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 105,
"metadata": {},
"outputs": [],
- "source": []
+ "source": [
+ "from sklearn import tree\n",
+ "import graphviz"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Extract a single tree from a list of estimators."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 106,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "There are 1000 trees.\n"
+ ]
+ }
+ ],
+ "source": [
+ "control_rfc = roc_rf_model.estimators_[0]\n",
+ "print(f\"There are {len(control_rfc.estimators_)} trees.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 109,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "
"
+ ]
+ },
+ "execution_count": 109,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "selected_tree = control_rfc.estimators_[0]\n",
+ "graph = graphviz.Source(tree.export_graphviz(\n",
+ " selected_tree, \n",
+ " feature_names=x_train.Gene.values, \n",
+ " class_names=[\"Not CONTROL\", \"CONTROL\"],\n",
+ " filled=True, \n",
+ " rounded=True, \n",
+ " special_characters=True))\n",
+ "\n",
+ "graph\n",
+ "# graph.render('decision_tree', format=\"svg\") "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### View Genes Used in a Given Trees"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 127,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array(['LOC_Os05g49770', 'LOC_Os07g42280', 'LOC_Os06g04070',\n",
+ " 'LOC_Os06g12370', 'LOC_Os01g12420'], dtype=object)"
+ ]
+ },
+ "execution_count": 127,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "counts.Gene.isel(Gene=np.argwhere(selected_tree.feature_importances_).flatten()).values"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 130,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.holoviews_exec.v0+json": "",
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "
\n",
+ "
\n",
+ ""
+ ],
+ "text/plain": [
+ ":Raster [x,y] (z)"
+ ]
+ },
+ "execution_count": 130,
+ "metadata": {
+ "application/vnd.holoviews_exec.v0+json": {
+ "id": "4891"
+ }
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tree_subset = counts.isel(Gene=np.argwhere(selected_tree.feature_importances_).flatten())\n",
+ "hv.Raster(tree_subset.values).opts(logz=True)"
+ ]
},
{
"cell_type": "markdown",
@@ -2224,13 +2862,13 @@
"pycharm": {
"stem_cell": {
"cell_type": "raw",
- "source": [],
"metadata": {
"collapsed": false
- }
+ },
+ "source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 4
-}
\ No newline at end of file
+}
diff --git a/examples/user_guide/plotting_guide/Genewise_Scatter_and_Distributions.ipynb b/examples/user_guide/plotting_guide/Genewise_Scatter_and_Distributions.ipynb
index c6393b0..3f5b0f5 100644
--- a/examples/user_guide/plotting_guide/Genewise_Scatter_and_Distributions.ipynb
+++ b/examples/user_guide/plotting_guide/Genewise_Scatter_and_Distributions.ipynb
@@ -745,59 +745,6 @@
" \n",
"\n",
"\n",
- " \n",
- " \n",
"\n",
"\n",
" \n"
@@ -1377,47 +1324,6 @@
" }\n",
"}\n",
"\n",
- "// Define MPL specific subclasses\n",
- "function MPLSelectionWidget() {\n",
- " SelectionWidget.apply(this, arguments);\n",
- "}\n",
- "\n",
- "function MPLScrubberWidget() {\n",
- " ScrubberWidget.apply(this, arguments);\n",
- "}\n",
- "\n",
- "// Let them inherit from the baseclasses\n",
- "MPLSelectionWidget.prototype = Object.create(SelectionWidget.prototype);\n",
- "MPLScrubberWidget.prototype = Object.create(ScrubberWidget.prototype);\n",
- "\n",
- "// Define methods to override on widgets\n",
- "var MPLMethods = {\n",
- " init_slider : function(init_val){\n",
- " if(this.load_json) {\n",
- " this.from_json()\n",
- " } else {\n",
- " this.update_cache();\n",
- " }\n",
- " if (this.dynamic | !this.cached | (this.current_vals === undefined)) {\n",
- " this.update(0)\n",
- " } else {\n",
- " this.set_frame(this.current_vals[0], 0)\n",
- " }\n",
- " },\n",
- " process_msg : function(msg) {\n",
- " var data = msg.content.data;\n",
- " this.frames[this.current] = data;\n",
- " this.update_cache(true);\n",
- " this.update(this.current);\n",
- " }\n",
- "}\n",
- "// Extend MPL widgets with backend specific methods\n",
- "extend(MPLSelectionWidget.prototype, MPLMethods);\n",
- "extend(MPLScrubberWidget.prototype, MPLMethods);\n",
- "\n",
- "window.HoloViews.MPLSelectionWidget = MPLSelectionWidget\n",
- "window.HoloViews.MPLScrubberWidget = MPLScrubberWidget\n",
- "\n",
"// Define Bokeh specific subclasses\n",
"function BokehSelectionWidget() {\n",
" SelectionWidget.apply(this, arguments);\n",
@@ -1679,7 +1585,7 @@
" }\n",
"}\n"
],
- "application/vnd.holoviews_load.v0+json": "function HoloViewsWidget() {\n}\n\nHoloViewsWidget.prototype.init_slider = function(init_val){\n if(this.load_json) {\n this.from_json()\n } else {\n this.update_cache();\n }\n}\n\nHoloViewsWidget.prototype.populate_cache = function(idx){\n this.cache[idx].innerHTML = this.frames[idx];\n if (this.embed) {\n delete this.frames[idx];\n }\n}\n\nHoloViewsWidget.prototype.process_error = function(msg){\n}\n\nHoloViewsWidget.prototype.from_json = function() {\n var data_url = this.json_path + this.id + '.json';\n $.getJSON(data_url, $.proxy(function(json_data) {\n this.frames = json_data;\n this.update_cache();\n this.update(0);\n }, this));\n}\n\nHoloViewsWidget.prototype.dynamic_update = function(current){\n if (current === undefined) {\n return\n }\n this.current = current;\n if (this.comm) {\n var msg = {comm_id: this.id+'_client', content: current}\n this.comm.send(msg);\n }\n}\n\nHoloViewsWidget.prototype.update_cache = function(force){\n var frame_len = Object.keys(this.frames).length;\n for (var i=0; i