Oct 6, 2023
TENSORFLOW, RUST, AND GO: AN UNHOLY UNION FOR CROSS-PLATFORM INFERENCE
How I ported small TensorFlow models to Rust, compiled them to WebAssembly, and embedded them in Go for cross-platform inference on end-user hardware.
clarkmcc/tensorflow-inference-go
Running TensorFlow model inference in pure Go using WebAssembly

I’ve been keeping my eye out for the past year or so for a way to embed machine learning models directly inside a Go service. This service runs on end-user hardware and on every major platform, so distributing the model was going to be difficult if I had to bundle it with a bunch of other stuff just to run it. I explored a few of the following projects:

  • onnx-go - This seemed promising, but the Gorgonia backend didn’t implement several of the operators required by my TensorFlow model.
  • tfgo - TensorFlow bindings in Go, could not be compiled to WebAssembly.
  • tensorflow/rust - Maybe I could embed the model in a Rust project and then compile to WebAssembly? Sadly, since this Rust project is just bindings to Tensorflow, it also could not be compiled to WebAssembly.

After working through these, and discussing online, I realized it wasn’t time, or at the very least, I just wasn’t smart enough to implement the model myself with the technology that was available.

Until…

Candle

I was introduced to huggingface/candle through the rustformers/llm project when I was building chitchat, a Rust/Tauri, statically compiled LLM chatting tool. Candle is an ML framework for Rust that is pure Rust and doesn’t require any external dependencies. This was a good sign for me — it would likely be able to compile to WebAssembly.

The readme example however was quite daunting:

use candle_core::{Device, Tensor};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let device = Device::Cpu;

    let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
    let b = Tensor::randn(0f32, 1., (3, 4), &device)?;

    let c = a.matmul(&b)?;
    println!("{c}");
    Ok(())
}

What was this Tensor business, and what’s a matmul? These low-level details are generally abstracted away when using TensorFlow. But a few days of blood, sweat, and tears later, I finally pulled it off, and I’m here to share my experience and findings in the form of a Github project that you can clone and tinker with.

In this post I’ll gloss over the ideas, show some code snippets, etc. but if you want to see the full code, check out the Github project.

So at a high level, here’s the idea:

  • Design and train a model in TensorFlow
  • Export the weights and biases of the layers of the model that are pertinent to inference (i.e. layers like Dropout are used during training but are disabled by default for inference).
  • Replicate the layers of the model in Rust using Candle.
  • Embed and load the weights and biases into the layers of the model.
  • Compile the Rust project to WebAssembly.
  • Embed the WebAssembly in a Go project using Scale as our Go-to-WebAssembly FFI generator.
The Model

To illustrate this project from start to finish, I’m going to use the MNIST example from the beginners guide of the TensorFlow project. I’m not going to go into the model development process (I’m no machine learning expert by any stretch of the imagination), except to illustrate how we’re going to export the weights and biases of the model. The MNIST example trains a neural network to classify handwritten digits. The model is a simple feed-forward neural network with 2 hidden layers. The model is trained on a collection of 28x28 pixel images of handwritten digits and can classify the digits 0-9.

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

There are a few things to note about this model that are applicable to porting it:

  • Flattening Layer The input shape is 28x28, this means the model is going to expect a 2-dimensional array containing 28 values, and each of those values is an array of 28 numbers. The model flattens these 2-dimensional arrays into 1-dimensional arrays of 784 numbers. When we port the model, we’re going to have to flatten the input data in the same way, or accept the input data as a 1-dimensional array of 784 numbers (we can do either one).

  • Dense Layers These layers actually contain a set of weights and biases, and are actually the only layers that actually need to be ported for inference.

  • Dropout Layer This layer is used during training to prevent overfitting. It randomly drops some of the neurons in the previous layer. We don’t need this layer for inference, so we can just ignore it.

Exporting the Model

There are a bunch of different ways to get the weights and biases out of the TensorFlow model and into our Rust model, but the simplest approach I found (after copying the numpy arrays from the console for a few days) was to use the huggingface/safetensors format, which provides a simple binary format for storing and loading tensors.

So, how do we export these? There’s a couple different ways, but unfortunately, we can’t just pass our model and export that to safe tensors, we can only provide the actual tf.Tensor for exporting. So let’s start by dissecting the layers to see which ones we need to get the weights and biases from.

print(model.layers)

# [<keras.src.layers.reshaping.flatten.Flatten at 0x2a0e51910>,
#  <keras.src.layers.core.dense.Dense at 0x2a0775f40>,
#  <keras.src.layers.regularization.dropout.Dropout at 0x2a0754f10>,
#  <keras.src.layers.core.dense.Dense at 0x1102df250>]

As I mentioned earlier, we only need the weights and biases from the dense layers. The flatten layer is more of a preprocessing step, and the dropout layer is more of a training technique. To get the weights and biases, we just call .get_weights() on the layers that we care about.

dense1 = model.layers[1].get_weights()
dense_weights1 = dense1[0]
dense_biases1 = dense1[1]

dense2 = model.layers[3].get_weights()
dense_weights2 = dense2[0]
dense_biases2 = dense2[1]

And now that we have our weights and biases, we can convert these numpy arrays into tf.Tensor objects and export them to safe tensors.

from safetensors.tensorflow import save_file

save_file({
    "d1_w": tf.convert_to_tensor(dense_weights1),
    "d1_b": tf.convert_to_tensor(dense_biases1),
    "d2_w": tf.convert_to_tensor(dense_weights2),
    "d2_b": tf.convert_to_tensor(dense_biases2),
}, "model.safetensors")
Porting the Model

I’ve ported a few small models so far and each time requires a little bit of tweaking and tuning to get it just right, so I hope that this section will help you get started faster than me when porting your own models. For starters, we need to import our newly exported weights and biases. Since I plan to compile this to WebAssembly, I’m going to embed the contents of that file directly into the binary so that my model doesn’t have to have filesystem access in order to load the weights and biases.

const SAFETENSORS: &[u8] = include_bytes!("../model.safetensors");

The candle library provides some utilities that we can use to parse the raw bytes into a HashMap of tensors, similar to our Python dictionary of tensors that we created when we exported the model.

let tensors = candle_core::safetensors::load_buffer(SAFETENSORS, &Device::Cpu).unwrap();
let vb = VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
=> Preprocessing

With our weights and biases loaded, we can start to design our model. Conceptually, if we’re classifying hand-drawn digits, then I want to create a Rust function that accepts the 784 (28x28) 0-255 pixel values of the image and returns a 0-9 digit. We’ll make the caller responsible for converting the image to a 1-dimensional array of 784 numbers, but our function will handle converting the values into a domain expected by the model, (specifically values from 0 to 1).

fn predict(&self, pixels: &[u32]) -> u32 {
    let pixels = pixels
        .iter()
        .map(|p| *p as f32 / 255.0)
        .collect::<Vec<f32>>();
    // ...
}
=> Input Layer

The first layer of our model is going to be an input layer where we provide the 784 pixel values as a tensor.

let input = Tensor::from_vec(pixels, (1, 784), &Device::Cpu).unwrap();
=> Hidden Layer

The second layer of our model is a hidden dense layer with 128 neurons. This layer is going to accept 784 inputs and produce 128 outputs. We’ll use the VarBuilder to retrieve the weights and biases for this layer.

let dense_weight = self
    .vb
    .get((784, 128), "d1_w")
    .unwrap()
    .transpose(0, 1)
    .unwrap();
let dense_bias = self.vb.get(128, "d1_b").unwrap();
let dense = Linear::new(dense_weight, Some(dense_bias));

One of the interesting things I discovered is the arrangement of the layer weights for Tensorflow are actually transposed from what candle expects, specifically, for this layer Tensorflow provides the weights in a rank two tensor (two-dimensional array) with the shape (784, 128), but due to how candle works, we’re going to need to transpose the weights to a shape of (128, 784) before we can use them. Without the transpose, candle will complain that it cannot perform a matrix multiplication with the provided shapes.

=> Output Layer

The output layer is going to map the 128 outputs from the hidden layer into 10 outputs, one for each digit. Each of the ten outputs will be a probability that the hand-drawn digit matches the index of the probability. To simplify with an example, the following probabilities would indicate that the hand-drawn digit is a 4.

[0.01, 0.01, 0.01, 0.01, 0.96, 0.01, 0.01, 0.01, 0.01, 0.01]

