Deep learning has almost exclusively been working with simple objects: images and text. By simple I am here referring to the graphical structure of these objects, where a word is a linear sequence of letters, a document is a linear sequence of words, and an image is a rectangular grid of pixels. Graph Neural Networks (GNNs) were invented in Gori et al (2005) by researchers from Università di Siena in Italy, and are networks which process information without requiring the input to have a particular rigid structure.

Recent work within GNNs has focused on the development of representation learning for graphs, and I’ll be writing about a handful of these ideas, each superseding the former. We start with the DeepWalk algorithm, introduced in Perozzi et al (2014) by researchers from Stony Brook University in the USA.

This post is part of my series on graph algorithms:

  1. PageRank
  2. DeepWalk
  3. Graph Convolutional Neural Networks

The SkipGram Algorithm

The DeepWalk algorithm is intimately connected to the SkipGram algorithm introduced a year before, in the Google paper Mikolov et al (2013). The goal of the SkipGram algorithm is to produce vector representations of words, solely from data. The fundamental idea in the SkipGram algorithm is that the model should attempt to predict the neighbouring words of a given input word.

A crucial notion here is neighbour. In this algorithm, we denote an n-neighbour of a given word to be any word at most $n$ spaces away from the word. For example, in the sentence “We are learning about SkipGram”, the 2-neighbours of “about” are “are”, “learning” and “SkipGram”. Note that we do not care how close the neighbours are, as we are not ranking them in any way. Here $n$ is a hyperparameter in the algorithm, for “neighbouring words” to have a precise meaning.

But where is the vector representation of the word then? Indeed, the missing piece to the SkipGram algorithm is that it consists of an encoder and a decoder, see Figure 1. These can in principle be arbitrary neural networks, and the output of the encoder-decoder model should then be a probability distribution of the neighbouring words of the input word.

A diagram of the SkipGram architecture, which takes an input word and attempts to predict one of the neighbouring words.

Figure 1. The SkipGram architecture, from the original paper.

This way of designing the architecture means that we have an intermediate representation, namely the output of the encoder, which we can use as the representation of the input word after we have trained the model.

In the original implementation from the above paper by Mikolov et al, which they denoted Word2Vec, simply used a linear projection for the encoder and another linear embedding for the decoder (no non-linearities used at all). This made it highly computationally efficient, making it possible to process millions of words in a reasonable amount of time.

From SkipGram to DeepWalk

Knowing what the SkipGram algorithm is about, the leap to DeepWalk is not far. As I mentioned above, the context of SkipGram is a linear chain (of words), so when we’re going from the linear context to an arbitrary graph structure, we only have to change the features in SkipGram which used the linearity, which was in the definition of neighbour.

In a general graph we could mimic the definition of SkipGram and simply define $n$-neighbours in the same way. Namely, a node which is at most $n$ hops away from the input node. The problem with this approach is that graphs are usually highly connected, so even going only 5 hops away from your node, you might suddenly have reached every node in the graph. As our graph might contain millions of nodes, this becomes computationally infeasible.

What DeepWalk does is to sample the neighbours in a particular way, rather than considering all of them at once. This is done through random walks, intuitively being finite sequences of nodes in the graph, obtained by starting from a random node and “walking” randomly around the graph. We can define this formally as a finite Markov chain with uniform transition probabilities; we defined these terms last time.

An example graph in which a random walk of length 4 is shown, starting from a node x. Another node y occuring on the node is marked, with a green boundary around the neighbouring two nodes on the random walk to y, being the 1-neighbourhood of y.

What we then do is two-fold. We firstly do a random walk $w_x$ at every node $x$ in the graph. Then, for every random walk $w_x$ and every node $y\in w_x$, we can now consider the $n$-neighbours of $y$ to be the nodes at most $n$ hops away within this random walk $w_x$. This point is worth re-iterating: we are using the random walks to reduce the neighbour concept back to the linear case!

This also means that in every epoch we will have started random walks at every node, and we will therefore have processed many nodes multiple times, as they could have occured in multiple random walks. Thus, an epoch here is slightly different from normal deep learning training loops.

Julia Implementation

We have two components that we need to implement. Firstly, we need to generate new samples through the use of random walks, and next we need to code a training loop.

We will be using the LightGraphs package for dealing with graphs, and the Flux package to implement and train the neural networks.

using LightGraphs
using Flux

The following function does the following:

  1. It starts from a specified node
  2. It generates a random walk of a fixed length from that node
  3. For every node in the random walk, it collects all the neighbours of that node, with neighbours being nodes within a pre-specified window
  4. It outputs the nodes in the random walk along with their neighbours
