Skip to content

Commit

Permalink
hotfix PLSDA predicted groups
Browse files Browse the repository at this point in the history
- predicted group now correctly assigned based on ingroup probability or yhat value
  • Loading branch information
grlloyd committed Sep 5, 2023
1 parent 710dd2a commit 46c7973
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 91 deletions.
186 changes: 105 additions & 81 deletions R/PLSDA_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
#' @include PLSR_class.R
#' @examples
#' M = PLSDA('number_components'=2,factor_name='Species')
PLSDA = function(number_components=2,factor_name,...) {
PLSDA = function(number_components=2,factor_name,pred_method='max_prob',...) {
out=struct::new_struct('PLSDA',
number_components=number_components,
factor_name=factor_name,
...)
number_components=number_components,
factor_name=factor_name,
pred_method=pred_method,
...)
return(out)
}

Expand All @@ -29,19 +30,24 @@ PLSDA = function(number_components=2,factor_name,...) {
pred='data.frame',
threshold='numeric',
sr = 'entity',
sr_pvalue='entity'
sr_pvalue='entity',
pred_method='entity'

),
prototype = list(name='Partial least squares discriminant analysis',
prototype = list(
name='Partial least squares discriminant analysis',
type="classification",
predicted='pred',
libraries='pls',
description=paste0('PLS is a multivariate regression technique that ',
description=paste0(
'PLS is a multivariate regression technique that ',
'extracts latent variables maximising covariance between the input ',
'data and the response. The Discriminant Analysis variant uses group ',
'labels in the response variable and applies a threshold to the ',
'predicted values in order to predict group membership for new samples.'),
.params=c('number_components','factor_name'),
'labels in the response variable. For >2 groups a 1-vs-all ',
'approach is used. Group membership can be predicted for test ',
'samples based on a probability estimate of group membership, ',
'or the estimated y-value.'),
.params=c('number_components','factor_name','pred_method'),
.outputs=c(
'scores',
'loadings',
Expand All @@ -57,12 +63,28 @@ PLSDA = function(number_components=2,factor_name,...) {
'sr',
'sr_pvalue'),

number_components=entity(value = 2,
number_components=entity(
value = 2,
name = 'Number of components',
description = 'The number of PLS components',
type = c('numeric','integer')
),
factor_name=ents$factor_name,
pred_method=enum(
name='Prediction method',
description=c(
'max_yhat'=
paste0('The predicted group is selected based on the ',
'largest value of y_hat.'),
'max_prob'=
paste0('The predicted group is selected based on the ',
'largest probability of group membership.')
),
value='max_prob',
allowed=c('max_yhat','max_prob'),
type='character',
max_length=1
),
sr = entity(
name = 'Selectivity ratio',
description = paste0(
Expand Down Expand Up @@ -92,8 +114,8 @@ PLSDA = function(number_components=2,factor_name,...) {
pages = '122-128',
author = as.person("Nestor F. Perez and Joan Ferre and Ricard Boque"),
title = paste0('Calculation of the reliability of ',
'classification in discriminant partial least-squares ',
'binary classification'),
'classification in discriminant partial least-squares ',
'binary classification'),
journal = "Chemometrics and Intelligent Laboratory Systems"
),
bibentry(
Expand All @@ -113,80 +135,83 @@ PLSDA = function(number_components=2,factor_name,...) {
#' @export
#' @template model_train
setMethod(f="model_train",
signature=c("PLSDA",'DatasetExperiment'),
definition=function(M,D)
{
SM=D$sample_meta
y=SM[[M$factor_name]]
# convert the factor to a design matrix
z=model.matrix(~y+0)
z[z==0]=-1 # +/-1 for PLS

X=as.matrix(D$data) # convert X to matrix

Z=as.data.frame(z)
colnames(Z)=as.character(interaction('PLSDA',1:ncol(Z),sep='_'))

D$sample_meta=cbind(D$sample_meta,Z)

# PLSR model
N = PLSR(number_components=M$number_components,factor_name=colnames(Z))
N = model_apply(N,D)

# copy outputs across
output_list(M) = output_list(N)

# some specific outputs for PLSDA
output_value(M,'design_matrix')=Z
output_value(M,'y')=D$sample_meta[,M$factor_name,drop=FALSE]

# for PLSDA compute probabilities
probs=prob(as.matrix(M$yhat),as.matrix(M$yhat),D$sample_meta[[M$factor_name]])
output_value(M,'probability')=as.data.frame(probs$ingroup)
output_value(M,'threshold')=probs$threshold

# update column names for outputs
colnames(M$reg_coeff)=levels(y)
colnames(M$sr)=levels(y)
colnames(M$vip)=levels(y)
colnames(M$yhat)=levels(y)
colnames(M$design_matrix)=levels(y)
colnames(M$probability)=levels(y)
names(M$threshold)=levels(y)
colnames(M$sr_pvalue)=levels(y)

return(M)
}
signature=c("PLSDA",'DatasetExperiment'),
definition=function(M,D)
{
SM=D$sample_meta
y=SM[[M$factor_name]]
# convert the factor to a design matrix
z=model.matrix(~y+0)
z[z==0]=-1 # +/-1 for PLS
X=as.matrix(D$data) # convert X to matrix
Z=as.data.frame(z)
colnames(Z)=as.character(interaction('PLSDA',1:ncol(Z),sep='_'))
D$sample_meta=cbind(D$sample_meta,Z)
# PLSR model
N = PLSR(number_components=M$number_components,factor_name=colnames(Z))
N = model_apply(N,D)
# copy outputs across
output_list(M) = output_list(N)
# some specific outputs for PLSDA
output_value(M,'design_matrix')=Z
output_value(M,'y')=D$sample_meta[,M$factor_name,drop=FALSE]
# for PLSDA compute probabilities
probs=prob(as.matrix(M$yhat),as.matrix(M$yhat),D$sample_meta[[M$factor_name]])
output_value(M,'probability')=as.data.frame(probs$ingroup)
output_value(M,'threshold')=probs$threshold
# update column names for outputs
colnames(M$reg_coeff)=levels(y)
colnames(M$sr)=levels(y)
colnames(M$vip)=levels(y)
colnames(M$yhat)=levels(y)
colnames(M$design_matrix)=levels(y)
colnames(M$probability)=levels(y)
names(M$threshold)=levels(y)
colnames(M$sr_pvalue)=levels(y)
return(M)
}
)

#' @export
#' @template model_predict
setMethod(f="model_predict",
signature=c("PLSDA",'DatasetExperiment'),
definition=function(M,D)
{
# call PLSR predict
N=callNextMethod(M,D)
SM=N$y

## probability estimate
# http://www.eigenvector.com/faq/index.php?id=38%7C
p=as.matrix(N$pred)
d=prob(x=p,yhat=as.matrix(N$yhat),ytrue=SM[[M$factor_name]])
pred=(p>d$threshold)*1
pred=apply(pred,MARGIN=1,FUN=which.max)
hi=apply(d$ingroup,MARGIN=1,FUN=which.max) # max probability
if (sum(is.na(pred)>0)) {
pred[is.na(pred)]=hi[is.na(pred)] # if none above threshold, use group with highest probability
}
pred=factor(pred,levels=1:nlevels(SM[[M$factor_name]]),labels=levels(SM[[M$factor_name]])) # make sure pred has all the levels of y
q=data.frame("pred"=pred)
output_value(M,'pred')=q
return(M)
}
signature=c("PLSDA",'DatasetExperiment'),
definition=function(M,D)
{
# call PLSR predict
N=callNextMethod(M,D)
SM=N$y

## probability estimate
# http://www.eigenvector.com/faq/index.php?id=38%7C
p=as.matrix(N$pred)
d=prob(x=p,yhat=as.matrix(N$yhat),ytrue=M$y[[M$factor_name]])

# predictions
if (M$pred_method=='max_yhat') {
pred=apply(p,MARGIN=1,FUN=which.max)
} else if (M$pred_method=='max_prob') {
pred=apply(d$ingroup,MARGIN=1,FUN=which.max)
}
pred=factor(pred,levels=1:nlevels(SM[[M$factor_name]]),labels=levels(SM[[M$factor_name]])) # make sure pred has all the levels of y
q=data.frame("pred"=pred)
output_value(M,'pred')=q
return(M)
}
)




prob=function(x,yhat,ytrue)
{
# x is predicted values
Expand Down Expand Up @@ -250,8 +275,7 @@ prob=function(x,yhat,ytrue)
}


gauss_intersect=function(m1,m2,s1,s2)
{
gauss_intersect=function(m1,m2,s1,s2) {
#https://stackoverflow.com/questions/22579434/python-finding-the-intersection-point-of-two-gaussian-curves
a=(1/(2*s1*s1))-(1/(2*s2*s2))
b=(m2/(s2*s2)) - (m1/(s1*s1))
Expand Down
6 changes: 4 additions & 2 deletions man/PLSDA.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-gridsearch1d.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ test_that('grid_search iterator',{
# run
I=run(I,D,B)
# calculate metric
expect_equal(I$metric$value,0.3,tolerance=0.05)
expect_equal(I$metric$value,0.045,tolerance=0.0005)
})

# test grid search
Expand All @@ -36,7 +36,7 @@ test_that('grid_search wf',{
# run
I=run(I,D,B)
# calculate metric
expect_equal(I$metric$value[1],0.3,tolerance=0.05)
expect_equal(I$metric$value[1],0.04,tolerance=0.005)
})

# test grid search
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-kfold-xval.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ test_that('kfold xval venetian',{
# run
I=run(I,D,B)
# calculate metric
expect_equal(I$metric$mean,0.23,tolerance=0.05)
expect_equal(I$metric$mean,0.11,tolerance=0.005)
})

test_that('kfold xval blocks',{
Expand All @@ -26,7 +26,7 @@ test_that('kfold xval blocks',{
# run
I=run(I,D,B)
# calculate metric
expect_equal(I$metric$mean,0.23,tolerance=0.05)
expect_equal(I$metric$mean,0.115,tolerance=0.005)
})

test_that('kfold xval random',{
Expand All @@ -40,7 +40,7 @@ test_that('kfold xval random',{
# run
I=run(I,D,B)
# calculate metric
expect_equal(I$metric$mean,0.23,tolerance=0.05)
expect_equal(I$metric$mean,0.105,tolerance=0.0005)
})

test_that('kfold xval metric plot',{
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-permutation_test.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ test_that('permutation test',{
# calculate metric
B=calculate(B,Yhat=output_value(I,'results.unpermuted')$predicted,
Y=output_value(I,'results.unpermuted')$actual)
expect_equal(value(B),expected=0.211,tolerance=0.004)
expect_equal(value(B),expected=0.105,tolerance=0.0005)
})

# permutation test box plot
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-permute-sample-order.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ test_that('permute sample order model_seq',{
B=balanced_accuracy()
# run
I=run(I,D,B)
expect_equal(I$metric$mean,expected=0.335,tolerance=0.05)
expect_equal(I$metric$mean,expected=0.04,tolerance=0.005)
})

# permute sample order
Expand All @@ -23,5 +23,5 @@ test_that('permute sample order iterator',{
B=balanced_accuracy()
# run
I=run(I,D,B)
expect_equal(I$metric$mean,expected=0.339,tolerance=0.05)
expect_equal(I$metric$mean,expected=0.048,tolerance=0.0005)
})

0 comments on commit 46c7973

Please sign in to comment.