Quick & simple intro to Go concurrency: Building a TCP chat server

A great project to get you comfortable with using Go's concurrency tools GitHub repo

Why should I read this article?

If you are new to go, and are hoping to learn how to write concurrent programs in Go - this article is for you.

Building this TCP Chat server as a project will get you comfortable using Go's concurrency toolkit...
(More specifically, this article talks about goroutines, channels, and waitgroups)
Contexts are another useful tool, but we will not be using them in this article.


Contexts are still a very useful and important tool to understand.
(for ex - they are great for waiting for timeouts during an API call without needing channels and a separate goroutine)
A great article on the subject can be found here

This article describes my thought process for designing this project, as well as some helpful notes and understanding-checks along the way. Let's get started!

Defining the functionality:

It is easiest to organize the design of this project if we first lay out the major functionality we expect our chat server to have ...
(Specified from the POV of our clients)

  1. TCP connect to chatroom
  2. Post messages to chatroom
  3. Receive messages from chatroom
  4. Disconnect from chatroom

Handling incoming connections:

The first function we must handle is 1) TCP Connect to chatroom
We can spin up a simple TCP server that allows incoming TCP connections, and prints their messages to the console:

func main() {
    listenAddr := ":3030"
    ln, err := net.Listen("tcp", listenAddr)
    if err != nil {
        log.Println("Failed to create listener for server")
    }
    defer ln.Close()

    log.Println("Server listening at", listenAddr)
    for {
        conn, err := ln.Accept()
        if err != nil {
            log.Println("Error during Accept - ", err)
            continue
        }
        log.Println("New conn from", conn.RemoteAddr())

        buf := make([]byte, 2048)
        for {
            n, err := conn.Read(buf)
            if err != nil {
                log.Println("Error during Read - ", err)
                continue
            }
            msg := string(buf[:n])
            fmt.Println("Got message:", msg)
        }
    }
}

                        

We already have a need for our first concurrency tool: a goroutine.
In the above code snippet, our server is stuck only servicing one connection at a time - what a useless chatroom server!

The inner for loop could be referred to as a read loop.
If we put our read loop in a goroutine, we can accept multiple connections, and print everyone's messages:

func readLoop(conn net.Conn) {
    buf := make([]byte, 2048)
    for {
        n, err := c.conn.Read(buf)
        if err != nil {
            log.Println("Error during Read - ", err)
            continue
        }
        msg := string(buf[:n])
        fmt.Println("Got message:", msg)
    }
}

func main() {
    //...
    for {
    conn, err := ln.Accept()
    if err != nil {
        log.Println("Error during Accept - ", err)
        continue
    }
    log.Println("New conn from", conn.RemoteAddr())

    go readLoop(coon)
    }
}

                        

We can quickly show how we would handle function 4) Disconnect from chatroom
In our read loop, simply check for a special message, and close the conn if found:

func readLoop(conn net.Conn) {
    buf := make([]byte, 2048)
    for {
        n, err := c.conn.Read(buf)
        if err != nil {
            log.Println("Error during Read - ", err)
            continue
        }
        msg := string(buf[:n])
        
        if msg == "$exit" {
            conn.Close()
            return
        }
        
        fmt.Println("Got message:", msg)
    }
}

                        

Save this code in a main.go and run it using go run main.go.
You can connect to the server and send messages using:

    ~$ telnet localhost 3030

Sending and receiving messages:

Now that we understand the mechanics of handling and reading from our TCP connection, how do we control sending / receiving messages? (Functions #2 and 3)

NOTE: We will now be using our 2nd go concurrency tool, the channel. You should take a second to play around with channels in Go playground so you understand the syntax of how to use them before continuing.

First we should introduce some structs to organize our code a little bit.
The first struct I introduced is called a ConnHandler.
The entire purpose of the ConnHandler is to:

  • Hold the conn struct that we read/write with.
  • Publish messages to the channel using a channel called publishCh
  • Read from the chatroom using a channel called readCh
  • Have an integer ID so we have an easy way to identify conns

Their code would be something like this:

type ConnHandler struct {
    ConnId      int
    conn        net.Conn    // We really dont want multiple things touching the conn
    ReadCh      chan string
    PublishCh   chan string
}

In my head, I like to picture them something like this:

Next we will add our top-level struct, the ChatroomBroker.
The ChatroomBroker's job is to be able to keep track of all active ConnHandlers.
Second, it will listen for incoming messages, and broadcast them to everyone when received.
It's code would look something like:

type ChatroomBroker struct {
    conns       []*ConnHandler
    PublishCh   chan string
}

So ChatroomBroker would have the one PublishCh that is shared among all ConnHandlers, and only ChatroomBroker will wait and read from it. Then, it will have a slice of ConnHandlers, each with their own ReadCh that ChatroomBroker will write to when it has received a message through PublishCh.

In my head, the ChatroomBroker with the individual ConnHandlers looks like this:

One more thing to mention is the ChatMsg in the diagram.
Up until now, the channels ReadCh and PublishCh are of type chan string
I actually want to make these of type chan ChatMsg where ChatMsg is a struct:

type ChatMsg struct {
    ConnId  int
    Message string
}

This way, every ConnHandler that receives a copy of this ChatMsg will know who the message came from,
so that it can display the messages for the client as follows:

[1234]: Hey guys, what's up?
[4321]: Hey 1234, just gaming - how about you?

One final thing that we need is a way for ChatroomBroker to know that ConnId #xyz is leaving the chat.
We can define a DisconnectFunc() and pass that as a parameter to our ConnHandlers when we create them.
ConnHandlers can call this function when they disconnect (after receiving $exit):

func (cb *ChatroomBroker) DisconnectFunc(connId int) {
    fmt.Printf("ConnId%d is disconnecting!\n", connId)
    for i, ch := range cb.Conns {
        if ch.ConnId == connId {
            cb.Conns = append(cb.Conns[:i], cb.Conns[i+1:]...)
        }
    }
}

NOTE: Can you spot the concurrency problems we have just introduced? How would you solve it?


The ChatroomBrokers slice of ConnHandlers cb.Conns is a resource that is accessed by multiple goroutines.
By not protecting the reads/writes to cb.Conns, we have introduced a race condition.
Try to fix the concurrency issue using a sync.Mutex to protect the reads and writes.

The full code can be found below. Try to go through each line and understand what the program is doing.
(Especially from the POV of individual goroutines) ...

package main

import (
    "fmt"
    "log"
    "math/rand"
    "net"
)

type ChatMsg struct {
    ConnId  int
    Message string
}

type ConnHandler struct {
    ConnId         int
    Conn           net.Conn // We really dont want multiple things touching the conn
    ReadCh         chan ChatMsg
    PublishCh      chan ChatMsg
    DisconnectFunc func(int) // To tell ChatroomBroker which ConnId is disconnecting
}

type ChatroomBroker struct {
    Conns     []*ConnHandler
    PublishCh chan ChatMsg
}

func (ch *ConnHandler) readLoop() {
    buf := make([]byte, 2048)
    for {
        n, err := ch.Conn.Read(buf)
        if err != nil {
            log.Println("Error during Read - ", err)
            continue
        }
        msg := string(buf[:n])
        msg = msg[:len(msg)-2]

        if msg == "$exit" {
            ch.DisconnectFunc(ch.ConnId)
            ch.Conn.Close()
            return
        } else {
            chatMsg := ChatMsg{
                ConnId:  ch.ConnId,
                Message: msg,
            }
            ch.PublishCh <- chatMsg
        }

    }
}

func (ch *ConnHandler) Start() {
    go ch.readLoop()
    for {
        select {
        case msg := <-ch.ReadCh:
            ch.Conn.Write([]byte(fmt.Sprintf("[%d]: %s\n", msg.ConnId, msg.Message)))
        }
    }
}

func (cb *ChatroomBroker) startListeningToPublishCh() {
    fmt.Println("ChatroomBroker listening to PublishCh")
    for {
        select {
        case chatMsg := <-cb.PublishCh:
            fmt.Printf("[%d]: %s\n", chatMsg.ConnId, chatMsg.Message)
            for _, ch := range cb.Conns {
                if ch.ConnId != chatMsg.ConnId { // Don't broadcast to ourselves
                    fmt.Println("Sending to", ch.ConnId)
                    ch.ReadCh <- chatMsg
                }
            }
        }
    }
}

func (cb *ChatroomBroker) DisconnectFunc(connId int) {
    fmt.Printf("ConnId%d is disconnecting!\n", connId)
    for i, ch := range cb.Conns {
        if ch.ConnId == connId {
            cb.Conns = append(cb.Conns[:i], cb.Conns[i+1:]...)
        }
    }
}

func (cb *ChatroomBroker) Start(ln net.Listener) {
    go cb.startListeningToPublishCh()
    for {
        conn, err := ln.Accept()
        if err != nil {
            log.Println("Error during Accept - ", err)
            continue
        }

        ch := ConnHandler{
            ConnId:         rand.Intn(1000),
            Conn:           conn,
            ReadCh:         make(chan ChatMsg),
            PublishCh:      cb.PublishCh,
            DisconnectFunc: cb.DisconnectFunc,
        }
        log.Println("New conn created with ID", ch.ConnId)
        cb.Conns = append(cb.Conns, &ch)
        go ch.Start()
    }
}

func main() {
    listenAddr := ":3030"
    ln, err := net.Listen("tcp", listenAddr)
    if err != nil {
        log.Println("Failed to create listener for server")
    }
    defer ln.Close()

    log.Println("Server listening at", listenAddr)
    cb := ChatroomBroker{
        Conns:     []*ConnHandler{},
        PublishCh: make(chan ChatMsg),
    }
    cb.Start(ln)

}    

A note on waitGroups and goroutine hygiene:

In the words of legendary West Coast rapper Del the Funky Homosapien, "It's important to practice good hygiene"

Our third concurrency tool, waitGroups, are a great way to practice concurrency hygiene,
and ensure we don't have stray goroutines running unexpectedly.

Consider our main() function in our program at the moment:

func main() {
	listenAddr := ":3030"
	ln, err := net.Listen("tcp", listenAddr)
	if err != nil {
		log.Println("Failed to create listener for server")
	}
	defer ln.Close()

	log.Println("Server listening at", listenAddr)
	cb := ChatroomBroker{
		Conns:     []*ConnHandler{},
		PublishCh: make(chan ChatMsg),
	}

	cb.Start(ln)
}

Currently, our program is just one chatroom. But suppose we have multiple chatrooms being created/deleted all the time.
We would need to make sure that all of the goroutines we spawn end up returning from the stack to prevent memory leaks and other problems.

WaitGroups are a perfect tool for this - they allow us to signal that a goroutine is starting or finishing.
You would first need to add a waitGroup to our ChatroomBroker that would signal to us when all the goroutines have finished.

type ChatroomBroker struct {
	Conns     []*ConnHandler
	PublishCh chan ChatMsg
	Wg *sync.WaitGroup
}

func main() {
	var wg sync.WaitGroup
	listenAddr := ":3030"
	ln, err := net.Listen("tcp", listenAddr)
	if err != nil {
		log.Println("Failed to create listener for server")
	}
	defer ln.Close()

	log.Println("Server listening at", listenAddr)
	cb := ChatroomBroker{
		Conns:     []*ConnHandler{},
		PublishCh: make(chan ChatMsg),
		Wg: &wg,
	}

	wg.Add(1)
	go cb.Start(ln)
	wg.Wait()
}

wg.Wait() would block until some other goroutine calls wg.Done()
This practice of using waitGroups to signal the addition/subtraction of goroutines to a temporary callstack is very useful for avoiding concurrency issues.

NOTE: In the code above, if you passed a regular sync.WaitGroup instead of a *sync.WaitGroup, what would happen?


Golang is pass-by-value by default. This means that it will always pass a COPY of whatever data type is being referenced (except for CERTAIN data types - like slices, maps, POINTERS, etc.)

If we didn't set wg to be *sync.WaitGroup, we would be passing a COPY of our waitGroup to each ChatroomBroker.
The result is that our program would wait forever at wg.Wait() in main(), because no goroutines would ever be calling wg.Done() on that specific waitGroup - they would be calling wg.Done() on their own COPY of the original waitGroup!

Closing notes and future improvements:

The given code from this article only supports one chatroom.
It is a great project to take this example code, and add functionality to create new chatrooms and join different chatrooms.
You may view how I did it on GitHub page for this project.

Some future improvements I would like to make are: