Binary Search in Go standard library

Given a non-decreasing array and a target value, we can find the target in the array in logarithmic time using binary search. My first programming language is C++ and the C++ Standard Template Library (STL) provides two functions for this task.

  1. iterator lower_bound(first, last, value) returns the smallest index with a value larger than or equal to the target. Put another way, if you'd insert the target, the lower_bound is the smallest index to insert.
  2. iterator upper_bound(first, last, value) returns the smallest index with a value larger than the target. Put another way, if you'd insert the target, the upper_bound is the largest index to insert.

Both functions take a half-close range [first, last). I appreciate the simplicity of lower_bound and upper_bound so much that I ported them to Go. However, with newer versions of Go offering binary search in the standard library, I am deprecating my own implementation of lowerBound and upperBound. There are four binary search methods.

  1. func sort.Find(n int, cmp func(int) int) (i int, found bool). It returns the smallest index i in [0, n) at which cmp(i) <= 0. If there is no such index i, Find returns i = n. The found result is true if i < n and cmp(i) == 0.
  2. func sort.Search(n int, f func(int) bool) int. It returns the smallest index i in [0, n) at which f(i) is true, assuming that on the range [0, n), f(i) == true implies f(i+1) == true.
  3. func slices.BinarySearch(x []E, target E) (int, bool). It returns the position where target is found, or the position where target would appear in the sort order. It also returns a bool saying whether the target is really found in the slice
  4. func slices.BinarySearchFunc(x []E, target T, cmp func(E, T) int) (int, bool). It works like BinarySearch, but uses a custom comparison function.

I find it overwhelming to remember four methods for binary search. Luckily, I've found that sort.Search provides everything I need for lower_bound and upper_bound, even in cases where the order of the array is reversed.

 1// Given nums non-decreasing, returns the smallest index i such that
 2// target <= nums[i]. Return len(nums) if no such nums[i] exists.
 3func lowerBound(nums []int, target int) int {
 4  return sort.Search(len(nums), func(i int) bool {
 5    return target <= nums[i]
 6  })
 7}
 8
 9// Given nums non-decreasing, returns the smallest index i such that
10// target < nums[i]. Return len(nums) if no such nums[i] exists.
11func upperBound(nums []int, target int) int {
12  return sort.Search(len(nums), func(i int) bool {
13    return target < nums[i]
14  })
15}
16
17func nonDecreasing() {
18  //  []int{1, 3, 4, 4, 6, 7, 9, 9}
19  // index: 0, 1, 2, 3, 4, 5, 6, 7
20  nums := []int{1, 3, 4, 4, 6, 7, 9, 9}
21
22  fmt.Printf("lowerBound(%d) %d\n", 0, lowerBound(nums, 0))   // 0
23  fmt.Printf("lowerBound(%d) %d\n", 1, lowerBound(nums, 1))   // 0
24  fmt.Printf("lowerBound(%d) %d\n", 9, lowerBound(nums, 9))   // 6
25  fmt.Printf("lowerBound(%d) %d\n", 10, lowerBound(nums, 10)) // 8
26
27  fmt.Printf("upperBound(%d) %d\n", 0, upperBound(nums, 0))   // 0
28  fmt.Printf("upperBound(%d) %d\n", 1, upperBound(nums, 1))   // 1
29  fmt.Printf("upperBound(%d) %d\n", 9, upperBound(nums, 9))   // 8
30  fmt.Printf("upperBound(%d) %d\n", 10, upperBound(nums, 10)) // 8
31}

What about non-increasing array?

The sort.Search works on an array and a function where the function transforms the original array to an array of [false, false, ..., true, true, ....]. Therefore, when dealing with non-increasing array, we simply need to invert comparison function.

 1// Given nums non-increasing, returns the smallest index i such that
 2// target >= nums[i]. Return len(nums) if no such nums[i] exists.
 3func lowerBoundNonIncreasing(nums []int, target int) int {
 4  return sort.Search(len(nums), func(i int) bool {
 5    return target >= nums[i] // flip the order from lowerBound, target <= nums[i]
 6  })
 7}
 8
 9// Given nums non-increasing, returns the smallest index i such that
10// target > nums[i]. Return len(nums) if no such nums[i] exists.
11func upperBoundNonIncreasing(nums []int, target int) int {
12  return sort.Search(len(nums), func(i int) bool {
13    return target > nums[i] // flip the order from upperBound, target < nums[i]
14  })
15}
16
17func nonIncreasing() {
18  //  []int{9, 9, 7, 6, 4, 4, 3, 1}
19  // index: 0, 1, 2, 3, 4, 5, 6, 7
20  nums := []int{9, 9, 7, 6, 4, 4, 3, 1}
21
22  fmt.Printf("lowerBoundNonIncreasing(%d) %d\n", 0, lowerBoundNonIncreasing(nums, 0))   // 8
23  fmt.Printf("lowerBoundNonIncreasing(%d) %d\n", 1, lowerBoundNonIncreasing(nums, 1))   // 7
24  fmt.Printf("lowerBoundNonIncreasing(%d) %d\n", 9, lowerBoundNonIncreasing(nums, 9))   // 0
25  fmt.Printf("lowerBoundNonIncreasing(%d) %d\n", 10, lowerBoundNonIncreasing(nums, 10)) // 0
26
27  fmt.Printf("upperBoundNonIncreasing(%d) %d\n", 0, upperBoundNonIncreasing(nums, 0))   // 8
28  fmt.Printf("upperBoundNonIncreasing(%d) %d\n", 1, upperBoundNonIncreasing(nums, 1))   // 8
29  fmt.Printf("upperBoundNonIncreasing(%d) %d\n", 9, upperBoundNonIncreasing(nums, 9))   // 2
30  fmt.Printf("upperBoundNonIncreasing(%d) %d\n", 10, upperBoundNonIncreasing(nums, 10)) // 0
31}