@@ -29,7 +29,6 @@ import com.fulcrumgenomics.util.NumericTypes._
2929import htsjdk .samtools .util .SequenceUtil
3030
3131import java .util
32- import scala .collection .mutable .ArrayBuffer
3332
3433object ConsensusCaller {
3534 type Base = Byte
@@ -83,14 +82,19 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
8382 class ConsensusBaseBuilder {
8483 private val observations = new Array [Int ](DnaBaseCount )
8584
86- // Note: to ensure numerical stability, we store the terms we want to eventually sum, rather than storing the sum
87- // itself. We can then sum smallest (in magnitude) to largest (in magnitude) for numerical stability.
88- private val likelihoods = Range .inclusive(1 , DnaBaseCount ).map { _ => new ArrayBuffer [LogProbability ](256 ) }.toArray
85+ // Note: to ensure numerical stability, we use Kahan (compensated) summation to accumulate the likelihoods.
86+ // This tracks rounding errors in a separate compensation array and corrects for them on each addition.
87+ private val likelihoods = new Array [LogProbability ](DnaBaseCount )
88+ private val compensation = new Array [LogProbability ](DnaBaseCount )
89+
90+ // Initialize on construction
91+ reset()
8992
9093 /** Resets the likelihoods to p=1 so that the builder can be re-used. */
9194 def reset (): Unit = {
9295 util.Arrays .fill(observations, 0 )
93- this .likelihoods.foreach { arr => arr.clear() }
96+ util.Arrays .fill(likelihoods, LnOne )
97+ util.Arrays .fill(compensation, 0.0 )
9498 }
9599
96100 /** Adds a base and un-adjusted base quality to the consensus likelihoods. */
@@ -108,18 +112,26 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
108112 while (i < DnaBaseCount ) {
109113 val candidateBase = DnaBasesUpperCase (i)
110114 if (base == candidateBase) {
111- likelihoods(i) += pTruth
115+ kahanAdd(i, pTruth)
112116 observations(i) += 1
113117 }
114118 else {
115- likelihoods(i) += pErrorPerBase
119+ kahanAdd(i, pErrorPerBase)
116120 }
117121
118122 i += 1
119123 }
120124 }
121125 }
122126
127+ /** Adds a term to the likelihood at the given index using Kahan (compensated) summation. */
128+ private def kahanAdd (index : Int , term : LogProbability ): Unit = {
129+ val compensatedTerm = term - compensation(index)
130+ val newSum = likelihoods(index) + compensatedTerm
131+ compensation(index) = (newSum - likelihoods(index)) - compensatedTerm
132+ likelihoods(index) = newSum
133+ }
134+
123135 /**
124136 * Returns the number of reads that contributed evidence to the consensus. The value is equal
125137 * to the number of times add() was called with non-ambiguous bases.
@@ -137,19 +149,12 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
137149 case x => throw new IllegalArgumentException (" Unsupported base: " + x.toChar)
138150 }
139151
140- /** Produces the final likelihoods per base by sorting the accumulated likelihood terms from smallest
141- * (in magnitude) to largest (in magnitude) and then summing them in that order. The values can be
142- * negated as we are operating using log probabilities, which should always be either zero or negative.*/
143- private def finalLikelihoods : Array [LogProbability ] = {
144- likelihoods.map(_.sortInPlaceBy(v => - v).sum)
145- }
146-
147152 /** Call the consensus base and quality score given the current set of likelihoods. */
148153 def call () : (Base , PhredScore ) = {
149- // sum the likelihood terms in a numerically stable way
154+ // likelihoods are accumulated using Kahan summation for numerical stability
150155 // pick the base with the maximum posterior
151- val lls = finalLikelihoods
152- val likelihoodSum = LogProbability .or(lls)
156+ val lls = likelihoods
157+ val likelihoodSum = LogProbability .or(lls)
153158 val (maxLikelihood, maxLlIndex) = MathUtil .maxWithIndex(lls, requireUniqueMaximum= true )
154159
155160 maxLlIndex match {
@@ -177,7 +182,7 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
177182 * labels.
178183 */
179184 private [umi] def logLikelihoods : Array [LogProbability ] = {
180- val lls = finalLikelihoods
185+ val lls = likelihoods
181186 val likelihoodSum = LogProbability .or(lls)
182187 val posteriors = lls.map(l => LogProbability .normalizeByLogProbability(l, likelihoodSum))
183188 val errors = posteriors.map(LogProbability .not)
0 commit comments