Graceful-shutdown-concurrent-Go-program-with-WaitGroup-and-Context

We become the most familiar strangers.

One of the well known advantage of Go is its support of concurrency. Thanks for the goroutine and channel, it makes writing high performance concurrent code become much easier. It is also fun to implement different concurrent patterns. I personally use this pattern a lot in some crawler and downloading resource concurrently, hope it helps!

Let start with a simple go program:

1
2
3
4
5
6
7
8
9
10
// main() not waiting
func main() {
go task()
fmt.Println("main exiting...")
}

func task() {
time.Sleep(time.Second)
fmt.Println("task finished!")
}

There is a task() function just sleeping for 1 second to simulate a time consuming task. And we want it to run concurrently so add a go keyword in front of the function call to start a goroutine.

1
2
go run main.go
main exiting...

As expected, the program will exit immediately because the main function doesn’t wait the goroutine to finish.

To fix it, we can simply add a channel to block the main function:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// main() waiting through channel
func main() {
ch := make(chan struct{})
go task(ch)

<-ch // block until receive something
fmt.Println("main exiting...")
}

func task(ch chan<- struct{}) {
time.Sleep(time.Second)
fmt.Println("task finished!")
ch <- struct{}{}
}

We create a non-buffered channel with empty struct{} type (since we just use the channel for signalling, the type doesn’t matters), after starting a goroutine, we immediately get the data from the channel by <-ch , it will block the main() until we can get something from the channel. When the task() finish, it will send an empty struct data to the ch , at the point, the main() can finally get something from the ch and continue to run.

1
2
3
go run main.go
task finished!
main exiting...

read and write to a non-buffered channel are blocking operations, it can be used to synchronize and communicate with different goroutines. Whereas buffered channel doesn’t block unless the buffer is full.

Beside using a channel, we can also use a WaitGroup to let the main function waits, it may also be more handy if there are multiple goroutines:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// main() waiting through waitGroup
func main() {
var wg sync.WaitGroup

for i := 0; i < 3; i++ {
wg.Add(1)
go func(i int) {
task(i)
wg.Done()
}(i)
}

fmt.Println("waiting...")
wg.Wait() // block until the WaitGroup counter becomes zero
fmt.Println("main exiting...")
}

func task(id int) {
time.Sleep(time.Second)
fmt.Println("task", id, "finished!")
}

What we need to do is quite simple, just declare a sync.WaitGroup variable. When starting a concurrent job, call the wg.Add(1) to increment the counter, when to job is done, call the wg.Done() to decrease the counter. And at the end of main() we need to call wg.Wait() , it will block until the counter become zero.

1
2
3
4
5
6
go run main.go
waiting...
task 0 finished!
task 1 finished!
task 2 finished!
main exiting...

The things becomes more interesting when we implement a worker pool pattern:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// simple worker pool
func main() {
var wg sync.WaitGroup
pool := make(chan int, 5)

// create a worker keeps fetching the task and work concurrently
go func() {
for id := range pool {
task(id)
wg.Done()
}
}()

// add 5 tasks to the pool
for i := 1; i <= 5; i++ {
wg.Add(1)
pool <- i
fmt.Println("task", i, "added!")
}

close(pool)

fmt.Println("waiting...")
wg.Wait()
fmt.Println("main exiting...")
}

func task(id int) {
time.Sleep(time.Second)
fmt.Println("task", id, "finished!")
}

First we declare a buffered int channel pool , then we create a goroutine which will keep fetching the data from the pool and execute the task , this is the worker, if we want multiple workers, we can simply copy the goroutine code multiple times or wrap it with a loop. The worker is ready and blocking because nothing is inside the pool, now we need to feed some jobs to the pool. We can do that by a simple for loop and sending the loop index as a task id to the pool. As soon as the worker can get something from the pool , it will start working. Finally don’t forget to close() the channel if the sender(main) finish their work, otherwise the receiver(worker) will block there forever waiting for the new data comes in the channel and produce a deadlock.

