
In [1]:
# Importing the libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
In [3]:
df = pd.read_csv("../data/50_Startups.csv")
In [4]:
df
Out[4]:
R&D Spend | Administration | Marketing Spend | State | Profit | |
---|---|---|---|---|---|
0 | 165349.20 | 136897.80 | 471784.10 | New York | 192261.83 |
1 | 162597.70 | 151377.59 | 443898.53 | California | 191792.06 |
2 | 153441.51 | 101145.55 | 407934.54 | Florida | 191050.39 |
3 | 144372.41 | 118671.85 | 383199.62 | New York | 182901.99 |
4 | 142107.34 | 91391.77 | 366168.42 | Florida | 166187.94 |
5 | 131876.90 | 99814.71 | 362861.36 | New York | 156991.12 |
6 | 134615.46 | 147198.87 | 127716.82 | California | 156122.51 |
7 | 130298.13 | 145530.06 | 323876.68 | Florida | 155752.60 |
8 | 120542.52 | 148718.95 | 311613.29 | New York | 152211.77 |
9 | 123334.88 | 108679.17 | 304981.62 | California | 149759.96 |
10 | 101913.08 | 110594.11 | 229160.95 | Florida | 146121.95 |
11 | 100671.96 | 91790.61 | 249744.55 | California | 144259.40 |
12 | 93863.75 | 127320.38 | 249839.44 | Florida | 141585.52 |
13 | 91992.39 | 135495.07 | 252664.93 | California | 134307.35 |
14 | 119943.24 | 156547.42 | 256512.92 | Florida | 132602.65 |
15 | 114523.61 | 122616.84 | 261776.23 | New York | 129917.04 |
16 | 78013.11 | 121597.55 | 264346.06 | California | 126992.93 |
17 | 94657.16 | 145077.58 | 282574.31 | New York | 125370.37 |
18 | 91749.16 | 114175.79 | 294919.57 | Florida | 124266.90 |
19 | 86419.70 | 153514.11 | 0.00 | New York | 122776.86 |
20 | 76253.86 | 113867.30 | 298664.47 | California | 118474.03 |
21 | 78389.47 | 153773.43 | 299737.29 | New York | 111313.02 |
22 | 73994.56 | 122782.75 | 303319.26 | Florida | 110352.25 |
23 | 67532.53 | 105751.03 | 304768.73 | Florida | 108733.99 |
24 | 77044.01 | 99281.34 | 140574.81 | New York | 108552.04 |
25 | 64664.71 | 139553.16 | 137962.62 | California | 107404.34 |
26 | 75328.87 | 144135.98 | 134050.07 | Florida | 105733.54 |
27 | 72107.60 | 127864.55 | 353183.81 | New York | 105008.31 |
28 | 66051.52 | 182645.56 | 118148.20 | Florida | 103282.38 |
29 | 65605.48 | 153032.06 | 107138.38 | New York | 101004.64 |
30 | 61994.48 | 115641.28 | 91131.24 | Florida | 99937.59 |
31 | 61136.38 | 152701.92 | 88218.23 | New York | 97483.56 |
32 | 63408.86 | 129219.61 | 46085.25 | California | 97427.84 |
33 | 55493.95 | 103057.49 | 214634.81 | Florida | 96778.92 |
34 | 46426.07 | 157693.92 | 210797.67 | California | 96712.80 |
35 | 46014.02 | 85047.44 | 205517.64 | New York | 96479.51 |
36 | 28663.76 | 127056.21 | 201126.82 | Florida | 90708.19 |
37 | 44069.95 | 51283.14 | 197029.42 | California | 89949.14 |
38 | 20229.59 | 65947.93 | 185265.10 | New York | 81229.06 |
39 | 38558.51 | 82982.09 | 174999.30 | California | 81005.76 |
40 | 28754.33 | 118546.05 | 172795.67 | California | 78239.91 |
41 | 27892.92 | 84710.77 | 164470.71 | Florida | 77798.83 |
42 | 23640.93 | 96189.63 | 148001.11 | California | 71498.49 |
43 | 15505.73 | 127382.30 | 35534.17 | New York | 69758.98 |
44 | 22177.74 | 154806.14 | 28334.72 | California | 65200.33 |
45 | 1000.23 | 124153.04 | 1903.93 | New York | 64926.08 |
46 | 1315.46 | 115816.21 | 297114.46 | Florida | 49490.75 |
47 | 0.00 | 135426.92 | 0.00 | California | 42559.73 |
48 | 542.05 | 51743.15 | 0.00 | New York | 35673.41 |
49 | 0.00 | 116983.80 | 45173.06 | California | 14681.40 |
In [5]:
df.describe()
Out[5]:
R&D Spend | Administration | Marketing Spend | Profit | |
---|---|---|---|---|
count | 50.000000 | 50.000000 | 50.000000 | 50.000000 |
mean | 73721.615600 | 121344.639600 | 211025.097800 | 112012.639200 |
std | 45902.256482 | 28017.802755 | 122290.310726 | 40306.180338 |
min | 0.000000 | 51283.140000 | 0.000000 | 14681.400000 |
25% | 39936.370000 | 103730.875000 | 129300.132500 | 90138.902500 |
50% | 73051.080000 | 122699.795000 | 212716.240000 | 107978.190000 |
75% | 101602.800000 | 144842.180000 | 299469.085000 | 139765.977500 |
max | 165349.200000 | 182645.560000 | 471784.100000 | 192261.830000 |
In [6]:
df.isna().sum()
Out[6]:
R&D Spend 0 Administration 0 Marketing Spend 0 State 0 Profit 0 dtype: int64
In [12]:
X = df.iloc[:, :-2+1 ]
In [7]:
y = df["Profit"]
In [13]:
y
Out[13]:
0 192261.83 1 191792.06 2 191050.39 3 182901.99 4 166187.94 5 156991.12 6 156122.51 7 155752.60 8 152211.77 9 149759.96 10 146121.95 11 144259.40 12 141585.52 13 134307.35 14 132602.65 15 129917.04 16 126992.93 17 125370.37 18 124266.90 19 122776.86 20 118474.03 21 111313.02 22 110352.25 23 108733.99 24 108552.04 25 107404.34 26 105733.54 27 105008.31 28 103282.38 29 101004.64 30 99937.59 31 97483.56 32 97427.84 33 96778.92 34 96712.80 35 96479.51 36 90708.19 37 89949.14 38 81229.06 39 81005.76 40 78239.91 41 77798.83 42 71498.49 43 69758.98 44 65200.33 45 64926.08 46 49490.75 47 42559.73 48 35673.41 49 14681.40 Name: Profit, dtype: float64
In [14]:
X
Out[14]:
R&D Spend | Administration | Marketing Spend | State | |
---|---|---|---|---|
0 | 165349.20 | 136897.80 | 471784.10 | New York |
1 | 162597.70 | 151377.59 | 443898.53 | California |
2 | 153441.51 | 101145.55 | 407934.54 | Florida |
3 | 144372.41 | 118671.85 | 383199.62 | New York |
4 | 142107.34 | 91391.77 | 366168.42 | Florida |
5 | 131876.90 | 99814.71 | 362861.36 | New York |
6 | 134615.46 | 147198.87 | 127716.82 | California |
7 | 130298.13 | 145530.06 | 323876.68 | Florida |
8 | 120542.52 | 148718.95 | 311613.29 | New York |
9 | 123334.88 | 108679.17 | 304981.62 | California |
10 | 101913.08 | 110594.11 | 229160.95 | Florida |
11 | 100671.96 | 91790.61 | 249744.55 | California |
12 | 93863.75 | 127320.38 | 249839.44 | Florida |
13 | 91992.39 | 135495.07 | 252664.93 | California |
14 | 119943.24 | 156547.42 | 256512.92 | Florida |
15 | 114523.61 | 122616.84 | 261776.23 | New York |
16 | 78013.11 | 121597.55 | 264346.06 | California |
17 | 94657.16 | 145077.58 | 282574.31 | New York |
18 | 91749.16 | 114175.79 | 294919.57 | Florida |
19 | 86419.70 | 153514.11 | 0.00 | New York |
20 | 76253.86 | 113867.30 | 298664.47 | California |
21 | 78389.47 | 153773.43 | 299737.29 | New York |
22 | 73994.56 | 122782.75 | 303319.26 | Florida |
23 | 67532.53 | 105751.03 | 304768.73 | Florida |
24 | 77044.01 | 99281.34 | 140574.81 | New York |
25 | 64664.71 | 139553.16 | 137962.62 | California |
26 | 75328.87 | 144135.98 | 134050.07 | Florida |
27 | 72107.60 | 127864.55 | 353183.81 | New York |
28 | 66051.52 | 182645.56 | 118148.20 | Florida |
29 | 65605.48 | 153032.06 | 107138.38 | New York |
30 | 61994.48 | 115641.28 | 91131.24 | Florida |
31 | 61136.38 | 152701.92 | 88218.23 | New York |
32 | 63408.86 | 129219.61 | 46085.25 | California |
33 | 55493.95 | 103057.49 | 214634.81 | Florida |
34 | 46426.07 | 157693.92 | 210797.67 | California |
35 | 46014.02 | 85047.44 | 205517.64 | New York |
36 | 28663.76 | 127056.21 | 201126.82 | Florida |
37 | 44069.95 | 51283.14 | 197029.42 | California |
38 | 20229.59 | 65947.93 | 185265.10 | New York |
39 | 38558.51 | 82982.09 | 174999.30 | California |
40 | 28754.33 | 118546.05 | 172795.67 | California |
41 | 27892.92 | 84710.77 | 164470.71 | Florida |
42 | 23640.93 | 96189.63 | 148001.11 | California |
43 | 15505.73 | 127382.30 | 35534.17 | New York |
44 | 22177.74 | 154806.14 | 28334.72 | California |
45 | 1000.23 | 124153.04 | 1903.93 | New York |
46 | 1315.46 | 115816.21 | 297114.46 | Florida |
47 | 0.00 | 135426.92 | 0.00 | California |
48 | 542.05 | 51743.15 | 0.00 | New York |
49 | 0.00 | 116983.80 | 45173.06 | California |
In [18]:
X["State"].nunique()
Out[18]:
3
In [22]:
sorted(X["State"].unique())
Out[22]:
['California', 'Florida', 'New York']
In [23]:
from sklearn.preprocessing import OneHotEncoder
In [24]:
from sklearn.compose import ColumnTransformer
In [28]:
X
Out[28]:
R&D Spend | Administration | Marketing Spend | State | |
---|---|---|---|---|
0 | 165349.20 | 136897.80 | 471784.10 | New York |
1 | 162597.70 | 151377.59 | 443898.53 | California |
2 | 153441.51 | 101145.55 | 407934.54 | Florida |
3 | 144372.41 | 118671.85 | 383199.62 | New York |
4 | 142107.34 | 91391.77 | 366168.42 | Florida |
5 | 131876.90 | 99814.71 | 362861.36 | New York |
6 | 134615.46 | 147198.87 | 127716.82 | California |
7 | 130298.13 | 145530.06 | 323876.68 | Florida |
8 | 120542.52 | 148718.95 | 311613.29 | New York |
9 | 123334.88 | 108679.17 | 304981.62 | California |
10 | 101913.08 | 110594.11 | 229160.95 | Florida |
11 | 100671.96 | 91790.61 | 249744.55 | California |
12 | 93863.75 | 127320.38 | 249839.44 | Florida |
13 | 91992.39 | 135495.07 | 252664.93 | California |
14 | 119943.24 | 156547.42 | 256512.92 | Florida |
15 | 114523.61 | 122616.84 | 261776.23 | New York |
16 | 78013.11 | 121597.55 | 264346.06 | California |
17 | 94657.16 | 145077.58 | 282574.31 | New York |
18 | 91749.16 | 114175.79 | 294919.57 | Florida |
19 | 86419.70 | 153514.11 | 0.00 | New York |
20 | 76253.86 | 113867.30 | 298664.47 | California |
21 | 78389.47 | 153773.43 | 299737.29 | New York |
22 | 73994.56 | 122782.75 | 303319.26 | Florida |
23 | 67532.53 | 105751.03 | 304768.73 | Florida |
24 | 77044.01 | 99281.34 | 140574.81 | New York |
25 | 64664.71 | 139553.16 | 137962.62 | California |
26 | 75328.87 | 144135.98 | 134050.07 | Florida |
27 | 72107.60 | 127864.55 | 353183.81 | New York |
28 | 66051.52 | 182645.56 | 118148.20 | Florida |
29 | 65605.48 | 153032.06 | 107138.38 | New York |
30 | 61994.48 | 115641.28 | 91131.24 | Florida |
31 | 61136.38 | 152701.92 | 88218.23 | New York |
32 | 63408.86 | 129219.61 | 46085.25 | California |
33 | 55493.95 | 103057.49 | 214634.81 | Florida |
34 | 46426.07 | 157693.92 | 210797.67 | California |
35 | 46014.02 | 85047.44 | 205517.64 | New York |
36 | 28663.76 | 127056.21 | 201126.82 | Florida |
37 | 44069.95 | 51283.14 | 197029.42 | California |
38 | 20229.59 | 65947.93 | 185265.10 | New York |
39 | 38558.51 | 82982.09 | 174999.30 | California |
40 | 28754.33 | 118546.05 | 172795.67 | California |
41 | 27892.92 | 84710.77 | 164470.71 | Florida |
42 | 23640.93 | 96189.63 | 148001.11 | California |
43 | 15505.73 | 127382.30 | 35534.17 | New York |
44 | 22177.74 | 154806.14 | 28334.72 | California |
45 | 1000.23 | 124153.04 | 1903.93 | New York |
46 | 1315.46 | 115816.21 | 297114.46 | Florida |
47 | 0.00 | 135426.92 | 0.00 | California |
48 | 542.05 | 51743.15 | 0.00 | New York |
49 | 0.00 | 116983.80 | 45173.06 | California |
In [29]:
ct = ColumnTransformer( [ ("encoder", OneHotEncoder(), [3]) ], remainder = 'passthrough')
In [30]:
ct.fit_transform( X )
Out[30]:
array([[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.6534920e+05, 1.3689780e+05, 4.7178410e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.6259770e+05, 1.5137759e+05, 4.4389853e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.5344151e+05, 1.0114555e+05, 4.0793454e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.4437241e+05, 1.1867185e+05, 3.8319962e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.4210734e+05, 9.1391770e+04, 3.6616842e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.3187690e+05, 9.9814710e+04, 3.6286136e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.3461546e+05, 1.4719887e+05, 1.2771682e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.3029813e+05, 1.4553006e+05, 3.2387668e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.2054252e+05, 1.4871895e+05, 3.1161329e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.2333488e+05, 1.0867917e+05, 3.0498162e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0191308e+05, 1.1059411e+05, 2.2916095e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0067196e+05, 9.1790610e+04, 2.4974455e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 9.3863750e+04, 1.2732038e+05, 2.4983944e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 9.1992390e+04, 1.3549507e+05, 2.5266493e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.1994324e+05, 1.5654742e+05, 2.5651292e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.1452361e+05, 1.2261684e+05, 2.6177623e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.8013110e+04, 1.2159755e+05, 2.6434606e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 9.4657160e+04, 1.4507758e+05, 2.8257431e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 9.1749160e+04, 1.1417579e+05, 2.9491957e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 8.6419700e+04, 1.5351411e+05, 0.0000000e+00], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.6253860e+04, 1.1386730e+05, 2.9866447e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.8389470e+04, 1.5377343e+05, 2.9973729e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 7.3994560e+04, 1.2278275e+05, 3.0331926e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.7532530e+04, 1.0575103e+05, 3.0476873e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.7044010e+04, 9.9281340e+04, 1.4057481e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.4664710e+04, 1.3955316e+05, 1.3796262e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 7.5328870e+04, 1.4413598e+05, 1.3405007e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.2107600e+04, 1.2786455e+05, 3.5318381e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.6051520e+04, 1.8264556e+05, 1.1814820e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 6.5605480e+04, 1.5303206e+05, 1.0713838e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.1994480e+04, 1.1564128e+05, 9.1131240e+04], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 6.1136380e+04, 1.5270192e+05, 8.8218230e+04], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.3408860e+04, 1.2921961e+05, 4.6085250e+04], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 5.5493950e+04, 1.0305749e+05, 2.1463481e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.6426070e+04, 1.5769392e+05, 2.1079767e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 4.6014020e+04, 8.5047440e+04, 2.0551764e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 2.8663760e+04, 1.2705621e+05, 2.0112682e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.4069950e+04, 5.1283140e+04, 1.9702942e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 2.0229590e+04, 6.5947930e+04, 1.8526510e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.8558510e+04, 8.2982090e+04, 1.7499930e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.8754330e+04, 1.1854605e+05, 1.7279567e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 2.7892920e+04, 8.4710770e+04, 1.6447071e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.3640930e+04, 9.6189630e+04, 1.4800111e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.5505730e+04, 1.2738230e+05, 3.5534170e+04], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.2177740e+04, 1.5480614e+05, 2.8334720e+04], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.0002300e+03, 1.2415304e+05, 1.9039300e+03], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.3154600e+03, 1.1581621e+05, 2.9711446e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.3542692e+05, 0.0000000e+00], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 5.4205000e+02, 5.1743150e+04, 0.0000000e+00], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.1698380e+05, 4.5173060e+04]])
In [31]:
X.head(3)
Out[31]:
R&D Spend | Administration | Marketing Spend | State | |
---|---|---|---|---|
0 | 165349.20 | 136897.80 | 471784.10 | New York |
1 | 162597.70 | 151377.59 | 443898.53 | California |
2 | 153441.51 | 101145.55 | 407934.54 | Florida |
In [32]:
X = ct.fit_transform( X )
In [34]:
from sklearn.model_selection import train_test_split
In [35]:
train_test_split(X , y, test_size=0.2, random_state=1)
Out[35]:
[array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.3408860e+04, 1.2921961e+05, 4.6085250e+04], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.8558510e+04, 8.2982090e+04, 1.7499930e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.8389470e+04, 1.5377343e+05, 2.9973729e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 2.8663760e+04, 1.2705621e+05, 2.0112682e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 8.6419700e+04, 1.5351411e+05, 0.0000000e+00], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.3640930e+04, 9.6189630e+04, 1.4800111e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.1698380e+05, 4.5173060e+04], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 7.5328870e+04, 1.4413598e+05, 1.3405007e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 7.3994560e+04, 1.2278275e+05, 3.0331926e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 9.1992390e+04, 1.3549507e+05, 2.5266493e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 2.7892920e+04, 8.4710770e+04, 1.6447071e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 9.4657160e+04, 1.4507758e+05, 2.8257431e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.0002300e+03, 1.2415304e+05, 1.9039300e+03], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.7044010e+04, 9.9281340e+04, 1.4057481e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.7532530e+04, 1.0575103e+05, 3.0476873e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.4210734e+05, 9.1391770e+04, 3.6616842e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 5.5493950e+04, 1.0305749e+05, 2.1463481e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.1994324e+05, 1.5654742e+05, 2.5651292e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.1994480e+04, 1.1564128e+05, 9.1131240e+04], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0191308e+05, 1.1059411e+05, 2.2916095e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 6.6051520e+04, 1.8264556e+05, 1.1814820e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.2177740e+04, 1.5480614e+05, 2.8334720e+04], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.6426070e+04, 1.5769392e+05, 2.1079767e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 9.1749160e+04, 1.1417579e+05, 2.9491957e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.6253860e+04, 1.1386730e+05, 2.9866447e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.4664710e+04, 1.3955316e+05, 1.3796262e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.3461546e+05, 1.4719887e+05, 1.2771682e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.3029813e+05, 1.4553006e+05, 3.2387668e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.3542692e+05, 0.0000000e+00], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.6259770e+05, 1.5137759e+05, 4.4389853e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.8013110e+04, 1.2159755e+05, 2.6434606e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.6534920e+05, 1.3689780e+05, 4.7178410e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.1452361e+05, 1.2261684e+05, 2.6177623e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.3187690e+05, 9.9814710e+04, 3.6286136e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0067196e+05, 9.1790610e+04, 2.4974455e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.2333488e+05, 1.0867917e+05, 3.0498162e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.2054252e+05, 1.4871895e+05, 3.1161329e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 9.3863750e+04, 1.2732038e+05, 2.4983944e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.5505730e+04, 1.2738230e+05, 3.5534170e+04], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.4069950e+04, 5.1283140e+04, 1.9702942e+05]]), array([[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.2107600e+04, 1.2786455e+05, 3.5318381e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 4.6014020e+04, 8.5047440e+04, 2.0551764e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.8754330e+04, 1.1854605e+05, 1.7279567e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 2.0229590e+04, 6.5947930e+04, 1.8526510e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.5344151e+05, 1.0114555e+05, 4.0793454e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.4437241e+05, 1.1867185e+05, 3.8319962e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 5.4205000e+02, 5.1743150e+04, 0.0000000e+00], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 6.5605480e+04, 1.5303206e+05, 1.0713838e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.3154600e+03, 1.1581621e+05, 2.9711446e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 6.1136380e+04, 1.5270192e+05, 8.8218230e+04]]), 32 97427.84 39 81005.76 21 111313.02 36 90708.19 19 122776.86 42 71498.49 49 14681.40 26 105733.54 22 110352.25 13 134307.35 41 77798.83 17 125370.37 45 64926.08 24 108552.04 23 108733.99 4 166187.94 33 96778.92 14 132602.65 30 99937.59 10 146121.95 28 103282.38 44 65200.33 34 96712.80 18 124266.90 20 118474.03 25 107404.34 6 156122.51 7 155752.60 47 42559.73 1 191792.06 16 126992.93 0 192261.83 15 129917.04 5 156991.12 11 144259.40 9 149759.96 8 152211.77 12 141585.52 43 69758.98 37 89949.14 Name: Profit, dtype: float64, 27 105008.31 35 96479.51 40 78239.91 38 81229.06 2 191050.39 3 182901.99 48 35673.41 29 101004.64 46 49490.75 31 97483.56 Name: Profit, dtype: float64]
In [36]:
X_train, X_test, y_train, y_test = train_test_split(X , y, test_size=0.2, random_state=1)
In [ ]:
In [38]:
from sklearn.linear_model import LinearRegression
In [39]:
regressor = LinearRegression()
In [40]:
regressor.fit(X_train,y_train)
Out[40]:
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
In [45]:
X_test.shape
Out[45]:
(10, 6)
In [47]:
y_pred = regressor.predict(X_test)
In [48]:
y_pred
Out[48]:
array([114664.41715867, 90593.1553162 , 75692.84151574, 70221.88679651, 179790.25514874, 171576.92018521, 49753.58752029, 102276.65888936, 58649.37795761, 98272.02561131])
In [49]:
y_test
Out[49]:
27 105008.31 35 96479.51 40 78239.91 38 81229.06 2 191050.39 3 182901.99 48 35673.41 29 101004.64 46 49490.75 31 97483.56 Name: Profit, dtype: float64
In [50]:
y_test-y_pred
Out[50]:
27 -9656.107159 35 5886.354684 40 2547.068484 38 11007.173203 2 11260.134851 3 11325.069815 48 -14080.177520 29 -1272.018889 46 -9158.627958 31 -788.465611 Name: Profit, dtype: float64
In [52]:
error = y_test-y_pred
In [54]:
(error ** 2).mean()
Out[54]:
79495441.50407246
In [62]:
plt.plot(y_test.values)
plt.plot(y_pred)
plt.legend(["real","pred"])
plt.savefig("chart1.jpg")
plt.show()
In [ ]:
In [63]:
# 만든 인공지능을, 서비스로 배포(deploy) 하는것이 중요하다.
In [ ]:
In [64]:
# 실제 예측을 해보자.
In [ ]:
# 운영비는 15만달러, 마케팅비는 40만달러, 연구개발비는 13만달러이고,
# 회사는 Florida 에 있다.
# 이회사는 얼마의 수익을 얻을 수 있을까 예측???
In [69]:
df.head(2)
Out[69]:
R&D Spend | Administration | Marketing Spend | State | Profit | |
---|---|---|---|---|---|
0 | 165349.2 | 136897.80 | 471784.10 | New York | 192261.83 |
1 | 162597.7 | 151377.59 | 443898.53 | California | 191792.06 |
In [73]:
X_test
Out[73]:
array([[0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 7.2107600e+04, 1.2786455e+05, 3.5318381e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 4.6014020e+04, 8.5047440e+04, 2.0551764e+05], [1.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.8754330e+04, 1.1854605e+05, 1.7279567e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 2.0229590e+04, 6.5947930e+04, 1.8526510e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.5344151e+05, 1.0114555e+05, 4.0793454e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 1.4437241e+05, 1.1867185e+05, 3.8319962e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 5.4205000e+02, 5.1743150e+04, 0.0000000e+00], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 6.5605480e+04, 1.5303206e+05, 1.0713838e+05], [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.3154600e+03, 1.1581621e+05, 2.9711446e+05], [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 6.1136380e+04, 1.5270192e+05, 8.8218230e+04]])
In [72]:
new_data = np.array([130000, 150000, 400000, "Florida"])
In [77]:
new_data = new_data.reshape(1,4) # 트랜스폼은 2차원이기 때문에 , 1행4열 2차원으로 변경
In [80]:
new_data = ct.transform(new_data).astype(float) # fit은 교육시킬때만 쓴다.
C:\Users\5-10\Anaconda3\envs\YH\lib\site-packages\sklearn\base.py:450: UserWarning: X does not have valid feature names, but OneHotEncoder was fitted with feature names warnings.warn(
In [81]:
new_data
Out[81]:
array([[0.0e+00, 1.0e+00, 0.0e+00, 1.3e+05, 1.5e+05, 4.0e+05]])
In [82]:
regressor.predict(new_data) # 예측한 결과값
Out[82]:
array([160947.68743064])
In [ ]:
In [84]:
regressor.coef_ # X에 붙어있는 값 y= ax1 + bx2 + cx3 + dx4 + fx5 + hx6 + g
Out[84]:
array([-2.85177769e+02, 2.97560876e+02, -1.23831070e+01, 7.74342081e-01, -9.44369585e-03, 2.89183133e-02])
In [ ]:
regressor.intercept_ # 상수항
'DataScience > MachineLearning' 카테고리의 다른 글
Machine [supervised{Classification(Logisticregression)}] (0) | 2022.12.02 |
---|---|
Machine 예측 모델 실습, 배포를 위한 저장 (0) | 2022.12.01 |
Machine [supervised{Prediction(Linear Regression)}] (0) | 2022.12.01 |
Machine preprocessing, Feature Scaling, Dataset Training & Test (0) | 2022.12.01 |
Machine 원핫 인코딩 (One Hot Encoding) (0) | 2022.12.01 |