function generate_training_samples(
  graph::AbstractGraph,
  start_node::Integer,
  walk_len::Integer,
  window::Integer)::Tuple

  # Generate random walk
  walk = randomwalk(graph, start_node, walk_len)

  # Collect nodes and their neighbours within the random walk
  nodes = Vector{Integer}()
  neighbours = Vector{Integer}()
  for (idx, node) in enumerate(walk)
    window_left = max(1, idx - window)
    window_right = min(size(walk, 1), idx + window)
    for neighbour in walk[window_left:window_right]
      if neighbour != node
        push!(nodes, node)
        push!(neighbours, neighbour)
      end
    end
  end
  nodes, neighbours
end

The next function is a training loop, which is very similar to training loops in PyTorch and TensorFlow. The main difference is that we’re calling the generate_training_samples helper function above to generate our batches:

function train_model(
  graph::AbstractGraph;
  walk_len::Integer=10,
  window::Integer=3,
  num_epochs::Integer=1,
  emb_dim::Integer=10)

  # Collect all the nodes in the graph
  num_nodes = nv(graph)
  all_nodes = collect(vertices(graph))

  # Build the model
  model = Chain(Dense(num_nodes, emb_dim), Dense(emb_dim, num_nodes))

  # Define our loss function, being cross entropy on a one-hot encoding
  # of the nodes
  function criterion(nodes::Array{Integer},
                     neighbours::Array{Integer})::AbstractFloat
    nodes = Flux.onehotbatch(nodes, all_nodes)
    neighbours = Flux.onehotbatch(neighbours, all_nodes)
    Flux.Losses.logitcrossentropy(model(nodes), neighbours)
  end

  # Set up the optimiser and fetch the parameters of the model
  optimiser = Flux.Optimise.ADAMW(3e-4)
  params = Flux.params(model)

  # Main training loop
  losses = []
  for epoch in 1:num_epochs
    avg_loss = 0

    # Looping over all nodes in the graph
    for start_node in vertices(graph)

      # Generate samples using the helper function
      (nodes, neighbours) = generate_training_samples(graph, start_node,
                                                      walk_len, window)

      # Enable gradient computation and compute the loss
      gradients = gradient(params) do
        loss = criterion(nodes, neighbours)
        avg_loss += loss
        return loss
      end

      # Backpropagate the loss through the network
      Flux.Optimise.update!(optimiser, params, gradients)
    end

    # Get the average loss of the epoch, save it to `losses`
    # and print out the status
    avg_loss /= nv(graph)
    append!(losses, avg_loss)
    println("Epoch ", epoch, " - Average loss: ", avg_loss)
  end

  # Return the model along with the array of all the average losses
  model, losses
end

Let’s see this in action! I’ll apply this here to a fairly simple graph, a clique graph with 20 cliques, each having 5 nodes:

> using GraphPlot
> graph = CliqueGraph(5, 20)
> gplot(graph)

A connected graph with 20 groupings, each having 5 nodes

We can then train our model using our train_model function above:

> model, losses = train_model(graph, num_epochs=200, emb_dim=50)
Epoch 1 - Average loss: 4.609087
Epoch 2 - Average loss: 4.5906944
(...)
Epoch 200 - Average loss: 1.908743

Let’s look at a loss graph as a sanity check:

> Using Plots
> plot(losses, legend=false, title="Average Loss per Epoch")

A plot of the loss, which converges to around 2

Now, using the model we can fetch the embeddings for all the nodes:

> all_nodes = collect(vertices(graph))
> nodes = Flux.onehotbatch(all_nodes, all_nodes)
> embeddings = model[1](nodes) # model[1] means that we only apply the encoding layer
> size(embeddings)
(50, 100)

And there we go!

We can evaluate the embeddings by clustering them and colouring the graph based on the clusters, here using the DBSCAN clustering algorithm:

> using Clustering
> clusters = dbscan(embeddings, 3) # radius is 3
> size(clusters)
(19,)

We next map the indices to colours and plot our coloured graph (there might be an easier way to do this):

> using ColorSchemes
> colour_idxs = zeros(Int8, nv(graph))
> for (idx, cluster) in enumerate(clusters)
>     colour_idxs[cluster.core_indices] .= idx
> end
> colours = ColorSchemes.gist_rainbow[colour_idxs .* 5]
> gplot(graph, nodefillc = colours)

The same graph as before, but the cliques now have different colours

Final Comments

I mentioned in the introduction that this was the first (very successful) attempt at producing representations for nodes in a graph. Several other methods have superseded this algorithm by now, so this post is mainly to understand how this field of graph representation learning has progressed and to have some context for the representation algorithms that I will cover in future posts.