Josh Walters

I'm a software engineer working with big data.

Blog Github Email Linked-In

Simple Linear Regression in Haskell

April 10, 2014

As part of my journey of learning Haskell, I figured I would implement some statistics / machine learning algorithms. As an added benefit, I can write some posts about the code and describe the algorithms as I go.

First up, the easiest regression method our there (drum roll)… Simple Linear Regression!

First, I will show my full implementation, then I will break down the code and walk you through it. Maybe read the Wikipedia page on simple linear models first, so that you know what it is that we are trying to do.

import Data.List

data RegressionLine = Line {m :: Double, b :: Double}
                      deriving (Show, Eq)

mean :: (Fractional a) => [a] -> a
mean xs = sum(xs) / fromIntegral (length xs)

meanOfPoints :: (Fractional a) => [(a, a)] -> (a, a)
meanOfPoints x = let (a, b) = unzip x
                 in (mean a, mean b)

correlation :: (Floating a) => [(a, a)] -> a
correlation xs = xy / sqrt (xx * yy)
                 where xy = sum $ map (\x -> fst x * snd x) xs
                       xx = sum $ map (\x -> fst x ^ 2) xs
                       yy = sum $ map (\x -> snd x ^ 2) xs

stdDev :: Floating a => [a] -> a
stdDev xs = sqrt $ sum (xMinusMean xs) / lengthList
            where mu = mean xs
                  xMinusMean = map (\x -> (x - mu) ^ 2)
                  lengthList = fromIntegral (length xs)

stdDevOfPoints :: (Floating a) => [(a, a)] -> (a, a)
stdDevOfPoints x = let (a, b) = unzip x
                   in (stdDev a, stdDev b)

simpleLinear :: [(Double, Double)] -> RegressionLine
simpleLinear xs = Line m b
                  where r = correlation xs
                        (sx, sy) = stdDevOfPoints xs
                        (mx, my) = meanOfPoints xs
                        m = r * sy / sx
                        b = my - m * mx

predict :: Double -> RegressionLine -> Double
predict x (Line m b) = b + m * x

rmse xs line@(Line m b) = let diff = [y - (predict y line) |
                                      x <- xs,
                                      let y = snd x]
                              diffSquared = map (\x -> x^2) diff
                              sumDiffSquared = sum(diffSquared)
                          in sqrt $ sumDiffSquared / fromIntegral (length xs)

Mean calculates the mean of a list, very simple.

Mean of points applies mean to a list of points. It gets the mean for all the x’s and all the y’s.

Correlation finds how related the x and y variables are. A correlation of 1 means that x and y are equal. A correlation of 0 means that x and y are not related. If we have a high correlation, we should be able to get a good model with simple linear regression.

The standard deviation function calculates the standard deviation of a list. The formula is very simple.

Standard deviation of points just calculates standard deviation for both x and y.

Finally, the simple linear function calculates the simple linear regression line for the given data. It should be very easy to understand what this function does, as it is just composed of the previous functions.

The predict function will predict a y given an x. Very simple, uses the geometric formula for a line.

The RMSE (root mean squared error) function is a way to measure how accurate our predictions are. The lower the better. List comprehensions are awesome.

Now, lets do something cool with this!

I am a huge fan of the Dresden Files books, and I love that the author, Jim Butcher, pumps them out like clockwork. But, I am a very impatient fan, and I want the books now!

So, how many days does it take from book release to book release? So far, 15 books have been published, so we have 14 day count deltas. Is there a trend? Are books published at a constant rate? Is it taking longer for books to come out? Lets see.

import Graphics.EasyPlot

-- x is book number (starting at book #2)
-- y is number of days since last book
daysBetweenDresdenBooks = [(1.0,275.0),  (2.0,243.0),  (3.0,367.0),
                           (4.0,336.0),  (5.0,363.0),  (6.0,274.0),
                           (7.0,364.0),  (8.0,363.0),  (9.0,364.0),
                           (10.0,371.0), (11.0,364.0), (12.0,476.0),
                           (13.0,490.0), (14.0,546.0)]

plotPred = let regressionLine = simpleLinear daysBetweenDresdenBooks
           in plot X11 $ [Data2D [Title "Dresden Files Days Between Books",
                                  Color Red]
                          Function2D [Title "Simple Linear Regression Line",
                                      Color Blue]
                                     [Range 0 23]
                                     (\x -> predict x regressionLine)]

I am using EasyPlot, a wrapper around GNUPlot. It is a very simple way to get a plot up and running in Haskell.