WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit 070db5f

Browse files
authored
Merge pull request #105 from pyiron/function_access
Make MPI functions accessible for external scripts
2 parents c78d6d0 + d4a2afe commit 070db5f

File tree

1 file changed

+48
-44
lines changed

1 file changed

+48
-44
lines changed

pylammpsmpi/mpi/lmpmpi.py

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,8 @@
4242
# taken directly from atom.cpp -> extract()
4343
}
4444

45-
# Lammps executable
46-
args = ["-screen", "none"]
47-
if len(sys.argv) > 3:
48-
args.extend(sys.argv[3:])
49-
job = lammps(cmdargs=args)
5045

51-
52-
def extract_compute(funct_args):
46+
def extract_compute(job, funct_args):
5347
def convert_data(val, type, length, width):
5448
data = []
5549
if type == 2:
@@ -88,42 +82,42 @@ def convert_data(val, type, length, width):
8882
raise ValueError("Local style is currently not supported")
8983

9084

91-
def get_version(funct_args):
85+
def get_version(job, funct_args):
9286
if MPI.COMM_WORLD.rank == 0:
9387
return job.version()
9488

9589

96-
def get_file(funct_args):
90+
def get_file(job, funct_args):
9791
job.file(*funct_args)
9892
return 1
9993

10094

101-
def commands_list(funct_args):
95+
def commands_list(job, funct_args):
10296
job.commands_list(*funct_args)
10397
return 1
10498

10599

106-
def commands_string(funct_args):
100+
def commands_string(job, funct_args):
107101
job.commands_string(*funct_args)
108102
return 1
109103

110104

111-
def extract_setting(funct_args):
105+
def extract_setting(job, funct_args):
112106
if MPI.COMM_WORLD.rank == 0:
113107
return job.extract_setting(*funct_args)
114108

115109

116-
def extract_global(funct_args):
110+
def extract_global(job, funct_args):
117111
if MPI.COMM_WORLD.rank == 0:
118112
return job.extract_global(*funct_args)
119113

120114

121-
def extract_box(funct_args):
115+
def extract_box(job, funct_args):
122116
if MPI.COMM_WORLD.rank == 0:
123117
return job.extract_box(*funct_args)
124118

125119

126-
def extract_atom(funct_args):
120+
def extract_atom(job, funct_args):
127121
if MPI.COMM_WORLD.rank == 0:
128122
# extract atoms return an internal data type
129123
# this has to be reformatted
@@ -153,12 +147,12 @@ def extract_atom(funct_args):
153147
return np.array(data)
154148

155149

156-
def extract_fix(funct_args):
150+
def extract_fix(job, funct_args):
157151
if MPI.COMM_WORLD.rank == 0:
158152
return job.extract_fix(*funct_args)
159153

160154

161-
def extract_variable(funct_args):
155+
def extract_variable(job, funct_args):
162156
# in the args - if the third one,
163157
# which is the type is 1 - a lammps array is returned
164158
if funct_args[2] == 1:
@@ -177,21 +171,21 @@ def extract_variable(funct_args):
177171
return data
178172

179173

180-
def get_natoms(funct_args):
174+
def get_natoms(job, funct_args):
181175
if MPI.COMM_WORLD.rank == 0:
182176
return job.get_natoms()
183177

184178

185-
def set_variable(funct_args):
179+
def set_variable(job, funct_args):
186180
return job.set_variable(*funct_args)
187181

188182

189-
def reset_box(funct_args):
183+
def reset_box(job, funct_args):
190184
job.reset_box(*funct_args)
191185
return 1
192186

193187

194-
def gather_atoms(funct_args):
188+
def gather_atoms(job, funct_args):
195189
# extract atoms return an internal data type
196190
# this has to be reformatted
197191
name = str(funct_args[0])
@@ -217,7 +211,7 @@ def gather_atoms(funct_args):
217211
return np.array(data)
218212

219213

220-
def gather_atoms_concat(funct_args):
214+
def gather_atoms_concat(job, funct_args):
221215
# extract atoms return an internal data type
222216
# this has to be reformatted
223217
name = str(funct_args[0])
@@ -243,7 +237,7 @@ def gather_atoms_concat(funct_args):
243237
return np.array(data)
244238

245239

246-
def gather_atoms_subset(funct_args):
240+
def gather_atoms_subset(job, funct_args):
247241
# convert to ctypes
248242
name = str(funct_args[0])
249243
lenids = int(funct_args[1])
@@ -280,76 +274,76 @@ def gather_atoms_subset(funct_args):
280274
return np.array(data)
281275

282276

283-
def create_atoms(funct_args):
277+
def create_atoms(job, funct_args):
284278
job.create_atoms(*funct_args)
285279
return 1
286280

287281

288-
def has_exceptions(funct_args):
282+
def has_exceptions(job, funct_args):
289283
return job.has_exceptions
290284

291285

292-
def has_gzip_support(funct_args):
286+
def has_gzip_support(job, funct_args):
293287
return job.has_gzip_support
294288

295289

296-
def has_png_support(funct_args):
290+
def has_png_support(job, funct_args):
297291
return job.has_png_support
298292

299293

300-
def has_jpeg_support(funct_args):
294+
def has_jpeg_support(job, funct_args):
301295
return job.has_jpeg_support
302296

303297

304-
def has_ffmpeg_support(funct_args):
298+
def has_ffmpeg_support(job, funct_args):
305299
return job.has_ffmpeg_support
306300

307301

308-
def installed_packages(funct_args):
302+
def installed_packages(job, funct_args):
309303
return job.installed_packages
310304

311305

312-
def set_fix_external_callback(funct_args):
306+
def set_fix_external_callback(job, funct_args):
313307
job.set_fix_external_callback(*funct_args)
314308
return 1
315309

316310

317-
def get_neighlist(funct_args):
311+
def get_neighlist(job, funct_args):
318312
if MPI.COMM_WORLD.rank == 0:
319313
return job.get_neighlist(*funct_args)
320314

321315

322-
def find_pair_neighlist(funct_args):
316+
def find_pair_neighlist(job, funct_args):
323317
if MPI.COMM_WORLD.rank == 0:
324318
return job.find_pair_neighlist(*funct_args)
325319

326320

327-
def find_fix_neighlist(funct_args):
321+
def find_fix_neighlist(job, funct_args):
328322
if MPI.COMM_WORLD.rank == 0:
329323
return job.find_fix_neighlist(*funct_args)
330324

331325

332-
def find_compute_neighlist(funct_args):
326+
def find_compute_neighlist(job, funct_args):
333327
if MPI.COMM_WORLD.rank == 0:
334328
return job.find_compute_neighlist(*funct_args)
335329

336330

337-
def get_neighlist_size(funct_args):
331+
def get_neighlist_size(job, funct_args):
338332
if MPI.COMM_WORLD.rank == 0:
339333
return job.get_neighlist_size(*funct_args)
340334

341335

342-
def get_neighlist_element_neighbors(funct_args):
336+
def get_neighlist_element_neighbors(job, funct_args):
343337
if MPI.COMM_WORLD.rank == 0:
344338
return job.get_neighlist_element_neighbors(*funct_args)
345339

346340

347-
def get_thermo(funct_args):
341+
def get_thermo(job, funct_args):
348342
if MPI.COMM_WORLD.rank == 0:
349343
return np.array(job.get_thermo(*funct_args))
350344

351345

352-
def scatter_atoms(funct_args):
346+
def scatter_atoms(job, funct_args):
353347
name = str(funct_args[0])
354348
py_vector = funct_args[1]
355349
# now see if its an integer or double type- but before flatten
@@ -366,7 +360,7 @@ def scatter_atoms(funct_args):
366360
return 1
367361

368362

369-
def scatter_atoms_subset(funct_args):
363+
def scatter_atoms_subset(job, funct_args):
370364
name = str(funct_args[0])
371365
lenids = int(funct_args[2])
372366
ids = funct_args[3]
@@ -396,7 +390,7 @@ def scatter_atoms_subset(funct_args):
396390
return 1
397391

398392

399-
def command(funct_args):
393+
def command(job, funct_args):
400394
job.command(funct_args)
401395
return 1
402396

@@ -464,13 +458,19 @@ def _gather_data_from_all_processors(data):
464458
return data
465459

466460

467-
if __name__ == "__main__":
461+
def _run_lammps_mpi(argument_lst):
468462
if MPI.COMM_WORLD.rank == 0:
469463
context = zmq.Context()
470464
socket = context.socket(zmq.PAIR)
471-
argument_lst = sys.argv
472465
port_selected = argument_lst[argument_lst.index("--zmqport") + 1]
473466
socket.connect("tcp://localhost:" + port_selected)
467+
else:
468+
context, socket = None, None
469+
# Lammps executable
470+
args = ["-screen", "none"]
471+
if len(argument_lst) > 3:
472+
args.extend(argument_lst[3:])
473+
job = lammps(cmdargs=args)
474474
while True:
475475
if MPI.COMM_WORLD.rank == 0:
476476
input_dict = cloudpickle.loads(socket.recv())
@@ -485,8 +485,12 @@ def _gather_data_from_all_processors(data):
485485
context.term()
486486
job.close()
487487
break
488-
output = select_cmd(input_dict["c"])(input_dict["d"])
488+
output = select_cmd(input_dict["c"])(job=job, funct_args=input_dict["d"])
489489
if MPI.COMM_WORLD.rank == 0 and output is not None:
490490
# with open('process.txt', 'a') as file:
491491
# print('Output:', output, file=file)
492492
socket.send(cloudpickle.dumps({"r": output}))
493+
494+
495+
if __name__ == "__main__":
496+
_run_lammps_mpi(argument_lst=sys.argv)

0 commit comments

Comments
 (0)