I want to improve my Rust skills, and help you hone yours as well. So I've decided to write a series of articles about the Rust programing language.

By actually building stuff with Rust, we'll learn about a wide range of technological concepts in the process. In this installment, we’ll learn how to implement the Naive Bayes classifier with Rust.

You may encounter a few unfamiliar terms or concepts in this article. Don’t be discouraged. Look these up if you have the time, but regardless, this article’s main ideas will not be lost on you.

What is a Naive Bayes Classifier?

The Naive Bayes classifier is a machine learning algorithm based on Bayes’ Theorem. Bayes’ Theorem gives us a way to update the probability of a hypothesis \(H\), given some data \(D\).

Expressed mathematically, we have:

\[P(H|D) = \frac{P(D|H)P(H)}{P(D)}\]

where \(P(H|D) =\) the probability of \(H\) given \(D\).

If we accumulate more data, we can update \(P(H|D)\) accordingly.

Naive Bayesian models rest on a big assumption: whether a data point is present or absent from the data set is independent from data already in that set (source). That is, each piece of data conveys no information about any other data points.

We do not expect this assumption to be true – it is weak. But it’s still useful, allowing us to create efficient classifiers that work quite well (source).

We’ll leave our description of Naive Bayes there. A lot more could be said, but the main point of this article is to practice Rust.

If you’d like to learn more about the algorithm, here are some resources:

The canonical application of the Naive Bayes classifier is a spam classifier. That is what we’ll build. You can find all the code here: https://github.com/josht-jpg/shaking-off-the-rust

We’ll begin by creating a new library with Cargo.

cargo new naive_bayes --lib
cd naive_bayes

Now let's dive into it.

Tokenization in Rust

Our classifier will take in a message as input and return a classification of spam or not spam.

To work with the message we’re given, we’ll want to tokenize it. Our tokenized representation will be a set of words in lower case where order and repeat entries are disregarded. Rust’s std::collections::HashSet structure is a great way to achieve this.

The function we’ll write to preform tokenization will require the use of the regex crate. Make sure you include this dependency in your Cargo.toml file:

[dependencies]
regex = "^1.5.4"

And here’s the tokenize function:

// lib.rs

// We'll need HashMap later
use std::collections::{HashMap, HashSet};

extern crate regex;
use regex::Regex;

pub fn tokenize(lower_case_text: &str) -> HashSet<&str> {
    Regex::new(r"[a-z0-9']+")
        .unwrap()
        .find_iter(lower_case_text)
        .map(|mat| mat.as_str())
        .collect()
}

This function uses a regular expression to match all numbers and lowercase letters. Whenever we come across a different type of symbol (often whitespace or punctuation), we split the input and group together all numbers and letters encountered since the last split (you can read more about regex in Rust here). That is, we're identifying and isolating words in the input text.

Some Handy Structures

Using a struct to represent a message will be helpful. This struct will contain a string slice for the message’s text, and a Boolean value to indicate whether or not the message is spam:

pub struct Message<'a> {
    pub text: &'a str,
    pub is_spam: bool,
}

The 'a is a lifetime parameter annotation. If you’re unfamiliar with lifetimes, and want to learn about them, I recommend reading section 10.3 of The Rust Programming Language Book.

A struct will also be useful to represent our classifier. Before creating the struct, we need a short digression on Laplacian Smoothing.

What is Laplace Smoothing?

Assume that – in our training data – the word fubar appears in some non-spam messages, but does not appear in any spam messages. Then, the Naive Bayes classifier will assign a probability 0 of spam to any message that contains the word fubar (source).

Unless we’re talking about my success with online dating, it’s not smart to assign a probability of 0 to an event just because it hasn’t happened yet.

Enter Laplace Smoothing. This is the technique of adding \(\alpha\) to the number of observations of each token (source). Let’s see this mathematically: without Laplace Smoothing, the probability of seeing a word \(w\) in a spam message is:

\[P(w|S) = \frac{number\ of\ spam\ messages\ containing\ w}{total\ number\ of\ spams}\]

