WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit 3478476

Browse files
committed
feat: replace likelihood summation with Kahan summation
1 parent 941433d commit 3478476

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

src/main/scala/com/fulcrumgenomics/umi/ConsensusCaller.scala

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import com.fulcrumgenomics.util.NumericTypes._
2929
import htsjdk.samtools.util.SequenceUtil
3030

3131
import java.util
32-
import scala.collection.mutable.ArrayBuffer
3332

3433
object ConsensusCaller {
3534
type Base = Byte
@@ -83,14 +82,19 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
8382
class ConsensusBaseBuilder {
8483
private val observations = new Array[Int](DnaBaseCount)
8584

86-
// Note: to ensure numerical stability, we store the terms we want to eventually sum, rather than storing the sum
87-
// itself. We can then sum smallest (in magnitude) to largest (in magnitude) for numerical stability.
88-
private val likelihoods = Range.inclusive(1, DnaBaseCount).map { _ => new ArrayBuffer[LogProbability](256) }.toArray
85+
// Note: to ensure numerical stability, we use Kahan (compensated) summation to accumulate the likelihoods.
86+
// This tracks rounding errors in a separate compensation array and corrects for them on each addition.
87+
private val likelihoods = new Array[LogProbability](DnaBaseCount)
88+
private val compensation = new Array[LogProbability](DnaBaseCount)
89+
90+
// Initialize on construction
91+
reset()
8992

9093
/** Resets the likelihoods to p=1 so that the builder can be re-used. */
9194
def reset(): Unit = {
9295
util.Arrays.fill(observations, 0)
93-
this.likelihoods.foreach { arr => arr.clear() }
96+
util.Arrays.fill(likelihoods, LnOne)
97+
util.Arrays.fill(compensation, 0.0)
9498
}
9599

96100
/** Adds a base and un-adjusted base quality to the consensus likelihoods. */
@@ -108,18 +112,26 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
108112
while (i < DnaBaseCount) {
109113
val candidateBase = DnaBasesUpperCase(i)
110114
if (base == candidateBase) {
111-
likelihoods(i) += pTruth
115+
kahanAdd(i, pTruth)
112116
observations(i) += 1
113117
}
114118
else {
115-
likelihoods(i) += pErrorPerBase
119+
kahanAdd(i, pErrorPerBase)
116120
}
117121

118122
i += 1
119123
}
120124
}
121125
}
122126

127+
/** Adds a term to the likelihood at the given index using Kahan (compensated) summation. */
128+
private def kahanAdd(index: Int, term: LogProbability): Unit = {
129+
val compensatedTerm = term - compensation(index)
130+
val newSum = likelihoods(index) + compensatedTerm
131+
compensation(index) = (newSum - likelihoods(index)) - compensatedTerm
132+
likelihoods(index) = newSum
133+
}
134+
123135
/**
124136
* Returns the number of reads that contributed evidence to the consensus. The value is equal
125137
* to the number of times add() was called with non-ambiguous bases.
@@ -137,19 +149,12 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
137149
case x => throw new IllegalArgumentException("Unsupported base: " + x.toChar)
138150
}
139151

140-
/** Produces the final likelihoods per base by sorting the accumulated likelihood terms from smallest
141-
* (in magnitude) to largest (in magnitude) and then summing them in that order. The values can be
142-
* negated as we are operating using log probabilities, which should always be either zero or negative.*/
143-
private def finalLikelihoods: Array[LogProbability] = {
144-
likelihoods.map(_.sortInPlaceBy(v => -v).sum)
145-
}
146-
147152
/** Call the consensus base and quality score given the current set of likelihoods. */
148153
def call() : (Base, PhredScore) = {
149-
// sum the likelihood terms in a numerically stable way
154+
// likelihoods are accumulated using Kahan summation for numerical stability
150155
// pick the base with the maximum posterior
151-
val lls = finalLikelihoods
152-
val likelihoodSum = LogProbability.or(lls)
156+
val lls = likelihoods
157+
val likelihoodSum = LogProbability.or(lls)
153158
val (maxLikelihood, maxLlIndex) = MathUtil.maxWithIndex(lls, requireUniqueMaximum=true)
154159

155160
maxLlIndex match {
@@ -177,7 +182,7 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
177182
* labels.
178183
*/
179184
private[umi] def logLikelihoods: Array[LogProbability] = {
180-
val lls = finalLikelihoods
185+
val lls = likelihoods
181186
val likelihoodSum = LogProbability.or(lls)
182187
val posteriors = lls.map(l => LogProbability.normalizeByLogProbability(l, likelihoodSum))
183188
val errors = posteriors.map(LogProbability.not)

0 commit comments

Comments
 (0)