4242 # taken directly from atom.cpp -> extract()
4343}
4444
45- # Lammps executable
46- args = ["-screen" , "none" ]
47- if len (sys .argv ) > 3 :
48- args .extend (sys .argv [3 :])
49- job = lammps (cmdargs = args )
5045
51-
52- def extract_compute (funct_args ):
46+ def extract_compute (job , funct_args ):
5347 def convert_data (val , type , length , width ):
5448 data = []
5549 if type == 2 :
@@ -88,42 +82,42 @@ def convert_data(val, type, length, width):
8882 raise ValueError ("Local style is currently not supported" )
8983
9084
91- def get_version (funct_args ):
85+ def get_version (job , funct_args ):
9286 if MPI .COMM_WORLD .rank == 0 :
9387 return job .version ()
9488
9589
96- def get_file (funct_args ):
90+ def get_file (job , funct_args ):
9791 job .file (* funct_args )
9892 return 1
9993
10094
101- def commands_list (funct_args ):
95+ def commands_list (job , funct_args ):
10296 job .commands_list (* funct_args )
10397 return 1
10498
10599
106- def commands_string (funct_args ):
100+ def commands_string (job , funct_args ):
107101 job .commands_string (* funct_args )
108102 return 1
109103
110104
111- def extract_setting (funct_args ):
105+ def extract_setting (job , funct_args ):
112106 if MPI .COMM_WORLD .rank == 0 :
113107 return job .extract_setting (* funct_args )
114108
115109
116- def extract_global (funct_args ):
110+ def extract_global (job , funct_args ):
117111 if MPI .COMM_WORLD .rank == 0 :
118112 return job .extract_global (* funct_args )
119113
120114
121- def extract_box (funct_args ):
115+ def extract_box (job , funct_args ):
122116 if MPI .COMM_WORLD .rank == 0 :
123117 return job .extract_box (* funct_args )
124118
125119
126- def extract_atom (funct_args ):
120+ def extract_atom (job , funct_args ):
127121 if MPI .COMM_WORLD .rank == 0 :
128122 # extract atoms return an internal data type
129123 # this has to be reformatted
@@ -153,12 +147,12 @@ def extract_atom(funct_args):
153147 return np .array (data )
154148
155149
156- def extract_fix (funct_args ):
150+ def extract_fix (job , funct_args ):
157151 if MPI .COMM_WORLD .rank == 0 :
158152 return job .extract_fix (* funct_args )
159153
160154
161- def extract_variable (funct_args ):
155+ def extract_variable (job , funct_args ):
162156 # in the args - if the third one,
163157 # which is the type is 1 - a lammps array is returned
164158 if funct_args [2 ] == 1 :
@@ -177,21 +171,21 @@ def extract_variable(funct_args):
177171 return data
178172
179173
180- def get_natoms (funct_args ):
174+ def get_natoms (job , funct_args ):
181175 if MPI .COMM_WORLD .rank == 0 :
182176 return job .get_natoms ()
183177
184178
185- def set_variable (funct_args ):
179+ def set_variable (job , funct_args ):
186180 return job .set_variable (* funct_args )
187181
188182
189- def reset_box (funct_args ):
183+ def reset_box (job , funct_args ):
190184 job .reset_box (* funct_args )
191185 return 1
192186
193187
194- def gather_atoms (funct_args ):
188+ def gather_atoms (job , funct_args ):
195189 # extract atoms return an internal data type
196190 # this has to be reformatted
197191 name = str (funct_args [0 ])
@@ -217,7 +211,7 @@ def gather_atoms(funct_args):
217211 return np .array (data )
218212
219213
220- def gather_atoms_concat (funct_args ):
214+ def gather_atoms_concat (job , funct_args ):
221215 # extract atoms return an internal data type
222216 # this has to be reformatted
223217 name = str (funct_args [0 ])
@@ -243,7 +237,7 @@ def gather_atoms_concat(funct_args):
243237 return np .array (data )
244238
245239
246- def gather_atoms_subset (funct_args ):
240+ def gather_atoms_subset (job , funct_args ):
247241 # convert to ctypes
248242 name = str (funct_args [0 ])
249243 lenids = int (funct_args [1 ])
@@ -280,76 +274,76 @@ def gather_atoms_subset(funct_args):
280274 return np .array (data )
281275
282276
283- def create_atoms (funct_args ):
277+ def create_atoms (job , funct_args ):
284278 job .create_atoms (* funct_args )
285279 return 1
286280
287281
288- def has_exceptions (funct_args ):
282+ def has_exceptions (job , funct_args ):
289283 return job .has_exceptions
290284
291285
292- def has_gzip_support (funct_args ):
286+ def has_gzip_support (job , funct_args ):
293287 return job .has_gzip_support
294288
295289
296- def has_png_support (funct_args ):
290+ def has_png_support (job , funct_args ):
297291 return job .has_png_support
298292
299293
300- def has_jpeg_support (funct_args ):
294+ def has_jpeg_support (job , funct_args ):
301295 return job .has_jpeg_support
302296
303297
304- def has_ffmpeg_support (funct_args ):
298+ def has_ffmpeg_support (job , funct_args ):
305299 return job .has_ffmpeg_support
306300
307301
308- def installed_packages (funct_args ):
302+ def installed_packages (job , funct_args ):
309303 return job .installed_packages
310304
311305
312- def set_fix_external_callback (funct_args ):
306+ def set_fix_external_callback (job , funct_args ):
313307 job .set_fix_external_callback (* funct_args )
314308 return 1
315309
316310
317- def get_neighlist (funct_args ):
311+ def get_neighlist (job , funct_args ):
318312 if MPI .COMM_WORLD .rank == 0 :
319313 return job .get_neighlist (* funct_args )
320314
321315
322- def find_pair_neighlist (funct_args ):
316+ def find_pair_neighlist (job , funct_args ):
323317 if MPI .COMM_WORLD .rank == 0 :
324318 return job .find_pair_neighlist (* funct_args )
325319
326320
327- def find_fix_neighlist (funct_args ):
321+ def find_fix_neighlist (job , funct_args ):
328322 if MPI .COMM_WORLD .rank == 0 :
329323 return job .find_fix_neighlist (* funct_args )
330324
331325
332- def find_compute_neighlist (funct_args ):
326+ def find_compute_neighlist (job , funct_args ):
333327 if MPI .COMM_WORLD .rank == 0 :
334328 return job .find_compute_neighlist (* funct_args )
335329
336330
337- def get_neighlist_size (funct_args ):
331+ def get_neighlist_size (job , funct_args ):
338332 if MPI .COMM_WORLD .rank == 0 :
339333 return job .get_neighlist_size (* funct_args )
340334
341335
342- def get_neighlist_element_neighbors (funct_args ):
336+ def get_neighlist_element_neighbors (job , funct_args ):
343337 if MPI .COMM_WORLD .rank == 0 :
344338 return job .get_neighlist_element_neighbors (* funct_args )
345339
346340
347- def get_thermo (funct_args ):
341+ def get_thermo (job , funct_args ):
348342 if MPI .COMM_WORLD .rank == 0 :
349343 return np .array (job .get_thermo (* funct_args ))
350344
351345
352- def scatter_atoms (funct_args ):
346+ def scatter_atoms (job , funct_args ):
353347 name = str (funct_args [0 ])
354348 py_vector = funct_args [1 ]
355349 # now see if its an integer or double type- but before flatten
@@ -366,7 +360,7 @@ def scatter_atoms(funct_args):
366360 return 1
367361
368362
369- def scatter_atoms_subset (funct_args ):
363+ def scatter_atoms_subset (job , funct_args ):
370364 name = str (funct_args [0 ])
371365 lenids = int (funct_args [2 ])
372366 ids = funct_args [3 ]
@@ -396,7 +390,7 @@ def scatter_atoms_subset(funct_args):
396390 return 1
397391
398392
399- def command (funct_args ):
393+ def command (job , funct_args ):
400394 job .command (funct_args )
401395 return 1
402396
@@ -464,13 +458,19 @@ def _gather_data_from_all_processors(data):
464458 return data
465459
466460
467- if __name__ == "__main__" :
461+ def _run_lammps_mpi ( argument_lst ) :
468462 if MPI .COMM_WORLD .rank == 0 :
469463 context = zmq .Context ()
470464 socket = context .socket (zmq .PAIR )
471- argument_lst = sys .argv
472465 port_selected = argument_lst [argument_lst .index ("--zmqport" ) + 1 ]
473466 socket .connect ("tcp://localhost:" + port_selected )
467+ else :
468+ context , socket = None , None
469+ # Lammps executable
470+ args = ["-screen" , "none" ]
471+ if len (argument_lst ) > 3 :
472+ args .extend (argument_lst [3 :])
473+ job = lammps (cmdargs = args )
474474 while True :
475475 if MPI .COMM_WORLD .rank == 0 :
476476 input_dict = cloudpickle .loads (socket .recv ())
@@ -485,8 +485,12 @@ def _gather_data_from_all_processors(data):
485485 context .term ()
486486 job .close ()
487487 break
488- output = select_cmd (input_dict ["c" ])(input_dict ["d" ])
488+ output = select_cmd (input_dict ["c" ])(job = job , funct_args = input_dict ["d" ])
489489 if MPI .COMM_WORLD .rank == 0 and output is not None :
490490 # with open('process.txt', 'a') as file:
491491 # print('Output:', output, file=file)
492492 socket .send (cloudpickle .dumps ({"r" : output }))
493+
494+
495+ if __name__ == "__main__" :
496+ _run_lammps_mpi (argument_lst = sys .argv )
0 commit comments