Skip to content

Commit

Permalink
Merge pull request #18 from florencejt/docs/usemnistexampledata
Browse files Browse the repository at this point in the history
Docs/usemnistexampledata: Change all examples in docs to use MNIST data
  • Loading branch information
florencejt authored Jan 12, 2024
2 parents 5bb530f + b1de556 commit 6b1b7af
Show file tree
Hide file tree
Showing 31 changed files with 3,988 additions and 1,196 deletions.
501 changes: 501 additions & 0 deletions docs/_static/mnist_data/mnist1.csv

Large diffs are not rendered by default.

501 changes: 501 additions & 0 deletions docs/_static/mnist_data/mnist1_binary.csv

Large diffs are not rendered by default.

501 changes: 501 additions & 0 deletions docs/_static/mnist_data/mnist1_regression.csv

Large diffs are not rendered by default.

101 changes: 101 additions & 0 deletions docs/_static/mnist_data/mnist1_regression_test.csv

Large diffs are not rendered by default.

501 changes: 501 additions & 0 deletions docs/_static/mnist_data/mnist2.csv

Large diffs are not rendered by default.

501 changes: 501 additions & 0 deletions docs/_static/mnist_data/mnist2_binary.csv

Large diffs are not rendered by default.

501 changes: 501 additions & 0 deletions docs/_static/mnist_data/mnist2_regression.csv

Large diffs are not rendered by default.

101 changes: 101 additions & 0 deletions docs/_static/mnist_data/mnist2_regression_test.csv

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
warnings.filterwarnings(
"ignore", message="Checkpoint directory.*exists and is not empty."
)
warnings.filterwarnings("ignore", ".*samples in targets.*")

project = "fusilli"
copyright = "2023, Florence J Townend"
Expand Down
90 changes: 0 additions & 90 deletions docs/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -1,90 +0,0 @@
import pandas as pd
from sklearn.datasets import make_classification, make_regression


def generate_sklearn_simulated_data(
prediction_task, num_samples, num_tab1_features, num_tab2_features, external=False
):
"""
Generate simulated data for all modalities, adds a prediction label and returns the params dictionary
Parameters
----------
prediction_task : str
The type of prediction to be performed. This is either ``regression``, ``binary``, or ``classification``.
num_samples : int
Number of samples to generate
num_tab1_features : int
Number of features to generate for tabular1 data
num_tab2_features : int
Number of features to generate for tabular2 data
Returns
-------
params : dict
Dictionary of parameters with the paths to the simulated data added
"""
if external:
tabular1_path = "../../../fusilli/utils/simulated_data/external_tabular1data.csv"
tabular2_path = "../../../fusilli/utils/simulated_data/external_tabular2data.csv"
else:
tabular1_path = "../../../fusilli/utils/simulated_data/tabular1data.csv"
tabular2_path = "../../../fusilli/utils/simulated_data/tabular2data.csv"

if prediction_task == "binary":
# Creating a simulated feature matrix and output vector with 100 samples
all_tab_features, labels = make_classification(
n_samples=num_samples,
n_features=num_tab1_features + num_tab2_features, # taking features
n_informative=(num_tab1_features + num_tab2_features)
// 3, # features that predict the output's classes
n_classes=2, # three output classes
weights=None, # equal number of samples per class)
flip_y=0.1, # flip 10% of the labels
)
elif prediction_task == "multiclass":
num_classes = 3
all_tab_features, labels = make_classification(
n_samples=num_samples,
n_features=num_tab1_features + num_tab2_features, # taking features
n_informative=(num_tab1_features + num_tab2_features)
// 2, # features that predict the output's classes
n_classes=num_classes, # three output classes
weights=None, # equal number of samples per class)
flip_y=0.1, # flip 10% of the labels
)
elif prediction_task == "regression":
all_tab_features, labels = make_regression(
n_samples=num_samples,
n_features=num_tab1_features + num_tab2_features, # taking features
n_informative=(num_tab1_features + num_tab2_features)
// 2, # features that predict the output's classes
noise=3,
effective_rank=3,
)
else:
raise ValueError(
f"pred_type must be one of 'binary', 'multiclass' or 'regression', not {prediction_task}"
)

tabular1_data = pd.DataFrame()
tabular1_data["ID"] = [f"{i}" for i in range(num_samples)]
for i in range(num_tab1_features):
feature_name = f"feature{i + 1}"
tabular1_data[feature_name] = all_tab_features[:, i]
tabular1_data.set_index("ID", inplace=True)
tabular1_data["prediction_label"] = labels

