I want to improve my Rust skills, and help you hone yours as well. So I've decided to write a series of articles about the Rust programing language.

By actually building stuff with Rust, we'll learn about a wide range of technological concepts in the process. In this installment, we’ll learn how to implement the Naive Bayes classifier with Rust.

You may encounter a few unfamiliar terms or concepts in this article. Don’t be discouraged. Look these up if you have the time, but regardless, this article’s main ideas will not be lost on you.

## What is a Naive Bayes Classifier?

The Naive Bayes classifier is a machine learning algorithm based on Bayes’ Theorem. Bayes’ Theorem gives us a way to update the probability of a hypothesis *H*, given some data *D*.

Expressed mathematically, we have:

where *P(H|D) =* the probability of *H* given *D*.

If we ac*c*umulate more data, we can update *P(H|D)* accordingly.

Naive Bayesian models rest on a big assumption: whether a data point is present or absent from the data set is independent from data already in that set (source). That is, each piece of data conveys no information about any other data points.

We do not expect this assumption to be true – it is weak. But it’s still useful, allowing us to create efficient classifiers that work quite well (source).

We’ll leave our description of Naive Bayes there. A lot more could be said, but the main point of this article is to practice Rust.

If you’d like to learn more about the algorithm, here are some resources:

- I can’t say enough good things about this video from Josh Starmer.
- Joel Grus has written a chapter on Naive Bayes in his great book
*Data Science from Scratch,*which was the main inspiration for this implementation. - If mathematical notation is your thing, try section 6.6.3 of
*The Elements of Statisical Learning**.* - And here's a helpful article on the basics of how the algorithm works

The canonical application of the Naive Bayes classifier is a spam classifier. That is what we’ll build. You can find all the code here: https://github.com/josht-jpg/shaking-off-the-rust

We’ll begin by creating a new library with Cargo.

```
cargo new naive_bayes --lib
cd naive_bayes
```

Now let's dive into it.

## Tokenization in Rust

Our classifier will take in a message as input and return a classification of spam or not spam.

To work with the message we’re given, we’ll want to *tokenize* it. Our tokenized representation will be a set of words in lower case where order and repeat entries are disregarded. Rust’s ** std::collections::HashSet** structure is a great way to achieve this.

The function we’ll write to preform tokenization will require the use of the regex crate. Make sure you include this dependency in your `Cargo.toml`

file:

```
[dependencies]
regex = "^1.5.4"
```

And here’s the `tokenize`

function:

```
// lib.rs
// We'll need HashMap later
use std::collections::{HashMap, HashSet};
extern crate regex;
use regex::Regex;
pub fn tokenize(lower_case_text: &str) -> HashSet<&str> {
Regex::new(r"[a-z0-9']+")
.unwrap()
.find_iter(lower_case_text)
.map(|mat| mat.as_str())
.collect()
}
```

This function uses a regular expression to match all numbers and lowercase letters. Whenever we come across a different type of symbol (often whitespace or punctuation), we split the input and group together all numbers and letters encountered since the last split (you can read more about regex in Rust here). That is, we're identifying and isolating words in the input text.

### Some Handy Structures

Using a `struct`

to represent a message will be helpful. This `struct`

will contain a *string slice* for the message’s text, and a Boolean value to indicate whether or not the message is spam:

```
pub struct Message<'a> {
pub text: &'a str,
pub is_spam: bool,
}
```

The `'a`

is a lifetime parameter annotation. If you’re unfamiliar with lifetimes, and want to learn about them, I recommend reading section 10.3 of The Rust Programming Language Book.

A `struct`

will also be useful to represent our classifier. Before creating the `struct`

, we need a short digression on Laplacian Smoothing.

### What is Laplace Smoothing?

Assume that – in our training data – the word *fubar* appears in some non-spam messages, but does not appear in any spam messages. Then, the Naive Bayes classifier will assign a probability **0** of spam to any message that contains the word *fubar* (source).

Unless we’re talking about my success with online dating, it’s not smart to assign a probability of **0** to an event just because it hasn’t happened yet.

Enter Laplace Smoothing. This is the technique of adding

to the number of observations of each token (source). Let’s see this mathematically: without Laplace Smoothing, the probability of seeing a word *w* in a spam message is:

With Laplace Smoothing, it’s:

Back to our classifier `struct`

:

```
pub struct NaiveBayesClassifier {
pub alpha: f64,
pub tokens: HashSet<String>,
pub token_ham_counts: HashMap<String, i32>,
pub token_spam_counts: HashMap<String, i32>,
pub spam_messages_count: i32,
pub ham_messages_count: i32,
}
```

The implementation block for `NaiveBayesClassifier`

will center around a `train`

method and a `predict`

method.

## How to Train Our Classifier

The `train`

method will take in a slice of `Message`

s and loop through each `Message`

, doing the following:

- Check whether the message is spam and update
`spam_messages_count`

or`ham_messages_count`

accordingly. We’ll create the helper function`increment_message_classifications_count`

for this. - Tokenize the message’s contents with our
`tokenize`

function. - Loop through each token in the message and:
- Insert the token into the
`tokens`

`HashSet`

, then update`token_spam_counts`

or`token_ham_counts`

. We’ll create the helper function`increment_token_count`

for this.

Here’s the pseudocode for our `train`

method. If you feel like it, try to convert the pseudocode into Rust before looking at my implementation below. Don't hesitate to send me your implementation, I’d love to see it!

```
implementation block for NaiveBayesClassifier {
train(self, messages) {
for each message in messages {
self.increment_message_classifications_count(message)
lowercase_text = to_lowercase(message.text)
for each token in tokenize(lowercase_text) {
self.tokens.insert(tokens)
self.increment_token_count(token, message.is_spam)
}
}
}
increment_message_classifications_count(self, message) {
if message.is_spam {
self.spam_messages_count = self.spam_messages_count + 1
} else {
self.ham_messages_count = self.ham_messages_count + 1
}
}
increment_token_count(&mut self, token, is_spam) {
if token is not a key of self.token_spam_counts {
insert record with key=token and value=0 into self.token_spam_counts
}
if token is not a key of self.token_ham_counts {
insert record with key=token and value=0 into self.token_ham_counts
}
if is_spam {
self.token_spam_counts[token] = self.token_spam_counts[token] + 1
} else {
self.token_ham_counts[token] = self.token_ham_counts[token] + 1
}
}
}
```

And here’s the Rust implementation:

```
impl NaiveBayesClassifier {
pub fn train(&mut self, messages: &[Message]) {
for message in messages.iter() {
self.increment_message_classifications_count(message);
for token in tokenize(&message.text.to_lowercase()) {
self.tokens.insert(token.to_string());
self.increment_token_count(token, message.is_spam)
}
}
}
fn increment_message_classifications_count(&mut self, message: &Message) {
if message.is_spam {
self.spam_messages_count += 1;
} else {
self.ham_messages_count += 1;
}
}
fn increment_token_count(&mut self, token: &str, is_spam: bool) {
if !self.token_spam_counts.contains_key(token) {
self.token_spam_counts.insert(token.to_string(), 0);
}
if !self.token_ham_counts.contains_key(token) {
self.token_ham_counts.insert(token.to_string(), 0);
}
if is_spam {
self.increment_spam_count(token);
} else {
self.increment_ham_count(token);
}
}
fn increment_spam_count(&mut self, token: &str) {
*self.token_spam_counts.get_mut(token).unwrap() += 1;
}
fn increment_ham_count(&mut self, token: &str) {
*self.token_ham_counts.get_mut(token).unwrap() += 1;
}
}
```

Notice that incrementing a value in a `HashMap`

is pretty cumbersome. A novice Rust programmer would have difficulty understanding what

`*self.token_spam_counts.get_mut(token).unwrap() += 1`

is doing.

In an attempt to make the code more explicit, I’ve created the `increment_spam_count`

and `increment_ham_count`

functions. But I’m not happy with that – it still feels cumbersome. Reach out to me if you have suggestions for a better approach.