1
2
3
4
5
6
7
8
9
10
11
12
13
go run main.go
task 1 added!
task 2 added!
task 3 added!
task 4 added!
task 5 added!
waiting...
task 1 finished!
task 2 finished!
task 3 finished!
task 4 finished!
task 5 finished!
main exiting...

In real world situation, it is more likely that we don’t know how many jobs we need to do or we just want to keep feeding the jobs unless we stop it.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// panic: send on closed channel
func main() {
// create a channel to capture SIGTERM, SIGINT signal
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGTERM, syscall.SIGINT)

var wg sync.WaitGroup
pool := make(chan int, 10)
id := 1

// create a worker keeps fetching the task and work concurrently
go func() {
for id := range pool {
task(id)
wg.Done()
}
}()

// adding task to the pool infinitely
go func() {
for {
wg.Add(1)
pool <- id
fmt.Println("task", id, "added!")
id += 1
time.Sleep(time.Millisecond * 500)
}
}()

<-quit // block until receive SIGTERM, SIGINT
close(pool)
wg.Wait()
fmt.Println("main exiting...")
}

func task(id int) {
time.Sleep(time.Second)
fmt.Println("task", id, "finished!")
}

To achieve that, we can remove the loop condition to make it infinite loop and wrap it inside a goroutine to make it non blocking. Then we also need a channel with os.Signal type to block the main() . The program will capture SIGTERM, SIGINT and send to the channel by signal.Notify()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
go run main.go
task 1 added!
task 2 added!
task 1 finished!
task 3 added!
task 4 added!
^Ctask 2 finished!
panic: send on closed channel
goroutine 34 [running]:
main.main.func2()
/Users/yk/Project/test/main.go:77 +0x59
created by main.main
/Users/yk/Project/test/main.go:74 +0x185
exit status 2

What!? panic…It’s because we close the pool channel after we received the quit signal but the producer goroutine still trying to send the job to the pool channel and panic happens. We also need a way to stop the producer goroutine.

Worker pool graceful shutdown with WaitGroup and Context:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
// worker pool graceful shutdown with waitGroup and context
func main() {
// create a channel to capture SIGTERM, SIGINT signal
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGTERM, syscall.SIGINT)

var wg sync.WaitGroup
pool := make(chan int, 10)
id := 1

// create a worker keeps fetching the task and work concurrently
go func() {
for id := range pool {
task(id)
wg.Done()
}
}()

// create a context which listening to SIGTERM, SIGINT
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop()

// adding task to the pool infinitely, break until ctx.Done is closed
go func() {
for {
select {
case <-ctx.Done():
fmt.Println("stop filling the pool!")
close(pool)
return
default:
wg.Add(1)
pool <- id
fmt.Println("task", id, "added!")
id += 1
time.Sleep(time.Millisecond * 500)
}
}
}()

<-quit
wg.Wait()
fmt.Println("main exiting...")
}

func task(id int) {
time.Sleep(time.Second)
fmt.Println("task", id, "finished!")
}

Base on the previous version, we create a context ctx using the signal.NotifyContext() function, it will close the Done channel of the context when the corresponding SIGTERM, SIGINT arrives. In the producer goroutine, instead of a simple for loop, we also need to add a select{} statement. If we receive SIGTERM, SIGINT , it will notify to close the Done channel of the context and enter the case <- ctx.Done(): and exit the goroutine. Otherwise, it will just run the default case to feed jobs to the pool.

1
2
3
4
5
6
7
8
9
go run main.go
task 1 added!
task 2 added!
task 3 added!
task 1 finished!
^Cstop filling the pool!
task 2 finished!
task 3 finished!
main exiting...

Now when we send the SIGTERM, SIGINT to the program, it will first stop feeding more jobs to the pool and exit from the producer goroutine, then it waits for the worker goroutine to finsish all the existing task, then exit the main program.

https://medium.com/@yu-yk/graceful-shutdown-concurrent-go-program-with-waitgroup-and-context-33166210e170