-
Notifications
You must be signed in to change notification settings - Fork 4
/
connect.py
145 lines (115 loc) · 5.26 KB
/
connect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# First command line argument is a keyword specifying what action you would like to perform.
# The implemented keywords are 'create' and 'train', which corresponds to the actions creating
# data, training a network, training as a job, sampling of the posterior distribution, and
# plotting of the trained models and the inferred parameters.
# The code needs to have a parameter file specified as the second command line argument,
# which contains the parameters and hyperparameters of the model and the architechture of
# the network.
# -_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_
import sys
import os
import argparse
CONNECT_PATH = os.path.realpath(os.path.dirname(__file__))
CURRENT_PATH = os.path.abspath(os.getcwd())
sys.path.insert(1, CONNECT_PATH)
from source.default_module import Parameters
keyword = sys.argv[1]
if keyword in ['create', 'train']:
param_file = sys.argv[2]
param = Parameters(param_file)
parameters = param.parameters
path = CONNECT_PATH + f'/data/{param.jobname}/'
#####################################
# ____________ create _____________ #
#####################################
if keyword == 'create':
if not param.resume_iterations:
os.system(f'rm -f {path}output.log')
else:
os.system('echo "'+62*'#'+f'" >> {path}output.log')
os.system('echo "'+7*'Resuming '+f'" >> {path}output.log')
os.system('echo "'+62*'#'+f'" >> {path}output.log')
from source.tools import create_output_folders
create_output_folders(param, resume=param.resume_iterations)
from source.data_sampling import Sampling
s = Sampling(param_file, CONNECT_PATH)
if not os.path.isdir(path):
os.mkdir(path)
with open(os.path.join(CONNECT_PATH,'source/assets/logo_colour.txt'),'r') as f:
log_string = '-'*62+'\n\n\n' + \
f.read()+'\n' + \
'-'*62+'\n\n' + \
'Running CONNECT\n' + \
f'Parameter file : {param_file}\n' + \
'Mode : Create'
mode = param.resume_iterations*'a+' + (not param.resume_iterations)*'w'
if param.sampling == 'iterative':
with open(path+'output.log', mode) as sys.stdout:
print(log_string, flush=True)
print('Sampling method : Iterative', flush=True)
print('\n'+'-'*62+'\n', flush=True)
s.create_iterative_data()
elif param.sampling == 'lhc':
with open(path+f'N-{param.N}/output.log', mode) as sys.stdout:
print(log_string, flush=True)
print('Sampling method : Latin Hypercube', flush=True)
print('\n'+'-'*62+'\n', flush=True)
s.create_lhc_data()
from source.tools import join_data_files
join_data_files(param)
elif param.sampling == 'hypersphere':
with open(path+f'N-{param.N}/output.log', mode) as sys.stdout:
print(log_string, flush=True)
print('Sampling method : Hypersphere', flush=True)
print('\n'+'-'*62+'\n', flush=True)
s.create_hypersphere_data()
from source.tools import join_data_files
join_data_files(param)
elif param.sampling == 'pickle':
with open(path+f'N-{param.N}/output.log', mode) as sys.stdout:
print(log_string, flush=True)
print('Sampling method : From Pickle file', flush=True)
print('\n'+'-'*62+'\n', flush=True)
s.create_pickle_data()
from source.tools import join_data_files
join_data_files(param)
#####################################
# _____________ train _____________ #
#####################################
def join_output_files():
try:
i = max([int(f.split('number_')[-1]) for f in os.listdir(path) if f.startswith('number')])
if param.sampling == 'iterative' and not os.path.isfile(CONNECT_PATH + f'/data/{param.jobname}/number_{i}/model_params.txt'):
from source.join_output import CreateSingleDataFile
CSDF = CreateSingleDataFile(param, CONNECT_PATH)
CSDF.join()
except:
if not os.path.isfile(CONNECT_PATH + f'/data/{param.jobname}/N-{param.N}/model_params.txt'):
from source.join_output import CreateSingleDataFile
CSDF = CreateSingleDataFile(param, CONNECT_PATH)
CSDF.join()
if keyword == 'train':
join_output_files()
from source.train_network import Training
tr = Training(param, CONNECT_PATH)
try:
tr.train_model(epochs=sys.argv[3])
except:
tr.train_model()
tr.save_model()
tr.save_history()
tr.save_test_data()
#####################################
# ____________ animate ____________ #
#####################################
if keyword == 'animate':
from source.assets.animate import play
play()
#####################################
# _________ procrastinate _________ #
#####################################
if keyword == 'procrastinate':
import base64
with open('source/assets/surprise.txt','r') as f:
obfuscated_code = f.readlines()[0]
exec(base64.b85decode(obfuscated_code.encode('utf-8')))