Fig 7.1, An Introduction To Statistical Learning
I am currently studying a book named Introduction to Statistical Learning with applications in R, and also converting the solutions to python language.
I am not able to get how to get the confidence intervals and plot them as shown in the above image(dashed lines).
I have plotted the line. Here's my code for that -
(I am using polynomial regression with predictiors - 'age' and response - 'wage',degree is 4)
poly = PolynomialFeatures(4)
X = poly.fit_transform(data['age'].to_frame())
y = data['wage']
# X.shape
model = sm.OLS(y,X).fit()
print(model.summary())
# So, what we want here is not only the final line, but also the standart error related to the line
# TO find that we need to calcualte the predictions for some values of age
test_ages = np.linspace(data['age'].min(),data['age'].max(),100)
X_test = poly.transform(test_ages.reshape(-1,1))
pred = model.predict(X_test)
plt.figure(figsize = (12,8))
plt.scatter(data['age'],data['wage'],facecolors='none', edgecolors='darkgray')
plt.plot(test_ages,pred)
Here data is WAGE data which is available in R. This is the resulting graph i get -