-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from florencejt/docs/usemnistexampledata
Docs/usemnistexampledata: Change all examples in docs to use MNIST data
- Loading branch information
Showing
31 changed files
with
3,988 additions
and
1,196 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +0,0 @@ | ||
import pandas as pd | ||
from sklearn.datasets import make_classification, make_regression | ||
|
||
|
||
def generate_sklearn_simulated_data( | ||
prediction_task, num_samples, num_tab1_features, num_tab2_features, external=False | ||
): | ||
""" | ||
Generate simulated data for all modalities, adds a prediction label and returns the params dictionary | ||
Parameters | ||
---------- | ||
prediction_task : str | ||
The type of prediction to be performed. This is either ``regression``, ``binary``, or ``classification``. | ||
num_samples : int | ||
Number of samples to generate | ||
num_tab1_features : int | ||
Number of features to generate for tabular1 data | ||
num_tab2_features : int | ||
Number of features to generate for tabular2 data | ||
Returns | ||
------- | ||
params : dict | ||
Dictionary of parameters with the paths to the simulated data added | ||
""" | ||
if external: | ||
tabular1_path = "../../../fusilli/utils/simulated_data/external_tabular1data.csv" | ||
tabular2_path = "../../../fusilli/utils/simulated_data/external_tabular2data.csv" | ||
else: | ||
tabular1_path = "../../../fusilli/utils/simulated_data/tabular1data.csv" | ||
tabular2_path = "../../../fusilli/utils/simulated_data/tabular2data.csv" | ||
|
||
if prediction_task == "binary": | ||
# Creating a simulated feature matrix and output vector with 100 samples | ||
all_tab_features, labels = make_classification( | ||
n_samples=num_samples, | ||
n_features=num_tab1_features + num_tab2_features, # taking features | ||
n_informative=(num_tab1_features + num_tab2_features) | ||
// 3, # features that predict the output's classes | ||
n_classes=2, # three output classes | ||
weights=None, # equal number of samples per class) | ||
flip_y=0.1, # flip 10% of the labels | ||
) | ||
elif prediction_task == "multiclass": | ||
num_classes = 3 | ||
all_tab_features, labels = make_classification( | ||
n_samples=num_samples, | ||
n_features=num_tab1_features + num_tab2_features, # taking features | ||
n_informative=(num_tab1_features + num_tab2_features) | ||
// 2, # features that predict the output's classes | ||
n_classes=num_classes, # three output classes | ||
weights=None, # equal number of samples per class) | ||
flip_y=0.1, # flip 10% of the labels | ||
) | ||
elif prediction_task == "regression": | ||
all_tab_features, labels = make_regression( | ||
n_samples=num_samples, | ||
n_features=num_tab1_features + num_tab2_features, # taking features | ||
n_informative=(num_tab1_features + num_tab2_features) | ||
// 2, # features that predict the output's classes | ||
noise=3, | ||
effective_rank=3, | ||
) | ||
else: | ||
raise ValueError( | ||
f"pred_type must be one of 'binary', 'multiclass' or 'regression', not {prediction_task}" | ||
) | ||
|
||
tabular1_data = pd.DataFrame() | ||
tabular1_data["ID"] = [f"{i}" for i in range(num_samples)] | ||
for i in range(num_tab1_features): | ||
feature_name = f"feature{i + 1}" | ||
tabular1_data[feature_name] = all_tab_features[:, i] | ||
tabular1_data.set_index("ID", inplace=True) | ||
tabular1_data["prediction_label"] = labels | ||
|
||
tabular2_data = pd.DataFrame() | ||
tabular2_data["ID"] = [f"{i}" for i in range(num_samples)] | ||
for i in range(num_tab2_features): | ||
feature_name = f"feature{i + 1}" | ||
tabular2_data[feature_name] = all_tab_features[:, num_tab1_features + i] | ||
tabular2_data.set_index("ID", inplace=True) | ||
tabular2_data["prediction_label"] = labels | ||
|
||
# save to csv and pt | ||
tabular1_data.to_csv(tabular1_path) | ||
tabular2_data.to_csv(tabular2_path) | ||
|
||
return tabular1_path, tabular2_path | ||
Binary file added
BIN
+39.9 KB
docs/examples/customising_behaviour/loss_figures/losses/AttentionWeightedGNN.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+34.3 KB
...examples/customising_behaviour/loss_figures/losses/ConcatTabularFeatureMaps.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions
1
.../examples/customising_behaviour/loss_logs/modify_layers/AttentionWeightedGNN/hparams.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{} |
12 changes: 12 additions & 0 deletions
12
docs/examples/customising_behaviour/loss_logs/modify_layers/AttentionWeightedGNN/metrics.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
epoch,val_loss,MAE_val,R2_val,step,train_loss,MAE_train,R2_train | ||
0,91613.78125,255.58047485351562,-10884.7041015625,0,,, | ||
0,,,,0,126736.65625,283.7513427734375,-14838.0244140625 | ||
1,22806.869140625,133.0036163330078,-2708.950927734375,1,,, | ||
1,,,,1,95430.828125,245.3120574951172,-11172.5654296875 | ||
2,2364.317138671875,40.63412857055664,-279.9321594238281,2,,, | ||
2,,,,2,39681.44921875,155.34410095214844,-4645.12158203125 | ||
3,948.3421630859375,24.115896224975586,-111.68362426757812,3,,, | ||
3,,,,3,24457.408203125,121.0082778930664,-2862.607666015625 | ||
4,516.4700927734375,18.50086212158203,-60.367855072021484,4,,, | ||
4,,,,4,13347.4072265625,91.87945556640625,-1561.7877197265625 | ||
5,516.4700927734375,18.50086212158203,-60.367855072021484,5,,, |
1 change: 1 addition & 0 deletions
1
...mples/customising_behaviour/loss_logs/modify_layers/ConcatTabularFeatureMaps/hparams.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{} |
12 changes: 12 additions & 0 deletions
12
...amples/customising_behaviour/loss_logs/modify_layers/ConcatTabularFeatureMaps/metrics.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
epoch,val_loss,MAE_val,R2_val,step,train_loss,MAE_train,R2_train | ||
0,4.192819595336914,1.7022111415863037,0.5379555225372314,49,,, | ||
0,,,,49,8.381141662597656,2.3306796550750732,-0.24445590376853943 | ||
1,3.2308895587921143,1.4975028038024902,0.6439592838287354,99,,, | ||
1,,,,99,3.6040618419647217,1.4797532558441162,0.46048277616500854 | ||
2,2.6154050827026367,1.2608485221862793,0.711785078048706,149,,, | ||
2,,,,149,2.725576877593994,1.2617511749267578,0.5758683681488037 | ||
3,2.830514430999756,1.331453561782837,0.6880801916122437,199,,, | ||
3,,,,199,1.5875434875488281,0.9392735958099365,0.7461423277854919 | ||
4,3.048844337463379,1.3958659172058105,0.6640203595161438,249,,, | ||
4,,,,249,1.254799246788025,0.8462849259376526,0.7862377762794495 | ||
5,3.048844337463379,1.3958659172058105,0.6640203595161438,250,,, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
...les/model_comparison/loss_logs/two_models_traintest/ConcatTabularFeatureMaps/hparams.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{} |
130 changes: 130 additions & 0 deletions
130
...ples/model_comparison/loss_logs/two_models_traintest/ConcatTabularFeatureMaps/metrics.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
MAE_val,epoch,val_loss,R2_val,step,R2_train,train_loss,MAE_train | ||
2.550821542739868,0,9.993765830993652,-0.36750197410583496,7,,, | ||
,0,,,7,-0.9941763877868652,16.622907638549805,3.3130204677581787 | ||
2.0250279903411865,1,6.120434284210205,0.1625072956085205,15,,, | ||
,1,,,15,0.14741721749305725,7.222578048706055,2.1738603115081787 | ||
2.1401681900024414,2,7.6228156089782715,-0.04307210445404053,23,,, | ||
,2,,,23,0.3598036766052246,5.378986835479736,1.850309133529663 | ||
1.7970805168151855,3,5.474188804626465,0.25093650817871094,31,,, | ||
,3,,,31,0.46779435873031616,4.45926570892334,1.6522819995880127 | ||
1.7589600086212158,4,5.531983375549316,0.24302823841571808,39,,, | ||
,4,,,39,0.549390971660614,3.7244315147399902,1.50660240650177 | ||
1.7531583309173584,5,5.516550540924072,0.24513985216617584,47,,, | ||
,5,,,47,0.6313987970352173,3.1030783653259277,1.3298112154006958 | ||
1.8356657028198242,6,6.117806434631348,0.1628667712211609,55,,, | ||
,6,,,55,0.7357838749885559,2.1222023963928223,1.1111221313476562 | ||
1.6618714332580566,7,5.2861833572387695,0.27666234970092773,63,,, | ||
,7,,,63,0.7709552049636841,1.9139409065246582,1.0187078714370728 | ||
1.57106614112854,8,4.7738189697265625,0.3467719554901123,71,,, | ||
,8,,,71,0.7927483320236206,1.7200523614883423,0.9630319476127625 | ||
1.604772925376892,9,4.912045001983643,0.3278576731681824,79,,, | ||
,9,,,79,0.8458226919174194,1.3106739521026611,0.8364636898040771 | ||
1.6100950241088867,10,4.99435567855835,0.316594660282135,87,,, | ||
,10,,,87,0.8695825934410095,1.0703457593917847,0.7448607683181763 | ||
1.5111275911331177,11,4.433798789978027,0.39329880475997925,95,,, | ||
,11,,,95,0.8898072838783264,0.9113877415657043,0.6880857348442078 | ||
1.5012693405151367,12,4.484768867492676,0.38632428646087646,103,,, | ||
,12,,,103,0.9058310985565186,0.7781545519828796,0.623124361038208 | ||
1.4845701456069946,13,4.483181476593018,0.386541485786438,111,,, | ||
,13,,,111,0.9015885591506958,0.8185599446296692,0.6399004459381104 | ||
1.4515013694763184,14,4.19098424911499,0.42652440071105957,119,,, | ||
,14,,,119,0.917405366897583,0.7098179459571838,0.5883429646492004 | ||
1.6063990592956543,15,5.107571601867676,0.30110275745391846,127,,, | ||
,15,,,127,0.9220234751701355,0.6577357053756714,0.5817375183105469 | ||
1.6577519178390503,16,5.300245761871338,0.274738073348999,135,,, | ||
,16,,,135,0.910961925983429,0.750971257686615,0.6317367553710938 | ||
1.6055717468261719,17,5.1569318771362305,0.29434847831726074,143,,, | ||
,17,,,143,0.9021233916282654,0.8221578598022461,0.6475740075111389 | ||
1.639248013496399,18,5.171427249908447,0.29236501455307007,151,,, | ||
,18,,,151,0.900653064250946,0.8402024507522583,0.6508421897888184 | ||
1.5495505332946777,19,4.559011459350586,0.3761652708053589,159,,, | ||
,19,,,159,0.908098042011261,0.7756197452545166,0.6134839653968811 | ||
1.5707274675369263,20,4.734753608703613,0.35211747884750366,167,,, | ||
,20,,,167,0.924924910068512,0.6461987495422363,0.5735353827476501 | ||
1.4228588342666626,21,3.9855732917785645,0.45463207364082336,175,,, | ||
,21,,,175,0.8851704001426697,0.9396422505378723,0.6774986386299133 | ||
1.5047160387039185,22,4.698852062225342,0.35703009366989136,183,,, | ||
,22,,,183,0.9096055626869202,0.7635362148284912,0.6242097616195679 | ||
1.3928208351135254,23,4.0358099937438965,0.4477577209472656,191,,, | ||
,23,,,191,0.9363532662391663,0.5386157631874084,0.5349072217941284 | ||
1.431113600730896,24,4.470632076263428,0.38825875520706177,199,,, | ||
,24,,,199,0.9415258169174194,0.4951440393924713,0.48046642541885376 | ||
1.4176338911056519,25,4.323647499084473,0.4083714485168457,207,,, | ||
,25,,,207,0.9421520829200745,0.4700969457626343,0.47758549451828003 | ||
1.3952497243881226,26,4.078487873077393,0.44191786646842957,215,,, | ||
,26,,,215,0.9555037021636963,0.3723544776439667,0.41501185297966003 | ||
1.3771188259124756,27,3.975924491882324,0.45595234632492065,223,,, | ||
,27,,,223,0.9446684718132019,0.47763383388519287,0.46678513288497925 | ||
1.39987313747406,28,4.1419172286987305,0.43323853611946106,231,,, | ||
,28,,,231,0.950677752494812,0.4249856173992157,0.46403563022613525 | ||
1.396775245666504,29,4.147250652313232,0.4325087368488312,239,,, | ||
,29,,,239,0.9434041976928711,0.48246464133262634,0.4782574772834778 | ||
1.390718936920166,30,4.0993194580078125,0.43906742334365845,247,,, | ||
,30,,,247,0.9475470185279846,0.4307110905647278,0.45598751306533813 | ||
1.3702670335769653,31,3.924558162689209,0.4629810154438019,255,,, | ||
,31,,,255,0.9597123265266418,0.33705028891563416,0.40675410628318787 | ||
1.4085280895233154,32,4.261821269989014,0.41683149337768555,263,,, | ||
,32,,,263,0.9405664205551147,0.4967845678329468,0.47750386595726013 | ||
1.3695627450942993,33,3.9851067066192627,0.4546957314014435,271,,, | ||
,33,,,271,0.9426913857460022,0.46704617142677307,0.47678470611572266 | ||
1.4094480276107788,34,3.998617649078369,0.45284709334373474,279,,, | ||
,34,,,279,0.9507781863212585,0.41333433985710144,0.4429933428764343 | ||
1.3696682453155518,35,3.881700038909912,0.4688454568386078,287,,, | ||
,35,,,287,0.939416766166687,0.4943542778491974,0.49081888794898987 | ||
1.4207769632339478,36,4.0211310386657715,0.4497663974761963,295,,, | ||
,36,,,295,0.9628061056137085,0.3204095661640167,0.3762396574020386 | ||
1.3708957433700562,37,3.8153915405273438,0.4779188930988312,303,,, | ||
,37,,,303,0.945135235786438,0.4457513093948364,0.4563809335231781 | ||
1.4144178628921509,38,4.181662559509277,0.4277999997138977,311,,, | ||
,38,,,311,0.953737199306488,0.37626779079437256,0.41615381836891174 | ||
1.4202336072921753,39,4.486833095550537,0.38604187965393066,319,,, | ||
,39,,,319,0.9369199872016907,0.5130748152732849,0.48724305629730225 | ||
1.341239094734192,40,3.9252257347106934,0.4628896713256836,327,,, | ||
,40,,,327,0.9515608549118042,0.40434104204177856,0.4478156864643097 | ||
1.539413332939148,41,4.673248291015625,0.36053353548049927,335,,, | ||
,41,,,335,0.9387563467025757,0.5167527794837952,0.5156168937683105 | ||
1.7450042963027954,42,5.426409721374512,0.2574743628501892,343,,, | ||
,42,,,343,0.8893561959266663,0.9203564524650574,0.6920171976089478 | ||
1.5992671251296997,43,4.80446195602417,0.3425789475440979,351,,, | ||
,43,,,351,0.90372234582901,0.773030161857605,0.6137067675590515 | ||
1.4908099174499512,44,4.300416946411133,0.41155022382736206,359,,, | ||
,44,,,359,0.920629620552063,0.6477848887443542,0.574256956577301 | ||
1.44012451171875,45,4.305934429168701,0.4107951521873474,367,,, | ||
,45,,,367,0.9128478169441223,0.7299042344093323,0.6047486066818237 | ||
1.4437092542648315,46,4.356184005737305,0.4039192795753479,375,,, | ||
,46,,,375,0.9185203909873962,0.6919479966163635,0.5823092460632324 | ||
1.3632363080978394,47,4.0035600662231445,0.4521707594394684,383,,, | ||
,47,,,383,0.9386501908302307,0.5217953324317932,0.5004262328147888 | ||
1.3210381269454956,48,3.677823543548584,0.49674296379089355,391,,, | ||
,48,,,391,0.9318889379501343,0.5586847066879272,0.5229012370109558 | ||
1.3599282503128052,49,4.0863728523254395,0.4408389925956726,399,,, | ||
,49,,,399,0.9411574602127075,0.480117529630661,0.45792442560195923 | ||
1.330924391746521,50,3.81209397315979,0.4783701002597809,407,,, | ||
,50,,,407,0.9523176550865173,0.38065025210380554,0.42401397228240967 | ||
1.332220435142517,51,3.883242130279541,0.46863454580307007,415,,, | ||
,51,,,415,0.9544154405593872,0.3866163492202759,0.4231767952442169 | ||
1.3906828165054321,52,4.229849815368652,0.4212062954902649,423,,, | ||
,52,,,423,0.9538908004760742,0.3787587583065033,0.42101413011550903 | ||
1.3811572790145874,53,3.968116044998169,0.45702075958251953,431,,, | ||
,53,,,431,0.9562690258026123,0.3682621121406555,0.40828174352645874 | ||
1.366807460784912,54,4.064504146575928,0.4438314437866211,439,,, | ||
,54,,,439,0.955906093120575,0.365860253572464,0.421358585357666 | ||
1.3128173351287842,55,3.7178304195404053,0.49126872420310974,447,,, | ||
,55,,,447,0.9617348909378052,0.3125576674938202,0.399527370929718 | ||
1.3332505226135254,56,3.8579626083374023,0.4720936119556427,455,,, | ||
,56,,,455,0.9516554474830627,0.39683592319488525,0.42932966351509094 | ||
1.3601473569869995,57,4.019503593444824,0.44998905062675476,463,,, | ||
,57,,,463,0.9577350616455078,0.3561760187149048,0.3812394440174103 | ||
1.3970615863800049,58,4.052127838134766,0.44552499055862427,471,,, | ||
,58,,,471,0.9404559135437012,0.4958237111568451,0.46953076124191284 | ||
1.614027500152588,59,4.767022609710693,0.3477018475532532,479,,, | ||
,59,,,479,0.9305849671363831,0.5771317481994629,0.541130781173706 | ||
1.4933671951293945,60,4.511728286743164,0.38263529539108276,487,,, | ||
,60,,,487,0.9263445138931274,0.5848267078399658,0.5564913153648376 | ||
1.3761061429977417,61,3.990039110183716,0.4540208578109741,495,,, | ||
,61,,,495,0.9608601331710815,0.31816357374191284,0.3896094262599945 | ||
1.4019064903259277,62,4.246546268463135,0.4189215302467346,503,,, | ||
,62,,,503,0.9589810371398926,0.34225142002105713,0.40126940608024597 | ||
1.441614031791687,63,4.2292351722717285,0.42129039764404297,511,,, | ||
,63,,,511,0.9644243717193604,0.29825612902641296,0.37975960969924927 | ||
1.441614031791687,64,4.2292351722717285,0.42129039764404297,512,,, |
1 change: 1 addition & 0 deletions
1
docs/examples/model_comparison/loss_logs/two_models_traintest/Tabular1Unimodal/hparams.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{} |
Oops, something went wrong.