This commit is contained in:
Nuno Cruces
2025-01-21 13:59:32 +00:00
parent d2f162972d
commit 40090d8250
6 changed files with 202 additions and 79 deletions

19
ext/stats/kahan.go Normal file
View File

@@ -0,0 +1,19 @@
package stats
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
type kahan struct{ hi, lo float64 }
func (k *kahan) add(x float64) {
y := k.lo + x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
func (k *kahan) sub(x float64) {
y := k.lo - x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}

90
ext/stats/moments.go Normal file
View File

@@ -0,0 +1,90 @@
package stats
import "math"
// Terriberry's algorithm with Kahan summation:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
type moments struct {
m1, m2, m3, m4 kahan
n int64
}
func (m moments) mean() float64 {
return m.m1.hi
}
func (m moments) var_pop() float64 {
return m.m2.hi / float64(m.n)
}
func (m moments) var_samp() float64 {
return m.m2.hi / float64(m.n-1) // Bessel's correction
}
func (m moments) stddev_pop() float64 {
return math.Sqrt(m.var_pop())
}
func (m moments) stddev_samp() float64 {
return math.Sqrt(m.var_samp())
}
func (m moments) skewness_pop() float64 {
m2 := m.m2.hi
if div := m2 * m2 * m2; div != 0 {
return m.m3.hi * math.Sqrt(float64(m.n)/div)
}
return math.NaN()
}
func (m moments) skewness_samp() float64 {
n := m.n
// https://mathworks.com/help/stats/skewness.html#f1132178
return m.skewness_pop() * math.Sqrt(float64(n*(n-1))) / float64(n-2)
}
func (m moments) kurtosis_pop() float64 {
m2 := m.m2.hi
if div := m2 * m2; div != 0 {
return m.m4.hi * float64(m.n) / div
}
return math.NaN()
}
func (m moments) kurtosis_samp() float64 {
n := m.n
// https://mathworks.com/help/stats/kurtosis.html#f4975293
k := math.FMA(m.kurtosis_pop(), float64(n+1), float64(3-3*n))
return math.FMA(k, float64(n-1)/float64((n-2)*(n-3)), 3)
}
func (m *moments) enqueue(x float64) {
n := m.n + 1
m.n = n
d1 := x - m.m1.hi - m.m1.lo
dn := d1 / float64(n)
d2 := dn * dn
t1 := d1 * dn * float64(n-1)
m.m4.add(t1*d2*float64(n*n-3*n+3) + 6*d2*m.m2.hi - 4*dn*m.m3.hi)
m.m3.add(t1*dn*float64(n-2) - 3*dn*m.m2.hi)
m.m2.add(t1)
m.m1.add(dn)
}
func (m *moments) dequeue(x float64) {
n := m.n - 1
if n <= 0 {
*m = moments{}
return
}
m.n = n
d1 := x - m.m1.hi - m.m1.lo
dn := d1 / float64(n)
d2 := dn * dn
t1 := d1 * dn * float64(n+1)
m.m4.sub(t1*d2*float64(n*n+3*n+3) - 6*d2*m.m2.hi - 4*dn*m.m3.hi)
m.m3.sub(t1*dn*float64(n+2) - 3*dn*m.m2.hi)
m.m2.sub(t1)
m.m1.sub(dn)
}

View File

@@ -1,38 +0,0 @@
package stats
import "math"
type moment struct {
m1, m2, m3, m4 kahan
n int64
}
func (w *moment) enqueue(x float64) {
n := w.n + 1
w.n = n
y := x - w.m1.hi - w.m1.lo
w.m1.add(y / float64(n))
y = math.FMA(y, x, -w.m2.hi) - w.m2.lo
w.m2.add(y / float64(n))
y = math.FMA(y, x, -w.m3.hi) - w.m3.lo
w.m3.add(y / float64(n))
y = math.FMA(y, x, -w.m4.hi) - w.m4.lo
w.m4.add(y / float64(n))
}
func (w *moment) dequeue(x float64) {
n := w.n - 1
if n <= 0 {
*w = moment{}
return
}
w.n = n
y := x - w.m1.hi + w.m1.lo
w.m1.sub(y / float64(n))
y = math.FMA(y, x, w.m2.hi) + w.m2.lo
w.m2.sub(y / float64(n))
y = math.FMA(y, x, w.m3.hi) + w.m3.lo
w.m3.sub(y / float64(n))
y = math.FMA(y, x, w.m4.hi) + w.m4.lo
w.m4.sub(y / float64(n))
}

87
ext/stats/moments_test.go Normal file
View File

@@ -0,0 +1,87 @@
package stats
import (
"math"
"testing"
)
func Test_moments(t *testing.T) {
t.Parallel()
var s1 moments
s1.enqueue(1)
s1.dequeue(1)
if !math.IsNaN(s1.skewness_pop()) {
t.Errorf("want NaN")
}
if !math.IsNaN(s1.kurtosis_pop()) {
t.Errorf("want NaN")
}
s1.enqueue(+0.5377)
s1.enqueue(+1.8339)
s1.enqueue(-2.2588)
s1.enqueue(+0.8622)
s1.enqueue(+0.3188)
s1.enqueue(-1.3077)
s1.enqueue(-0.4336)
s1.enqueue(+0.3426)
s1.enqueue(+3.5784)
s1.enqueue(+2.7694)
if got := float32(s1.skewness_pop()); got != 0.106098293 {
t.Errorf("got %v, want 0.1061", got)
}
if got := float32(s1.skewness_samp()); got != 0.1258171 {
t.Errorf("got %v, want 0.1258", got)
}
if got := float32(s1.kurtosis_pop()); got != 2.3121266 {
t.Errorf("got %v, want 2.3121", got)
}
if got := float32(s1.kurtosis_samp()); got != 2.7482237 {
t.Errorf("got %v, want 2.7483", got)
}
var s2 welford
s2.enqueue(+0.5377)
s2.enqueue(+1.8339)
s2.enqueue(-2.2588)
s2.enqueue(+0.8622)
s2.enqueue(+0.3188)
s2.enqueue(-1.3077)
s2.enqueue(-0.4336)
s2.enqueue(+0.3426)
s2.enqueue(+3.5784)
s2.enqueue(+2.7694)
if got, want := s1.mean(), s2.mean(); got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := s1.stddev_pop(), s2.stddev_pop(); got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := s1.stddev_samp(), s2.stddev_samp(); got != want {
t.Errorf("got %v, want %v", got, want)
}
s1.enqueue(math.Pi)
s1.enqueue(math.Sqrt2)
s1.enqueue(math.E)
s1.dequeue(math.Pi)
s1.dequeue(math.E)
s1.dequeue(math.Sqrt2)
if got := float32(s1.skewness_pop()); got != 0.106098293 {
t.Errorf("got %v, want 0.1061", got)
}
if got := float32(s1.skewness_samp()); got != 0.1258171 {
t.Errorf("got %v, want 0.1258", got)
}
if got := float32(s1.kurtosis_pop()); got != 2.3121266 {
t.Errorf("got %v, want 2.3121", got)
}
if got := float32(s1.kurtosis_samp()); got != 2.7482237 {
t.Errorf("got %v, want 2.7483", got)
}
}

View File

@@ -10,14 +10,13 @@ import (
// Welford's algorithm with Kahan summation:
// The effect of truncation in statistical computation [van Reeken, AJ 1970]
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
type welford struct {
m1, m2 kahan
n int64
}
func (w welford) average() float64 {
func (w welford) mean() float64 {
return w.m1.hi
}
@@ -170,19 +169,3 @@ func (w *welford2) dequeue(y, x float64) {
w.m2x.sub(d1x * d2x)
w.cov.sub(d1y * d2x)
}
type kahan struct{ hi, lo float64 }
func (k *kahan) add(x float64) {
y := k.lo + x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
func (k *kahan) sub(x float64) {
y := k.lo - x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}

View File

@@ -9,12 +9,14 @@ func Test_welford(t *testing.T) {
t.Parallel()
var s1, s2 welford
s1.enqueue(1)
s1.dequeue(1)
s1.enqueue(4)
s1.enqueue(7)
s1.enqueue(13)
s1.enqueue(16)
if got := s1.average(); got != 10 {
if got := s1.mean(); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := s1.var_samp(); got != 30 {
@@ -37,22 +39,14 @@ func Test_welford(t *testing.T) {
if s1.var_pop() != s2.var_pop() {
t.Errorf("got %v, want %v", s1, s2)
}
s1.dequeue(16)
s1.dequeue(7)
s1.dequeue(13)
s1.enqueue(16)
s1.enqueue(7)
s1.enqueue(13)
if s1.var_pop() != s2.var_pop() {
t.Errorf("got %v, want %v", s1, s2)
}
}
func Test_covar(t *testing.T) {
t.Parallel()
var c1, c2 welford2
c1.enqueue(1, 1)
c1.dequeue(1, 1)
c1.enqueue(3, 70)
c1.enqueue(5, 80)
@@ -75,18 +69,6 @@ func Test_covar(t *testing.T) {
if c1.covar_pop() != c2.covar_pop() {
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
}
c1.dequeue(2, 60)
c1.dequeue(5, 80)
c1.dequeue(4, 75)
c1.dequeue(7, 90)
c1.enqueue(2, 60)
c1.enqueue(5, 80)
c1.enqueue(4, 75)
c1.enqueue(7, 90)
if c1.covar_pop() != c2.covar_pop() {
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
}
}
func Test_correlation(t *testing.T) {