0
votes

i am using plotly for linear regression model and I trying to embed results of OLS trendline as a chart legend. I can see the linear regerssion details when i hover over the linear regression line but i would like to put those results to be visible always as a chart legend.

is there a way this can be done? Here is my code:

# Importing plotly dependency
import plotly.express as px

#Ploting the graph
fig = px.scatter(df_linear, x="Days_ct", y="Conf_ct", trendline="ols")
fig.update_traces(name = "OLS trendline")


fig.update_layout(template="ggplot2",title_text = '<b>Linear Regression Model</b>',
                  font=dict(family="Arial, Balto, Courier New, Droid Sans",color='black'), showlegend=True)
fig.update_layout(
    legend=dict(
        x=0.01,
        y=.98,
        traceorder="normal",
        font=dict(
            family="sans-serif",
            size=12,
            color="Black"
        ),
        bgcolor="LightSteelBlue",
        bordercolor="dimgray",
        borderwidth=2
    ))
fig.show()

enter image description here

1

1 Answers

2
votes

With your setup and some synthetic data you can retrieve px OLS estimates using:

model = px.get_trendline_results(fig)
alpha = model.iloc[0]["px_fit_results"].params[0]
beta = model.iloc[0]["px_fit_results"].params[1]

And then include those findings in your legend and make the necessary layout adjustments direclty using:

fig.data[0].name = 'observations'
fig.data[0].showlegend = True
fig.data[1].name = fig.data[1].name  + ' y = ' + str(round(alpha, 2)) + ' + ' + str(round(beta, 2)) + 'x'
fig.data[1].showlegend = True

Plot 1:

enter image description here

Edit: R-squared

Following up on your comment, I'll show you how to include other values of interest from the regression analysis. However, it does not make as much sense anymore to keep including estimates in the legend. Nevertheless, that's exactly what the following addition does:

rsq = model.iloc[0]["px_fit_results"].rsquared
fig.add_trace(go.Scatter(x=[100], y=[100],
                         name = "R-squared" + ' = ' + str(round(rsq, 2)),
                         showlegend=True,
                         mode='markers',
                         marker=dict(color='rgba(0,0,0,0)')
                         ))

Plot 2: R-squared included in legend

enter image description here

Complete code with synthetic data:

import plotly.graph_objects as go
import plotly.express as px
import statsmodels.api as sm
import pandas as pd
import numpy as np
import datetime

# data
np.random.seed(123)
numdays=20

X = (np.random.randint(low=-20, high=20, size=numdays).cumsum()+100).tolist()
Y = (np.random.randint(low=-20, high=20, size=numdays).cumsum()+100).tolist()

df_linear = pd.DataFrame({'Days_ct': X, 'Conf_ct':Y})

#Ploting the graph
fig = px.scatter(df_linear, x="Days_ct", y="Conf_ct", trendline="ols")
fig.update_traces(name = "OLS trendline")



fig.update_layout(template="ggplot2",title_text = '<b>Linear Regression Model</b>',
                  font=dict(family="Arial, Balto, Courier New, Droid Sans",color='black'), showlegend=True)
fig.update_layout(
    legend=dict(
        x=0.01,
        y=.98,
        traceorder="normal",
        font=dict(
            family="sans-serif",
            size=12,
            color="Black"
        ),
        bgcolor="LightSteelBlue",
        bordercolor="dimgray",
        borderwidth=2
    ))

# retrieve model estimates
model = px.get_trendline_results(fig)
alpha = model.iloc[0]["px_fit_results"].params[0]
beta = model.iloc[0]["px_fit_results"].params[1]

# restyle figure
fig.data[0].name = 'observations'
fig.data[0].showlegend = True
fig.data[1].name = fig.data[1].name  + ' y = ' + str(round(alpha, 2)) + ' + ' + str(round(beta, 2)) + 'x'
fig.data[1].showlegend = True

# addition for r-squared
rsq = model.iloc[0]["px_fit_results"].rsquared
fig.add_trace(go.Scatter(x=[100], y=[100],
                         name = "R-squared" + ' = ' + str(round(rsq, 2)),
                         showlegend=True,
                         mode='markers',
                         marker=dict(color='rgba(0,0,0,0)')
                         ))

fig.show()