Skip to content

Commit

Permalink
Refactore apply task to use collection interface
Browse files Browse the repository at this point in the history
  • Loading branch information
lukfor committed Dec 2, 2023
1 parent b69c08b commit 727ee42
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 67 deletions.
35 changes: 35 additions & 0 deletions src/main/java/genepi/riskscore/io/scores/IRiskScoreCollection.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package genepi.riskscore.io.scores;

import java.util.Map.Entry;
import java.util.Set;
import java.util.SortedSet;

import genepi.riskscore.io.Chunk;
import genepi.riskscore.model.ReferenceVariant;
import genepi.riskscore.model.RiskScoreSummary;

public interface IRiskScoreCollection {

public String getBuild();

public String getName();

public String getVersion();

public void buildIndex(String chromosome, Chunk chunk, String dbsnp, String proxies) throws Exception;

public RiskScoreSummary getSummary(int index);

public boolean contains(int index, int position);

public Set<Entry<Integer, ReferenceVariant>> getAllVariants(int index);

public ReferenceVariant getVariant(int index, int position);

public int getSize();

public boolean isEmpty();

public RiskScoreSummary[] getSummaries();

}
147 changes: 147 additions & 0 deletions src/main/java/genepi/riskscore/io/scores/RiskScoreCollection.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package genepi.riskscore.io.scores;

import java.io.File;
import java.util.*;

import genepi.riskscore.io.Chunk;
import genepi.riskscore.io.RiskScoreFile;
import genepi.riskscore.io.formats.RiskScoreFormatFactory.RiskScoreFormat;
import genepi.riskscore.model.ReferenceVariant;
import genepi.riskscore.model.RiskScoreSummary;