tabular2_data = pd.DataFrame()
tabular2_data["ID"] = [f"{i}" for i in range(num_samples)]
for i in range(num_tab2_features):
feature_name = f"feature{i + 1}"
tabular2_data[feature_name] = all_tab_features[:, num_tab1_features + i]
tabular2_data.set_index("ID", inplace=True)
tabular2_data["prediction_label"] = labels

# save to csv and pt
tabular1_data.to_csv(tabular1_path)
tabular2_data.to_csv(tabular2_path)

return tabular1_path, tabular2_path
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
epoch,val_loss,MAE_val,R2_val,step,train_loss,MAE_train,R2_train
0,91613.78125,255.58047485351562,-10884.7041015625,0,,,
0,,,,0,126736.65625,283.7513427734375,-14838.0244140625
1,22806.869140625,133.0036163330078,-2708.950927734375,1,,,
1,,,,1,95430.828125,245.3120574951172,-11172.5654296875
2,2364.317138671875,40.63412857055664,-279.9321594238281,2,,,
2,,,,2,39681.44921875,155.34410095214844,-4645.12158203125
3,948.3421630859375,24.115896224975586,-111.68362426757812,3,,,
3,,,,3,24457.408203125,121.0082778930664,-2862.607666015625
4,516.4700927734375,18.50086212158203,-60.367855072021484,4,,,
4,,,,4,13347.4072265625,91.87945556640625,-1561.7877197265625
5,516.4700927734375,18.50086212158203,-60.367855072021484,5,,,
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
epoch,val_loss,MAE_val,R2_val,step,train_loss,MAE_train,R2_train
0,4.192819595336914,1.7022111415863037,0.5379555225372314,49,,,
0,,,,49,8.381141662597656,2.3306796550750732,-0.24445590376853943
1,3.2308895587921143,1.4975028038024902,0.6439592838287354,99,,,
1,,,,99,3.6040618419647217,1.4797532558441162,0.46048277616500854
2,2.6154050827026367,1.2608485221862793,0.711785078048706,149,,,
2,,,,149,2.725576877593994,1.2617511749267578,0.5758683681488037
3,2.830514430999756,1.331453561782837,0.6880801916122437,199,,,
3,,,,199,1.5875434875488281,0.9392735958099365,0.7461423277854919
4,3.048844337463379,1.3958659172058105,0.6640203595161438,249,,,
4,,,,249,1.254799246788025,0.8462849259376526,0.7862377762794495
5,3.048844337463379,1.3958659172058105,0.6640203595161438,250,,,
41 changes: 17 additions & 24 deletions docs/examples/customising_behaviour/plot_modify_layer_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Setting up the experiment
# -------------------------
#
# First, we will set up the experiment by importing the necessary packages, creating the simulated data, and setting the parameters for the experiment.
# First, we will set up the experiment by importing the necessary packages, specifying MNIST data paths, and setting the parameters for the experiment.
#
# For a more detailed explanation of this process, please see the example tutorials.
#
Expand All @@ -29,9 +29,7 @@
import torch.nn as nn
from torch_geometric.nn import GCNConv, ChebConv

from docs.examples import generate_sklearn_simulated_data
from fusilli.data import prepare_fusion_data
from fusilli.eval import RealsVsPreds
from fusilli.train import train_and_save_models

from fusilli.fusionmodels.tabularfusion.attention_weighted_GNN import AttentionWeightedGNN
Expand All @@ -54,14 +52,9 @@
# remove dir
os.rmdir(os.path.join(output_paths["losses"], dir))

tabular1_path, tabular2_path = generate_sklearn_simulated_data(prediction_task,
num_samples=500,
num_tab1_features=10,
num_tab2_features=20)

data_paths = {
"tabular1": tabular1_path,
"tabular2": tabular2_path,
"tabular1": "../../_static/mnist_data/mnist1_regression.csv",
"tabular2": "../../_static/mnist_data/mnist2_regression.csv",
"image": "",
}

Expand Down Expand Up @@ -115,7 +108,7 @@
layer_mods = {
"AttentionWeightedGNN": {
"graph_conv_layers": nn.Sequential(
ChebConv(20, 50, K=3),
ChebConv(392, 50, K=3),
ChebConv(50, 100, K=3),
ChebConv(100, 130, K=3),
),
Expand All @@ -127,19 +120,19 @@
"AttentionWeightingMLPInstance.weighting_layers": nn.ModuleDict(
{
"Layer 1": nn.Sequential(
nn.Linear(30, 100),
nn.Linear(784, 500),
nn.ReLU()),
"Layer 2": nn.Sequential(
nn.Linear(100, 75),
nn.Linear(500, 128),
nn.ReLU()),
"Layer 3": nn.Sequential(
nn.Linear(75, 75),
nn.Linear(128, 128),
nn.ReLU()),
"Layer 4": nn.Sequential(
nn.Linear(75, 100),
nn.Linear(128, 500),
nn.ReLU()),
"Layer 5": nn.Sequential(
nn.Linear(100, 30),
nn.Linear(500, 784),
nn.ReLU()),
}
)},
Expand Down Expand Up @@ -211,7 +204,7 @@
layer_mods = {
"AttentionWeightedGraphMaker": {
"AttentionWeightingMLPInstance.weighting_layers": nn.Sequential(
nn.Linear(30, 75),
nn.Linear(392, 75),
nn.ReLU(),
nn.Linear(75, 75),
nn.ReLU(),
Expand Down Expand Up @@ -268,37 +261,37 @@
"mod1_layers": nn.ModuleDict(
{
"layer 1": nn.Sequential(
nn.Linear(10, 32),
nn.Linear(392, 300),
nn.ReLU(),
),
"layer 2": nn.Sequential(
nn.Linear(32, 66),
nn.Linear(300, 128),
nn.ReLU(),
),
"layer 3": nn.Sequential(
nn.Linear(66, 128),
nn.Linear(128, 128),
nn.ReLU(),
),
}
),
"mod2_layers": nn.ModuleDict(
{
"layer 1": nn.Sequential(
nn.Linear(20, 45),
nn.Linear(392, 300),
nn.ReLU(),
),
"layer 2": nn.Sequential(
nn.Linear(45, 70),
nn.Linear(300, 128),
nn.ReLU(),
),
"layer 3": nn.Sequential(
nn.Linear(70, 100),
nn.Linear(128, 100),
nn.ReLU(),
),
}
),
"fused_layers": nn.Sequential(
nn.Linear(30, 150),
nn.Linear(25, 150),
nn.ReLU(),
nn.Linear(150, 75),
nn.ReLU(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
MAE_val,epoch,val_loss,R2_val,step,R2_train,train_loss,MAE_train
2.550821542739868,0,9.993765830993652,-0.36750197410583496,7,,,
,0,,,7,-0.9941763877868652,16.622907638549805,3.3130204677581787
2.0250279903411865,1,6.120434284210205,0.1625072956085205,15,,,
,1,,,15,0.14741721749305725,7.222578048706055,2.1738603115081787
2.1401681900024414,2,7.6228156089782715,-0.04307210445404053,23,,,
,2,,,23,0.3598036766052246,5.378986835479736,1.850309133529663
1.7970805168151855,3,5.474188804626465,0.25093650817871094,31,,,
,3,,,31,0.46779435873031616,4.45926570892334,1.6522819995880127
1.7589600086212158,4,5.531983375549316,0.24302823841571808,39,,,
,4,,,39,0.549390971660614,3.7244315147399902,1.50660240650177
1.7531583309173584,5,5.516550540924072,0.24513985216617584,47,,,
,5,,,47,0.6313987970352173,3.1030783653259277,1.3298112154006958
1.8356657028198242,6,6.117806434631348,0.1628667712211609,55,,,
,6,,,55,0.7357838749885559,2.1222023963928223,1.1111221313476562
1.6618714332580566,7,5.2861833572387695,0.27666234970092773,63,,,
,7,,,63,0.7709552049636841,1.9139409065246582,1.0187078714370728
1.57106614112854,8,4.7738189697265625,0.3467719554901123,71,,,
,8,,,71,0.7927483320236206,1.7200523614883423,0.9630319476127625
1.604772925376892,9,4.912045001983643,0.3278576731681824,79,,,
,9,,,79,0.8458226919174194,1.3106739521026611,0.8364636898040771
1.6100950241088867,10,4.99435567855835,0.316594660282135,87,,,
,10,,,87,0.8695825934410095,1.0703457593917847,0.7448607683181763
1.5111275911331177,11,4.433798789978027,0.39329880475997925,95,,,
,11,,,95,0.8898072838783264,0.9113877415657043,0.6880857348442078
1.5012693405151367,12,4.484768867492676,0.38632428646087646,103,,,
,12,,,103,0.9058310985565186,0.7781545519828796,0.623124361038208
1.4845701456069946,13,4.483181476593018,0.386541485786438,111,,,
,13,,,111,0.9015885591506958,0.8185599446296692,0.6399004459381104
1.4515013694763184,14,4.19098424911499,0.42652440071105957,119,,,
,14,,,119,0.917405366897583,0.7098179459571838,0.5883429646492004
1.6063990592956543,15,5.107571601867676,0.30110275745391846,127,,,
,15,,,127,0.9220234751701355,0.6577357053756714,0.5817375183105469
1.6577519178390503,16,5.300245761871338,0.274738073348999,135,,,
,16,,,135,0.910961925983429,0.750971257686615,0.6317367553710938
1.6055717468261719,17,5.1569318771362305,0.29434847831726074,143,,,
,17,,,143,0.9021233916282654,0.8221578598022461,0.6475740075111389
1.639248013496399,18,5.171427249908447,0.29236501455307007,151,,,
,18,,,151,0.900653064250946,0.8402024507522583,0.6508421897888184
1.5495505332946777,19,4.559011459350586,0.3761652708053589,159,,,
,19,,,159,0.908098042011261,0.7756197452545166,0.6134839653968811
1.5707274675369263,20,4.734753608703613,0.35211747884750366,167,,,
,20,,,167,0.924924910068512,0.6461987495422363,0.5735353827476501
1.4228588342666626,21,3.9855732917785645,0.45463207364082336,175,,,
,21,,,175,0.8851704001426697,0.9396422505378723,0.6774986386299133
1.5047160387039185,22,4.698852062225342,0.35703009366989136,183,,,
,22,,,183,0.9096055626869202,0.7635362148284912,0.6242097616195679
1.3928208351135254,23,4.0358099937438965,0.4477577209472656,191,,,
,23,,,191,0.9363532662391663,0.5386157631874084,0.5349072217941284
1.431113600730896,24,4.470632076263428,0.38825875520706177,199,,,
,24,,,199,0.9415258169174194,0.4951440393924713,0.48046642541885376
1.4176338911056519,25,4.323647499084473,0.4083714485168457,207,,,
,25,,,207,0.9421520829200745,0.4700969457626343,0.47758549451828003
1.3952497243881226,26,4.078487873077393,0.44191786646842957,215,,,
,26,,,215,0.9555037021636963,0.3723544776439667,0.41501185297966003
1.3771188259124756,27,3.975924491882324,0.45595234632492065,223,,,
,27,,,223,0.9446684718132019,0.47763383388519287,0.46678513288497925
1.39987313747406,28,4.1419172286987305,0.43323853611946106,231,,,
,28,,,231,0.950677752494812,0.4249856173992157,0.46403563022613525
1.396775245666504,29,4.147250652313232,0.4325087368488312,239,,,
,29,,,239,0.9434041976928711,0.48246464133262634,0.4782574772834778
1.390718936920166,30,4.0993194580078125,0.43906742334365845,247,,,
,30,,,247,0.9475470185279846,0.4307110905647278,0.45598751306533813
1.3702670335769653,31,3.924558162689209,0.4629810154438019,255,,,
,31,,,255,0.9597123265266418,0.33705028891563416,0.40675410628318787
1.4085280895233154,32,4.261821269989014,0.41683149337768555,263,,,
,32,,,263,0.9405664205551147,0.4967845678329468,0.47750386595726013
1.3695627450942993,33,3.9851067066192627,0.4546957314014435,271,,,
,33,,,271,0.9426913857460022,0.46704617142677307,0.47678470611572266
1.4094480276107788,34,3.998617649078369,0.45284709334373474,279,,,
,34,,,279,0.9507781863212585,0.41333433985710144,0.4429933428764343
1.3696682453155518,35,3.881700038909912,0.4688454568386078,287,,,
,35,,,287,0.939416766166687,0.4943542778491974,0.49081888794898987
1.4207769632339478,36,4.0211310386657715,0.4497663974761963,295,,,
,36,,,295,0.9628061056137085,0.3204095661640167,0.3762396574020386
1.3708957433700562,37,3.8153915405273438,0.4779188930988312,303,,,
,37,,,303,0.945135235786438,0.4457513093948364,0.4563809335231781
1.4144178628921509,38,4.181662559509277,0.4277999997138977,311,,,
,38,,,311,0.953737199306488,0.37626779079437256,0.41615381836891174
1.4202336072921753,39,4.486833095550537,0.38604187965393066,319,,,
,39,,,319,0.9369199872016907,0.5130748152732849,0.48724305629730225
1.341239094734192,40,3.9252257347106934,0.4628896713256836,327,,,
,40,,,327,0.9515608549118042,0.40434104204177856,0.4478156864643097
1.539413332939148,41,4.673248291015625,0.36053353548049927,335,,,
,41,,,335,0.9387563467025757,0.5167527794837952,0.5156168937683105
1.7450042963027954,42,5.426409721374512,0.2574743628501892,343,,,
,42,,,343,0.8893561959266663,0.9203564524650574,0.6920171976089478
1.5992671251296997,43,4.80446195602417,0.3425789475440979,351,,,
,43,,,351,0.90372234582901,0.773030161857605,0.6137067675590515
1.4908099174499512,44,4.300416946411133,0.41155022382736206,359,,,
,44,,,359,0.920629620552063,0.6477848887443542,0.574256956577301
1.44012451171875,45,4.305934429168701,0.4107951521873474,367,,,
,45,,,367,0.9128478169441223,0.7299042344093323,0.6047486066818237
1.4437092542648315,46,4.356184005737305,0.4039192795753479,375,,,
,46,,,375,0.9185203909873962,0.6919479966163635,0.5823092460632324
1.3632363080978394,47,4.0035600662231445,0.4521707594394684,383,,,
,47,,,383,0.9386501908302307,0.5217953324317932,0.5004262328147888
1.3210381269454956,48,3.677823543548584,0.49674296379089355,391,,,
,48,,,391,0.9318889379501343,0.5586847066879272,0.5229012370109558
1.3599282503128052,49,4.0863728523254395,0.4408389925956726,399,,,
,49,,,399,0.9411574602127075,0.480117529630661,0.45792442560195923
1.330924391746521,50,3.81209397315979,0.4783701002597809,407,,,
,50,,,407,0.9523176550865173,0.38065025210380554,0.42401397228240967
1.332220435142517,51,3.883242130279541,0.46863454580307007,415,,,
,51,,,415,0.9544154405593872,0.3866163492202759,0.4231767952442169
1.3906828165054321,52,4.229849815368652,0.4212062954902649,423,,,
,52,,,423,0.9538908004760742,0.3787587583065033,0.42101413011550903
1.3811572790145874,53,3.968116044998169,0.45702075958251953,431,,,
,53,,,431,0.9562690258026123,0.3682621121406555,0.40828174352645874
1.366807460784912,54,4.064504146575928,0.4438314437866211,439,,,
,54,,,439,0.955906093120575,0.365860253572464,0.421358585357666
1.3128173351287842,55,3.7178304195404053,0.49126872420310974,447,,,
,55,,,447,0.9617348909378052,0.3125576674938202,0.399527370929718
1.3332505226135254,56,3.8579626083374023,0.4720936119556427,455,,,
,56,,,455,0.9516554474830627,0.39683592319488525,0.42932966351509094
1.3601473569869995,57,4.019503593444824,0.44998905062675476,463,,,
,57,,,463,0.9577350616455078,0.3561760187149048,0.3812394440174103
1.3970615863800049,58,4.052127838134766,0.44552499055862427,471,,,
,58,,,471,0.9404559135437012,0.4958237111568451,0.46953076124191284
1.614027500152588,59,4.767022609710693,0.3477018475532532,479,,,
,59,,,479,0.9305849671363831,0.5771317481994629,0.541130781173706
1.4933671951293945,60,4.511728286743164,0.38263529539108276,487,,,
,60,,,487,0.9263445138931274,0.5848267078399658,0.5564913153648376
1.3761061429977417,61,3.990039110183716,0.4540208578109741,495,,,
,61,,,495,0.9608601331710815,0.31816357374191284,0.3896094262599945
1.4019064903259277,62,4.246546268463135,0.4189215302467346,503,,,
,62,,,503,0.9589810371398926,0.34225142002105713,0.40126940608024597
1.441614031791687,63,4.2292351722717285,0.42129039764404297,511,,,
,63,,,511,0.9644243717193604,0.29825612902641296,0.37975960969924927
1.441614031791687,64,4.2292351722717285,0.42129039764404297,512,,,
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Loading

0 comments on commit 6b1b7af

Please sign in to comment.