## How to Predict with Our Classifier

The `predict`

method will take a *string slice* and return the model’s calculated probability of spam.

We’ll create two helper functions `probabilities_of_message`

and `probabilites_of_token`

to do the heavy lifting for `predict`

.

`probabilities_of_message`

returns *P(Message|Spam)* and *P(Message|ham)*.

`probabilities_of_token`

returns *P(Token|Spam)* and *P(Token|ham)*.

Calculating the probability that the input message is spam involves multiplying together each word’s probability of occurring in a spam message.

Since probabilities are floating point numbers between 0 and 1, multiplying many probabilities together can result in **underflow** (source). This is when an operation results in a number smaller than what the computer can accurately store (see here and here). Thus, we’ll use logarithms and exponentials to transform the task into a series of additions:

Which we can do because for any real numbers *a* and *b,*

Once again I’ll start with pseudocode for the predict method:

```
implementation block for NaiveBayesCalssifier {
/*...*/
predict(self, text) {
lower_case_text = to_lowercase(text)
message_tokens = tokenize(text)
(prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens)
return prob_if_spam / (prob_if_spam + prob_if_ham)
}
probabilities_of_message(self, message_tokens) {
log_prob_if_spam = 0
log_prob_if_ham = 0
for each token in self.tokens {
(prob_if_spam, prob_if_ham) = self.probabilites_of_token(token)
if message_tokens contains token {
log_prob_if_spam = log_prob_if_spam + ln(prob_if_spam)
log_prob_if_ham = log_prob_if_ham + ln(prob_if_ham)
} else {
log_prob_if_spam = log_prob_if_spam + ln(1 - prob_if_spam)
log_prob_if_ham = log_prob_if_ham + ln(1 - prob_if_ham)
}
}
prob_if_spam = exp(log_prob_if_spam)
prob_if_ham = exp(log_prob_if_ham)
return (prob_if_spam, prob_if_ham)
}
probabilites_of_token(self, token) {
prob_of_token_spam = (self.token_spam_counts[token] + self.alpha)
/ (self.spam_messages_count + 2 * self.alpha)
prob_of_token_ham = (self.token_ham_counts[token] + self.alpha)
/ (self.ham_messages_count + 2 * self.alpha)
return (prob_of_token_spam, prob_of_token_ham)
}
}
```

And here’s the Rust code:

```
impl NaiveBayesClassifier {
/*...*/
pub fn predict(&self, text: &str) -> f64 {
let lower_case_text = text.to_lowercase();
let message_tokens = tokenize(&lower_case_text);
let (prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens);
return prob_if_spam / (prob_if_spam + prob_if_ham);
}
fn probabilities_of_message(&self, message_tokens: HashSet<&str>) -> (f64, f64) {
let mut log_prob_if_spam = 0.;
let mut log_prob_if_ham = 0.;
for token in self.tokens.iter() {
let (prob_if_spam, prob_if_ham) = self.probabilites_of_token(&token);
if message_tokens.contains(token.as_str()) {
log_prob_if_spam += prob_if_spam.ln();
log_prob_if_ham += prob_if_ham.ln();
} else {
log_prob_if_spam += (1. - prob_if_spam).ln();
log_prob_if_ham += (1. - prob_if_ham).ln();
}
}
let prob_if_spam = log_prob_if_spam.exp();
let prob_if_ham = log_prob_if_ham.exp();
return (prob_if_spam, prob_if_ham);
}
fn probabilites_of_token(&self, token: &str) -> (f64, f64) {
let prob_of_token_spam = (self.token_spam_counts[token] as f64 + self.alpha)
/ (self.spam_messages_count as f64 + 2. * self.alpha);
let prob_of_token_ham = (self.token_ham_counts[token] as f64 + self.alpha)
/ (self.ham_messages_count as f64 + 2. * self.alpha);
return (prob_of_token_spam, prob_of_token_ham);
}
}
```

## How to Test Our Classifier

Let’s give our model a test. The test below goes through Naive Bayes manually, then checks that our model gives the same result.

