A Tufts University Data Lab Tutorial
Written by Uku-Kaspar Uustalu
Contact: uku-kaspar.uustalu@tufts.edu
Last updated: 2023-03-02
We will be using the following Python data analysis and visualization libraries throughout this tutorial:
pd
.matplotlib.pyplot
module, which is usually imported under the alias plt
.sns
.hvplot.pandas
module must be imported to allow for seamless integration with Pandas.plotly.express
module is the easiest to use as it allows for the creation of whole plots using a single command. The module is usually imported under the alias px
.import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import hvplot.pandas
import plotly.express as px
For the first part of this tutorial, we will be using the following datasets from the data
directory to investigate the relationship between health and wealth:
gdp.csv
-- World Bank gross domestic product (GDP) estimates (in USD) for world countries and regions from 1960 until 2021life-expectancy.csv
-- World Bank life expectancy estimates for world countries and regions from 1960 until 2020m49.csv
-- United Nations M49 Standard Country or Area Codes for Statistical Usepopulation.csv
-- World Bank population estimates for world countries and regions from 1960 until 2021All the datasets are in IEFT RFC 4180 CSV (comma-separated values) format and the first four rows of the World Bank data files contain metadata with the actual data table starting on row five.
Let us start by reading in the population data. Pandas can easily read CSV datasets via the pandas.read_csv()
function. The function reads the contents of the file into a pandas.DataFrame
data structure and supports various additional arguments. For example, we can utilize the skiprows
argument to tell Pandas to skip the first four rows of the dataset as the data table does not start until fow five.
population = pd.read_csv('data/population.csv', skiprows=4)
Now the World Bank population dataset is stored in a DataFrame called population
. Calling the DataFrame by its name will display the first and last five rows of the table by default.
population
Country Name | Country Code | Indicator Name | Indicator Code | 1960 | 1961 | 1962 | 1963 | 1964 | 1965 | ... | 2012 | 2013 | 2014 | 2015 | 2016 | 2017 | 2018 | 2019 | 2020 | 2021 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Aruba | ABW | Population, total | SP.POP.TOTL | 54608.0 | 55811.0 | 56682.0 | 57475.0 | 58178.0 | 58782.0 | ... | 102112.0 | 102880.0 | 103594.0 | 104257.0 | 104874.0 | 105439.0 | 105962.0 | 106442.0 | 106585.0 | 106537.0 |
1 | Africa Eastern and Southern | AFE | Population, total | SP.POP.TOTL | 130692579.0 | 134169237.0 | 137835590.0 | 141630546.0 | 145605995.0 | 149742351.0 | ... | 552530654.0 | 567891875.0 | 583650827.0 | 600008150.0 | 616377331.0 | 632746296.0 | 649756874.0 | 667242712.0 | 685112705.0 | 702976832.0 |
2 | Afghanistan | AFG | Population, total | SP.POP.TOTL | 8622466.0 | 8790140.0 | 8969047.0 | 9157465.0 | 9355514.0 | 9565147.0 | ... | 30466479.0 | 31541209.0 | 32716210.0 | 33753499.0 | 34636207.0 | 35643418.0 | 36686784.0 | 37769499.0 | 38972230.0 | 40099462.0 |
3 | Africa Western and Central | AFW | Population, total | SP.POP.TOTL | 97256290.0 | 99314028.0 | 101445032.0 | 103667517.0 | 105959979.0 | 108336203.0 | ... | 376797999.0 | 387204553.0 | 397855507.0 | 408690375.0 | 419778384.0 | 431138704.0 | 442646825.0 | 454306063.0 | 466189102.0 | 478185907.0 |
4 | Angola | AGO | Population, total | SP.POP.TOTL | 5357195.0 | 5441333.0 | 5521400.0 | 5599827.0 | 5673199.0 | 5736582.0 | ... | 25188292.0 | 26147002.0 | 27128337.0 | 28127721.0 | 29154746.0 | 30208628.0 | 31273533.0 | 32353588.0 | 33428486.0 | 34503774.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
261 | Kosovo | XKX | Population, total | SP.POP.TOTL | 947000.0 | 966000.0 | 994000.0 | 1022000.0 | 1050000.0 | 1078000.0 | ... | 1807106.0 | 1818117.0 | 1812771.0 | 1788196.0 | 1777557.0 | 1791003.0 | 1797085.0 | 1788878.0 | 1790133.0 | 1786038.0 |
262 | Yemen, Rep. | YEM | Population, total | SP.POP.TOTL | 5542459.0 | 5646668.0 | 5753386.0 | 5860197.0 | 5973803.0 | 6097298.0 | ... | 26223391.0 | 26984002.0 | 27753304.0 | 28516545.0 | 29274002.0 | 30034389.0 | 30790513.0 | 31546691.0 | 32284046.0 | 32981641.0 |
263 | South Africa | ZAF | Population, total | SP.POP.TOTL | 16520441.0 | 16989464.0 | 17503133.0 | 18042215.0 | 18603097.0 | 19187194.0 | ... | 53145033.0 | 53873616.0 | 54729551.0 | 55876504.0 | 56422274.0 | 56641209.0 | 57339635.0 | 58087055.0 | 58801927.0 | 59392255.0 |
264 | Zambia | ZMB | Population, total | SP.POP.TOTL | 3119430.0 | 3219451.0 | 3323427.0 | 3431381.0 | 3542764.0 | 3658024.0 | ... | 14744658.0 | 15234976.0 | 15737793.0 | 16248230.0 | 16767761.0 | 17298054.0 | 17835893.0 | 18380477.0 | 18927715.0 | 19473125.0 |
265 | Zimbabwe | ZWE | Population, total | SP.POP.TOTL | 3806310.0 | 3925952.0 | 4049778.0 | 4177931.0 | 4310332.0 | 4447149.0 | ... | 13265331.0 | 13555422.0 | 13855753.0 | 14154937.0 | 14452704.0 | 14751101.0 | 15052184.0 | 15354608.0 | 15669666.0 | 15993524.0 |
266 rows × 66 columns
We see that the DataFrame appears to have the following columns:
Country Name
-- English name of the countryCountry Code
-- ISO 3166-1 alpha-3 country codeIndicator Name
-- name of the indicator represented by the dataIndicator Code
-- World Bank code for the indicator1960
... 2021
-- population estimates by yearWe also see that the DataFrame has 266 rows and 66 columns. We can double-check this by looking at the value of the pandas.DataFrame.shape
attribute.
population.shape
(266, 66)
The pandas.DataFrame.size
attribute will give us the total number of values in the table (number of columns times number of rows).
population.size
17556
pandas.DataFrame.columns
can be used to get a list of all the column names and pandas.DataFrame.dtypes
will display the datatype of each column.
population.columns
Index(['Country Name', 'Country Code', 'Indicator Name', 'Indicator Code', '1960', '1961', '1962', '1963', '1964', '1965', '1966', '1967', '1968', '1969', '1970', '1971', '1972', '1973', '1974', '1975', '1976', '1977', '1978', '1979', '1980', '1981', '1982', '1983', '1984', '1985', '1986', '1987', '1988', '1989', '1990', '1991', '1992', '1993', '1994', '1995', '1996', '1997', '1998', '1999', '2000', '2001', '2002', '2003', '2004', '2005', '2006', '2007', '2008', '2009', '2010', '2011', '2012', '2013', '2014', '2015', '2016', '2017', '2018', '2019', '2020', '2021'], dtype='object')
population.dtypes
Country Name object Country Code object Indicator Name object Indicator Code object 1960 float64 ... 2017 float64 2018 float64 2019 float64 2020 float64 2021 float64 Length: 66, dtype: object
Note how the first four columns all have the object
datatype. This could mean that the column contains textual data (string), has a mix of different datatypes (both textual and numeric for example), or contains a more complex data structure (like a list or tuple). The population columns are all float64
denoting floating-point numbers. It might feel odd to store population values as floating-point numbers as population counts are always whole integers. However, in Pandas all numeric data is stored as floating-point numbers by default. This is due to the fact that integer columns in Pandas do not support missing data values by default. The default missing data value in Pandas is the numpy.nan
from NumPy, which is a float64
datatype.
We know that the population
DataFrame stores population values, so the Indicator Name
and Indicator Code
columns are redundant. We can drop them from the table using the pandas.DataFrame.drop()
method.
population.drop(columns=['Indicator Name', 'Indicator Code'], inplace=True)
Note how we specified two arguments when calling the pandas.DataFrame.drop()
method. First we specified a list of columns to drop using the columns
argument. The pandas.DataFrame.drop()
method also supports dropping rows, so that is why the columns
argument is needed. Then we also specified inplace
to be True
. This ensures that the original population
DataFrame gets modified. Otherwise the method would just return a new DataFrame and keep the population
DataFrame unchanged.
We can validate that the desired columns have been removed by taking a quick peek at the DataFrame via the pandas.DataFrame.head()
method. It displays the fist five rows of the DataFrame by default but you can also pass the number of rows desired as an argument.
population.head()
Country Name | Country Code | 1960 | 1961 | 1962 | 1963 | 1964 | 1965 | 1966 | 1967 | ... | 2012 | 2013 | 2014 | 2015 | 2016 | 2017 | 2018 | 2019 | 2020 | 2021 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Aruba | ABW | 54608.0 | 55811.0 | 56682.0 | 57475.0 | 58178.0 | 58782.0 | 59291.0 | 59522.0 | ... | 102112.0 | 102880.0 | 103594.0 | 104257.0 | 104874.0 | 105439.0 | 105962.0 | 106442.0 | 106585.0 | 106537.0 |
1 | Africa Eastern and Southern | AFE | 130692579.0 | 134169237.0 | 137835590.0 | 141630546.0 | 145605995.0 | 149742351.0 | 153955516.0 | 158313235.0 | ... | 552530654.0 | 567891875.0 | 583650827.0 | 600008150.0 | 616377331.0 | 632746296.0 | 649756874.0 | 667242712.0 | 685112705.0 | 702976832.0 |
2 | Afghanistan | AFG | 8622466.0 | 8790140.0 | 8969047.0 | 9157465.0 | 9355514.0 | 9565147.0 | 9783147.0 | 10010030.0 | ... | 30466479.0 | 31541209.0 | 32716210.0 | 33753499.0 | 34636207.0 | 35643418.0 | 36686784.0 | 37769499.0 | 38972230.0 | 40099462.0 |
3 | Africa Western and Central | AFW | 97256290.0 | 99314028.0 | 101445032.0 | 103667517.0 | 105959979.0 | 108336203.0 | 110798486.0 | 113319950.0 | ... | 376797999.0 | 387204553.0 | 397855507.0 | 408690375.0 | 419778384.0 | 431138704.0 | 442646825.0 | 454306063.0 | 466189102.0 | 478185907.0 |
4 | Angola | AGO | 5357195.0 | 5441333.0 | 5521400.0 | 5599827.0 | 5673199.0 | 5736582.0 | 5787044.0 | 5827503.0 | ... | 25188292.0 | 26147002.0 | 27128337.0 | 28127721.0 | 29154746.0 | 30208628.0 | 31273533.0 | 32353588.0 | 33428486.0 | 34503774.0 |
5 rows × 64 columns
Knowing that the World Bank GDP dataset follows the exact same format as the World Bank population dataset, we can read it in and drop the Indicator Name
and Indicator Code
columns all in one go by chaining together the pandas.read_csv()
function and the pandas.DataFrame.drop()
method. If we want to include a line break somewhere in the chain, we need to wrap the whole thing in parentheses ()
.
gdp = (pd.read_csv('data/gdp.csv', skiprows=4)
.drop(columns=['Indicator Name', 'Indicator Code']))
Note how here we did not specify inplace=True
when dropping the columns. That is because we want the pandas.DataFrame.drop()
method to take the DataFrame generated by pandas.read_csv()
and then output a new DataFrame that we can save into the gdp
variable. We can take a look at our newly created DataFrame by using the pandas.DataFrame.head()
method again.
gdp.head()
Country Name | Country Code | 1960 | 1961 | 1962 | 1963 | 1964 | 1965 | 1966 | 1967 | ... | 2012 | 2013 | 2014 | 2015 | 2016 | 2017 | 2018 | 2019 | 2020 | 2021 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Aruba | ABW | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 2.615084e+09 | 2.727933e+09 | 2.791061e+09 | 2.963128e+09 | 2.983799e+09 | 3.092179e+09 | 3.202235e+09 | 3.368970e+09 | 2.610039e+09 | 3.126019e+09 |
1 | Africa Eastern and Southern | AFE | 2.129081e+10 | 2.180870e+10 | 2.370727e+10 | 2.821034e+10 | 2.611906e+10 | 2.968249e+10 | 3.223946e+10 | 3.351491e+10 | ... | 9.725734e+11 | 9.834729e+11 | 1.003768e+12 | 9.245228e+11 | 8.827213e+11 | 1.021119e+12 | 1.007240e+12 | 1.001017e+12 | 9.274845e+11 | 1.080712e+12 |
2 | Afghanistan | AFG | 5.377778e+08 | 5.488889e+08 | 5.466667e+08 | 7.511112e+08 | 8.000000e+08 | 1.006667e+09 | 1.400000e+09 | 1.673333e+09 | ... | 2.020357e+10 | 2.056449e+10 | 2.055058e+10 | 1.999816e+10 | 1.801956e+10 | 1.889635e+10 | 1.841885e+10 | 1.890449e+10 | 2.014344e+10 | 1.478686e+10 |
3 | Africa Western and Central | AFW | 1.040414e+10 | 1.112789e+10 | 1.194319e+10 | 1.267633e+10 | 1.383837e+10 | 1.486223e+10 | 1.583259e+10 | 1.442604e+10 | ... | 7.360399e+11 | 8.322169e+11 | 8.924979e+11 | 7.669580e+11 | 6.905454e+11 | 6.837480e+11 | 7.663597e+11 | 7.947191e+11 | 7.847997e+11 | 8.401873e+11 |
4 | Angola | AGO | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 1.249982e+11 | 1.334016e+11 | 1.372444e+11 | 8.721930e+10 | 4.984049e+10 | 6.897277e+10 | 7.779294e+10 | 6.930911e+10 | 5.361907e+10 | 6.740429e+10 |
5 rows × 64 columns
GDP on its own is not a good indicator of the wealth of a country as countries with more people tend to have higher GDP. But if we were to normalize GDP by population, then the resulting GDP per capita values can be compared across countries and used as a proxy for wealth. To do so, we must be able to match up the GDP and population values for each unique combination of country and year.
The GDP and population tables currently are in wide format -- each row represents a unique country and each column represents a unique year with the cell values representing unique population estimates. While this wide format has many advantages and is commonly used in geospatial applications, it does complicate joining various datasets. One option would be to treat both tables as matrices and calculate GDP per capita by dividing the GDP matrix with the population matrix. However, both tables need to have the exact same layout with the same number of countries and years in the same exact order for this to work and the result to be reliable. Ensuring this is not a trivial task, so this method would involve a lot of work to produce reliable results.
Alternatively the two tables could be joined by country. Then we will have an extra-wide table with two sets of year columns -- one set of year columns for population and another set of year columns for GDP. Then we would need to create another new column for each year by dividing the corresponding GDP column with the corresponding population column, resulting in another new set of year columns. As you can see, this approach would quickly lead to a very messy and difficult to manage dataset and would also involve a lot of work, making it far from preferred.
The easiest option for calculating GDP per capita would involve converting both datasets into a long format, where each row represents a single unique observation (estimation). Instead of having countries in rows and years in columns, each row would instead represent a unique country and year combination. This would allow us to easily combine datasets on both country and year, ensuring that the GDP and population values for each country-year combination get matched.
We can use the pandas.DataFrame.melt()
method to convert wide DataFames to long format. We need to specify three arguments when using this method:
id_vars
– name(s) of the column(s) that define a unique observation in the original wide datasetvar_name
– name of the the column in the new long dataset that stores the column names of the original wide datasetvalue_name
– name of the column in the new long dataset that stores the values of the original wide datasetEach observation in the original wide dataset represents a unique country defined either by the country name or country code. Let us include both of these as id_vars
to carry both columns over to the long dataset. The columns of the wide dataset represent years, so that is the name we will pass on to the var_name
argument. The values of the wide dataset represent population estimates, so that will be the name passed on to the value_name
argument.
The reverse command for pandas.DataFrame.melt()
is pandas.DataFrame.pivot()
, which can convert a long format table to wide format.
population_long = population.melt(id_vars=['Country Name', 'Country Code'],
var_name='year',
value_name='population')
population_long
Country Name | Country Code | year | population | |
---|---|---|---|---|
0 | Aruba | ABW | 1960 | 54608.0 |
1 | Africa Eastern and Southern | AFE | 1960 | 130692579.0 |
2 | Afghanistan | AFG | 1960 | 8622466.0 |
3 | Africa Western and Central | AFW | 1960 | 97256290.0 |
4 | Angola | AGO | 1960 | 5357195.0 |
... | ... | ... | ... | ... |
16487 | Kosovo | XKX | 2021 | 1786038.0 |
16488 | Yemen, Rep. | YEM | 2021 | 32981641.0 |
16489 | South Africa | ZAF | 2021 | 59392255.0 |
16490 | Zambia | ZMB | 2021 | 19473125.0 |
16491 | Zimbabwe | ZWE | 2021 | 15993524.0 |
16492 rows × 4 columns
Now we have a new long population DataFrame called population_long
, where each row represents a unique country and year combination. Let us use pandas.DataFrame.dtypes
to confirm the data types of this new table.
population_long.dtypes
Country Name object Country Code object year object population float64 dtype: object
Note how the year
column is of type object
, meaning that the years are currently stored as strings. As the years were perviously column names, this makes sense. However, as years are actually numbers, they should also be stored as such to allow for easy comparisons and mathematical operations.
To convert the year values to integers, we must first extract the year
column as a pandas.Series
object. This can be done by either using square brackets df["column"]
or via dot-notation df.column
. The latter requires the column name to consist of only letters, numbers, and underscores (and not start with a number), so it is only useful if the column names are neatly formatted. Using square brackets to extract columns is more robust and as the column name is passed as a string, it can contain spaces and other special characters.
Square brackets can be used to also create a new column or overwrite an existing column. Dot-notation should only be used to read columns. Attempting to write columns using dot-notation could have unexpected consequences.
Knowing this, let us extract the year
column as a Series object using dot-notation df.column
and then call pandas.Series.astype()
on the extracted values to convert them to integers. Then we can use square bracket notation df["column"]
to replace the values of the year
column with their integer equivalents.
population_long['year'] = population_long.year.astype(int)
population_long.head()
Country Name | Country Code | year | population | |
---|---|---|---|---|
0 | Aruba | ABW | 1960 | 54608.0 |
1 | Africa Eastern and Southern | AFE | 1960 | 130692579.0 |
2 | Afghanistan | AFG | 1960 | 8622466.0 |
3 | Africa Western and Central | AFW | 1960 | 97256290.0 |
4 | Angola | AGO | 1960 | 5357195.0 |
population_long.dtypes
Country Name object Country Code object year int64 population float64 dtype: object
Note how the values of the year
column seemingly did not change, but the datatype of the values is now int32
, which means that the values have been converted to numeric integers.
Now let us convert the GDP dataset to long format as well. We can chain the pandas.DataFrame.melt()
method together with the pandas.DataFrame.astype()
method to convert the DataFrame from wide to long format and change the datatype of the year
column to integer all in one go. The pandas.DataFrame.astype()
method is very similar to the pandas.Series.astype()
method, but instead of taking a single datatype as an argument, it takes a dictionary that maps column names to datatypes.
gdp_long = (gdp.melt(id_vars=['Country Name', 'Country Code'],
var_name='year',
value_name='gdp')
.astype({'year': int}))
gdp_long
Country Name | Country Code | year | gdp | |
---|---|---|---|---|
0 | Aruba | ABW | 1960 | NaN |
1 | Africa Eastern and Southern | AFE | 1960 | 2.129081e+10 |
2 | Afghanistan | AFG | 1960 | 5.377778e+08 |
3 | Africa Western and Central | AFW | 1960 | 1.040414e+10 |
4 | Angola | AGO | 1960 | NaN |
... | ... | ... | ... | ... |
16487 | Kosovo | XKX | 2021 | 9.412034e+09 |
16488 | Yemen, Rep. | YEM | 2021 | NaN |
16489 | South Africa | ZAF | 2021 | 4.190150e+11 |
16490 | Zambia | ZMB | 2021 | 2.214763e+10 |
16491 | Zimbabwe | ZWE | 2021 | 2.837124e+10 |
16492 rows × 4 columns
gdp_long.dtypes
Country Name object Country Code object year int64 gdp float64 dtype: object
Finally we are ready to combine the population and GDP datasets. pandas.DataFrame.merge()
can be used to perform a join on on one or more columns. The method is called on the left DataFrame and takes the right DataFrame as its first argument (this is only important to know when performing a left or right join). Additional arguments are as follows:
on
– A single column name (string) or list of column names to join on. These column names should appear in both tables. If the column names differ between datasets, the separate left_on
and right_on
arguments should be used instead.how
– The type of join to perform. Here are the possible values:"left"
– use only keys from the left DataFrame (include all rows from left DataFrame)"right"
– use only keys from the right DataFrame (include all rows from right DataFrame)"outer"
– use the union of keys from both DataFrames (include all rows from both DataFrames)"inner"
– use the intersection of keys from both DataFrames (include only matching rows)"cross"
– creates the cartesian product from both DataFrames (similar to cross-tabulation)We would like to join on each unique country and year combination. As spellings of country names might differ between datasets, it is good practice to always use the ISO 3166-1 alpha-3 country code or some other analogous unique identifier to distinguish between countries. The country code for each country is determined by an international standard and should not differ between datasets, allowing us to reliably join the data. Hence we will specify on=["Country Code, "year"]
to perform the join on unique country-year combinations and how="inner"
to only keep year-country combinations that are present in both datasets. Since we do not want the Country Name
column repeated in the joined dataset, we should remove it from the GDP table using pandas.DataFrame.drop()
before performing the join. Otherwise the Country Name
column from the GDP dataset will also get joined, resulting in the joined table having two separate columns with country names.
data = population_long.merge(gdp_long.drop(columns='Country Name'),
on=['Country Code', 'year'],
how='inner')
data.head()
Country Name | Country Code | year | population | gdp | |
---|---|---|---|---|---|
0 | Aruba | ABW | 1960 | 54608.0 | NaN |
1 | Africa Eastern and Southern | AFE | 1960 | 130692579.0 | 2.129081e+10 |
2 | Afghanistan | AFG | 1960 | 8622466.0 | 5.377778e+08 |
3 | Africa Western and Central | AFW | 1960 | 97256290.0 | 1.040414e+10 |
4 | Angola | AGO | 1960 | 5357195.0 | NaN |
Now we have a table with a population and GDP value for each country and year combination. We can easily add a new column denoting GDP per capita to this table by dividing the GDP column with the population column.
data['gdp_per_capita'] = data.gdp / data.population
data.head()
Country Name | Country Code | year | population | gdp | gdp_per_capita | |
---|---|---|---|---|---|---|
0 | Aruba | ABW | 1960 | 54608.0 | NaN | NaN |
1 | Africa Eastern and Southern | AFE | 1960 | 130692579.0 | 2.129081e+10 | 162.907576 |
2 | Afghanistan | AFG | 1960 | 8622466.0 | 5.377778e+08 | 62.369375 |
3 | Africa Western and Central | AFW | 1960 | 97256290.0 | 1.040414e+10 | 106.976475 |
4 | Angola | AGO | 1960 | 5357195.0 | NaN | NaN |
Now we would also like to add life expectancy information to this joined dataset. Knowing that all World Bank data tables follow the same format, we can easily convert the workflow from before into a function that reads in a World Bank dataset, drops unneeded columns, converts it to long format, and ensures the year is in numeric format. That function would only need two inputs – the path of the CSV file and the name of the indicator represented by the data. (This name will be used as the colum name for the values column in the long format table.) Let us define this function and use it to read in the World Bank life expectancy dataset and convert it to long format.
def read_world_bank_data(file_name, value_name):
return (pd.read_csv(file_name, skiprows=4)
.drop(columns=['Indicator Name', 'Indicator Code'])
.melt(id_vars=['Country Name', 'Country Code'],
var_name='year',
value_name=value_name)
.astype({'year': int}))
life_exp = read_world_bank_data(file_name='data/life-expectancy.csv',
value_name='life_exp')
life_exp.head()
Country Name | Country Code | year | life_exp | |
---|---|---|---|---|
0 | Aruba | ABW | 1960 | 64.152000 |
1 | Africa Eastern and Southern | AFE | 1960 | 44.085552 |
2 | Afghanistan | AFG | 1960 | 32.535000 |
3 | Africa Western and Central | AFW | 1960 | 37.845152 |
4 | Angola | AGO | 1960 | 38.211000 |
Using the same workflow from before, we can join the long format life expectancy dataset to our table containing the GDP and population data.
data = data.merge(life_exp.drop(columns='Country Name'),
on=['Country Code', 'year'],
how='inner')
data.head()
Country Name | Country Code | year | population | gdp | gdp_per_capita | life_exp | |
---|---|---|---|---|---|---|---|
0 | Aruba | ABW | 1960 | 54608.0 | NaN | NaN | 64.152000 |
1 | Africa Eastern and Southern | AFE | 1960 | 130692579.0 | 2.129081e+10 | 162.907576 | 44.085552 |
2 | Afghanistan | AFG | 1960 | 8622466.0 | 5.377778e+08 | 62.369375 | 32.535000 |
3 | Africa Western and Central | AFW | 1960 | 97256290.0 | 1.040414e+10 | 106.976475 | 37.845152 |
4 | Angola | AGO | 1960 | 5357195.0 | NaN | NaN | 38.211000 |
Finally we would also like to know which United Nations regional geoscheme the country belongs to. Information on this is available in the United Nations M49 dataset. As this dataset is a standard CSV table, we can use pandas.read_csv()
without any additional arguments to read it into a DataFrame.
m49 = pd.read_csv('data/m49.csv')
m49.head()
Global Code | Global Name | Region Code | Region Name | Sub-region Code | Sub-region Name | Intermediate Region Code | Intermediate Region Name | Country or Area | M49 Code | ISO-alpha2 Code | ISO-alpha3 Code | Least Developed Countries (LDC) | Land Locked Developing Countries (LLDC) | Small Island Developing States (SIDS) | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | World | 2.0 | Africa | 15.0 | Northern Africa | NaN | NaN | Algeria | 12 | DZ | DZA | NaN | NaN | NaN |
1 | 1 | World | 2.0 | Africa | 15.0 | Northern Africa | NaN | NaN | Egypt | 818 | EG | EGY | NaN | NaN | NaN |
2 | 1 | World | 2.0 | Africa | 15.0 | Northern Africa | NaN | NaN | Libya | 434 | LY | LBY | NaN | NaN | NaN |
3 | 1 | World | 2.0 | Africa | 15.0 | Northern Africa | NaN | NaN | Morocco | 504 | MA | MAR | NaN | NaN | NaN |
4 | 1 | World | 2.0 | Africa | 15.0 | Northern Africa | NaN | NaN | Sudan | 729 | SD | SDN | x | NaN | NaN |
m49.columns
Index(['Global Code', 'Global Name', 'Region Code', 'Region Name', 'Sub-region Code', 'Sub-region Name', 'Intermediate Region Code', 'Intermediate Region Name', 'Country or Area', 'M49 Code', 'ISO-alpha2 Code', 'ISO-alpha3 Code', 'Least Developed Countries (LDC)', 'Land Locked Developing Countries (LLDC)', 'Small Island Developing States (SIDS)'], dtype='object')
Note how this dataset contains a lot of information on the various groups and codes assigned to each country. We are only interested in the name of the region the country belongs into and the ISO 3166-1 alpha-3 code assigned to the country. Using double square brackets [[ ]]
we can extract the desired columns as a new DataFrame. (In reality we are just passing a list of column names to the standard single square brackets indexer.)
regions = m49[['Region Name', 'ISO-alpha3 Code']]
regions.head()
Region Name | ISO-alpha3 Code | |
---|---|---|
0 | Africa | DZA |
1 | Africa | EGY |
2 | Africa | LBY |
3 | Africa | MAR |
4 | Africa | SDN |
Now we can use pandas.DataFrame.merge()
again to join the regions to the rest of our data. Since the names of the columns containing the country code information differ between the datasets, we must use the left_on
and right_on
arguments instead of the on
argument from before.
data = data.merge(regions,
left_on='Country Code',
right_on='ISO-alpha3 Code',
how='inner')
data.head()
Country Name | Country Code | year | population | gdp | gdp_per_capita | life_exp | Region Name | ISO-alpha3 Code | |
---|---|---|---|---|---|---|---|---|---|
0 | Aruba | ABW | 1960 | 54608.0 | NaN | NaN | 64.152 | Americas | ABW |
1 | Aruba | ABW | 1961 | 55811.0 | NaN | NaN | 64.537 | Americas | ABW |
2 | Aruba | ABW | 1962 | 56682.0 | NaN | NaN | 64.752 | Americas | ABW |
3 | Aruba | ABW | 1963 | 57475.0 | NaN | NaN | 65.132 | Americas | ABW |
4 | Aruba | ABW | 1964 | 58178.0 | NaN | NaN | 65.294 | Americas | ABW |
Note how the new joined dataset contains both of the country code columns (because their names were different). Also, the naming convention in our table is not uniform – some column names are in snake_case
(which is preferred) while others contain spaces and a mix of uppercase and lowercase letters. Let us use pandas.DataFrame.drop()
to drop the second country code column and pandas.DataFrame.rename()
to rename some of the columns to ensure an uniform column naming convention. Remember that we can use the inplace=True
argument to apply the changes to the original DataFrame.
data.drop(columns='ISO-alpha3 Code', inplace=True)
pandas.DataFrame.rename()
takes a dictionary in the format {"old_name": "new_name"}
as an argument and you need to specify whether you would like to rename rows or columns.
data.rename(columns={'Country Name': 'country_name',
'Country Code': 'country_code',
'Region Name': 'region_name'},
inplace=True)
data.head()
country_name | country_code | year | population | gdp | gdp_per_capita | life_exp | region_name | |
---|---|---|---|---|---|---|---|---|
0 | Aruba | ABW | 1960 | 54608.0 | NaN | NaN | 64.152 | Americas |
1 | Aruba | ABW | 1961 | 55811.0 | NaN | NaN | 64.537 | Americas |
2 | Aruba | ABW | 1962 | 56682.0 | NaN | NaN | 64.752 | Americas |
3 | Aruba | ABW | 1963 | 57475.0 | NaN | NaN | 65.132 | Americas |
4 | Aruba | ABW | 1964 | 58178.0 | NaN | NaN | 65.294 | Americas |
To extract specific rows from a DataFrame, we can combine the square brackets indexing operator pandas.DataFrame[]
with a logical operation that produces a boolean array. This would select every row from the DataFrame where the corresponding element in the boolean array equals True
. For example, to extract all rows that correspond to the United States, we could use data.country_code == "USA"
. This would return an array of True
and False
values where the value of a specific element in the array is True
if the corresponding row in the data
DataFrame had the value "USA"
in its country_code
column.
data.country_code == 'USA'
0 False 1 False 2 False 3 False 4 False ... 13325 False 13326 False 13327 False 13328 False 13329 False Name: country_code, Length: 13330, dtype: bool
Combining this with the square brackets indexing operator data[]
will extract all values from the data
DataFrame where the country_code
column has the value "USA"
.
usa_data = data[data.country_code == 'USA']
usa_data.head()
country_name | country_code | year | population | gdp | gdp_per_capita | life_exp | region_name | |
---|---|---|---|---|---|---|---|---|
12524 | United States | USA | 1960 | 180671000.0 | 5.433000e+11 | 3007.123445 | 69.770732 | Americas |
12525 | United States | USA | 1961 | 183691000.0 | 5.633000e+11 | 3066.562869 | 70.270732 | Americas |
12526 | United States | USA | 1962 | 186538000.0 | 6.051000e+11 | 3243.843078 | 70.119512 | Americas |
12527 | United States | USA | 1963 | 189242000.0 | 6.386000e+11 | 3374.515171 | 69.917073 | Americas |
12528 | United States | USA | 1964 | 191889000.0 | 6.858000e+11 | 3573.941185 | 70.165854 | Americas |
We can ensure that this new usa_data
DataFrame only contains values corresponding to the United States by calling pandas.Series.unique()
on the country_name
column. This will return an array of all the unique country names present in the table.
usa_data.country_name.unique()
array(['United States'], dtype=object)
Note that even though there is only one unique value, the result is still an array. To extract the value as a string, we must extract the first element of the array using [0]
.
usa_data.country_name.unique()[0]
'United States'
Matplotlib is the primary plotting library in Python and it is designed to resemble the plotting functionalities of MATLAB. While it provides all kinds of different plotting functionality, the matplotlib.plyplot
module is used the most. It is common to import this module under the alias plt
as we did before. Matplotlib works in a layered fashion. First you define your plot using matplotlib.pyplot.plot(x, y, ...)
, then you can use additional matplotlib.plyplot
methods to add more layers to your plot or modify its appearance. Finally, you use matplotlib.pyplot.show()
to display the plot or matplotlib.pyplot.savefig()
to save it to an external file.
The x
and y
arguments in the matplotlib.pyplot.plot()
call can be either arrays or pandas.Series
objects. For example, we can visualize the population of the United States over time by extracting the year
and population
columns of the usa_data
table as pandas.Series
objects and passing them along to matplotlib.pyplot.plot()
as follows.
plt.plot(usa_data.year, usa_data.population)
plt.show()
Alternatively we could pass the pandas.DataFrame
to the matplotlib.pyplot.plot()
command using the optional data
argument. This will allow us to specify the desired column names as the x
and y
arguments instead of having to extract them as pandas.Series
objects. For example, we can visualize the GDP of the United States over time as follows.
plt.plot('year', 'gdp', data=usa_data)
plt.show()
Pandas also has built-in plotting functionality via the pandas.DataFrame.plot()
method. It takes the column names of the x
and y
columns as arguments and uses a plotting backend to generate the plot. By default, the plotting backend is Matplotlib, but this could be reconfigured to be something else instead. For example, we can create a Matplotlib visualization showing United States life expectancy over time as follows.
usa_data.plot(x='year', y='life_exp')
plt.show()
To create a line graph with multiple lines, we need to stack the lines using multiple matplotlib.pyplot.plot()
or pandas.DataFrame.plot()
calls. But how can we specify that we would like to stack the lines onto a single plot instead of creating a new plot for each line? This is where the matplotlib.axes.Axes
class comes into play. For simplicity, you can think of each matplotlib.axes.Axes
object as a canvas onto which one can add multiple layers of visualization. When using pandas.DataFrame.plot()
to create visualizations, we can utilize matplotlib.axes.Axes
to create multi-layered plots as follows:
pandas.DataFrame.plot()
command will return a matplotlib.axes.Axes
object. This object should be saved into a variable. It is common to save it into a variable called ax
.pandas.DataFrame.plot()
call, the matplotlib.axes.Axes
object from before should be passed on using the ax
argument. This will ensure the new plot gets added to the same canvas.We can combine the pandas.DataFrame.plot()
call with boolean indexing to easily visualize subsets of the data and use the color
and label
arguments to specify a color and legend label for each subset.
Once all the lines have been added to the plot, we can use matplotlib.pyplot.ylabel()
and matplotlib.pyplot.xlabel()
to label the axes and matplotlib.pyplot.title()
to specify a title for the visualization. Finally we call matplotlib.pyplot.show()
to display the plot.
Knowing all this, a visualization illustrating the GDP per capita over time for North American countries can be generated as follows.
ax = data[data.country_code == 'USA'].plot(x='year',
y='gdp_per_capita',
color='blue',
label='USA')
data[data.country_code == 'CAN'].plot(x='year',
y='gdp_per_capita',
color='red',
label='Canada',
ax=ax)
data[data.country_code == 'MEX'].plot(x='year',
y='gdp_per_capita',
color='green',
label='Mexico',
ax=ax)
plt.ylabel('GDP per capita')
plt.xlabel('Year')
plt.title('GDP per Capita Over Time for North American Countries')
plt.show()
There are many benefits to using pandas.DataFrame.plot()
over matplotlib.pyplot.plot()
when dealing with DataFrames. Most importantly, pandas.DataFrame.plot()
interacts directly with a Pandas DataFrame and has a much simpler user interface with numerous named arguments allowing for easy customization. However, when it comes to more advanced tasks, Matplotlib allows for better fine-tuning and more flexibility. However, this comes at a cost of more complex commands. For example, to create a plot that displays the temporal variation of both the life expectancy and GDP per capita of the United States using two different Y axes, we must use a relatively advanced workflow.
First, we define the size of our plot using the figsize
argument of matplotlib.pyplot.subplots()
. This command allows for the creation of multiple subplots, but is also frequently used to specify the size of a single plot. It returns a tuple consisting of a matplotlib.figure.Figure
and a matplotlib.axes.Axes
object.
To add a plot layer to a specific matplotlib.axes.Axes
object, we can use matplotlib.axes.Axes.plot()
which works very similarly to the previously discussed matplotlib.pyplot.plot()
command. Both commands take an optional format string as the third positional argument that allows you to specify the line and marker style and color using a simple shorthand. For example, "g--"
means a green dashed line and "mx"
indicates magenta-colored X-shaped markers. Refer to the function documentation for a full overview of all the shorthand characters. The Matplotlib commands for adding axes labels and plot titles also have additional arguments that modify the appearance of the label or title. For example, color
usually specifies the text color and size
is used to specify the size of the font.
To add another Y axis to the plot, we can use matplotlib.axes.Axes.twinx()
to create another matplotlib.axes.Axes
object that defines a new Y axis but shares the same X axis.
Finally, we can use matplotlib.figure.Figure.legend()
to add a legend to the whole figure (including all the axes objects).
fig, ax = plt.subplots(figsize=(7, 5))
ax.plot(usa_data.year, usa_data.gdp_per_capita, 'g--', label='GDP per Capita')
plt.ylabel('GDP per Capita', color='g',)
plt.xlabel('Year')
ax2 = ax.twinx()
ax2.plot(usa_data.year, usa_data.life_exp, 'mx', label='Life Expectancy')
plt.ylabel('Life Expectancy', color='m')
plt.title('United States', size=20)
fig.legend()
plt.show()
Let us return to our original goal of exploring the relationship between health and wealth. We will use GDP per capita as a proxy for wealth and life expectancy as an indicator of health. We can simplify the analysis by looking only at one point in time and focus our analysis on 2020, which is the latest year we have both GDP per capita and life expectancy data available. We shall use boolean indexing to extract 2020 data into a new DataFrame called data2020
.
data2020 = data[data.year == 2020]
data2020.head()
country_name | country_code | year | population | gdp | gdp_per_capita | life_exp | region_name | |
---|---|---|---|---|---|---|---|---|
60 | Aruba | ABW | 2020 | 106585.0 | 2.610039e+09 | 24487.863560 | 75.723 | Americas |
122 | Afghanistan | AFG | 2020 | 38972230.0 | 2.014344e+10 | 516.866552 | 62.575 | Asia |
184 | Angola | AGO | 2020 | 33428486.0 | 5.361907e+10 | 1603.993477 | 62.261 | Africa |
246 | Albania | ALB | 2020 | 2837849.0 | 1.513187e+10 | 5332.160475 | 76.989 | Europe |
308 | Andorra | AND | 2020 | 77700.0 | 2.891022e+09 | 37207.493861 | NaN | Europe |
How is wealth distributed amongst the global population? Let us get a vague idea by visualizing the distribution of GDP per capita amongst world countries in 2020. We can easily create an histogram by using the matplotlib.pyplot.hist()
command and passing it the GDP per capita pandas.Series
.
plt.hist(data2020.gdp_per_capita)
plt.xlabel('GDP per Capita')
plt.show()
Note how we were able to easily create a histogram, but the result was quite ugly. If we wanted a prettier plot, we could go though the trouble of customizing the plot using various additional arguments and commands, which would take quite a while. Or we could use Seaborn which allows us to easily create beautiful visualizations with sensible defaults. For example, we could create a well-designed histogram with a smoothed kernel density estimate (KDE) overlay using the seaborn.histplot()
function along with the kde=True
flag. Knowing this, let us look at the distribution of life expectancy amongst world countries in 2020.
sns.histplot(data2020.life_exp, kde=True)
plt.show()
To easily create a scatter plot analyzing the relationship between GDP per capita and life expectancy, we can use pandas.DataFrame.plot()
and specify kind="scatter"
to ensure the result is a scatter plot.
data2020.plot(x='gdp_per_capita', y='life_exp', kind='scatter')
plt.show()
The relationship appears to be logarithmic. This is likely due to the distribution of GDP per capita being heavily skewed. We can easily confirm this by plotting a two-dimensional kernel density estimate (KDE) plot using seaborn.jointplot()
along with kind="kde"
. (To get a scatter plot with histograms, one would use kind="scatter"
.)
sns.jointplot(data=data2020,
x='gdp_per_capita',
y='life_exp',
kind='kde',
fill=True)
plt.show()
To get a better sense of the potentially logarithmic relationship between GDP per capita and life expectancy, we should apply a logarithmic transformation to the axis corresponding to GDP per capita. In our example this is the X axis and we can apply a logarithmic transformation on the X axis by passing "log"
to the matplotlib.pyplot.xscale()
function.
plt.scatter(data2020.gdp_per_capita, data2020.life_exp)
plt.xscale('log')
plt.xlabel('GDP per Capita')
plt.ylabel('Life Expectancy')
plt.show()
Note how we used matplotlib.pyplot.scatter()
instead of pandas.DataFrame.plot()
to create the scatter plot. Both functions are very similar and in reality, the latter simply calls the former. Also note how now the X axis of the scatter plot is logarithmic. This makes the relationship much clearer and we can quite definitely state that there appears to be a logarithmic relationship between life expectancy and GDP per capita.
But does the size of a country play a role in this relationship? To find out, we can scale the size of the data points proportionally to the population such that bigger points indicate countries with more population. This can be done using the s
argument in matplotlib.pyplot.scatter()
, which takes an array of point sizes. This array needs to be the same size as the x
and y
arrays with one size value for each x
and y
combination. We can easily generate an array like this using the formula $X \div max(X) \times s$ where $X$ is the array we want to base the sizes on and $s$ is a scaling factor in arbitrary plot units. Note that the scaling factor is completely arbitrary and you might need to try different values until you find something that makes the visualization look good. We divide the input array with its maximum value to properly normalize and scale the sizes.
Scaling the point sizes by population might cause some bigger points to overlap smaller ones. To ensure we can properly see overlapping points, we can use the alpha
argument in the matplotlib.pyplot.scatter()
call to specify a transparency factor.
fig, ax = plt.subplots(figsize=(7, 7))
ax.scatter(data2020.gdp_per_capita,
data2020.life_exp,
s=data2020.population/data2020.population.max()*5000,
alpha=0.5)
plt.xscale('log')
plt.xlabel('GDP per Capita')
plt.ylabel('Life Expectancy')
plt.show()
Looks like the size of a country does not seem to be related to GDP per capita or life expectancy. But what about the region a country is in? There is a good chance a correlation exists between the geographical location of a country and other indicators. To find out, we should color the points based on their geographic region. We know from before that this requires adding multiple layers to the plot – one for each region. We can get a list of all the regions by using pandas.Series.unique()
on the region_name
column. Then we can iterate over that list using a loop, subset the data for each region, and create a scatter plot layer using the subset data.
fig, ax = plt.subplots(figsize=(7, 7))
for region in data2020.region_name.unique():
subset = data2020[data2020.region_name == region]
ax.scatter(subset.gdp_per_capita,
subset.life_exp,
s=subset.population/data2020.population.max()*5000,
label=region,
alpha=0.5)
plt.xscale('log')
plt.xlabel('GDP per Capita')
plt.ylabel('Life Expectancy')
plt.title('2020')
plt.show()
One of the main drawbacks on Matplotlib is the fact that one needs to create multiple layers to visualize groups using different colors. This can be a tedious process and usually involves having to subset the data using a loop. To circumnavigate this, many choose to use Seaborn instead, which allows for a grouping variable to be passed via the hue
argument. For example, to recreate the plot from above without having to use a loop, we can utilize seaborn.scatterplot()
with hue="region_name"
. To scale the point sizes by population, we can specify size="population"
then use the sizes
arguments to give a tuple that defines the smallest point size and the largest point size in arbitrary plot units. As before, you might need to play around with the tuple values in sizes
until you find a combination that looks good.
fig, ax = plt.subplots(figsize=(7, 7))
sns.scatterplot(data=data2020,
x='gdp_per_capita',
y='life_exp',
hue='region_name',
size='population',
sizes=(10, 5000),
alpha=0.5,
legend=False,
ax=ax)
plt.xscale('log')
plt.xlabel('GDP per Capita')
plt.ylabel('Life Expectancy')
plt.title('2020')
plt.show()
While the static scatter plot above is quite pretty to look at, it is not the most informative. We have no idea which points represent which countries and many countries appear clustered together, which makes it harder to tell them apart. An interactive visualization would allow for better exploration and investigation of the data. The easiest way of creating an interactive visualization out of a Pandas DataFrame is to use HVPlot, which is built on top of Bokeh and HoloViews and utilizes them in the background. Importing the hvplot.pandas
module as we did before adds a new pandas.DataFrame.hvplot
interface that allows for the creation of interactive plots using a syntax very similar to that of pandas.DataFrame.plot()
.
We can easily create an interactive version of the scatter plot from before by using pandas.DataFrame.hvplot.scatter()
with the following arguments:
x
and y
– the column names for the data plotted on the X and Y axes respectivelyc
– the column name that defines the groups or values based on which to color the points bys
– the column name that defines values to use as point sizesscale
– scaling factor to use when deriving point sizes from values specified by s
(we will use $1 \div max(X) \times y$, where $X$ is the column specified in s
and $y$ is an arbitrary scaling factor)hover_cols
– fields to include in the tooltips in addition to those specified in x
, y
, c
, and s
alpha
– transparency factorlogx
– whether to apply a logarithmic transformation on the X axiswidth
and height
– the size of the visualization in pixelsTake some time to explore the the interactive visualization using the available controls. Experiment with panning and zooming and hover over various points to explore the tooltips.
data2020.hvplot.scatter(x='gdp_per_capita',
y='life_exp',
c='region_name',
s='population',
scale=1/data2020.population.max()*2000000,
hover_cols=['country_name', 'country_code'],
alpha=0.5,
logx=True,
width=650,
height=500)
An alternative to HVPlot is Plotly, which is a popular interactive visualization library used in many programming languages. It consists of a complex ecosystem of various modules, but the plotly.express
module is the most popular and easiest to use. The syntax of plotly.express
is very similar to that of HVPlot. The biggest difference between the two libraries is that plotly.express
does not handle missing data and expects the input DataFrame to not contain any missing values. Hence we must drop all rows with missing values from the table using pandas.DataFrame.dropna()
before passing it onto Plotly.
We can create an interactive scatter plot via Plotly using the plotly.express.scatter()
function along with the following arguments (note the similarities between HVPlot):
data_frame
– the pandas.DataFrame
to use for the visualization with rows containing missing values removedx
and y
– the column names for the data plotted on the X and Y axes respectivelycolor
– the column name that defines the groups or values based on which to color the points bysize
– the column name that defines values to use as point sizessize_max
– the size of the largest point in pixels (used to scale all point sizes)hover_name
– the column name that defines the values to be used as tooltip titleshover_data
– fields to include in the tooltip in addition to those specified in x
, y
, color
, size
, and hover_name
opacity
– transparency factorlog_x
– whether to apply a logarithmic transformation on the X axiswidth
and height
– the size of the visualization in pixelsAs before, make sure to explore the the interactive visualization using the available controls. Note how the tooltips and controls differ from those provided by HVPlot.
px.scatter(data_frame=data2020.dropna(),
x='gdp_per_capita',
y='life_exp',
color='region_name',
size='population',
size_max=40,
hover_name='country_name',
hover_data=['country_code'],
opacity=0.5,
log_x=True,
width=650,
height=600)
Thus far we have covered the basics of working with data in Python, including reading CSV files, manipulating and reshaping data, joining tables, and creating both static and interactive visualizations. This covers the majority of the most essential data analysis workflows you might need. However, there are two major topics we have yet to discuss -- working with time series and aggregating data by group. We will explore these concepts using rapid transit ridership data from the Massachusetts Bay Transportation Authority (MBTA). Once we have covered these two final topics, you should have all the skills you need to begin your Python data analysis journey.
The dataset we will use is a CSV file named mbta-gated-entries-2020.csv
located in the data
directory. Each row in the table represents an unique 30-minute service time period for a specific MBTA rapid transit station and line combination in 2020. The columns are as follows:
service_date
-- date in ISO 8601 YYYY-MM-DD
formattime period
-- timestamp denoting the start of the 30-minute time period in a somewhat unusual (HH:mm:ss)
formatstop_id
-- unique identifier for the rapid transit stopstation_name
-- name of the rapid transit stoproute_or_line
-- route or line served by the stopgated_entires
-- number of gated entries at the specified stop for the specified line or route in the specified time periodLet us read this dataset into a pandas.DataFrame
called mbta
using pandas.read_csv()
and explore it via pandas.DataFrame.head()
and pandas.DataFrame.dtypes
.
mbta = pd.read_csv('data/mbta-gated-entries-2020.csv')
mbta.head()
service_date | time_period | stop_id | station_name | route_or_line | gated_entries | |
---|---|---|---|---|---|---|
0 | 2020-01-01 | (00:00:00) | place-alfcl | Alewife | Red Line | 3 |
1 | 2020-01-01 | (00:00:00) | place-andrw | Andrew | Red Line | 8 |
2 | 2020-01-01 | (00:00:00) | place-aport | Airport | Blue Line | 32 |
3 | 2020-01-01 | (00:00:00) | place-aqucl | Aquarium | Blue Line | 15 |
4 | 2020-01-01 | (00:00:00) | place-armnl | Arlington | Green Line | 3 |
mbta.dtypes
service_date object time_period object stop_id object station_name object route_or_line object gated_entries int64 dtype: object
Note how both the service_date
and time_period
have the datatype of object
, indicating that they are stored as text. This does not allow us to treat these values as proper timestamps, limiting our options for quantitative analysis. To fix this, we should combine the service_date
and time_period
into a single timestamp using pandas.to_datetime()
. But first we must clean the time_period
values, which are all in parentheses for some reason.
To strip the time_period
values of the parentheses, we can utilize the pandas.Series.str
interface that allows us to apply string methods on the whole column. This allows us to apply a vectorized version of the built-in str.strip()
method to the whole column via pandas.Series.str.strip()
.
Hence we can do the following all in one command:
time_period
column as a pandas.Series
object.pandas.Series.str.strip()
to remove the parentheses from the values, creating a new pandas.Series
object.time_period
column with this new pandas.Series
object where the parentheses have been removed.mbta['time_period'] = mbta.time_period.str.strip('()')
mbta.head()
service_date | time_period | stop_id | station_name | route_or_line | gated_entries | |
---|---|---|---|---|---|---|
0 | 2020-01-01 | 00:00:00 | place-alfcl | Alewife | Red Line | 3 |
1 | 2020-01-01 | 00:00:00 | place-andrw | Andrew | Red Line | 8 |
2 | 2020-01-01 | 00:00:00 | place-aport | Airport | Blue Line | 32 |
3 | 2020-01-01 | 00:00:00 | place-aqucl | Aquarium | Blue Line | 15 |
4 | 2020-01-01 | 00:00:00 | place-armnl | Arlington | Green Line | 3 |
We can concatenate pandas.Series
objects containing textual data the same way we can concatenate strings in Python. Knowing this, we can easily combine the service_date
and time_period
columns into a single timestamp.
mbta['timestamp'] = mbta.service_date + ' ' + mbta.time_period
mbta.head()
service_date | time_period | stop_id | station_name | route_or_line | gated_entries | timestamp | |
---|---|---|---|---|---|---|---|
0 | 2020-01-01 | 00:00:00 | place-alfcl | Alewife | Red Line | 3 | 2020-01-01 00:00:00 |
1 | 2020-01-01 | 00:00:00 | place-andrw | Andrew | Red Line | 8 | 2020-01-01 00:00:00 |
2 | 2020-01-01 | 00:00:00 | place-aport | Airport | Blue Line | 32 | 2020-01-01 00:00:00 |
3 | 2020-01-01 | 00:00:00 | place-aqucl | Aquarium | Blue Line | 15 | 2020-01-01 00:00:00 |
4 | 2020-01-01 | 00:00:00 | place-armnl | Arlington | Green Line | 3 | 2020-01-01 00:00:00 |
mbta.dtypes
service_date object time_period object stop_id object station_name object route_or_line object gated_entries int64 timestamp object dtype: object
Although the text in the new timestamp
column sure looks like a valid timestamp, it is still just textual data and has no meaning to Python or Pandas. To covert these textual timestamps into Pandas-aware timestamps, we can use pandas.to_datetime()
and pass the timestamp
series as input. This will produce a new series that we will use to replace the timestamp
column.
mbta['timestamp'] = pd.to_datetime(mbta.timestamp)
mbta.head()
service_date | time_period | stop_id | station_name | route_or_line | gated_entries | timestamp | |
---|---|---|---|---|---|---|---|
0 | 2020-01-01 | 00:00:00 | place-alfcl | Alewife | Red Line | 3 | 2020-01-01 |
1 | 2020-01-01 | 00:00:00 | place-andrw | Andrew | Red Line | 8 | 2020-01-01 |
2 | 2020-01-01 | 00:00:00 | place-aport | Airport | Blue Line | 32 | 2020-01-01 |
3 | 2020-01-01 | 00:00:00 | place-aqucl | Aquarium | Blue Line | 15 | 2020-01-01 |
4 | 2020-01-01 | 00:00:00 | place-armnl | Arlington | Green Line | 3 | 2020-01-01 |
mbta.dtypes
service_date object time_period object stop_id object station_name object route_or_line object gated_entries int64 timestamp datetime64[ns] dtype: object
Note how timestamp
now has a datatype of datetime64
(with nanosecond precision). This allows us to perform arithmetic and comparisons on the timestamps and also utilize various additional date-time methods (like extracting the month or weekday for example) via the pandas.Series.dt
interface.
Let us simplify our further analysis by dropping the redundant service_date
and time_period
columns using the pandas.DataFrame.drop()
method.
mbta.drop(columns=['service_date', 'time_period'], inplace=True)
mbta.head()
stop_id | station_name | route_or_line | gated_entries | timestamp | |
---|---|---|---|---|---|
0 | place-alfcl | Alewife | Red Line | 3 | 2020-01-01 |
1 | place-andrw | Andrew | Red Line | 8 | 2020-01-01 |
2 | place-aport | Airport | Blue Line | 32 | 2020-01-01 |
3 | place-aqucl | Aquarium | Blue Line | 15 | 2020-01-01 |
4 | place-armnl | Arlington | Green Line | 3 | 2020-01-01 |
We can easily get the total number of gated entries across the whole MBTA system in 2020 by using pandas.Series.sum()
.
mbta.gated_entries.sum()
50199157
Combining pandas.Series.sum()
with boolean indexing allows us to extract the total number of gated entries for specific stations, lines, or even dates.
mbta.gated_entries[mbta.station_name == 'Davis'].sum()
996012
mbta.gated_entries[mbta.route_or_line == 'Red Line'].sum()
18947501
mbta.gated_entries[mbta.timestamp == '2020-02-24'].sum()
979
Having the timestamps in datetime64
format allows us to extract specific time periods using comparisons. For example, we can get the total number of gated entries across the whole MBTA system in February 2020 as follows.
mbta.gated_entries[
(mbta.timestamp >= '2020-02-01') & (mbta.timestamp < '2020-03-01')].sum()
10776306
Alternatively, we could take advantage of pandas.Series.dt.month
to extract the month numbers of the datetype64
values and use that to get the same information.
mbta.gated_entries[mbta.timestamp.dt.month == 2].sum()
10776306
Let us say we would like to get the number of gated entries across the whole MBTA system for each day in 2020. Pandas provides easy functionality to calculate various aggregate values by group, as long as there is a categorical column that defines the groups. Currently we only have a datetime column, which is not categorical and hence not suitable for aggregating entires by date. However, the service_date
column we removed would have been perfect for this task. Luckily we can easily recreate this column using pandas.Series.dt.date
to extract the date from the datetime64
timestamp.
mbta['date'] = mbta.timestamp.dt.date
mbta.head()
stop_id | station_name | route_or_line | gated_entries | timestamp | date | |
---|---|---|---|---|---|---|
0 | place-alfcl | Alewife | Red Line | 3 | 2020-01-01 | 2020-01-01 |
1 | place-andrw | Andrew | Red Line | 8 | 2020-01-01 | 2020-01-01 |
2 | place-aport | Airport | Blue Line | 32 | 2020-01-01 | 2020-01-01 |
3 | place-aqucl | Aquarium | Blue Line | 15 | 2020-01-01 | 2020-01-01 |
4 | place-armnl | Arlington | Green Line | 3 | 2020-01-01 | 2020-01-01 |
Now we can use pandas.DataFrame.groupby()
to convert the pandas.DataFrame
into a pandas.groupby.DataFrameGroupBy
object, where all the values of the DataFrame are grouped by the specified categorical variable and any methods will apply by group. Note that this is no longer a DataFrame, so we cannot display it as such.
mbta.groupby('date')
<pandas.core.groupby.generic.DataFrameGroupBy object at 0x7f586afe0750>
We can extract the desired column from this pandas.groupby.DataFrameGroupBy
object as a pandas.groupby.SeriesGroupBy
object, where any methods called on the series will apply by the previously defined groups.
mbta.groupby('date').gated_entries
<pandas.core.groupby.generic.SeriesGroupBy object at 0x7f586afe0250>
When we call pandas.groupby.GroupBy.sum()
on this pandas.groupby.SeriesGroupBy
object, we will get a new pandas.Series
object where all the gated entries for each unique date have been added together.
mbta.groupby('date').gated_entries.sum()
date 2020-01-01 131374 2020-01-02 398109 2020-01-03 402018 2020-01-04 196810 2020-01-05 152047 ... 2020-12-27 50923 2020-12-28 94764 2020-12-29 96679 2020-12-30 96727 2020-12-31 54861 Name: gated_entries, Length: 366, dtype: int64
We can convert this pandas.Series
into a pandas.DataFrame
using pandas.Series.to_frame()
. We can also specify a new name for the column containing the aggregated values if desired.
mbta.groupby('date').gated_entries.sum().to_frame('total_entries').head()
total_entries | |
---|---|
date | |
2020-01-01 | 131374 |
2020-01-02 | 398109 |
2020-01-03 | 402018 |
2020-01-04 | 196810 |
2020-01-05 | 152047 |
Note how the groups make up the index of the new DataFrame. We can use pandas.DataFrame.reset_index()
to convert the dates back into a column and reset the index to a numerical one ranging from zero to one less than the number of rows. We can chain all the methods from before together and create a new DataFrame called mbta_daily_sum
that contains the total number of gated entries across the MBTA system for each date in 2020.
mbta_daily_sum = (mbta.groupby('date')
.gated_entries.sum()
.to_frame('total_entries')
.reset_index())
mbta_daily_sum.head()
date | total_entries | |
---|---|---|
0 | 2020-01-01 | 131374 |
1 | 2020-01-02 | 398109 |
2 | 2020-01-03 | 402018 |
3 | 2020-01-04 | 196810 |
4 | 2020-01-05 | 152047 |
If we wanted to find out which date had the most ridership, we could use pandas.Series.max()
to get the maximum number of gated entries and then utilize boolean indexing to find out which date it corresponds to.
mbta_daily_sum.total_entries.max()
495770
mbta_daily_sum.date[
mbta_daily_sum.total_entries == mbta_daily_sum.total_entries.max()]
42 2020-02-12 Name: date, dtype: object
mbta_daily_sum.date[
mbta_daily_sum.total_entries == mbta_daily_sum.total_entries.max()
].values[0]
datetime.date(2020, 2, 12)
Alternatively, we could use pandas.Series.argmax()
to extract the index of the row with the most ridership and then utilize pandas.DataFrame.loc[]
to extract said row using its index.
mbta_daily_sum.total_entries.argmax()
42
mbta_daily_sum.loc[mbta_daily_sum.total_entries.argmax()]
date 2020-02-12 total_entries 495770 Name: 42, dtype: object
mbta_daily_sum.loc[mbta_daily_sum.total_entries.argmax(), 'date']
datetime.date(2020, 2, 12)
Finally, we could utilize pandas.DataFrame.sort_values()
to sort the values by total number of gated entries.
mbta_daily_sum.sort_values('total_entries')
date | total_entries | |
---|---|---|
102 | 2020-04-12 | 17976 |
116 | 2020-04-26 | 18506 |
109 | 2020-04-19 | 20025 |
95 | 2020-04-05 | 20240 |
359 | 2020-12-25 | 20680 |
... | ... | ... |
29 | 2020-01-30 | 490601 |
14 | 2020-01-15 | 491462 |
35 | 2020-02-05 | 492955 |
57 | 2020-02-27 | 493547 |
42 | 2020-02-12 | 495770 |
366 rows × 2 columns
Knowing all of this, we can easily take a quick look at the most and least used rapid transit stations and lines across the MBTA system in 2020 by combining the following:
(mbta.groupby('station_name')
.gated_entries.sum()
.sort_values(ascending=False)
.to_frame('total_entries')
.reset_index())
station_name | total_entries | |
---|---|---|
0 | Downtown Crossing | 2289292 |
1 | South Station | 2017283 |
2 | North Station | 1720378 |
3 | Harvard | 1690353 |
4 | Maverick | 1661287 |
... | ... | ... |
59 | Symphony | 182537 |
60 | Riverside | 137870 |
61 | Suffolk Downs | 112710 |
62 | World Trade Center | 91707 |
63 | Science Park | 38237 |
64 rows × 2 columns
(mbta.groupby('route_or_line')
.gated_entries.sum()
.sort_values(ascending=False)
.to_frame('total_entries')
.reset_index())
route_or_line | total_entries | |
---|---|---|
0 | Red Line | 18947501 |
1 | Orange Line | 16121128 |
2 | Green Line | 7493791 |
3 | Blue Line | 6626063 |
4 | Silver Line | 1010674 |
We can also group by multiple columns. For example, we can group by date
, station_name
, and route_or_line
to create a new DataFrame mbta_daily
, where the gated entries for each station and line combination are shown in 24-hour intervals instead of 30-minute intervals.
mbta_daily = (mbta.groupby(['date', 'station_name', 'route_or_line'])
.gated_entries.sum()
.to_frame()
.reset_index())
mbta_daily.head()
date | station_name | route_or_line | gated_entries | |
---|---|---|---|---|
0 | 2020-01-01 | Airport | Blue Line | 3883 |
1 | 2020-01-01 | Alewife | Red Line | 2449 |
2 | 2020-01-01 | Andrew | Red Line | 1668 |
3 | 2020-01-01 | Aquarium | Blue Line | 2015 |
4 | 2020-01-01 | Arlington | Green Line | 1592 |
This new DataFrame will allow us to perform further analysis that does not require 30-minute temporal resolution and where a daily resolution is suitable. For example, we could look at the daily number of gated entries at the Harvard Square MBTA station throughout 2020 and see whether the onset of the global pandemic had an effect on ridership.
fig, ax = plt.subplots(figsize=(7, 5))
mbta_daily[(mbta_daily.station_name == 'Harvard')].plot(x='date',
y='gated_entries',
legend=False,
color='crimson',
ax=ax)
plt.xlabel('Date')
plt.ylabel('Gated Entries')
plt.title('2020 Daily Gated Entires at the Harvard Square MBTA Station')
plt.show()
Interactive Kaggle tutorials with built-in exercises:
Official Pandas resources:
Official Matplotlib resources:
Official Seaborn resources:
Official HVPlot resources:
Official Plotly resources: