Skip to content

Commit

Permalink
Fix an error causing xtick label wrong
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacobluke- committed Mar 25, 2024
1 parent d712e8e commit e5b26a1
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 38 deletions.
38 changes: 19 additions & 19 deletions dabest/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,32 +805,32 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
ticks_with_counts = []
ticks_loc = rawdata_axes.get_xticks()
rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc))
for xticklab in rawdata_axes.xaxis.get_ticklabels():
t = xticklab.get_text()
# Extract the text after the last newline, if present
te = t[t.rfind("\n") + len("\n"):] if t.rfind("\n") != -1 else t

def lookup_value(text, counts):
try:
# Try to access 'counts' directly with 'te'.
N = str(counts.loc[te])
return str(counts.loc[text])
except KeyError:
# If direct access fails, attempt a numeric interpretation.
try:
# Attempt to convert 'te' to numeric (float or int, as appropriate)
numeric_key = pd.to_numeric(te, errors='coerce')
# 'pd.to_numeric()' will convert strings to float or int, as appropriate,
# and will return NaN if conversion fails. It preserves integers.
if pd.notnull(numeric_key): # Check if conversion was successful
N = str(counts.loc[numeric_key])
numeric_key = pd.to_numeric(text, errors='coerce')
if pd.notnull(numeric_key):
return str(counts.loc[numeric_key])
else:
raise ValueError # Raise an error to trigger the except block
raise ValueError
except (ValueError, KeyError):
# Handle cases where 'te' cannot be converted or the converted key doesn't exist
print(f"Key '{te}' not found in counts.")
N = "N/A"
print(f"Key '{text}' not found in counts.")
return "N/A"
for xticklab in rawdata_axes.xaxis.get_ticklabels():
t = xticklab.get_text()
# Extract the text after the last newline, if present
if t.rfind("\n") != -1:
te = t[t.rfind("\n") + len("\n"):]
value = lookup_value(te, counts)
te = t
else:
te = t
value = lookup_value(te, counts)

# Append the modified tick label with the count to the list
ticks_with_counts.append(f"{te}\nN = {N}")
ticks_with_counts.append(f"{te}\nN = {value}")


if plot_kwargs["fontsize_rawxlabel"] is not None:
Expand Down
38 changes: 19 additions & 19 deletions nbs/API/plotter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -864,32 +864,32 @@
" ticks_with_counts = []\n",
" ticks_loc = rawdata_axes.get_xticks()\n",
" rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc))\n",
" for xticklab in rawdata_axes.xaxis.get_ticklabels():\n",
" t = xticklab.get_text()\n",
" # Extract the text after the last newline, if present\n",
" te = t[t.rfind(\"\\n\") + len(\"\\n\"):] if t.rfind(\"\\n\") != -1 else t\n",
"\n",
" def lookup_value(text, counts):\n",
" try:\n",
" # Try to access 'counts' directly with 'te'.\n",
" N = str(counts.loc[te])\n",
" return str(counts.loc[text])\n",
" except KeyError:\n",
" # If direct access fails, attempt a numeric interpretation.\n",
" try:\n",
" # Attempt to convert 'te' to numeric (float or int, as appropriate)\n",
" numeric_key = pd.to_numeric(te, errors='coerce')\n",
" # 'pd.to_numeric()' will convert strings to float or int, as appropriate,\n",
" # and will return NaN if conversion fails. It preserves integers.\n",
" if pd.notnull(numeric_key): # Check if conversion was successful\n",
" N = str(counts.loc[numeric_key])\n",
" numeric_key = pd.to_numeric(text, errors='coerce')\n",
" if pd.notnull(numeric_key):\n",
" return str(counts.loc[numeric_key])\n",
" else:\n",
" raise ValueError # Raise an error to trigger the except block\n",
" raise ValueError\n",
" except (ValueError, KeyError):\n",
" # Handle cases where 'te' cannot be converted or the converted key doesn't exist\n",
" print(f\"Key '{te}' not found in counts.\")\n",
" N = \"N/A\"\n",
" print(f\"Key '{text}' not found in counts.\")\n",
" return \"N/A\"\n",
" for xticklab in rawdata_axes.xaxis.get_ticklabels():\n",
" t = xticklab.get_text()\n",
" # Extract the text after the last newline, if present\n",
" if t.rfind(\"\\n\") != -1:\n",
" te = t[t.rfind(\"\\n\") + len(\"\\n\"):]\n",
" value = lookup_value(te, counts)\n",
" te = t\n",
" else:\n",
" te = t\n",
" value = lookup_value(te, counts)\n",
"\n",
" # Append the modified tick label with the count to the list\n",
" ticks_with_counts.append(f\"{te}\\nN = {N}\")\n",
" ticks_with_counts.append(f\"{te}\\nN = {value}\")\n",
"\n",
"\n",
" if plot_kwargs[\"fontsize_rawxlabel\"] is not None:\n",
Expand Down
88 changes: 88 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import numpy as np
from scipy.stats import norm
import pandas as pd
import matplotlib as mpl
import os
from pathlib import Path

import matplotlib.ticker as Ticker
import matplotlib.pyplot as plt

from dabest._api import load

import dabest

columns = [1, 2.0]
columns_str = ["1", "2.0"]
# create a test database
N = 100
df = pd.DataFrame(np.vstack([np.random.normal(loc=i, size=(N,)) for i in range(len(columns))]).T, columns=columns_str)
females = np.repeat("Female", N / 2).tolist()
males = np.repeat("Male", N / 2).tolist()
df['gender'] = females + males

# Add an `id` column for paired data plotting.
df['ID'] = pd.Series(range(1, N + 1))


db = dabest.load(data=df, idx=columns_str, paired="baseline", id_col="ID")
print(db.mean_diff)
db.mean_diff.plot();

# def create_demo_dataset(seed=9999, N=20):
# import numpy as np
# import pandas as pd
# from scipy.stats import norm # Used in generation of populations.

# np.random.seed(9999) # Fix the seed so the results are replicable.
# # pop_size = 10000 # Size of each population.

# # Create samples
# c1 = norm.rvs(loc=3, scale=0.4, size=N)
# c2 = norm.rvs(loc=3.5, scale=0.75, size=N)
# c3 = norm.rvs(loc=3.25, scale=0.4, size=N)

# t1 = norm.rvs(loc=3.5, scale=0.5, size=N)
# t2 = norm.rvs(loc=2.5, scale=0.6, size=N)
# t3 = norm.rvs(loc=3, scale=0.75, size=N)
# t4 = norm.rvs(loc=3.5, scale=0.75, size=N)
# t5 = norm.rvs(loc=3.25, scale=0.4, size=N)
# t6 = norm.rvs(loc=3.25, scale=0.4, size=N)

# # Add a `gender` column for coloring the data.
# females = np.repeat("Female", N / 2).tolist()
# males = np.repeat("Male", N / 2).tolist()
# gender = females + males

# # Add an `id` column for paired data plotting.
# id_col = pd.Series(range(1, N + 1))

# # Combine samples and gender into a DataFrame.
# df = pd.DataFrame(
# {
# "Control 1": c1,
# "Test 1": t1,
# "Control 2": c2,
# "Test 2": t2,
# "Control 3": c3,
# "Test 3": t3,
# "Test 4": t4,
# "Test 5": t5,
# "Test 6": t6,
# "Gender": gender,
# "ID": id_col,
# }
# )

# return df


# df = create_demo_dataset()

# two_groups_unpaired = load(df, idx=("Control 1", "Test 1"))

# two_groups_paired = load(
# df, idx=("Control 1", "Test 1"), paired="baseline", id_col="ID"
# )

# two_groups_unpaired.mean_diff.plot()

0 comments on commit e5b26a1

Please sign in to comment.