-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
sunqingxiao
committed
May 27, 2022
1 parent
455be96
commit 2ec9fac
Showing
14 changed files
with
196 additions
and
1 deletion.
There are no files selected for viewing
Binary file not shown.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# MPS | ||
|
||
This folder contains `Default`, no modification needed for the PyTorch code (PyTorch v1.8.1). | ||
|
||
## Files | ||
|
||
- default\_main.py: execute one training task. | ||
|
||
- run\_default.sh: execute training tasks in order. | ||
|
||
## Usage | ||
|
||
``` | ||
./run_default.sh | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import sys | ||
import time | ||
import importlib | ||
|
||
import torch | ||
from cognn.policy import * | ||
|
||
def read_list(model_list_file_name): | ||
model_list = [] | ||
with open(model_list_file_name) as f: | ||
for line in f.readlines(): | ||
if len(line.split()) != 3: | ||
continue | ||
model_list.append([line.split()[0], line.split()[1], line.split()[2]]) | ||
return model_list | ||
|
||
|
||
def main(): | ||
# Load model list (task & data) | ||
model_list = read_list(sys.argv[1]) | ||
model_id = int(sys.argv[2]) | ||
|
||
task_name = model_list[model_id][0] | ||
data_name = model_list[model_id][1] | ||
num_layers = int(model_list[model_id][2]) | ||
|
||
model_module = importlib.import_module('task.' + task_name) | ||
model, func, _ = model_module.import_task(data_name, num_layers) | ||
_, data = model_module.import_model(data_name, num_layers) | ||
output = func(model, data) | ||
# print('Training time: {} ms'.format(output)) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/bin/bash | ||
|
||
network=("GCN" "GraphSAGE" "GAT" "GIN" "mix") | ||
for i in `seq 0 19`; | ||
do | ||
CUDA_VISIBLE_DEVICES=0 python default_main.py ../data/model/${network[4]}_model.txt ${i} | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import sys | ||
import numpy as np | ||
|
||
|
||
def list2txt(fileName="", myfile=[]): | ||
fileout = open(fileName, 'w') | ||
for i in range(len(myfile)): | ||
for j in range(len(myfile[i])): | ||
fileout.write(str(myfile[i][j]) + ' , ') | ||
fileout.write('\r\n') | ||
fileout.close() | ||
|
||
|
||
def main(): | ||
convs = ['GCN', 'GraphSAGE', 'GAT', 'GIN', 'mix'] | ||
all_qt, all_jct = [], [] | ||
|
||
ne_list = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] | ||
|
||
for convId in range(0, 5): | ||
dur_list = [] | ||
group_dur_list = [] | ||
|
||
logfile = open('{}.log'.format(convs[convId])) | ||
for line in logfile: | ||
if line.find('Training') > -1: | ||
seg = line.split() | ||
dur_list.append(float(seg[-2])) | ||
logfile.close() | ||
|
||
eleCounter = 0 | ||
for groupId in range(0, len(ne_list)): | ||
group_dur_list.append([]) | ||
for _ in range(ne_list[groupId]): | ||
group_dur_list[groupId].append(dur_list[eleCounter]) | ||
eleCounter += 1 | ||
|
||
group_max_dur, group_qt = [], [] | ||
tmp_qt = 0 | ||
for groupId in range(0, len(group_dur_list)-1): | ||
tmp_max_dur = 0 | ||
for eleId in range(0, len(group_dur_list[groupId])): | ||
if group_dur_list[groupId][eleId] > tmp_max_dur: | ||
tmp_max_dur = group_dur_list[groupId][eleId] | ||
group_max_dur.append(tmp_max_dur) | ||
tmp_qt += tmp_max_dur | ||
group_qt.append(tmp_qt) | ||
|
||
qt, jct = [], [] | ||
for groupId in range(0, len(group_dur_list)): | ||
for eleId in range(0, len(group_dur_list[groupId])): | ||
ele_qt = group_qt[groupId-1] | ||
if groupId == 0: ele_qt = 0 | ||
ele_jct = ele_qt + group_dur_list[groupId][eleId] | ||
qt.append(ele_qt) | ||
jct.append(ele_jct) | ||
|
||
all_qt.append(qt) | ||
all_jct.append(jct) | ||
|
||
list2txt('default_all_qt.csv', all_qt) | ||
list2txt('default_all_jct.csv', all_jct) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import os | ||
import sys | ||
import numpy as np | ||
|
||
|
||
def list2txt(fileName="", myfile=[]): | ||
fileout = open(fileName, 'w') | ||
for i in range(len(myfile)): | ||
for j in range(len(myfile[i])): | ||
fileout.write(str(myfile[i][j]) + ' , ') | ||
fileout.write('\r\n') | ||
fileout.close() | ||
|
||
|
||
def main(): | ||
convs = ['GCN', 'GraphSAGE', 'GAT', 'GIN', 'mix'] | ||
makespan_total, ave_qt_total, ave_jct_total = [], [], [] | ||
|
||
ne_list = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] | ||
for convId in range(0, 5): | ||
dur_list = [] | ||
group_dur_list = [] | ||
|
||
logfile = open('{}.log'.format(convs[convId])) | ||
for line in logfile: | ||
if line.find('Training') > -1: | ||
seg = line.split() | ||
dur_list.append(float(seg[-2])) | ||
logfile.close() | ||
|
||
eleCounter = 0 | ||
for groupId in range(0, len(ne_list)): | ||
group_dur_list.append([]) | ||
for _ in range(ne_list[groupId]): | ||
group_dur_list[groupId].append(dur_list[eleCounter]) | ||
eleCounter += 1 | ||
|
||
group_max_dur, group_qt = [], [] | ||
tmp_qt = 0 | ||
for groupId in range(0, len(group_dur_list)-1): | ||
tmp_max_dur = 0 | ||
for eleId in range(0, len(group_dur_list[groupId])): | ||
if group_dur_list[groupId][eleId] > tmp_max_dur: | ||
tmp_max_dur = group_dur_list[groupId][eleId] | ||
group_max_dur.append(tmp_max_dur) | ||
tmp_qt += tmp_max_dur | ||
group_qt.append(tmp_qt) | ||
|
||
qt, jct = [], [] | ||
for groupId in range(0, len(group_dur_list)): | ||
for eleId in range(0, len(group_dur_list[groupId])): | ||
ele_qt = group_qt[groupId-1] | ||
if groupId == 0: ele_qt = 0 | ||
ele_jct = ele_qt + group_dur_list[groupId][eleId] | ||
qt.append(ele_qt) | ||
jct.append(ele_jct) | ||
|
||
ave_qt, ave_jct = np.mean(qt), np.mean(jct) | ||
makespan_total.append(jct[-1]) | ||
ave_qt_total.append(ave_qt) | ||
ave_jct_total.append(ave_jct) | ||
|
||
output = [] | ||
for i in range(0, len(makespan_total)): | ||
output.append([makespan_total[i], ave_qt_total[i], ave_jct_total[i]]) | ||
list2txt('default.csv', output) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |