3
votes

I'm following along Tibshirani's ISL text. I'm trying to plot the results of an SVM in ggplot2. I can get the points and the support vectors, but I can't figure out how to get the margins and hyperplane drawn for the 2D case. I Googled and checked the e1071 readme. A general, dynamic solution (applicable to a variety of SVM kernels,costs,etc.) would be great. Here is my MWE:

set.seed(1)
N=20
x=matrix(rnorm(n=N*2), ncol=2)
y=c(rep(-1,N/2), rep(1,N/2))
x[y==1,] = x[y==1,] + 1;x[y==1,]
dat = data.frame(x=x, y=as.factor(y))
library(e1071)
library(ggplot2)
svmfit=svm(y~., data=dat, kernel="linear", cost=10, scale=FALSE)

df = dat; df
df = cbind(df, sv=rep(0,nrow(df)))
df[svmfit$index,]$sv = 1

ggplot(data=df,aes(x=x.1,y=x.2,group=y,color=y)) +     
    geom_point(aes(shape=factor(sv)))

Something like this: enter image description here (From Python's scikit-learn)

1
There's a base graphics plotting method already defined for svm in e1071. Did you look at the results of plot(svmfit, dat)? Are you trying to just replicate that in ggplot? - arvi1000
Yes, I am trying to replicate that in ggplot, plus add lines for the hyperplane and dashed lines for the margins for the 2D case (K=2). - user2205916
You might want to accept @user21359's answer as it works like a charm - OganM

1 Answers

5
votes

So you don't want to plot the support vectors right? Here's something very basic that works for your example, based on the plot.svm source code.

https://github.com/cran/e1071/blob/master/R/svm.R

You can construct something much richer by taking a look at that source code.

library(e1071)
library(ggplot2)
set.seed(1)
N=20
x=matrix(rnorm(n=N*2), ncol=2)
y=c(rep(-1,N/2), rep(1,N/2))
x[y==1,] = x[y==1,] + 1;x[y==1,]
dat = data.frame(x=x, y=as.factor(y))
svmfit=svm(y~., data=dat, kernel="linear", cost=10, scale=FALSE)

grid <- expand.grid(seq(min(dat[, 1]), max(dat[, 1]),length.out=100),                                                                                                         
                            seq(min(dat[, 2]), max(dat[, 2]),length.out=100)) 
names(grid) <- names(dat)[1:2]
preds <- predict(svmfit, grid)
df <- data.frame(grid, preds)
ggplot(df, aes(x = x.2, y = x.1, fill = preds)) + geom_tile()

Should output this:

enter image description here

Compare this to the plot.svm output:

plot(svmfit, dat)

enter image description here

EDIT:

If you want to reproduce the points as well, I've altered the above code slightly:

cols <- c('1' = 'red', '-1' = 'black')
tiles <- c('1' = 'magenta', '-1' = 'cyan')
shapes <- c('support' = 4, 'notsupport' = 1)
dat$support <- 'notsupport'
dat[svmfit$index, 'support'] <- 'support'

ggplot(df, aes(x = x.2, y = x.1)) + geom_tile(aes(fill = preds)) + 
  scale_fill_manual(values = tiles) +
  geom_point(data = dat, aes(color = y, shape = support), size = 2) +
  scale_color_manual(values = cols) +
  scale_shape_manual(values = shapes) +
  ggtitle('SVM classification plot')

enter image description here