from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
X, y = make_blobs(n_samples=100,
centers=4,
cluster_std=1,
random_state=10
)
plt.scatter(X[:, 0], X[:, 1]);
from sklearn.cluster import KMeans
km = KMeans(n_clusters = 4)
km.fit(X)
km.labels_
array([3, 1, 3, 1, 0, 2, 2, 3, 0, 2, 3, 2, 0, 2, 0, 2, 0, 1, 0, 2, 2, 2,
1, 3, 2, 2, 2, 0, 3, 1, 2, 2, 0, 2, 2, 0, 3, 1, 3, 1, 1, 1, 1, 0,
3, 3, 2, 0, 0, 0, 0, 0, 2, 3, 0, 0, 2, 1, 1, 3, 0, 1, 1, 2, 2, 1,
0, 1, 0, 3, 3, 2, 1, 0, 2, 2, 3, 1, 2, 3, 0, 0, 1, 1, 0, 3, 1, 3,
1, 3, 0, 3, 3, 1, 1, 1, 3, 3, 3, 3], dtype=int32)
km.cluster_centers_
array([[ 5.49855163, -9.40880959],
[ 2.79419702, 4.79694276],
[ 0.348301 , -5.45307298],
[-6.04098578, 5.06798706]])
X
array([[ -5.57785425, 5.87298826],
[ 1.62783216, 4.17806883],
[ -6.37184387, 4.41922347],
[ 1.75005543, 5.44582908],
[ 6.55010412, -7.9123388 ],
[ -0.66982236, -5.19023657],
[ 0.48085466, -5.08976945],
[ -7.45962322, 4.53166747],
[ 5.55912116, -10.06110303],
[ -1.25569573, -5.72586023],
[ -5.03188157, 4.91618824],
[ 1.31006656, -5.47475738],
[ 6.68288513, -10.31693051],
[ 0.6769707 , -6.29133602],
[ 5.69192445, -9.47641249],
[ -0.07790108, -5.98485443],
[ 6.2686376 , -9.38138022],
[ 2.61105267, 4.22218469],
[ 6.91094987, -10.6647659 ],
[ -1.15296379, -5.89279504],
[ -0.31748917, -6.86337766],
[ 1.20634557, -3.03874201],
[ 2.44078244, 4.47434875],
[ -7.06349567, 5.37101341],
[ 0.34789333, -3.88965912],
[ 0.99265635, -5.33725682],
[ 0.26308097, -5.97487434],
[ 5.15516488, -8.97175683],
[ -7.7498139 , 5.82291156],
[ 2.60711685, 2.84436554],
[ 1.35337248, -5.15783397],
[ 2.37446585, -6.24342383],
[ 5.4307043 , -9.75956122],
[ -0.32584361, -4.65585848],
[ 0.3024902 , -4.36909392],
[ 6.08664442, -9.9358329 ],
[ -5.09161663, 4.18830355],
[ 2.62413419, 5.36941887],
[ -5.79412818, 5.03331542],
[ 2.98771848, 7.44372871],
[ 3.78067293, 5.22062163],
[ 3.38492372, 5.8943468 ],
[ 2.79044036, 3.06862076],
[ 4.46134719, -8.55668693],
[ -5.68526509, 5.00333476],
[ -5.49031464, 5.81381329],
[ -0.2598064 , -6.63361828],
[ 6.73488595, -9.38994773],
[ 5.40050753, -9.29586681],
[ 5.655043 , -9.1398234 ],
[ 4.88653379, -8.87680099],
[ 3.44868458, -11.32833331],
[ 1.62685687, -4.83617748],
[ -4.87157434, 4.63863743],
[ 5.69248303, -7.19999368],
[ 5.52556208, -8.18696464],
[ -0.60297312, -6.82451464],
[ 2.56069223, 4.6138972 ],
[ 2.89022984, 2.98168388],
[ -7.16921799, 3.26931456],
[ 4.48697951, -10.07429823],
[ 2.99232112, 5.43698055],
[ 1.16464321, 5.59667831],
[ -0.34268851, -5.85294901],
[ 1.31468967, -5.01055177],
[ 5.06766836, 5.89353659],
[ 4.62182172, -9.79765865],
[ 2.59184251, 4.44678157],
[ 4.28981065, -9.44982413],
[ -5.76456992, 4.69570432],
[ -4.93861333, 5.77496677],
[ -0.26686394, -5.44678194],
[ 0.81677922, 4.75330395],
[ 5.82662285, -9.92259335],
[ -0.46888599, -5.36296292],
[ 0.24318957, -7.12263784],
[ -6.42972489, 6.46578798],
[ 4.65804929, 6.7208918 ],
[ 1.95552599, -4.05690149],
[ -4.82059385, 5.15409352],
[ 5.15909568, -10.13427003],
[ 5.85943906, -8.38192364],
[ 2.45717481, 5.96515011],
[ 2.52859794, 4.5759393 ],
[ 5.08727262, -9.27279108],
[ -3.81369307, 5.32779566],
[ 2.19087156, 5.06566526],
[ -5.5048579 , 5.95458865],
[ 3.49996332, 3.02156553],
[ -6.44447223, 5.99238943],
[ 5.99156553, -9.73238127],
[ -6.75154214, 4.94975477],
[ -8.33384603, 4.01468493],
[ 3.80174985, 4.27826762],
[ 2.31046552, 4.85417196],
[ 3.71914756, 3.55752162],
[ -7.11844009, 5.08754442],
[ -6.74581415, 5.75727908],
[ -5.0962423 , 2.23101747],
[ -5.90560521, 6.41335812]])
plt.scatter(X[:, 0], X[:, 1], c=km.labels_)
plt.scatter(km.cluster_centers_[:, 0],
km.cluster_centers_[:, 1],
c='r',
marker='o',
s=200
);
Inertia is the sum of squared distance from each data point to the center of its cluster. Smaller inertia means that clusters are more tightly organized.
km.inertia_
187.49933118261387
km2 = KMeans(n_clusters=2)
km2.fit(X)
km2.inertia_
1691.3350029252008
Idea: try different numbers of clusters, check how inertia is changing.
inertia_list = []
for i in range(2, 11):
km = KMeans(n_clusters=i)
km.fit(X)
inertia_list.append(km.inertia_)
plt.plot(range(2, 11), inertia_list, 'ro-');
X, y = make_blobs(n_samples=100,
centers=3,
cluster_std=5,
random_state=10
)
plt.scatter(X[:, 0], X[:, 1], c=y);
inertia_list = []
for i in range(2, 11):
km = KMeans(n_clusters=i)
km.fit(X)
inertia_list.append(km.inertia_)
plt.plot(range(2, 11), inertia_list, 'ro-');
km
KMeans(n_clusters=10)
km.transform(X)
array([[16.75494251, 16.61493925, 21.18778656, 8.13994105, 15.68793953,
26.63742076, 25.92456949, 22.17095862, 14.02963681, 34.1182112 ],
[ 2.3157333 , 8.07063441, 8.61876223, 8.72593814, 4.95748316,
10.8464402 , 9.5273581 , 12.74197984, 17.06343243, 19.4935828 ],
[ 4.47937032, 10.17266727, 6.60877727, 9.10117287, 3.02683895,
12.16827611, 8.99505325, 14.90505377, 18.73362464, 18.19580115],
[ 7.65586021, 2.59484061, 18.4452566 , 12.30480412, 14.59562841,
10.73229532, 16.50174014, 3.42868185, 10.8999757 , 27.40252879],
[11.89859662, 17.66184516, 1.34394996, 14.68756066, 6.16526106,
16.72063105, 8.96216263, 22.04537497, 25.90500139, 13.19430766],
[ 6.11260483, 11.30750146, 7.37586384, 13.65212662, 7.52868542,
8.14407907, 4.67584359, 14.2104449 , 21.28362064, 15.24959959],
[ 9.86428016, 4.57317535, 20.63275151, 14.00619139, 16.79151853,
11.84027405, 18.3901423 , 2.23203731, 10.43893286, 29.29113585],
[ 7.82967995, 13.04709064, 5.74093881, 8.73395242, 0.76908902,
15.87682599, 11.23056181, 18.29140105, 20.23249302, 18.61332564],
[11.64468611, 16.70808545, 7.57867022, 18.67691513, 11.12229182,
10.18070802, 1.00267019, 18.81697008, 26.80431059, 9.91328417],
[ 7.52405332, 12.6634751 , 7.06653246, 14.91742426, 8.27723244,
8.32761781, 3.26957316, 15.30214342, 22.68725283, 13.8842851 ],
[ 7.78113931, 2.33353491, 18.53324941, 9.63722131, 13.8500999 ,
13.67535713, 18.04157957, 6.55550242, 8.06993321, 28.78742321],
[15.29669057, 20.81015414, 5.9318592 , 15.72993324, 8.44360482,
21.23054179, 13.51246094, 25.73925536, 27.87919788, 15.65013275],
[ 1.4692998 , 7.16603228, 9.6151993 , 9.51254041, 6.40532711,
9.54279605, 9.36411625, 11.44282681, 16.70358223, 19.74864293],
[ 2.61639431, 4.53452243, 12.86269636, 10.55441275, 9.63113906,
8.20081405, 11.22302902, 8.09583506, 14.81306876, 22.03530323],
[ 9.33760394, 6.85535561, 18.30015916, 5.4674635 , 12.72999526,
17.95965352, 20.06206457, 12.11312354, 7.22577515, 30.12655551],
[14.75262842, 19.44063474, 10.60100487, 22.09718238, 14.57228375,
11.08973338, 4.06934619, 20.77815145, 29.71752998, 8.14863178],
[ 5.91584385, 5.47389159, 14.81417638, 4.26538156, 9.39524809,
15.24454821, 16.52209703, 11.46446872, 10.44910024, 26.52536985],
[ 3.06619974, 5.68804587, 12.3283301 , 11.49487103, 9.64276447,
7.14213948, 10.1159814 , 8.5702647 , 16.02251859, 20.97604454],
[ 6.9282631 , 1.20238025, 17.78292218, 9.93124432, 13.33326979,
12.46382779, 16.95820636, 5.87622132, 9.2686376 , 27.74954595],
[ 5.92903199, 11.73041225, 4.97905887, 11.09425518, 3.74195178,
11.78186351, 7.1853025 , 15.97977401, 20.68165516, 16.13878642],
[ 1.80502361, 4.7744912 , 12.10328284, 7.23482577, 7.62115417,
11.40782772, 12.56805237, 10.19049384, 13.56261755, 22.90667222],
[ 8.79381739, 12.8835073 , 10.14343972, 17.03666709, 11.26940602,
5.56928323, 3.80740727, 14.21168221, 23.24446118, 14.43543179],
[ 4.91393005, 10.06330505, 7.5499342 , 7.3005052 , 2.33227439,
13.77683405, 10.92498939, 15.32355732, 17.72836289, 19.77724619],
[ 3.31187616, 7.35637639, 10.77523855, 11.94548124, 8.70848008,
6.88958514, 8.40206856, 10.21858758, 17.59333589, 19.23368073],
[ 2.92678837, 7.66309999, 9.98460856, 11.42627445, 7.82633802,
7.63764941, 8.16319034, 10.93711445, 17.73624579, 18.88672308],
[13.34934843, 18.28028862, 7.5226446 , 11.72732689, 6.32470186,
21.13581217, 15.00239495, 23.74986234, 24.16672032, 19.49779875],
[ 7.74518046, 13.54659002, 3.18493548, 12.33546035, 4.1275173 ,
12.81908545, 6.76096211, 17.71515594, 22.40308765, 14.70262858],
[11.12518845, 16.24972467, 7.14622711, 18.10296312, 10.55708115,
10.08961491, 0.81789558, 18.49266955, 26.30666017, 10.32516588],
[14.0103669 , 19.80786851, 3.38490178, 17.74013074, 9.13474285,
17.01093159, 8.19181007, 23.64807148, 28.58957119, 10.09053221],
[ 5.39703077, 2.05668171, 16.12423484, 11.17136852, 12.42990077,
9.46411107, 14.33328853, 5.17841156, 12.26498639, 25.21317132],
[ 6.13264407, 0.53931656, 16.98720636, 9.51190613, 12.56964739,
12.03384745, 16.23011839, 6.21820612, 9.87129305, 26.99879089],
[11.08706968, 16.452585 , 4.56275556, 11.54177932, 4.06072632,
18.15482587, 11.8660662 , 21.5731658 , 23.49135135, 17.1559407 ],
[ 7.19018674, 12.86468868, 4.77869721, 13.30530124, 5.86047102,
10.70652261, 4.92715744, 16.42587855, 22.34001253, 14.19236464],
[11.78092505, 7.3082788 , 22.09521592, 16.98864612, 18.87045166,
10.93778769, 18.75456038, 1.32253474, 13.10815694, 29.5182774 ],
[ 6.52133961, 11.17534661, 7.94512769, 6.45874147, 2.16207412,
15.64138377, 12.41917265, 16.71542401, 17.88869294, 20.64486815],
[ 7.37266018, 13.17249988, 3.52722748, 11.93853655, 3.8126542 ,
12.73624054, 6.99129148, 17.4076786 , 21.98955161, 15.10830596],
[ 2.72394449, 3.23378716, 13.57482499, 9.35738275, 9.72264254,
9.67107708, 12.62685887, 7.86641072, 13.30836228, 23.35514039],
[14.63897707, 20.42723692, 4.15477405, 18.53497969, 9.92357643,
17.22113372, 8.22734453, 24.14217463, 29.31095816, 9.31982636],
[ 1.80156746, 5.27971787, 11.6851714 , 6.95461895, 7.11290129,
11.66749112, 12.42561679, 10.71606767, 13.85072464, 22.65937897],
[ 0.43822395, 5.3735614 , 11.31232142, 8.51024335, 7.41511798,
10.121057 , 11.17865327, 10.15466165, 14.8125573 , 21.64223557],
[12.51630044, 15.70140411, 12.42280443, 6.71885974, 7.83352275,
22.14392206, 18.56708233, 21.68223788, 19.01411847, 25.26295068],
[ 3.96552034, 7.10573234, 11.72966223, 12.59702596, 9.74601072,
6.0686339 , 8.80484119, 9.39994171, 17.46559641, 19.69963648],
[ 9.31130733, 11.25821436, 13.44443318, 1.87751558, 7.74223082,
19.33560135, 17.82200743, 17.27798896, 14.34486577, 26.26421826],
[ 4.26672497, 4.28675467, 13.97399326, 5.6870683 , 8.90638861,
13.44578789, 15.01086574, 10.28328106, 11.34030891, 25.2406113 ],
[12.68790306, 17.51237689, 8.98784536, 19.99025968, 12.56690677,
9.99400366, 1.95836665, 19.19652795, 27.72960087, 9.45396986],
[ 7.76838297, 2.03559781, 18.65728104, 11.28084056, 14.42813754,
12.08521942, 17.31542227, 4.58089001, 9.50338359, 28.18042038],
[ 7.31350461, 3.19353046, 17.89930501, 12.86379084, 14.37878717,
9.48150457, 15.5259591 , 3.20828412, 12.1585483 , 26.42965806],
[ 3.15204971, 7.28925338, 10.32142702, 5.87668742, 5.22839569,
13.10786407, 12.36418249, 12.85024085, 14.91642531, 22.03906277],
[ 8.90586059, 11.6044272 , 12.06212322, 3.05293849, 6.41839842,
18.8456153 , 16.73466452, 17.59990681, 15.56683909, 24.92394012],
[13.83765433, 9.13542423, 24.17195713, 18.68170717, 20.91983614,
12.53948369, 20.68634526, 3.35131049, 13.46693515, 31.38330831],
[22.44814038, 27.77248089, 14.07141858, 28.21375786, 19.69224414,
20.31945707, 12.11017813, 29.81676171, 37.68310475, 1.23104611],
[ 7.35316019, 8.02599185, 15.02912826, 15.66322295, 13.49620243,
3.4767599 , 10.49166523, 7.55560757, 18.13858111, 21.1754528 ],
[16.01118037, 12.49682298, 24.7119701 , 10.6610505 , 18.99566276,
23.99107032, 26.76304732, 16.19641182, 4.69147157, 36.79162869],
[ 8.02045818, 10.79763399, 12.59656122, 16.66063315, 12.45138092,
2.80423591, 6.97383778, 11.24557079, 21.17231412, 17.50394 ],
[ 5.23390362, 5.80630637, 13.81575068, 4.06791452, 8.39920417,
14.86153911, 15.68285653, 11.82545704, 11.45024928, 25.58792574],
[13.65595254, 13.39483177, 19.9795499 , 21.8894412 , 19.46802541,
4.60125418, 13.64743411, 10.06639342, 22.47175216, 22.87761014],
[11.57629061, 17.35301624, 2.08216464, 16.07406828, 7.50459152,
14.57302598, 6.15627358, 21.07057414, 26.39674704, 11.02856187],
[ 1.66505126, 7.23619793, 9.49906454, 8.04232181, 5.43150822,
11.07735722, 10.40904958, 12.14246405, 16.10127207, 20.45138536],
[ 8.41162097, 3.83705026, 19.02216239, 13.57032233, 15.45836092,
10.10447962, 16.54767288, 2.22024545, 11.76694425, 27.445522 ],
[ 5.63841332, 1.44679112, 16.34514235, 8.34332743, 11.70282431,
12.5956293 , 16.1054639 , 7.38625499, 9.80483679, 26.76244095],
[21.87769306, 27.53586019, 11.99192383, 26.388371 , 17.77274035,
21.86135049, 12.7818789 , 30.45411588, 36.89886426, 4.48553137],
[12.40289419, 14.99351633, 14.78809241, 21.01163733, 16.0139168 ,
3.91241251, 7.47353047, 14.32419488, 25.26562863, 16.04301261],
[ 9.83082333, 11.05649737, 14.73328328, 1.33033734, 8.99674074,
19.84005811, 18.8544438 , 17.04323469, 13.22261203, 27.50887418],
[ 5.94837975, 11.59650899, 5.36811347, 9.54436571, 1.93423522,
13.26071862, 9.0167029 , 16.38242474, 19.86854285, 17.49660242],
[ 4.3254928 , 1.99037328, 15.14983733, 10.16093919, 11.30266812,
9.7420203 , 13.80120198, 6.33931525, 12.37649716, 24.62488478],
[11.22936082, 6.65748009, 21.63259828, 16.32967713, 18.30993658,
10.89732162, 18.48133326, 0.74724869, 12.63078926, 29.28697731],
[ 3.97397632, 3.22500419, 14.50459141, 10.90246805, 11.06901405,
8.53263328, 12.68629955, 6.52853655, 13.62157379, 23.54980924],
[ 2.98797809, 8.76625684, 7.91866954, 9.10349155, 4.51540811,
10.94030934, 9.0144785 , 13.34356495, 17.7488928 , 18.84612396],
[ 7.32842019, 6.79732995, 15.51549478, 3.20119135, 9.89727621,
16.7820244 , 17.76649555, 12.69870327, 10.05436754, 27.54926301],
[ 5.18170142, 10.93191579, 5.79258321, 9.82694086, 2.93987472,
12.21449917, 8.38168715, 15.51655042, 19.58194885, 17.37840609],
[20.28856186, 25.06968069, 14.28990482, 27.23995493, 19.21502728,
16.19324361, 9.50642497, 26.23332191, 35.33263349, 4.69510099],
[ 9.16246193, 7.0909262 , 18.39134792, 16.32676525, 16.05179991,
6.34764025, 14.28369674, 4.07189177, 16.00053699, 24.94871507],
[ 8.22535156, 14.01005685, 2.68737013, 12.15732642, 3.69402253,
13.69112097, 7.48978687, 18.36186199, 22.60567262, 14.87805343],
[10.74979987, 16.5387096 , 0.25428688, 14.16388686, 5.54498893,
15.41004298, 7.87102593, 20.7961501 , 25.03623443, 13.19420005],
[ 3.14671773, 5.43518 , 12.28080852, 5.66239165, 7.25916255,
12.97375675, 13.6169771 , 11.23209836, 13.00654286, 23.67375113],
[ 8.39469598, 13.538931 , 5.84123625, 8.80864439, 1.37430074,
16.4767666 , 11.67544693, 18.84032321, 20.51156009, 18.76802749],
[11.1942473 , 15.92692774, 8.76554561, 18.72607945, 11.65637816,
8.66541172, 1.04141849, 17.62881386, 26.16860881, 11.01265305],
[15.87166919, 10.19956115, 26.58899352, 15.62188356, 21.67147293,
20.01341109, 25.86480812, 9.9302742 , 4.69147157, 36.71377508],
[ 1.98698701, 3.82115295, 12.86261445, 8.47411625, 8.7787422 ,
10.31841885, 12.47408917, 8.86089847, 13.43008316, 23.06710662],
[ 6.13935751, 0.46009776, 17.02865386, 10.17467106, 12.80617588,
11.38768615, 15.92541062, 5.56863518, 10.37272544, 26.74611531],
[ 4.2567513 , 3.52576123, 14.62492089, 11.35818657, 11.34213862,
8.11820089, 12.55278823, 6.31991525, 13.91079726, 23.43703385],
[13.17580858, 18.88077685, 2.95917364, 15.15569745, 6.92669784,
18.33279942, 10.49255918, 23.44886127, 26.75555972, 13.6938363 ],
[ 6.52747686, 3.68342161, 16.54385552, 6.47341071, 11.40102915,
14.65350731, 17.30107144, 9.4165567 , 8.7990449 , 27.69644991],
[14.34151903, 19.81162375, 6.70049197, 20.1917446 , 11.84135326,
14.08683119, 4.84343756, 22.46506502, 29.56036272, 6.94227407],
[ 7.68200432, 9.2988088 , 13.37386806, 1.11134345, 7.59244066,
17.694021 , 16.92840814, 15.31869804, 13.07904022, 25.9408323 ],
[10.57381853, 16.35324913, 0.59001851, 13.84898891, 5.23345281,
15.46024016, 8.0748444 , 20.67909062, 24.77642048, 13.52971795],
[11.71900826, 12.26512645, 16.69671864, 3.07386633, 11.00320802,
21.67019837, 20.93221223, 18.13154733, 12.71811233, 29.52518175],
[ 3.46642667, 2.55823971, 14.24137928, 8.06903147, 9.83446702,
11.29147819, 14.00599199, 8.21425614, 11.89233714, 24.60960691],
[ 2.54454597, 7.53835816, 9.82928169, 10.98341118, 7.44809779,
8.09083785, 8.38979486, 11.06216636, 17.51035294, 19.04134468],
[ 2.78851561, 8.48758785, 8.23391282, 8.56192676, 4.42951521,
11.28740024, 9.56136022, 13.24694956, 17.29569854, 19.3494358 ],
[ 7.55339962, 2.04501775, 18.33349453, 9.67936296, 13.70161239,
13.38212551, 17.76939742, 6.38909321, 8.36360246, 28.5254254 ],
[13.94957912, 8.55343262, 24.68005111, 17.49997907, 20.86893824,
14.50625537, 22.00438253, 4.21748611, 10.88085787, 32.85008306],
[ 7.77803663, 2.02233748, 18.6316379 , 10.41585676, 14.14842144,
12.97254416, 17.74719557, 5.62791675, 8.64958943, 28.55812033],
[ 2.82586532, 4.95023228, 12.46666644, 6.16589665, 7.6005984 ,
12.50410725, 13.47007043, 10.69441694, 12.91441036, 23.65339578],
[ 5.00960027, 8.01876873, 11.24932303, 3.94650077, 5.64645658,
15.02790106, 14.13476136, 13.89569488, 14.19627285, 23.45973075],
[ 7.53135963, 11.39035955, 9.61652149, 4.93041644, 3.92168107,
17.12664885, 14.32448217, 17.19856768, 16.98366065, 22.43872952],
[ 0.49073567, 5.92825498, 10.76290914, 8.19089298, 6.7572908 ,
10.50960704, 11.00447248, 10.81270715, 15.12246836, 21.33720063],
[ 8.06167598, 4.69905553, 17.90963566, 6.69782556, 12.62350104,
16.00385092, 18.83862743, 9.92407917, 7.34641747, 29.20843045],
[ 7.12215682, 4.86268216, 17.01807982, 13.99741222, 14.15882844,
7.25365578, 13.83226343, 4.02122515, 14.390401 , 24.69602647],
[ 3.90108446, 2.08344485, 14.77069326, 9.63900198, 10.81115071,
10.02306928, 13.68452686, 6.87470049, 12.35397117, 24.4646711 ]])
from pathlib import Path
import requests
import numpy as np
import gzip
mnist_url = "http://yann.lecun.com/exdb/mnist/"
img_file = "train-images-idx3-ubyte.gz"
labels_file = "train-labels-idx1-ubyte.gz"
for fname in [img_file, labels_file]:
if Path(fname).is_file() :
print(f"Found: {fname}")
continue
print(f"Downloading: {fname}")
r = requests.get(mnist_url + fname)
with open(fname, 'wb') as foo:
foo.write(r.content)
with gzip.open(img_file, 'rb') as foo:
f = foo.read()
images = np.array([b for b in f[16:]]).reshape(-1, 28*28)
with gzip.open(labels_file, 'rb') as foo:
f = foo.read()
labels = np.array([b for b in f[8:]])
Found: train-images-idx3-ubyte.gz Found: train-labels-idx1-ubyte.gz
images
array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]])
labels
array([5, 0, 4, ..., 5, 6, 8])
images = images[:5000]
labels = labels[:5000]
km = KMeans(n_clusters=2)
km.fit(images)
KMeans(n_clusters=2)
reduced = km.transform(images)
reduced
array([[1881.19641716, 1826.21606123],
[2154.46745948, 1740.15736307],
[1929.54455919, 2069.64256587],
...,
[2041.06512567, 1827.45759362],
[1302.81517557, 1820.48677329],
[1764.73986247, 1867.56544165]])
sns.__version__
'0.11.2'
import seaborn as sns
plt.figure(figsize=(10,10))
sns.scatterplot(x = reduced[:, 0],
y = reduced[:, 1],
hue=labels,
palette='tab10',
s=100
);
selection = np.isin(labels, [1, 2])
selection
array([False, False, False, ..., True, True, True])
reduced[selection, 0].shape
(1051,)
reduced[selection, 1].shape
(1051,)
selection = np.isin(labels, [0, 1])
plt.figure(figsize=(10,10))
sns.scatterplot(x = reduced[selection, 0],
y = reduced[selection, 1],
hue=labels[selection],
palette='tab10',
s=100
);
import pandas as pd
planets = ["Mercury", "Venus", "Earth", "Mars", "Jupyter", "Saturn", "Uranus", "Neptune"]
diameters = [4879, 12104, 12756, 6792, 142984, 120536, 51118, 49528]
temperatures = [167, 464, 15, -65, -110, -140, -195, -200]
gravity = [3.7, 8.9, 9.8, 3.7, 23.1, 9.0, 8.7, 11.0]
s = pd.Series(diameters)
s
0 4879 1 12104 2 12756 3 6792 4 142984 5 120536 6 51118 7 49528 dtype: int64
s = pd.Series(diameters, index=planets)
s
Mercury 4879 Venus 12104 Earth 12756 Mars 6792 Jupyter 142984 Saturn 120536 Uranus 51118 Neptune 49528 dtype: int64
s['Mars']
6792
s[['Mars', 'Earth']]
Mars 6792 Earth 12756 dtype: int64
s['Earth':'Saturn']
Earth 12756 Mars 6792 Jupyter 142984 Saturn 120536 dtype: int64
s['Pluto'] = 2370
s
Mercury 4879 Venus 12104 Earth 12756 Mars 6792 Jupyter 142984 Saturn 120536 Uranus 51118 Neptune 49528 Pluto 2370 dtype: int64
s.mean()
44785.22222222222
s.max()
142984
s.min()
2370
s.argmax()
4
s.idxmax()
'Jupyter'
s/1.61
Mercury 3030.434783 Venus 7518.012422 Earth 7922.981366 Mars 4218.633540 Jupyter 88809.937888 Saturn 74867.080745 Uranus 31750.310559 Neptune 30762.732919 Pluto 1472.049689 dtype: float64
def size(x):
if x < 10000:
return "small"
else:
return "large"
s.apply(size)
0 small 1 large 2 large 3 small 4 large 5 large 6 large 7 large dtype: object
df = pd.DataFrame(
{"diameter": diameters,
"temperature": temperatures,
"gravity": gravity
}
)
df
| diameter | temperature | gravity | |
|---|---|---|---|
| 0 | 4879 | 167 | 3.7 |
| 1 | 12104 | 464 | 8.9 |
| 2 | 12756 | 15 | 9.8 |
| 3 | 6792 | -65 | 3.7 |
| 4 | 142984 | -110 | 23.1 |
| 5 | 120536 | -140 | 9.0 |
| 6 | 51118 | -195 | 8.7 |
| 7 | 49528 | -200 | 11.0 |
df.index
RangeIndex(start=0, stop=8, step=1)
df.index = planets
df
| diameter | temperature | gravity | |
|---|---|---|---|
| Mercury | 4879 | 167 | 3.7 |
| Venus | 12104 | 464 | 8.9 |
| Earth | 12756 | 15 | 9.8 |
| Mars | 6792 | -65 | 3.7 |
| Jupyter | 142984 | -110 | 23.1 |
| Saturn | 120536 | -140 | 9.0 |
| Uranus | 51118 | -195 | 8.7 |
| Neptune | 49528 | -200 | 11.0 |
df.index
Index(['Mercury', 'Venus', 'Earth', 'Mars', 'Jupyter', 'Saturn', 'Uranus',
'Neptune'],
dtype='object')
df.columns
Index(['diameter', 'temperature', 'gravity'], dtype='object')
df.head(3)
| diameter | temperature | gravity | |
|---|---|---|---|
| Mercury | 4879 | 167 | 3.7 |
| Venus | 12104 | 464 | 8.9 |
| Earth | 12756 | 15 | 9.8 |
df.tail(3)
| diameter | temperature | gravity | |
|---|---|---|---|
| Saturn | 120536 | -140 | 9.0 |
| Uranus | 51118 | -195 | 8.7 |
| Neptune | 49528 | -200 | 11.0 |
df.sample(3)
| diameter | temperature | gravity | |
|---|---|---|---|
| Neptune | 49528 | -200 | 11.0 |
| Venus | 12104 | 464 | 8.9 |
| Mercury | 4879 | 167 | 3.7 |
df['gravity']
Mercury 3.7 Venus 8.9 Earth 9.8 Mars 3.7 Jupyter 23.1 Saturn 9.0 Uranus 8.7 Neptune 11.0 Name: gravity, dtype: float64
df[['gravity', 'diameter']]
| gravity | diameter | |
|---|---|---|
| Mercury | 3.7 | 4879 |
| Venus | 8.9 | 12104 |
| Earth | 9.8 | 12756 |
| Mars | 3.7 | 6792 |
| Jupyter | 23.1 | 142984 |
| Saturn | 9.0 | 120536 |
| Uranus | 8.7 | 51118 |
| Neptune | 11.0 | 49528 |
df.loc['Earth', 'gravity']
9.8
df.loc[['Earth', 'Mars'], 'gravity']
Earth 9.8 Mars 3.7 Name: gravity, dtype: float64
df.loc[['Earth', 'Mars'], ['gravity', 'temperature']]
| gravity | temperature | |
|---|---|---|
| Earth | 9.8 | 15 |
| Mars | 3.7 | -65 |
df.iloc[0, 1]
167
df.iloc[0]
diameter 4879.0 temperature 167.0 gravity 3.7 Name: Mercury, dtype: float64
df[2:5]
| diameter | temperature | gravity | |
|---|---|---|---|
| Earth | 12756 | 15 | 9.8 |
| Mars | 6792 | -65 | 3.7 |
| Jupyter | 142984 | -110 | 23.1 |
df.iloc[2:5, [0, 1]]
| diameter | temperature | |
|---|---|---|
| Earth | 12756 | 15 |
| Mars | 6792 | -65 |
| Jupyter | 142984 | -110 |
df
| diameter | temperature | gravity | |
|---|---|---|---|
| Mercury | 4879 | 167 | 3.7 |
| Venus | 12104 | 464 | 8.9 |
| Earth | 12756 | 15 | 9.8 |
| Mars | 6792 | -65 | 3.7 |
| Jupyter | 142984 | -110 | 23.1 |
| Saturn | 120536 | -140 | 9.0 |
| Uranus | 51118 | -195 | 8.7 |
| Neptune | 49528 | -200 | 11.0 |
df['diameter'] > 10000
Mercury False Venus True Earth True Mars False Jupyter True Saturn True Uranus True Neptune True Name: diameter, dtype: bool
df[df['diameter'] > 10000]
| diameter | temperature | gravity | |
|---|---|---|---|
| Venus | 12104 | 464 | 8.9 |
| Earth | 12756 | 15 | 9.8 |
| Jupyter | 142984 | -110 | 23.1 |
| Saturn | 120536 | -140 | 9.0 |
| Uranus | 51118 | -195 | 8.7 |
| Neptune | 49528 | -200 | 11.0 |
df[(df['temperature'] > 0) & (df['gravity'] > 5)]
| diameter | temperature | gravity | |
|---|---|---|---|
| Venus | 12104 | 464 | 8.9 |
| Earth | 12756 | 15 | 9.8 |
df[(df['temperature'] > 0) | (df['temperature'] < -100)]
| diameter | temperature | gravity | |
|---|---|---|---|
| Mercury | 4879 | 167 | 3.7 |
| Venus | 12104 | 464 | 8.9 |
| Earth | 12756 | 15 | 9.8 |
| Jupyter | 142984 | -110 | 23.1 |
| Saturn | 120536 | -140 | 9.0 |
| Uranus | 51118 | -195 | 8.7 |
| Neptune | 49528 | -200 | 11.0 |
df[~(df['temperature'] > 0)]
| diameter | temperature | gravity | |
|---|---|---|---|
| Mars | 6792 | -65 | 3.7 |
| Jupyter | 142984 | -110 | 23.1 |
| Saturn | 120536 | -140 | 9.0 |
| Uranus | 51118 | -195 | 8.7 |
| Neptune | 49528 | -200 | 11.0 |
df
| diameter | temperature | gravity | |
|---|---|---|---|
| Mercury | 4879 | 167 | 3.7 |
| Venus | 12104 | 464 | 8.9 |
| Earth | 12756 | 15 | 9.8 |
| Mars | 6792 | -65 | 3.7 |
| Jupyter | 142984 | -110 | 23.1 |
| Saturn | 120536 | -140 | 9.0 |
| Uranus | 51118 | -195 | 8.7 |
| Neptune | 49528 | -200 | 11.0 |
df.sort_values(by='gravity', ascending=False)
| diameter | temperature | gravity | |
|---|---|---|---|
| Jupyter | 142984 | -110 | 23.1 |
| Neptune | 49528 | -200 | 11.0 |
| Earth | 12756 | 15 | 9.8 |
| Saturn | 120536 | -140 | 9.0 |
| Venus | 12104 | 464 | 8.9 |
| Uranus | 51118 | -195 | 8.7 |
| Mercury | 4879 | 167 | 3.7 |
| Mars | 6792 | -65 | 3.7 |
df.sort_index()
| diameter | temperature | gravity | |
|---|---|---|---|
| Earth | 12756 | 15 | 9.8 |
| Jupyter | 142984 | -110 | 23.1 |
| Mars | 6792 | -65 | 3.7 |
| Mercury | 4879 | 167 | 3.7 |
| Neptune | 49528 | -200 | 11.0 |
| Saturn | 120536 | -140 | 9.0 |
| Uranus | 51118 | -195 | 8.7 |
| Venus | 12104 | 464 | 8.9 |
df['temp_F'] = df['temperature']*1.8 + 32
df
| diameter | temperature | gravity | temp_F | |
|---|---|---|---|---|
| Mercury | 4879 | 167 | 3.7 | 332.6 |
| Venus | 12104 | 464 | 8.9 | 867.2 |
| Earth | 12756 | 15 | 9.8 | 59.0 |
| Mars | 6792 | -65 | 3.7 | -85.0 |
| Jupyter | 142984 | -110 | 23.1 | -166.0 |
| Saturn | 120536 | -140 | 9.0 | -220.0 |
| Uranus | 51118 | -195 | 8.7 | -319.0 |
| Neptune | 49528 | -200 | 11.0 | -328.0 |
df.columns
Index(['diameter', 'temperature', 'gravity', 'temp_F'], dtype='object')
df1 = df[['diameter', 'temperature', 'temp_F', 'gravity']]
df1
| diameter | temperature | temp_F | gravity | |
|---|---|---|---|---|
| Mercury | 4879 | 167 | 332.6 | 3.7 |
| Venus | 12104 | 464 | 867.2 | 8.9 |
| Earth | 12756 | 15 | 59.0 | 9.8 |
| Mars | 6792 | -65 | -85.0 | 3.7 |
| Jupyter | 142984 | -110 | -166.0 | 23.1 |
| Saturn | 120536 | -140 | -220.0 | 9.0 |
| Uranus | 51118 | -195 | -319.0 | 8.7 |
| Neptune | 49528 | -200 | -328.0 | 11.0 |
df
| diameter | temperature | gravity | temp_F | |
|---|---|---|---|---|
| Mercury | 4879 | 167 | 3.7 | 332.6 |
| Venus | 12104 | 464 | 8.9 | 867.2 |
| Earth | 12756 | 15 | 9.8 | 59.0 |
| Mars | 6792 | -65 | 3.7 | -85.0 |
| Jupyter | 142984 | -110 | 23.1 | -166.0 |
| Saturn | 120536 | -140 | 9.0 | -220.0 |
| Uranus | 51118 | -195 | 8.7 | -319.0 |
| Neptune | 49528 | -200 | 11.0 | -328.0 |
df.mean()
diameter 50087.1250 temperature -8.0000 gravity 9.7375 temp_F 17.6000 dtype: float64
df.min()
diameter 4879.0 temperature -200.0 gravity 3.7 temp_F -328.0 dtype: float64
df['gravity'].idxmin()
'Mercury'
df.loc[df['gravity'].idxmin()]
diameter 4879.0 temperature 167.0 gravity 3.7 temp_F 332.6 Name: Mercury, dtype: float64
df['gravity'].argmin()
0
df.iloc[df['gravity'].argmin()]
diameter 4879.0 temperature 167.0 gravity 3.7 temp_F 332.6 Name: Mercury, dtype: float64
def habitable(p):
if (-100 < p['temperature'] < 50) and p['gravity'] < 12:
return "Yes"
else:
return "No"
df.apply(habitable, axis=1)
Mercury No Venus No Earth Yes Mars Yes Jupyter No Saturn No Uranus No Neptune No dtype: object
df['habitable'] = df.apply(habitable, axis=1)
df
| diameter | temperature | gravity | temp_F | habitable | |
|---|---|---|---|---|---|
| Mercury | 4879 | 167 | 3.7 | 332.6 | No |
| Venus | 12104 | 464 | 8.9 | 867.2 | No |
| Earth | 12756 | 15 | 9.8 | 59.0 | Yes |
| Mars | 6792 | -65 | 3.7 | -85.0 | Yes |
| Jupyter | 142984 | -110 | 23.1 | -166.0 | No |
| Saturn | 120536 | -140 | 9.0 | -220.0 | No |
| Uranus | 51118 | -195 | 8.7 | -319.0 | No |
| Neptune | 49528 | -200 | 11.0 | -328.0 | No |
import pandas as pd
import seaborn as sns
df = sns.load_dataset("titanic")
df
| survived | pclass | sex | age | sibsp | parch | fare | embarked | class | who | adult_male | deck | embark_town | alive | alone | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 3 | male | 22.0 | 1 | 0 | 7.2500 | S | Third | man | True | NaN | Southampton | no | False |
| 1 | 1 | 1 | female | 38.0 | 1 | 0 | 71.2833 | C | First | woman | False | C | Cherbourg | yes | False |
| 2 | 1 | 3 | female | 26.0 | 0 | 0 | 7.9250 | S | Third | woman | False | NaN | Southampton | yes | True |
| 3 | 1 | 1 | female | 35.0 | 1 | 0 | 53.1000 | S | First | woman | False | C | Southampton | yes | False |
| 4 | 0 | 3 | male | 35.0 | 0 | 0 | 8.0500 | S | Third | man | True | NaN | Southampton | no | True |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 886 | 0 | 2 | male | 27.0 | 0 | 0 | 13.0000 | S | Second | man | True | NaN | Southampton | no | True |
| 887 | 1 | 1 | female | 19.0 | 0 | 0 | 30.0000 | S | First | woman | False | B | Southampton | yes | True |
| 888 | 0 | 3 | female | NaN | 1 | 2 | 23.4500 | S | Third | woman | False | NaN | Southampton | no | False |
| 889 | 1 | 1 | male | 26.0 | 0 | 0 | 30.0000 | C | First | man | True | C | Cherbourg | yes | True |
| 890 | 0 | 3 | male | 32.0 | 0 | 0 | 7.7500 | Q | Third | man | True | NaN | Queenstown | no | True |
891 rows × 15 columns
df['deck']
0 NaN
1 C
2 NaN
3 C
4 NaN
...
886 NaN
887 B
888 NaN
889 C
890 NaN
Name: deck, Length: 891, dtype: category
Categories (7, object): ['A', 'B', 'C', 'D', 'E', 'F', 'G']
len(df['deck'])
891
df['deck'].count()
203