We’ll use the VarBuilder to retrieve the weights and biases for this layer as well, and note that we also have to transpose the weights of this layer like we did for the hidden layer.

let output_weight = self
    .vb
    .get((128, 10), "d2_w")
    .unwrap()
    .transpose(0, 1)
    .unwrap();
let output_bias = self.vb.get(10, "d2_b").unwrap();
let output = Linear::new(output_weight, Some(output_bias));
=> Forward Pass

A forward pass is a fundamental concept in neural networks. It refers to the process of passing the input data through the layers of the network to obtain an output. Here’s a simple explanation:

Imagine you have a factory production line. The input data is like the raw materials you start with at the beginning of the line. As you send these materials through the production line (the neural network), they get transformed at each station (each layer of the network). At the end of the production line, you have your final product, which in the case of a neural network, is the output.

In our model, in order to perform the forward pass, we’ll pass the output of each layer as the input to the next layer, or on the first layer, we just use the pixel value input that we created earlier.

let next = dense.forward(&input);
let predictions = output.forward(&next);

The predictions tensor is a rank-one tensor with a shape of (10,). We can use the argmax function to find the index of the highest probability, and that index will be the digit that the model predicts.

let digit = predictions
    .argmax(1)
    .unwrap()
    .to_vec1::<u32>()
    .unwrap()
    .first()
    .cloned()
    .unwrap()
Calling Rust from Go

Because Go cross-compilation is a priority for me, I wanted a way to call Rust code from Go using WebAssembly rather than some kind of C-ABI. Calling WebAssembly functions isn’t as cut-and-dry as I’d like, due to the fact that you can only pass numbers to functions. If you want to call a WebAssembly function and pass a more complicated data structure like a byte slice, you’d need to allocate that byte slice in the WebAssembly memory, and then pass the pointer to that byte slice to the function.

This is a bit of a pain, so I decided to use Scale to generate the Go-to-WebAssembly bindings for me. Scale introduces the idea of a signature, which is essentially a type definition for what can be passed to and returned from a WebAssembly function. Scale then generates the glue code for both the Go host, and the Rust guest that allows you to call the WebAssembly function as if it were a normal Go function.

Scale ships with a CLI to generate the required files and the glue code. To get started, we’ll need to install the CLI.

$ curl -fsSL https://dl.scale.sh | sh

Next, we’ll need to create a signature file that defines the inputs and outputs of our WebAssembly function.

$ scale signature new

This will generate a boilerplate scale.signature file. The file defines a Context which we can extend with any properties we want. What’s interesting about Scale is that it’s designed for building chainable middleware, so instead of defining inputs and outputs for a WebAssembly function, you define properties on a Context, and each WebAssembly function you invoke can read and write properties on that context. This means we’ll specify the inputs (the 784 pixel values) and outputs (the predicted digit) in the Context.

version = "v1alpha"
context = "context"
model Context {
  uint32_array Pixels {
    initial_size = 784
  }
  uint32 Digit {
    default = 0
  }
}

To generate the glue code, we’ll run signature generate and provide a name and tag for our signature.

$ scale signature generate bindings:latest

These signatures are generated and saved in a local registry where they can either be uploaded to a cloud registry, or exported to the local filesystem. In our case, we’ll export them to the local filesystem where they can be used in our project.

# Generate the Go host types
$ scale signature export local/bindings:latest go host signature/host --manifest=false

# Generate the Rust guest types
$ scale signature export local/bindings:latest rust guest signature/guest

The last step is to generate the Rust WebAssembly function itself using Scale. We’ll give it the name model and have it use the local signature we just created bindings:latest. This will generate a Cargo project in the model directory. We’ll need to copy all our Rust code into this project.

$ scale function new model:latest -s local/bindings:latest -l rust -d model

There’s quite a bit of boilerplate required to actually run the scale function, so I’ll refer you to the Github repo itself for the full code.

Summary

While it’s not exactly simple, it is possible to reliably port a Tensorflow model to Rust, and then WebAssembly, and run that model natively from other languages like Go. This is currently being used in production at my company, and so far, with great success.

clarkmcc/tensorflow-inference-go
Running TensorFlow model inference in pure Go using WebAssembly