You may find it worth while to go through the test’s logic, or you may just want to paste the code to the bottom of your lib.rs file to check that your code works.

```
// ...lib.rs
pub fn new_classifier(alpha: f64) -> NaiveBayesClassifier {
return NaiveBayesClassifier {
alpha,
tokens: HashSet::new(),
token_ham_counts: HashMap::new(),
token_spam_counts: HashMap::new(),
spam_messages_count: 0,
ham_messages_count: 0,
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn naive_bayes() {
let train_messages = [
Message {
text: "Free Bitcoin viagra XXX christmas deals 😻😻😻",
is_spam: true,
},
Message {
text: "My dear Granddaughter, please explain Bitcoin over Christmas dinner",
is_spam: false,
},
Message {
text: "Here in my garage...",
is_spam: true,
},
];
let alpha = 1.;
let num_spam_messages = 2.;
let num_ham_messages = 1.;
let mut model = new_classifier(alpha);
model.train(&train_messages);
let mut expected_tokens: HashSet<String> = HashSet::new();
for message in train_messages.iter() {
for token in tokenize(&message.text.to_lowercase()) {
expected_tokens.insert(token.to_string());
}
}
let input_text = "Bitcoin crypto academy Christmas deals";
let probs_if_spam = [
1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "Free" (not present)
(1. + alpha) / (num_spam_messages + 2. * alpha), // "Bitcoin" (present)
1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "viagra" (not present)
1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "XXX" (not present)
(1. + alpha) / (num_spam_messages + 2. * alpha), // "christmas" (present)
(1. + alpha) / (num_spam_messages + 2. * alpha), // "deals" (present)
1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "my" (not present)
1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "dear" (not present)
1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "granddaughter" (not present)
1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "please" (not present)
1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "explain" (not present)
1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "over" (not present)
1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "dinner" (not present)
1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "here" (not present)
1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "in" (not present)
1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "garage" (not present)
];
let probs_if_ham = [
1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "Free" (not present)
(1. + alpha) / (num_ham_messages + 2. * alpha), // "Bitcoin" (present)
1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "viagra" (not present)
1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "XXX" (not present)
(1. + alpha) / (num_ham_messages + 2. * alpha), // "christmas" (present)
(0. + alpha) / (num_ham_messages + 2. * alpha), // "deals" (present)
1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "my" (not present)
1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "dear" (not present)
1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "granddaughter" (not present)
1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "please" (not present)
1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "explain" (not present)
1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "over" (not present)
1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "dinner" (not present)
1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "here" (not present)
1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "in" (not present)
1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "garage" (not present)
];
let p_if_spam_log: f64 = probs_if_spam.iter().map(|p| p.ln()).sum();
let p_if_spam = p_if_spam_log.exp();
let p_if_ham_log: f64 = probs_if_ham.iter().map(|p| p.ln()).sum();
let p_if_ham = p_if_ham_log.exp();
// P(message | spam) / (P(messge | spam) + P(message | ham)) rounds to 0.97
assert!((model.predict(input_text) - p_if_spam / (p_if_spam + p_if_ham)).abs() < 0.000001);
}
}
```

Now run `cargo test`

. If that passes for you, well done, you’ve implemented a Naive Bayes classifier in Rust!

Thank you for coding with me, friends. Feel free to reach out if you have any questions or suggestions.

### References

- Grus, J. (2019).
*Data Science from Scratch: First Principles with Python, 2nd edition.*O’Reilly Media. - Downey, A. (2021).
*Think Bayes: Bayesian Statistics in Python, 2nd edition.*O’Reilly Media. - Murphy, K. (2012).
*Machine Learning: A Probabilistic Perspective.*MIT Press. - Dhinakaran, V. (2017).
*Rust Cookbook.*Packt. - Ng, A. (2018).
*Stanford CS229: Lecture 5 - GDA & Naive Bayes.* - Burden, R. Faires, J. Burden, A. (2015).
*Numerical Analysis, 10th edition.*Brooks Cole. *Underflow.*Technopedia.