With Laplace Smoothing, it’s:

\[P(w|S) = \frac{(a+number\ of\ spam\ messages\ containing\ w)}{(2a+total\ number\ of\ spams)}\]

Back to our classifier struct:

pub struct NaiveBayesClassifier {
    pub alpha: f64,
    pub tokens: HashSet<String>,
    pub token_ham_counts: HashMap<String, i32>,
    pub token_spam_counts: HashMap<String, i32>,
    pub spam_messages_count: i32,
    pub ham_messages_count: i32,
}

The implementation block for NaiveBayesClassifier will center around a train method and a predict method.

How to Train Our Classifier

The train method will take in a slice of Messages and loop through each Message, doing the following:

  • Check whether the message is spam and update spam_messages_count or ham_messages_count accordingly. We’ll create the helper function increment_message_classifications_count for this.
  • Tokenize the message’s contents with our tokenize function.
  • Loop through each token in the message and:
  • Insert the token into the tokens HashSet, then update token_spam_counts or token_ham_counts. We’ll create the helper function increment_token_count for this.

Here’s the pseudocode for our train method. If you feel like it, try to convert the pseudocode into Rust before looking at my implementation below. Don't hesitate to send me your implementation, I’d love to see it!

implementation block for NaiveBayesClassifier {

	train(self, messages) {
		for each message in messages {
			self.increment_message_classifications_count(message)
			
			lowercase_text = to_lowercase(message.text)
			for each token in tokenize(lowercase_text) {
				self.tokens.insert(tokens)
				self.increment_token_count(token, message.is_spam)
			}			
		}
	}

	increment_message_classifications_count(self, message) {
		if message.is_spam {
			self.spam_messages_count = self.spam_messages_count + 1
		} else {
			self.ham_messages_count = self.ham_messages_count + 1
		}
	}

	increment_token_count(&mut self, token, is_spam) {
		if token is not a key of self.token_spam_counts {
			insert record with key=token and value=0 into self.token_spam_counts
		}

		if token is not a key of self.token_ham_counts {
			insert record with key=token and value=0 into self.token_ham_counts
		}

		if is_spam {
			self.token_spam_counts[token] = self.token_spam_counts[token] + 1
		} else {
			self.token_ham_counts[token] = self.token_ham_counts[token] + 1
		}
	}

}

And here’s the Rust implementation:

impl NaiveBayesClassifier {
    pub fn train(&mut self, messages: &[Message]) {
        for message in messages.iter() {
            self.increment_message_classifications_count(message);
            for token in tokenize(&message.text.to_lowercase()) {
                self.tokens.insert(token.to_string());
                self.increment_token_count(token, message.is_spam)
            }
        }
    }

    fn increment_message_classifications_count(&mut self, message: &Message) {
        if message.is_spam {
            self.spam_messages_count += 1;
        } else {
            self.ham_messages_count += 1;
        }
    }

    fn increment_token_count(&mut self, token: &str, is_spam: bool) {
        if !self.token_spam_counts.contains_key(token) {
            self.token_spam_counts.insert(token.to_string(), 0);
        }

        if !self.token_ham_counts.contains_key(token) {
            self.token_ham_counts.insert(token.to_string(), 0);
        }

        if is_spam {
            self.increment_spam_count(token);
        } else {
            self.increment_ham_count(token);
        }
    }

    fn increment_spam_count(&mut self, token: &str) {
        *self.token_spam_counts.get_mut(token).unwrap() += 1;
    }

    fn increment_ham_count(&mut self, token: &str) {
        *self.token_ham_counts.get_mut(token).unwrap() += 1;
    }
}

Notice that incrementing a value in a HashMap is pretty cumbersome. A novice Rust programmer would have difficulty understanding what

*self.token_spam_counts.get_mut(token).unwrap() += 1

is doing.

In an attempt to make the code more explicit, I’ve created the increment_spam_count and increment_ham_count functions. But I’m not happy with that – it still feels cumbersome. Reach out to me if you have suggestions for a better approach.