public class RiskScoreCollection implements IRiskScoreCollection {

private String build;

private String name;

private String version;

private RiskScoreFile[] riskscores;

private RiskScoreSummary[] summaries;

private int numberRiskScores;

private String[] filenames;

private boolean verbose = false;

private Map<String, RiskScoreFormat> formats;

public RiskScoreCollection(String... filenames) {
this.filenames = filenames;
}

public RiskScoreCollection(String[] filenames, Map<String, RiskScoreFormat> formats) {
this.filenames = filenames;
this.formats = formats;
}

@Override
public String getBuild() {
return build;
}

@Override
public String getName() {
return name;
}

@Override
public String getVersion() {
return version;
}

@Override
public void buildIndex(String chromosome, Chunk chunk, String dbsnp, String proxies) throws Exception {

numberRiskScores = filenames.length;

summaries = new RiskScoreSummary[numberRiskScores];
for (int i = 0; i < numberRiskScores; i++) {
String name = RiskScoreFile.getName(filenames[i]);
summaries[i] = new RiskScoreSummary(name);
}

int total = 0;

riskscores = new RiskScoreFile[numberRiskScores];
for (int i = 0; i < numberRiskScores; i++) {

if (verbose) {
System.out.println("Loading file " + filenames[i] + "...");
}

RiskScoreFormat format = formats.get(filenames[i]);
RiskScoreFile riskscore = new RiskScoreFile(filenames[i], format, dbsnp, proxies);

if (chunk != null) {
riskscore.buildIndex(chromosome, chunk);
} else {
riskscore.buildIndex(chromosome);
}

summaries[i].setVariants(riskscore.getTotalVariants());
summaries[i].setVariantsIgnored(riskscore.getIgnoredVariants());

if (verbose) {
System.out.println("Loaded " + riskscore.getLoadedVariants() + " weights for chromosome " + chromosome);
}
total += riskscore.getLoadedVariants();
riskscores[i] = riskscore;

}

if (verbose) {
System.out.println();
System.out.println("Collection contains " + total + " weights for chromosome " + chromosome);
System.out.println();
}
}

@Override
public RiskScoreSummary getSummary(int index) {
return summaries[index];
}

@Override
public boolean contains(int index, int position) {
return riskscores[index].contains(position);
}

@Override
public ReferenceVariant getVariant(int index, int position) {
return riskscores[index].getVariant(position);
}

@Override
public Set<Map.Entry<Integer, ReferenceVariant>> getAllVariants(int index) {
return riskscores[index].getVariants().entrySet();
}

@Override
public int getSize() {
return numberRiskScores;
}

@Override
public boolean isEmpty() {
for (RiskScoreFile riskscore : riskscores) {
if (riskscore.getLoadedVariants() > 0) {
return false;
}
}
return true;
}

@Override
public RiskScoreSummary[] getSummaries() {
return summaries;
}

public void setVerbose(boolean verbose) {
this.verbose = verbose;
}

}
87 changes: 20 additions & 67 deletions src/main/java/genepi/riskscore/tasks/ApplyScoreTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import genepi.riskscore.io.SamplesFile;
import genepi.riskscore.io.VariantFile;
import genepi.riskscore.io.formats.RiskScoreFormatFactory.RiskScoreFormat;
import genepi.riskscore.io.scores.IRiskScoreCollection;
import genepi.riskscore.io.scores.RiskScoreCollection;
import genepi.riskscore.io.vcf.FastVCFFileReader;
import genepi.riskscore.io.vcf.MinimalVariantContext;
import genepi.riskscore.model.ReferenceVariant;
Expand Down Expand Up @@ -55,10 +57,6 @@ public class ApplyScoreTask implements ITaskRunnable {

private String genotypeFormat = DOSAGE_FORMAT;

private int numberRiskScores = 0;

private RiskScoreSummary[] summaries;

private String output;

private String outputEffectsFilename;
Expand All @@ -74,6 +72,8 @@ public class ApplyScoreTask implements ITaskRunnable {
private boolean inverseDosage = false;

private boolean averaging = false;

private IRiskScoreCollection collection;

public static final String INFO_R2 = "R2";

Expand Down Expand Up @@ -171,30 +171,16 @@ public void run(ITaskMonitor monitor) throws Exception {
monitor.begin(taskName, new File(vcf).length());
monitor.worked(0);

numberRiskScores = riskScoreFilenames.length;
summaries = new RiskScoreSummary[numberRiskScores];
for (int i = 0; i < numberRiskScores; i++) {
String name = RiskScoreFile.getName(riskScoreFilenames[i]);
summaries[i] = new RiskScoreSummary(name);
}

RiskScoreFile[] riskscores = loadReferenceFiles(monitor, chromosome, dbsnp, proxies, riskScoreFilenames);

boolean empty = true;
for (RiskScoreFile riskscore : riskscores) {
if (riskscore.getLoadedVariants() > 0) {
empty = false;
break;
}
}
collection = new RiskScoreCollection(riskScoreFilenames, formats);
collection.buildIndex(chromosome, chunk, dbsnp, proxies);

processVCF(monitor, chromosome, vcf, riskscores, empty);
processVCF(monitor, chromosome, vcf, collection);

OutputFileWriter outputFile = new OutputFileWriter(riskScores, summaries);
OutputFileWriter outputFile = new OutputFileWriter(riskScores, collection.getSummaries());
outputFile.save(output);

if (outputReportFilename != null) {
ReportFile reportFile = new ReportFile(summaries);
ReportFile reportFile = new ReportFile(collection.getSummaries());
reportFile.save(outputReportFilename);
}

Expand All @@ -215,37 +201,7 @@ public void run(ITaskMonitor monitor) throws Exception {

}

private RiskScoreFile[] loadReferenceFiles(ITaskMonitor monitor, String chromosome, String dbsnp, String proxies,
String... riskScoreFilenames) throws Exception {

RiskScoreFile[] riskscores = new RiskScoreFile[numberRiskScores];
for (int i = 0; i < numberRiskScores; i++) {

debug("Loading file " + riskScoreFilenames[i] + "...");

RiskScoreFormat format = formats.get(riskScoreFilenames[i]);
RiskScoreFile riskscore = new RiskScoreFile(riskScoreFilenames[i], format, dbsnp, proxies);

if (chunk != null) {
riskscore.buildIndex(chromosome, chunk);
} else {
riskscore.buildIndex(chromosome);
}

summaries[i].setVariants(riskscore.getTotalVariants());
summaries[i].setVariantsIgnored(riskscore.getIgnoredVariants());

debug("Loaded " + riskscore.getLoadedVariants() + " weights for chromosome " + chromosome);
riskscores[i] = riskscore;
monitor.worked(0);
}

return riskscores;

}

private void processVCF(ITaskMonitor monitor, String chromosome, String vcfFilename, RiskScoreFile[] riskscores,
boolean empty) throws Exception {
private void processVCF(ITaskMonitor monitor, String chromosome, String vcfFilename,IRiskScoreCollection collection) throws Exception {

debug("Loading file " + vcfFilename + "...");

Expand Down Expand Up @@ -294,7 +250,7 @@ private void processVCF(ITaskMonitor monitor, String chromosome, String vcfFilen

int proxy = 0;

while (vcfReader.next() && !outOfChunk && !empty) {
while (vcfReader.next() && !outOfChunk && !collection.isEmpty()) {

if (monitor.isCanceled()) {
return;
Expand Down Expand Up @@ -324,12 +280,11 @@ private void processVCF(ITaskMonitor monitor, String chromosome, String vcfFilen

}

for (int j = 0; j < riskScoreFilenames.length; j++) {
for (int j = 0; j < collection.getSize(); j++) {

RiskScoreSummary summary = summaries[j];
RiskScoreSummary summary = collection.getSummaries()[j];

RiskScoreFile riskscore = riskscores[j];
boolean isPartOfRiskScore = riskscore.contains(position);
boolean isPartOfRiskScore = collection.contains(j, position);

if (!isPartOfRiskScore) {
summary.incNotFound();
Expand All @@ -349,8 +304,7 @@ private void processVCF(ITaskMonitor monitor, String chromosome, String vcfFilen
summary.incR2Filtered();
continue;
}

ReferenceVariant referenceVariant = riskscore.getVariant(position);
ReferenceVariant referenceVariant = collection.getVariant(j, position);

float effectWeight = referenceVariant.getEffectWeight();

Expand Down Expand Up @@ -477,10 +431,9 @@ private void processVCF(ITaskMonitor monitor, String chromosome, String vcfFilen
if (variantsWriter != null) {

// write all unused variants to file!
for (int j = 0; j < riskScoreFilenames.length; j++) {
RiskScoreSummary summary = summaries[j];
RiskScoreFile riskscore = riskscores[j];
for (Entry<Integer, ReferenceVariant> item : riskscore.getVariants().entrySet()) {
for (int j = 0; j < collection.getSize(); j++) {
RiskScoreSummary summary = collection.getSummaries()[j];
for (Entry<Integer, ReferenceVariant> item : collection.getAllVariants(j)) {
ReferenceVariant variant = item.getValue();
int position = item.getKey();

Expand Down Expand Up @@ -518,7 +471,7 @@ private void processVCF(ITaskMonitor monitor, String chromosome, String vcfFilen
}

vcfReader.close();
debug("Used " + proxy + " proxies");
debug("Used " + proxy + " proxies");
debug("Loaded " + countSamples + " samples and " + countVariants + " variants.");

}
Expand All @@ -545,7 +498,7 @@ public int getCountSamples() {
}

public RiskScoreSummary[] getSummaries() {
return summaries;
return collection.getSummaries();
}

int getCountVariants() {
Expand Down

0 comments on commit 727ee42

Please sign in to comment.