aboutsummaryrefslogtreecommitdiff
path: root/src/math
diff options
context:
space:
mode:
Diffstat (limited to 'src/math')
-rw-r--r--src/math/big/calibrate.md180
-rw-r--r--src/math/big/calibrate_graph.go321
-rw-r--r--src/math/big/calibrate_test.go332
-rw-r--r--src/math/big/nat_test.go13
-rw-r--r--src/math/big/natdiv.go2
-rw-r--r--src/math/big/natmul.go13
6 files changed, 724 insertions, 137 deletions
diff --git a/src/math/big/calibrate.md b/src/math/big/calibrate.md
new file mode 100644
index 0000000000..ff0b4ea137
--- /dev/null
+++ b/src/math/big/calibrate.md
@@ -0,0 +1,180 @@
+# Calibration of Algorithm Thresholds
+
+This document describes the approach to calibration of algorithmic thresholds in
+`math/big`, implemented in [calibrate_test.go](calibrate_test.go).
+
+Basic operations like multiplication and division have many possible implementations.
+Most algorithms that are better asymptotically have overheads that make them
+run slower for small inputs. When presented with an operation to run, `math/big`
+must decide which algorithm to use.
+
+For example, for small inputs, multiplication using the “grade school algorithm” is fastest.
+Given multi-digit x, y and a target z: clear z, and then for each digit y[i], z[i:] += x\*y[i].
+That last operation, adding a vector times a digit to another vector (including carrying up
+the vector during the multiplication and addition), can be implemented in a tight assembly loop.
+The overall speed is O(N\*\*2) where N is the number of digits in x and y (assume they match),
+but the tight inner loop performs well for small inputs.
+
+[Karatsuba's algorithm](https://en.wikipedia.org/wiki/Karatsuba_algorithm)
+multiplies two N-digit numbers by splitting them in half, computing
+three N/2-digit products, and then reconstructing the final product using a few more
+additions and subtractions. It runs in O(N\*\*log₂ 3) = O(N\*\*1.58) time.
+The grade school loop runs faster for small inputs,
+but eventually Karatsuba's smaller asymptotic run time wins.
+
+The multiplication implementation must decide which to use.
+Under the assumption that once Karatsuba is faster for some N,
+it will be larger for all larger N as well,
+the rule is to use Karatsuba's algorithm when the input length N ≥ karatsubaThreshold.
+
+Calibration is the process of determining what karatsubaThreshold should be set to.
+It doesn't sound like it should be that hard, but it is:
+- Theoretical analysis does not help: the answer depends on the actual machines
+and the actual constant factors in the two implementations.
+- We are picking a single karatsubaThreshold for all systems,
+despite them having different relative execution speeds for the operations
+in the two algorithms.
+(We could in theory pick different thresholds for different architectures,
+but there can still be significant variation within a given architecture.)
+- The assumption that there is a single N where
+an asymptotically better algorithm becomes faster and stays faster
+is not true in general.
+- Recursive algorithms like Karatsuba's may have different optimal
+thresholds for different large input sizes.
+- Thresholds can interfere. For example, changing the karatsubaThreshold makes
+multiplication faster or slower, which in turn affects the best divRecursiveThreshold
+(because divisions use multiplication).
+
+The best we can do is measure the performance of the overall multiplication
+algorithm across a variety of inputs and thresholds and look for a threshold
+that balances all these concerns reasonably well,
+setting thresholds in dependency order (for example, multiplication before division).
+
+The code in `calibrate_test.go` does this measurement of a variety of input sizes
+and threshold values and prints the timing results as a CSV file.
+The code in `calibrate_graph.go` reads the CSV and writes out an SVG file plotting the data.
+For example:
+
+ go test -run=Calibrate/KaratsubaMul -timeout=1h -calibrate >kmul.csv
+ go run calibrate_graph.go kmul.csv >kmul.svg
+
+Any particular input is sensitive to only a few transitions in threshold.
+For example, an input of size 320 recurses on inputs of size 160,
+which recurses on inputs of size 80,
+which recurses on inputs of size 40,
+and so on, until falling below the Karatsuba threshold.
+Here is what the timing looks like for an input of size 320,
+normalized so that 1.0 is the fastest timing observed:
+
+![KaratsubaThreshold on an Apple M3 Pro, N=320 only](https://swtch.com/math/big/_calibrate/KaratsubaMul/cal.mac320.svg)
+
+For this input, all thresholds from 21 to 40 perform optimally and identically: they all mean “recurse at N=40 but not at N=20”.
+From the single input of size N=320, we cannot decide which of these 20 thresholds is best.
+
+Other inputs exercise other decision points. For example, here is the timing for N=240:
+
+![KaratsubaThreshold on an Apple M3 Pro, N=240 only](https://swtch.com/math/big/_calibrate/KaratsubaMul/cal.mac240.svg)
+
+In this case, all the thresholds from 31 to 60 perform optimally and identically, recursing at N=60 but not N=30.
+
+If we combine these two into a single graph and then plot the geometric mean of the two lines in blue,
+the optimal range becomes a little clearer:
+
+![KaratsubaThreshold on an Apple M3 Pro](https://swtch.com/math/big/_calibrate/KaratsubaMul/cal.mac240+320.svg)
+
+The actual calibration runs all possible inputs from size N=200 to N=400, in increments of 8,
+plotting all 26 lines in a faded gray (note the changed y-axis scale, zooming in near 1.0).
+
+![KaratsubaThreshold on an Apple M3 Pro](https://swtch.com/math/big/_calibrate/KaratsubaMul/cal.mac.svg)
+
+Now the optimal value is clear: the best threshold on this chip, with these algorithmic implementations, is 40.
+
+Unfortunately, other chips are different. Here is an Intel Xeon server chip:
+
+![KaratsubaThreshold on an Apple M3 Pro](https://swtch.com/math/big/_calibrate/KaratsubaMul/cal.c2s16.svg)
+
+On this chip, the best threshold is closer to 60. Luckily, 40 is not a terrible choice either: it is only about 2% slower on average.
+
+The rest of this document presents the timings measured for the `math/big` thresholds on a variety of machines
+and justifies the final thresholds. The timings used these machines:
+
+- The `gotip-linux-amd64_c3h88-perf_vs_release` gomote, a Google Cloud c3-high-88 machine using an Intel Xeon Platinum 8481C CPU (Emerald Rapids).
+- The `gotip-linux-amd64_c2s16-perf_vs_release` gomote, a Google Cloud c2-standard-16 machine using an Intel Xeon Gold 6253CL CPU (Cascade Lake).
+- A home server built with an AMD Ryzen 9 7950X CPU.
+- The `gotip-linux-arm64_c4as16-perf_vs_release` gomote, a Google Cloud c4a-standard-16 machine using Google's Axiom Arm CPU.
+- An Apple MacBook Pro with an Apple M3 Pro CPU.
+
+In general, we break ties in favor of the newer c3h88 x86 perf gomote, then the c4as16 arm64 perf gomote, and then the others.
+
+## Karatsuba Multiplication
+
+Here are the full results for the Karatsuba multiplication threshold.
+
+![KaratsubaThreshold on an Intel Xeon Platium 8481C](https://swtch.com/math/big/_calibrate/KaratsubaMul/cal.c3h88.svg)
+![KaratsubaThreshold on an Intel Xeon Gold 6253CL](https://swtch.com/math/big/_calibrate/KaratsubaMul/cal.c2s16.svg)
+![KaratsubaThreshold on an AMD Ryzen 9 7950X](https://swtch.com/math/big/_calibrate/KaratsubaMul/cal.s7.svg)
+![KaratsubaThreshold on an Axiom Arm](https://swtch.com/math/big/_calibrate/KaratsubaMul/cal.c4as16.svg)
+![KaratsubaThreshold on an Apple M3 Pro](https://swtch.com/math/big/_calibrate/KaratsubaMul/cal.mac.svg)
+
+The majority of systems have optimum thresholds near 40, so we chose karatsubaThreshold = 40.
+
+## Basic Squaring
+
+For squaring a number (`z.Mul(x, x)`), math/big uses grade school multiplication
+up to basicSqrThreshold, where it switches to a customized algorithm that is
+still quadratic but avoids half the word-by-word multiplies
+since the two arguments are identical.
+That algorithm's inner loops are not as tight as the grade school multiplication,
+so it is slower for small inputs. How small?
+
+Here are the timings:
+
+![BasicSqrThreshold on an Intel Xeon Platium 8481C](https://swtch.com/math/big/_calibrate/BasicSqr/cal.c3h88.svg)
+![BasicSqrThreshold on an Intel Xeon Gold 6253CL](https://swtch.com/math/big/_calibrate/BasicSqr/cal.c2s16.svg)
+![BasicSqrThreshold on an AMD Ryzen 9 7950X](https://swtch.com/math/big/_calibrate/BasicSqr/cal.s7.svg)
+![BasicSqrThreshold on an Axiom Arm](https://swtch.com/math/big/_calibrate/BasicSqr/cal.c4as16.svg)
+![BasicSqrThreshold on an Apple M3 Pro](https://swtch.com/math/big/_calibrate/BasicSqr/cal.mac.svg)
+
+These inputs are so small that the calibration times batches of 100 instead of individual operations.
+There is no one best threshold, even on a single system, because some of the sizes seem to run
+the grade school algorithm faster than others.
+For example, on the AMD CPU,
+for N=14, basic squaring is 4% faster than basic multiplication,
+suggesting the threshold has been crossed,
+but for N=16, basic multiplication is 9% faster than basic squaring,
+probably because the tight assembly can use larger chunks.
+
+It is unclear why the Axiom Arm timings are so incredibly noisy.
+
+We chose basicSqrThreshold = 12.
+
+## Karatsuba Squaring
+
+Beyond the basic squaring threshold, at some point a customized Karatsuba can take over.
+It uses three half-sized squarings instead of three half-sized multiplies.
+Here are the timings:
+
+![KaratsubaSqrThreshold on an Intel Xeon Platium 8481C](https://swtch.com/math/big/_calibrate/KaratsubaSqr/cal.c3h88.svg)
+![KaratsubaSqrThreshold on an Intel Xeon Gold 6253CL](https://swtch.com/math/big/_calibrate/KaratsubaSqr/cal.c2s16.svg)
+![KaratsubaSqrThreshold on an AMD Ryzen 9 7950X](https://swtch.com/math/big/_calibrate/KaratsubaSqr/cal.s7.svg)
+![KaratsubaSqrThreshold on an Axiom Arm](https://swtch.com/math/big/_calibrate/KaratsubaSqr/cal.c4as16.svg)
+![KaratsubaSqrThreshold on an Apple M3 Pro](https://swtch.com/math/big/_calibrate/KaratsubaSqr/cal.mac.svg)
+
+The majority of chips preferred a lower threshold, around 60-70,
+but the older Intel Xeon and the AMD prefer a threshold around 100-120.
+
+We chose karatsubaSqrThreshold = 80, which is within 2% of optimal on all the chips.
+
+## Recursive Division
+
+Division uses a recursive divide-and-conquer algorithm for large inputs,
+eventually falling back to a more traditional grade-school whole-input trial-and-error division.
+Here are the timings for the threshold between the two:
+
+![DivRecursiveThreshold on an Intel Xeon Platium 8481C](https://swtch.com/math/big/_calibrate/DivRecursive/cal.c3h88.svg)
+![DivRecursiveThreshold on an Intel Xeon Gold 6253CL](https://swtch.com/math/big/_calibrate/DivRecursive/cal.c2s16.svg)
+![DivRecursiveThreshold on an AMD Ryzen 9 7950X](https://swtch.com/math/big/_calibrate/DivRecursive/cal.s7.svg)
+![DivRecursiveThreshold on an Axiom Arm](https://swtch.com/math/big/_calibrate/DivRecursive/cal.c4as16.svg)
+![DivRecursiveThreshold on an Apple M3 Pro](https://swtch.com/math/big/_calibrate/DivRecursive/cal.mac.svg)
+
+We chose divRecursiveThreshold = 40.
diff --git a/src/math/big/calibrate_graph.go b/src/math/big/calibrate_graph.go
new file mode 100644
index 0000000000..37596195a1
--- /dev/null
+++ b/src/math/big/calibrate_graph.go
@@ -0,0 +1,321 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build ignore
+
+// This program converts CSV calibration data printed by
+//
+// go test -run=Calibrate/Name -calibrate >file.csv
+//
+// into an SVG file. Invoke as:
+//
+// go run calibrate_graph.go file.csv >file.svg
+//
+// See calibrate.md for more details.
+
+package main
+
+import (
+ "bytes"
+ "encoding/csv"
+ "flag"
+ "fmt"
+ "log"
+ "math"
+ "os"
+ "strconv"
+)
+
+func usage() {
+ fmt.Fprintf(os.Stderr, "usage: go run calibrate_graph.go file.csv >file.svg\n")
+ os.Exit(2)
+}
+
+// A Point is an X, Y coordinate in the data being plotted.
+type Point struct {
+ X, Y float64
+}
+
+// A Graph is a graph to draw as SVG.
+type Graph struct {
+ Title string // title above graph
+ Geomean []Point // geomean line
+ Lines [][]Point // normalized data lines
+ XAxis string // x-axis label
+ YAxis string // y-axis label
+ Min Point // min point of data display
+ Max Point // max point of data display
+}
+
+var yMax = flag.Float64("ymax", 1.2, "maximum y axis value")
+var alphaNorm = flag.Float64("alphanorm", 0.1, "alpha for a single norm line")
+
+func main() {
+ flag.Usage = usage
+ flag.Parse()
+ if flag.NArg() != 1 {
+ usage()
+ }
+
+ // Read CSV. It may be enclosed in
+ // -- name.csv --
+ // ...
+ // -- eof --
+ // framing, in which case remove the framing.
+ fdata, err := os.ReadFile(flag.Arg(0))
+ if err != nil {
+ log.Fatal(err)
+ }
+ if _, after, ok := bytes.Cut(fdata, []byte(".csv --\n")); ok {
+ fdata = after
+ }
+ if before, _, ok := bytes.Cut(fdata, []byte("-- eof --\n")); ok {
+ fdata = before
+ }
+ rd := csv.NewReader(bytes.NewReader(fdata))
+ rd.FieldsPerRecord = -1
+ records, err := rd.ReadAll()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Construct graph from loaded CSV.
+ // CSV starts with metadata lines like
+ // goos,darwin
+ // and then has two tables of timings.
+ // Each table looks like
+ // size \ threshold,10,20,30,40
+ // 100,1,2,3,4
+ // 200,2,3,4,5
+ // 300,3,4,5,6
+ // 400,4,5,6,7
+ // 500,5,6,7,8
+ // The header line gives the threshold values and then each row
+ // gives an input size and the timings for each threshold.
+ // Omitted timings are empty strings and turn into infinities when parsing.
+ // The first table gives raw nanosecond timings.
+ // The second table gives timings normalized relative to the fastest
+ // possible threshold for a given input size.
+ // We only want the second table.
+ // The tables are followed by a list of geomeans of all the normalized
+ // timings for each threshold:
+ // geomean,1.2,1.1,1.0,1.4
+ // We turn each normalized timing row into a line in the graph,
+ // and we turn the geomean into an overlaid thick line.
+ // The metadata is used for preparing the titles.
+ g := &Graph{
+ YAxis: "Relative Slowdown",
+ Min: Point{0, 1},
+ Max: Point{1, 1.2},
+ }
+ meta := make(map[string]string)
+ table := 0 // number of table headers seen
+ var thresholds []float64
+ maxNorm := 0.0
+ for _, rec := range records {
+ if len(rec) == 0 {
+ continue
+ }
+ if len(rec) == 2 {
+ meta[rec[0]] = rec[1]
+ continue
+ }
+ if rec[0] == `size \ threshold` {
+ table++
+ if table == 2 {
+ thresholds = parseFloats(rec)
+ g.Min.X = thresholds[0]
+ g.Max.X = thresholds[len(thresholds)-1]
+ }
+ continue
+ }
+ if rec[0] == "geomean" {
+ table = 3 // end of norms table
+ geomeans := parseFloats(rec)
+ g.Geomean = floatsToLine(thresholds, geomeans)
+ continue
+ }
+ if table == 2 {
+ if _, err := strconv.Atoi(rec[0]); err != nil { // size
+ log.Fatalf("invalid table line: %q", rec)
+ }
+ norms := parseFloats(rec)
+ if len(norms) > len(thresholds) {
+ log.Fatalf("too many timings (%d > %d): %q", len(norms), len(thresholds), rec)
+ }
+ g.Lines = append(g.Lines, floatsToLine(thresholds, norms))
+ for _, y := range norms {
+ maxNorm = max(maxNorm, y)
+ }
+ continue
+ }
+ }
+
+ g.Max.Y = min(*yMax, math.Ceil(maxNorm*100)/100)
+ g.XAxis = meta["calibrate"] + "Threshold"
+ g.Title = meta["goos"] + "/" + meta["goarch"] + " " + meta["cpu"]
+
+ os.Stdout.Write(g.SVG())
+}
+
+// parseFloats parses rec[1:] as floating point values.
+// If a field is the empty string, it is represented as +Inf.
+func parseFloats(rec []string) []float64 {
+ floats := make([]float64, 0, len(rec)-1)
+ for _, v := range rec[1:] {
+ if v == "" {
+ floats = append(floats, math.Inf(+1))
+ continue
+ }
+ f, err := strconv.ParseFloat(v, 64)
+ if err != nil {
+ log.Fatalf("invalid record: %q (%v)", rec, err)
+ }
+ floats = append(floats, f)
+ }
+ return floats
+}
+
+// floatsToLine converts a sequence of floats into a line, ignoring missing (infinite) values.
+func floatsToLine(x, y []float64) []Point {
+ var line []Point
+ for i, yi := range y {
+ if !math.IsInf(yi, 0) {
+ line = append(line, Point{x[i], yi})
+ }
+ }
+ return line
+}
+
+const svgHeader = `<svg width="%d" height="%d" version="1.1" xmlns="http://www.w3.org/2000/svg">
+ <defs>
+ <style type="text/css"><![CDATA[
+ text { stroke-width: 0; white-space: pre; }
+ text.hjc { text-anchor: middle; }
+ text.hjl { text-anchor: start; }
+ text.hjr { text-anchor: end; }
+ .def { stroke-linecap: round; stroke-linejoin: round; fill: none; stroke: #000000; stroke-width: 1px; }
+ .tick { stroke: #000000; fill: #000000; font: %dpx Times; }
+ .title { stroke: #000000; fill: #000000; font: %dpx Times; font-weight: bold; }
+ .axis { stroke-width: 2px; }
+ .norm { stroke: rgba(0,0,0,%f); }
+ .geomean { stroke: #6666ff; stroke-width: 2px; }
+ ]]></style>
+ </defs>
+ <g class="def">
+`
+
+// Layout constants for drawing graph
+const (
+ DX = 600 // width of graphed data
+ DY = 150 // height of graphed data
+ ML = 80 // margin left
+ MT = 30 // margin top
+ MR = 10 // margin right
+ MB = 50 // margin bottom
+ PS = 14 // point size of text
+ W = ML + DX + MR // width of overall graph
+ H = MT + DY + MB // height of overall graph
+ Tick = 5 // axis tick length
+)
+
+// An SVGPoint is a point in the SVG image, in pixel units,
+// with Y increasing down the page.
+type SVGPoint struct {
+ X, Y int
+}
+
+func (p SVGPoint) String() string {
+ return fmt.Sprintf("%d,%d", p.X, p.Y)
+}
+
+// pt converts an x, y data value (such as from a Point) to an SVGPoint.
+func (g *Graph) pt(x, y float64) SVGPoint {
+ return SVGPoint{
+ X: ML + int((x-g.Min.X)/(g.Max.X-g.Min.X)*DX),
+ Y: H - MB - int((y-g.Min.Y)/(g.Max.Y-g.Min.Y)*DY),
+ }
+}
+
+// SVG returns the SVG text for the graph.
+func (g *Graph) SVG() []byte {
+
+ var svg bytes.Buffer
+ fmt.Fprintf(&svg, svgHeader, W, H, PS, PS, *alphaNorm)
+
+ // Draw data, clipped.
+ fmt.Fprintf(&svg, "<clipPath id=\"cp\"><path d=\"M %v L %v L %v L %v Z\" /></clipPath>\n",
+ g.pt(g.Min.X, g.Min.Y), g.pt(g.Max.X, g.Min.Y), g.pt(g.Max.X, g.Max.Y), g.pt(g.Min.X, g.Max.Y))
+ fmt.Fprintf(&svg, "<g clip-path=\"url(#cp)\">\n")
+ for _, line := range g.Lines {
+ if len(line) == 0 {
+ continue
+ }
+ fmt.Fprintf(&svg, "<path class=\"norm\" d=\"M %v", g.pt(line[0].X, line[0].Y))
+ for _, v := range line[1:] {
+ fmt.Fprintf(&svg, " L %v", g.pt(v.X, v.Y))
+ }
+ fmt.Fprintf(&svg, "\"/>\n")
+ }
+ // Draw geomean.
+ if len(g.Geomean) > 0 {
+ line := g.Geomean
+ fmt.Fprintf(&svg, "<path class=\"geomean\" d=\"M %v", g.pt(line[0].X, line[0].Y))
+ for _, v := range line[1:] {
+ fmt.Fprintf(&svg, " L %v", g.pt(v.X, v.Y))
+ }
+ fmt.Fprintf(&svg, "\"/>\n")
+ }
+ fmt.Fprintf(&svg, "</g>\n")
+
+ // Draw axes and major and minor tick marks.
+ fmt.Fprintf(&svg, "<path class=\"axis\" d=\"")
+ fmt.Fprintf(&svg, " M %v L %v", g.pt(g.Min.X, g.Min.Y), g.pt(g.Max.X, g.Min.Y)) // x axis
+ fmt.Fprintf(&svg, " M %v L %v", g.pt(g.Min.X, g.Min.Y), g.pt(g.Min.X, g.Max.Y)) // y axis
+ xscale := 10.0
+ if g.Max.X-g.Min.X < 100 {
+ xscale = 1.0
+ }
+ for x := int(math.Ceil(g.Min.X / xscale)); float64(x)*xscale <= g.Max.X; x++ {
+ if x%5 != 0 {
+ fmt.Fprintf(&svg, " M %v l 0,%d", g.pt(float64(x)*xscale, g.Min.Y), Tick)
+ } else {
+ fmt.Fprintf(&svg, " M %v l 0,%d", g.pt(float64(x)*xscale, g.Min.Y), 2*Tick)
+ }
+ }
+ yscale := 100.0
+ if g.Max.Y-g.Min.Y > 0.5 {
+ yscale = 10
+ }
+ for y := int(math.Ceil(g.Min.Y * yscale)); float64(y) <= g.Max.Y*yscale; y++ {
+ if y%5 != 0 {
+ fmt.Fprintf(&svg, " M %v l -%d,0", g.pt(g.Min.X, float64(y)/yscale), Tick)
+ } else {
+ fmt.Fprintf(&svg, " M %v l -%d,0", g.pt(g.Min.X, float64(y)/yscale), 2*Tick)
+ }
+ }
+ fmt.Fprintf(&svg, "\"/>\n")
+
+ // Draw tick labels on major marks.
+ for x := int(math.Ceil(g.Min.X / xscale)); float64(x)*xscale <= g.Max.X; x++ {
+ if x%5 == 0 {
+ p := g.pt(float64(x)*xscale, g.Min.Y)
+ fmt.Fprintf(&svg, "<text x=\"%d\" y=\"%d\" class=\"tick hjc\">%d</text>\n", p.X, p.Y+2*Tick+PS, x*int(xscale))
+ }
+ }
+ for y := int(math.Ceil(g.Min.Y * yscale)); float64(y) <= g.Max.Y*yscale; y++ {
+ if y%5 == 0 {
+ p := g.pt(g.Min.X, float64(y)/yscale)
+ fmt.Fprintf(&svg, "<text x=\"%d\" y=\"%d\" class=\"tick hjr\">%.2f</text>\n", p.X-2*Tick-Tick, p.Y+PS/3, float64(y)/yscale)
+ }
+ }
+
+ // Draw graph title and axis titles.
+ fmt.Fprintf(&svg, "<text x=\"%d\" y=\"%d\" class=\"title hjc\">%s</text>\n", ML+DX/2, MT-PS/3, g.Title)
+ fmt.Fprintf(&svg, "<text x=\"%d\" y=\"%d\" class=\"title hjc\">%s</text>\n", ML+DX/2, MT+DY+2*Tick+2*PS+PS/2, g.XAxis)
+ fmt.Fprintf(&svg, "<g transform=\"translate(%d,%d) rotate(-90)\"><text x=\"0\" y=\"0\" class=\"title hjc\">%s</text></g>\n", ML-Tick-Tick-3*PS, MT+DY/2, g.YAxis)
+
+ fmt.Fprintf(&svg, "</g></svg>\n")
+ return svg.Bytes()
+}
diff --git a/src/math/big/calibrate_test.go b/src/math/big/calibrate_test.go
index d85833aede..7d44c2ed0f 100644
--- a/src/math/big/calibrate_test.go
+++ b/src/math/big/calibrate_test.go
@@ -2,172 +2,266 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Calibration used to determine thresholds for using
-// different algorithms. Ideally, this would be converted
-// to go generate to create thresholds.go
-
-// This file prints execution times for the Mul benchmark
-// given different Karatsuba thresholds. The result may be
-// used to manually fine-tune the threshold constant. The
-// results are somewhat fragile; use repeated runs to get
-// a clear picture.
-
-// Calculates lower and upper thresholds for when basicSqr
-// is faster than standard multiplication.
-
-// Usage: go test -run='^TestCalibrate$' -v -calibrate
+// TestCalibrate determines appropriate thresholds for when to use
+// different calculation algorithms. To run it, use:
+//
+// go test -run=Calibrate -calibrate >cal.log
+//
+// Calibration data is printed in CSV format, along with the normal test output.
+// See calibrate.md for more details about using the output.
package big
import (
"flag"
"fmt"
+ "internal/sysinfo"
+ "math"
+ "runtime"
+ "slices"
+ "strings"
+ "sync"
"testing"
"time"
)
var calibrate = flag.Bool("calibrate", false, "run calibration test")
-
-const (
- sqrModeMul = "mul(x, x)"
- sqrModeBasic = "basicSqr(x)"
- sqrModeKaratsuba = "karatsubaSqr(x)"
-)
+var calibrateOnce sync.Once
func TestCalibrate(t *testing.T) {
if !*calibrate {
return
}
- computeKaratsubaThresholds()
+ t.Run("KaratsubaMul", computeKaratsubaThreshold)
+ t.Run("BasicSqr", computeBasicSqrThreshold)
+ t.Run("KaratsubaSqr", computeKaratsubaSqrThreshold)
+ t.Run("DivRecursive", computeDivRecursiveThreshold)
+}
+
+func computeKaratsubaThreshold(t *testing.T) {
+ set := func(n int) { karatsubaThreshold = n }
+ computeThreshold(t, "karatsuba", set, 0, 4, 200, benchMul, 200, 8, 400)
+}
- // compute basicSqrThreshold where overhead becomes negligible
- minSqr := computeSqrThreshold(10, 30, 1, 3, sqrModeMul, sqrModeBasic)
- // compute karatsubaSqrThreshold where karatsuba is faster
- maxSqr := computeSqrThreshold(200, 500, 10, 3, sqrModeBasic, sqrModeKaratsuba)
- if minSqr != 0 {
- fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
- } else {
- fmt.Println("no basicSqrThreshold found")
+func benchMul(size int) func() {
+ x := rndNat(size)
+ y := rndNat(size)
+ var z nat
+ return func() {
+ z.mul(nil, x, y)
}
- if maxSqr != 0 {
- fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
- } else {
- fmt.Println("no karatsubaSqrThreshold found")
+}
+
+func computeBasicSqrThreshold(t *testing.T) {
+ setDuringTest(t, &karatsubaSqrThreshold, 1e9)
+ set := func(n int) { basicSqrThreshold = n }
+ computeThreshold(t, "basicSqr", set, 2, 1, 40, benchBasicSqr, 1, 1, 40)
+}
+
+func benchBasicSqr(size int) func() {
+ x := rndNat(size)
+ var z nat
+ return func() {
+ // Run 100 squarings because 1 is too fast at the small sizes we consider.
+ // Some systems don't even have precise enough clocks to measure it accurately.
+ for range 100 {
+ z.sqr(nil, x)
+ }
}
}
-func karatsubaLoad(b *testing.B) {
- BenchmarkMul(b)
+func computeKaratsubaSqrThreshold(t *testing.T) {
+ set := func(n int) { karatsubaSqrThreshold = n }
+ computeThreshold(t, "karatsubaSqr", set, 0, 4, 200, benchSqr, 200, 8, 400)
}
-// measureKaratsuba returns the time to run a Karatsuba-relevant benchmark
-// given Karatsuba threshold th.
-func measureKaratsuba(th int) time.Duration {
- th, karatsubaThreshold = karatsubaThreshold, th
- res := testing.Benchmark(karatsubaLoad)
- karatsubaThreshold = th
- return time.Duration(res.NsPerOp())
+func benchSqr(size int) func() {
+ x := rndNat(size)
+ var z nat
+ return func() {
+ z.sqr(nil, x)
+ }
}
-func computeKaratsubaThresholds() {
- fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
- fmt.Printf("(run repeatedly for good results)\n")
+func computeDivRecursiveThreshold(t *testing.T) {
+ set := func(n int) { divRecursiveThreshold = n }
+ computeThreshold(t, "divRecursive", set, 4, 4, 200, benchDiv, 200, 8, 400)
+}
- // determine Tk, the work load execution time using basic multiplication
- Tb := measureKaratsuba(1e9) // th == 1e9 => Karatsuba multiplication disabled
- fmt.Printf("Tb = %10s\n", Tb)
+func benchDiv(size int) func() {
+ divx := rndNat(2 * size)
+ divy := rndNat(size)
+ var z, r nat
+ return func() {
+ z.div(nil, r, divx, divy)
+ }
+}
- // thresholds
- th := 4
- th1 := -1
- th2 := -1
+func computeThreshold(t *testing.T, name string, set func(int), thresholdLo, thresholdStep, thresholdHi int, bench func(int) func(), sizeLo, sizeStep, sizeHi int) {
+ // Start CSV output; wrapped in txtar framing to separate CSV from other test ouptut.
+ fmt.Printf("-- calibrate-%s.csv --\n", name)
+ defer fmt.Printf("-- eof --\n")
- var deltaOld time.Duration
- for count := -1; count != 0 && th < 128; count-- {
- // determine Tk, the work load execution time using Karatsuba multiplication
- Tk := measureKaratsuba(th)
+ fmt.Printf("goos,%s\n", runtime.GOOS)
+ fmt.Printf("goarch,%s\n", runtime.GOARCH)
+ fmt.Printf("cpu,%s\n", sysinfo.CPUName())
+ fmt.Printf("calibrate,%s\n", name)
- // improvement over Tb
- delta := (Tb - Tk) * 100 / Tb
+ // Expand lists of sizes and thresholds we will test.
+ var sizes, thresholds []int
+ for size := sizeLo; size <= sizeHi; size += sizeStep {
+ sizes = append(sizes, size)
+ }
+ for thresh := thresholdLo; thresh <= thresholdHi; thresh += thresholdStep {
+ thresholds = append(thresholds, thresh)
+ }
- fmt.Printf("th = %3d Tk = %10s %4d%%", th, Tk, delta)
+ fmt.Printf("%s\n", csv("size \\ threshold", thresholds))
- // determine break-even point
- if Tk < Tb && th1 < 0 {
- th1 = th
- fmt.Print(" break-even point")
+ // Track minimum time observed for each size, threshold pair.
+ times := make([][]float64, len(sizes))
+ for i := range sizes {
+ times[i] = make([]float64, len(thresholds))
+ for j := range thresholds {
+ times[i][j] = math.Inf(+1)
}
+ }
- // determine diminishing return
- if 0 < delta && delta < deltaOld && th2 < 0 {
- th2 = th
- fmt.Print(" diminishing return")
- }
- deltaOld = delta
+ // For each size, run at most MaxRounds of considering every threshold.
+ // If we run a threshold Stable times in a row without seeing more
+ // than a 1% improvement in the observed minimum, move on to the next one.
+ // After we run Converged rounds (not necessarily in a row)
+ // without seeing any threshold improve by more than 1%, stop.
+ const (
+ MaxRounds = 1600
+ Stable = 20
+ Converged = 200
+ )
- fmt.Println()
+ for i, size := range sizes {
+ b := bench(size)
+ same := 0
+ for range MaxRounds {
+ better := false
+ for j, threshold := range thresholds {
+ // No point if threshold is far beyond size
+ if false && threshold > size+2*sizeStep {
+ continue
+ }
- // trigger counter
- if th1 >= 0 && th2 >= 0 && count < 0 {
- count = 10 // this many extra measurements after we got both thresholds
+ // BasicSqr is different from the recursive thresholds: it either applies or not,
+ // without any question of recursive subproblems. Only try the thresholds
+ // size-1, size, size+1, size+2
+ // to get two data points using basic multiplication and two using basic squaring.
+ // This avoids gathering many redundant data points.
+ // (The others have redundant data points as well, but for them the math is less trivial
+ // and best not duplicated in the calibration code.)
+ if false && name == "basicSqr" && (threshold < size-1 || threshold > size+3) {
+ continue
+ }
+
+ set(threshold)
+ b() // warm up
+ b()
+ tmin := times[i][j]
+ for k := 0; k < Stable; k++ {
+ start := time.Now()
+ b()
+ t := float64(time.Since(start))
+ if t < tmin {
+ if t < tmin*99/100 {
+ better = true
+ k = 0
+ }
+ tmin = t
+ }
+ }
+ times[i][j] = tmin
+ }
+ if !better {
+ if same++; same >= Converged {
+ break
+ }
+ }
}
- th++
+ fmt.Printf("%s\n", csv(fmt.Sprint(size), times[i]))
}
-}
-func measureSqr(words, nruns int, mode string) time.Duration {
- // more runs for better statistics
- initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
-
- switch mode {
- case sqrModeMul:
- basicSqrThreshold = words + 1
- case sqrModeBasic:
- basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
- case sqrModeKaratsuba:
- karatsubaSqrThreshold = words - 1
+ // For each size, normalize timings by the minimum achieved for that size.
+ fmt.Printf("%s\n", csv("size \\ threshold", thresholds))
+ norms := make([][]float64, len(sizes))
+ for i, times := range times {
+ m := min(1e100, slices.Min(times)) // make finite so divide preserves inf values
+ norms[i] = make([]float64, len(times))
+ for j, d := range times {
+ norms[i][j] = d / m
+ }
+ fmt.Printf("%s\n", csv(fmt.Sprint(sizes[i]), norms[i]))
}
- var testval int64
- for i := 0; i < nruns; i++ {
- res := testing.Benchmark(func(b *testing.B) { benchmarkNatSqr(b, words) })
- testval += res.NsPerOp()
+ // For each threshold, compute geomean of normalized timings across all sizes.
+ geomeans := make([]float64, len(thresholds))
+ for j := range thresholds {
+ p := 1.0
+ n := 0
+ for i := range sizes {
+ if v := norms[i][j]; !math.IsInf(v, +1) {
+ p *= v
+ n++
+ }
+ }
+ if n == 0 {
+ geomeans[j] = math.Inf(+1)
+ } else {
+ geomeans[j] = math.Pow(p, 1/float64(n))
+ }
}
- testval /= int64(nruns)
+ fmt.Printf("%s\n", csv("geomean", geomeans))
- basicSqrThreshold, karatsubaSqrThreshold = initBasicSqr, initKaratsubaSqr
+ // Add best threshold and smallest, largest within 10% and 5% of best.
+ var lo10, lo5, best, hi5, hi10 int
+ for i, g := range geomeans {
+ if g < geomeans[best] {
+ best = i
+ }
+ }
+ lo5 = best
+ for lo5 > 0 && geomeans[lo5-1] <= 1.05 {
+ lo5--
+ }
+ lo10 = lo5
+ for lo10 > 0 && geomeans[lo10-1] <= 1.10 {
+ lo10--
+ }
+ hi5 = best
+ for hi5+1 < len(geomeans) && geomeans[hi5+1] <= 1.05 {
+ hi5++
+ }
+ hi10 = hi5
+ for hi10+1 < len(geomeans) && geomeans[hi10+1] <= 1.10 {
+ hi10++
+ }
+ fmt.Printf("lo10%%,%d\n", thresholds[lo10])
+ fmt.Printf("lo5%%,%d\n", thresholds[lo5])
+ fmt.Printf("min,%d\n", thresholds[best])
+ fmt.Printf("hi5%%,%d\n", thresholds[hi5])
+ fmt.Printf("hi10%%,%d\n", thresholds[hi10])
- return time.Duration(testval)
+ set(thresholds[best])
}
-func computeSqrThreshold(from, to, step, nruns int, lower, upper string) int {
- fmt.Printf("Calibrating threshold between %s and %s\n", lower, upper)
- fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
- var initPos bool
- var threshold int
- for i := from; i <= to; i += step {
- baseline := measureSqr(i, nruns, lower)
- testval := measureSqr(i, nruns, upper)
- pos := baseline > testval
- delta := baseline - testval
- percent := delta * 100 / baseline
- fmt.Printf("words = %3d deltaT = %10s (%4d%%) is %s better: %v", i, delta, percent, upper, pos)
- if i == from {
- initPos = pos
- }
- if threshold == 0 && pos != initPos {
- threshold = i
- fmt.Printf(" threshold found")
+// csv returns a single csv line starting with name and followed by the values.
+// Values that are float64 +infinity, denoting missing data, are replaced by an empty string.
+func csv[T int | float64](name string, values []T) string {
+ line := []string{name}
+ for _, v := range values {
+ if math.IsInf(float64(v), +1) {
+ line = append(line, "")
+ } else {
+ line = append(line, fmt.Sprint(v))
}
- fmt.Println()
-
- }
- if threshold != 0 {
- fmt.Printf("Found threshold = %d between %d - %d\n", threshold, from, to)
- } else {
- fmt.Printf("Found NO threshold between %d - %d\n", from, to)
}
- return threshold
+ return strings.Join(line, ",")
}
diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go
index 251877b506..f99fd19293 100644
--- a/src/math/big/nat_test.go
+++ b/src/math/big/nat_test.go
@@ -378,19 +378,6 @@ func rndNat1(n int) nat {
return x
}
-func BenchmarkMul(b *testing.B) {
- stk := getStack()
- defer stk.free()
-
- mulx := rndNat(1e4)
- muly := rndNat(1e4)
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- var z nat
- z.mul(stk, mulx, muly)
- }
-}
-
func benchmarkNatMul(b *testing.B, nwords int) {
x := rndNat(nwords)
y := rndNat(nwords)
diff --git a/src/math/big/natdiv.go b/src/math/big/natdiv.go
index b67d6afeda..1244fb61c5 100644
--- a/src/math/big/natdiv.go
+++ b/src/math/big/natdiv.go
@@ -722,7 +722,7 @@ func greaterThan(x1, x2, y1, y2 Word) bool {
// divRecursiveThreshold is the number of divisor digits
// at which point divRecursive is faster than divBasic.
-const divRecursiveThreshold = 100
+var divRecursiveThreshold = 40 // see calibrate_test.go
// divRecursive implements recursive division as described above.
// It overwrites z with ⌊u/v⌋ and overwrites u with the remainder r.
diff --git a/src/math/big/natmul.go b/src/math/big/natmul.go
index 8ab4d13cba..bd6ab3851c 100644
--- a/src/math/big/natmul.go
+++ b/src/math/big/natmul.go
@@ -9,7 +9,7 @@ package big
// Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used.
-var karatsubaThreshold = 40 // computed by calibrate_test.go
+var karatsubaThreshold = 40 // see calibrate_test.go
// mul sets z = x*y, using stk for temporary storage.
// The caller may pass stk == nil to request that mul obtain and release one itself.
@@ -65,8 +65,8 @@ func (z nat) mul(stk *stack, x, y nat) nat {
// Operands that are shorter than basicSqrThreshold are squared using
// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
// we use the Karatsuba algorithm optimized for x == y.
-var basicSqrThreshold = 20 // computed by calibrate_test.go
-var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
+var basicSqrThreshold = 12 // see calibrate_test.go
+var karatsubaSqrThreshold = 80 // see calibrate_test.go
// sqr sets z = x*x, using stk for temporary storage.
// The caller may pass stk == nil to request that sqr obtain and release one itself.
@@ -87,7 +87,7 @@ func (z nat) sqr(stk *stack, x nat) nat {
}
z = z.make(2 * n)
- if n < basicSqrThreshold {
+ if n < basicSqrThreshold && n < karatsubaSqrThreshold {
basicMul(z, x, x)
return z.norm()
}
@@ -112,6 +112,11 @@ func (z nat) sqr(stk *stack, x nat) nat {
// The (non-normalized) result is placed in z.
func basicSqr(stk *stack, z, x nat) {
n := len(x)
+ if n < basicSqrThreshold {
+ basicMul(z, x, x)
+ return
+ }
+
defer stk.restore(stk.save())
t := stk.nat(2 * n)
clear(t)