How to Predict with Our Classifier

The predict method will take a string slice and return the model’s calculated probability of spam.

We’ll create two helper functions probabilities_of_message and probabilites_of_token to do the heavy lifting for predict.

probabilities_of_message returns \(P(Message|Spam)\) and \(P(Message|ham)\).

probabilities_of_token returns \(P(Token|Spam)\) and \(P(Token|ham)\).

Calculating the probability that the input message is spam involves multiplying together each word’s probability of occurring in a spam message.

Since probabilities are floating point numbers between 0 and 1, multiplying many probabilities together can result in underflow (source). This is when an operation results in a number smaller than what the computer can accurately store (see here and here). Thus, we’ll use logarithms and exponentials to transform the task into a series of additions:

\[ \prod_{i=0}^{n}p_i = exp(\sum_{i=0}^{n} log(p_i)) \]

Which we can do because for any real numbers \(a\) and \(b\),

\[ab = exp(log(ab)) = exp(log(a) + log(b))\]

Once again I’ll start with pseudocode for the predict method:

implementation block for NaiveBayesCalssifier {
	/*...*/

	predict(self, text) {
		lower_case_text = to_lowercase(text)
		message_tokens = tokenize(text)
		(prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens)
		return prob_if_spam / (prob_if_spam + prob_if_ham)
	}
	
	probabilities_of_message(self, message_tokens) {
		log_prob_if_spam = 0
		log_prob_if_ham = 0

		for each token in self.tokens {
			(prob_if_spam, prob_if_ham) = self.probabilites_of_token(token)

			if message_tokens contains token {
				log_prob_if_spam = log_prob_if_spam + ln(prob_if_spam)
				log_prob_if_ham = log_prob_if_ham + ln(prob_if_ham)
			} else {
				log_prob_if_spam = log_prob_if_spam + ln(1 - prob_if_spam)
				log_prob_if_ham = log_prob_if_ham + ln(1 - prob_if_ham)
			}
		}

		prob_if_spam = exp(log_prob_if_spam)
		prob_if_ham = exp(log_prob_if_ham)

		return (prob_if_spam, prob_if_ham)
	}

	probabilites_of_token(self, token) {
		prob_of_token_spam = (self.token_spam_counts[token] + self.alpha) 
						/ (self.spam_messages_count + 2 * self.alpha)
        
		prob_of_token_ham = (self.token_ham_counts[token] + self.alpha) 
						/ (self.ham_messages_count + 2 * self.alpha)

		return (prob_of_token_spam, prob_of_token_ham)
	}
	
	
}

And here’s the Rust code:

impl NaiveBayesClassifier {

		/*...*/

	pub fn predict(&self, text: &str) -> f64 {
        let lower_case_text = text.to_lowercase();
        let message_tokens = tokenize(&lower_case_text);
        let (prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens);

        return prob_if_spam / (prob_if_spam + prob_if_ham);
    }

    fn probabilities_of_message(&self, message_tokens: HashSet<&str>) -> (f64, f64) {
        let mut log_prob_if_spam = 0.;
        let mut log_prob_if_ham = 0.;

        for token in self.tokens.iter() {
            let (prob_if_spam, prob_if_ham) = self.probabilites_of_token(&token);

            if message_tokens.contains(token.as_str()) {
                log_prob_if_spam += prob_if_spam.ln();
                log_prob_if_ham += prob_if_ham.ln();
            } else {
                log_prob_if_spam += (1. - prob_if_spam).ln();
                log_prob_if_ham += (1. - prob_if_ham).ln();
            }
        }

        let prob_if_spam = log_prob_if_spam.exp();
        let prob_if_ham = log_prob_if_ham.exp();

        return (prob_if_spam, prob_if_ham);
    }

    fn probabilites_of_token(&self, token: &str) -> (f64, f64) {
        let prob_of_token_spam = (self.token_spam_counts[token] as f64 + self.alpha)
            / (self.spam_messages_count as f64 + 2. * self.alpha);

        let prob_of_token_ham = (self.token_ham_counts[token] as f64 + self.alpha)
            / (self.ham_messages_count as f64 + 2. * self.alpha);

        return (prob_of_token_spam, prob_of_token_ham);
    }
}

How to Test Our Classifier

Let’s give our model a test. The test below goes through Naive Bayes manually, then checks that our model gives the same result.

You may find it worth while to go through the test’s logic, or you may just want to paste the code to the bottom of your lib.rs file to check that your code works.

// ...lib.rs

pub fn new_classifier(alpha: f64) -> NaiveBayesClassifier {
    return NaiveBayesClassifier {
        alpha,
        tokens: HashSet::new(),
        token_ham_counts: HashMap::new(),
        token_spam_counts: HashMap::new(),
        spam_messages_count: 0,
        ham_messages_count: 0,
    };
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn naive_bayes() {
        let train_messages = [
            Message {
                text: "Free Bitcoin viagra XXX christmas deals 😻😻😻",
                is_spam: true,
            },
            Message {
                text: "My dear Granddaughter, please explain Bitcoin over Christmas dinner",
                is_spam: false,
            },
            Message {
                text: "Here in my garage...",
                is_spam: true,
            },
        ];

        let alpha = 1.;
        let num_spam_messages = 2.;
        let num_ham_messages = 1.;

        let mut model = new_classifier(alpha);
        model.train(&train_messages);

        let mut expected_tokens: HashSet<String> = HashSet::new();
        for message in train_messages.iter() {
            for token in tokenize(&message.text.to_lowercase()) {
                expected_tokens.insert(token.to_string());
            }
        }

        let input_text = "Bitcoin crypto academy Christmas deals";

        let probs_if_spam = [
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "Free"  (not present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "Bitcoin"  (present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "viagra"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "XXX"  (not present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "christmas"  (present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "deals"  (present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "my"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "dear"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "granddaughter"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "please"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "explain"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "over"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "dinner"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "here"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "in"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "garage"  (not present)
        ];

        let probs_if_ham = [
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "Free"  (not present)
            (1. + alpha) / (num_ham_messages + 2. * alpha),      // "Bitcoin"  (present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "viagra"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "XXX"  (not present)
            (1. + alpha) / (num_ham_messages + 2. * alpha),      // "christmas"  (present)
            (0. + alpha) / (num_ham_messages + 2. * alpha),      // "deals"  (present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "my"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "dear"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "granddaughter"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "please"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "explain"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "over"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "dinner"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "here"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "in"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "garage"  (not present)
        ];

        let p_if_spam_log: f64 = probs_if_spam.iter().map(|p| p.ln()).sum();
        let p_if_spam = p_if_spam_log.exp();

        let p_if_ham_log: f64 = probs_if_ham.iter().map(|p| p.ln()).sum();
        let p_if_ham = p_if_ham_log.exp();

        // P(message | spam) / (P(messge | spam) + P(message | ham)) rounds to 0.97
        assert!((model.predict(input_text) - p_if_spam / (p_if_spam + p_if_ham)).abs() < 0.000001);
    }
}

Now run cargo test. If that passes for you, well done, you’ve implemented a Naive Bayes classifier in Rust!

Thank you for coding with me, friends. Feel free to reach out if you have any questions or suggestions.

References

  1. Grus, J. (2019). Data Science from Scratch: First Principles with Python, 2nd edition. O’Reilly Media.
  2. Downey, A. (2021). Think Bayes: Bayesian Statistics in Python, 2nd edition.  O’Reilly Media.
  3. Murphy, K. (2012). Machine Learning: A Probabilistic Perspective. MIT Press.
  4. Dhinakaran, V. (2017). Rust Cookbook. Packt.
  5. Ng, A. (2018). Stanford CS229: Lecture 5 - GDA & Naive Bayes.
  6. Burden, R. Faires, J. Burden, A. (2015). Numerical Analysis, 10th edition. Brooks Cole.
  7. Underflow. Technopedia.