Using WaitGroup to Track Work Items, Not Workers: A Multi-threaded BFS Example
WaitGroup and channels are two powerful primitives in Go for synchronizing goroutines. A common pattern uses a WaitGroup to wait for goroutines completion:
1wg.Add(1)
2go func() {
3 defer wg.Done()
4 for {
5 select {
6 case <- done:
7 return
8 case task <- tasks:
9 handle(task)
10 }
11 }
12}()
13wg.Wait()
In this post, we'll explore an interesting use of WaitGroup and channels where the WaitGroup counts work items each goroutine handles rather than goroutines. We'll implement a multi-threaded Breadth-First Search (BFS) on a graph.
Each node in our graph has:
- A score (1-100) that we want to accumulate
- A latency (1-10 seconds) that simulates processing time
- Neighbors to visit
The goal is to traverse the entire graph starting from node 0, accumulate all scores, and minimize total execution time.
Version 1: Single-threaded BFS
Let's start with a baseline single-threaded implementation. Full code: bfs-single-thread/main.go.
1func main() {
2 graph := readGraph("graph.txt")
3 startTime := time.Now()
4
5 visited := make(map[int]bool)
6 queue := []int{0}
7 visited[0] = true
8 totalScore := 0
9
10 for len(queue) > 0 {
11 node := queue[0]
12 queue = queue[1:]
13
14 // Process node
15 time.Sleep(time.Duration(graph[node].latency) * time.Second)
16 totalScore += graph[node].score
17
18 for _, neighbor := range graph[node].neighbors {
19 if !visited[neighbor] {
20 visited[neighbor] = true
21 queue = append(queue, neighbor)
22 }
23 }
24 }
25
26 elapsed := time.Since(startTime)
27 fmt.Printf("\nTotal time elapsed: %v\n", elapsed)
28 fmt.Printf("Total score: %d\n", totalScore)
29}
Results:
1Total time elapsed: 4m29.065671s
2
3=== BFS Complete ===
4Total score: 5596
It's quite slow. Let's add concurrency!
Version 2: Naive Multi-threaded with Shared Memory (Doesn't Work)
A naive approach uses multiple worker goroutines accessing a shared queue and visited map. However, this doesn't work because detecting BFS completion is difficult. For example, if workers exit when the queue is empty, they might all exit before the starting node finishes processing and adds its neighbors to the queue.
Version 3: Multi-threaded with Channels (Centralized Coordinator)
We can use a coordinator goroutine to manage enqueue/dequeue operations while worker goroutines process nodes. Full code: bfs-multithread-central-coordinator/main.go.
1func bfsWorker(workerID int, graph *Graph, toVisitChan <-chan int, visitedChan chan<- int, wg *sync.WaitGroup) {
2 defer wg.Done()
3 for nodeID := range toVisitChan {
4 node := graph.Nodes[nodeID]
5 time.Sleep(time.Duration(node.Latency) * time.Second)
6 visitedChan <- nodeID
7 }
8}
9
10func bfsConcurrent(graph *Graph, startNode int, numWorkers int) int {
11 visited := make(map[int]bool)
12 visited[startNode] = true
13
14 toVisitChan := make(chan int, 100)
15 visitedChan := make(chan int, 100)
16
17 var wg sync.WaitGroup
18 var totalScore int64
19
20 startTime := time.Now()
21
22 for i := 0; i < numWorkers; i++ {
23 wg.Add(1)
24 go bfsWorker(i, graph, toVisitChan, visitedChan, &wg)
25 }
26
27 // Close visitedChan when all workers are done
28 go func() {
29 wg.Wait()
30 close(visitedChan)
31 }()
32
33 // Send the start node
34 toVisitChan <- startNode
35 nodesInFlight := 1
36
37 // Coordinator: process results and send new nodes
38 for nodeID := range visitedChan {
39 nodesInFlight--
40
41 node := graph.Nodes[nodeID]
42 atomic.AddInt64(&totalScore, int64(node.Score))
43
44 // Add unvisited neighbors to the channel
45 for _, neighbor := range node.Neighbors {
46 if !visited[neighbor] {
47 visited[neighbor] = true
48 toVisitChan <- neighbor
49 nodesInFlight++
50 }
51 }
52
53 // If no more nodes in flight, we're done
54 if nodesInFlight == 0 {
55 close(toVisitChan)
56 break
57 }
58 }
59
60 elapsed := time.Since(startTime)
61 return int(atomic.LoadInt64(&totalScore))
62}
Running with 10 workers is much faster:
1Total time elapsed: 32.009325958s
2
3=== BFS Complete ===
4Total score: 5596
However, there are two issues:
- The coordinator is a bottleneck
- We need two channels (toVisitChan and visitedChan) plus a nodesInFlight counter - unnecessarily complex
Version 4: Decentralized with Channels and WaitGroup
We can eliminate the coordinator bottleneck by letting workers enqueue neighbors directly. This also reduces the number of channels from two to one. Full code: bfs-multithread-decentralized/main.go.
1func bfsWorker(workerID int, graph *Graph, nodeChan chan int, visited *sync.Map, totalScore *int64, wg *sync.WaitGroup) {
2 for nodeID := range nodeChan {
3 node := graph.Nodes[nodeID]
4 time.Sleep(time.Duration(node.Latency) * time.Second)
5 atomic.AddInt64(totalScore, int64(node.Score))
6
7 // Each worker independently enqueues unvisited neighbors
8 for _, neighbor := range node.Neighbors {
9 _, loaded := visited.LoadOrStore(neighbor, true)
10 if !loaded {
11 wg.Add(1)
12 // This is super important to avoid deadlock: all worker goroutines block on writing to channel.
13 go func(n int) {
14 nodeChan <- n
15 }(neighbor)
16 }
17 }
18 wg.Done()
19 }
20}
21
22func bfsConcurrent(graph *Graph, startNode int, numWorkers int) int {
23 var visited sync.Map
24 var totalScore int64
25 var wg sync.WaitGroup
26 nodeChan := make(chan int, 10)
27
28 startTime := time.Now()
29
30 for i := 0; i < numWorkers; i++ {
31 go bfsWorker(i, graph, nodeChan, &visited, &totalScore, &wg)
32 }
33
34 visited.Store(startNode, true)
35 wg.Add(1)
36 nodeChan <- startNode
37
38 wg.Wait()
39 close(nodeChan)
40
41 elapsed := time.Since(startTime)
42 return int(atomic.LoadInt64(&totalScore))
43}
Running with 10 workers:
1Total time elapsed: 32.006704666s
2
3=== BFS Complete ===
4Total score: 5596
The key insight: the WaitGroup tracks work items, not workers! A WaitGroup is simply a concurrent counter that blocks until it reaches zero.