diff --git a/package/AUTHORS b/package/AUTHORS index e23f2afcf2..8db9489238 100644 --- a/package/AUTHORS +++ b/package/AUTHORS @@ -265,6 +265,7 @@ Chronological list of authors - Raúl Lois-Cuns - Pranay Pelapkar - Shreejan Dolai + - Pardhav Maradani External code ------------- diff --git a/package/CHANGELOG b/package/CHANGELOG index 642e717b85..ab64174402 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -15,7 +15,7 @@ The rules for this file: ------------------------------------------------------------------------------- ??/??/?? IAlibay, orbeckst, marinegor, tylerjereddy, ljwoods2, marinegor, - spyke7, talagayev + spyke7, talagayev, PardhavMaradani * 2.11.0 @@ -30,6 +30,9 @@ Fixes DSSP by porting upstream PyDSSP 0.9.1 fix (Issue #4913) Enhancements + * Added a current frame iterator `StreamFrameIteratorCurrent` for streamed + trajectories to enable `AnalysisBase.run` on per-frame streamed data + (Issue #5183, PR #5184) * Enables parallelization for analysis.diffusionmap.DistanceMatrix (Issue #4679, PR #4745) diff --git a/package/MDAnalysis/coordinates/IMD.py b/package/MDAnalysis/coordinates/IMD.py index 4945d43adf..16b6bf93dc 100644 --- a/package/MDAnalysis/coordinates/IMD.py +++ b/package/MDAnalysis/coordinates/IMD.py @@ -286,6 +286,9 @@ def __init__( raise RuntimeError(f"IMDReader: Read error: {e}") from e def _read_frame(self, frame): + if frame == self._frame: + logger.debug("IMDReader: Using current frame %d", self._frame) + return self.ts imdf = self._imdclient.get_imdframe() diff --git a/package/MDAnalysis/coordinates/base.py b/package/MDAnalysis/coordinates/base.py index d2bfbe63d2..eeb89521dd 100644 --- a/package/MDAnalysis/coordinates/base.py +++ b/package/MDAnalysis/coordinates/base.py @@ -52,6 +52,8 @@ .. autoclass:: StreamFrameIteratorSliced +.. autoclass:: StreamFrameIteratorCurrent + .. _ReadersBase: Readers @@ -1882,10 +1884,11 @@ class StreamReaderBase(ReaderBase): See Also -------- StreamFrameIteratorSliced : Iterator for stepped streaming access + StreamFrameIteratorCurrent : Iterator for current frame streaming access ReaderBase : Base class for standard trajectory readers - .. versionadded:: 2.10.0 + .. versionadded:: 2.10.0 """ def __init__(self, filename, convert_units=True, **kwargs): @@ -2005,7 +2008,9 @@ def _reopen(self): raise RuntimeError( "{}: Cannot reopen stream".format(self.__class__.__name__) ) - self._frame = -1 + if self._frame == 0: + # only reset when stream hasn't been iterated using next + self._frame = -1 self._reopen_called = True def timeseries(self, **kwargs): @@ -2040,7 +2045,7 @@ def __getitem__(self, frame): Returns ------- - FrameIteratorAll or StreamFrameIteratorSliced + FrameIteratorAll or StreamFrameIteratorSliced or StreamFrameIteratorCurrent Iterator for the requested slice. Raises @@ -2060,8 +2065,16 @@ def __getitem__(self, frame): See Also -------- StreamFrameIteratorSliced + StreamFrameIteratorCurrent """ - if isinstance(frame, slice): + if isinstance(frame, numbers.Integral): + if frame == self.trajectory.frame: + return self._read_frame(frame) + else: + raise ValueError( + "Streamed trajectories must specify current frame value" + ) + elif isinstance(frame, slice): _, _, step = self.check_slice_indices( frame.start, frame.stop, frame.step ) @@ -2069,6 +2082,13 @@ def __getitem__(self, frame): return FrameIteratorAll(self) else: return StreamFrameIteratorSliced(self, step) + elif isinstance(frame, (list, np.ndarray)): + if len(frame) == 1 and frame[0] == self.trajectory.frame: + return StreamFrameIteratorCurrent(self) + else: + raise ValueError( + "Streamed trajectories must have single current frame value" + ) else: raise TypeError( "Streamed trajectories must be an indexed using a slice" @@ -2310,4 +2330,39 @@ def step(self): Step size for iteration. Always a positive integer greater than 0. """ - return self._step \ No newline at end of file + return self._step + + +class StreamFrameIteratorCurrent(FrameIteratorBase): + """Iterator for current frame access in a streamed trajectory. + + Created when an array with a single current frame value is passed. + + Parameters + ---------- + trajectory : StreamReaderBase + The streaming trajectory reader to iterate over. Must be a + stream-based reader that supports continuous data reading. + + See Also + -------- + StreamReaderBase + FrameIteratorBase + + .. versionadded:: 2.11.0 + """ + + def __init__(self, trajectory): + super(StreamFrameIteratorCurrent, self).__init__(trajectory) + + def __len__(self): + return 1 + + def __iter__(self): + yield self.trajectory._read_frame(self.trajectory.frame) + + def __next__(self): + raise StopIteration from None + + def __getitem__(self, frame): + raise RuntimeError("Current frame iterator does not support indexing") diff --git a/testsuite/MDAnalysisTests/coordinates/test_imd.py b/testsuite/MDAnalysisTests/coordinates/test_imd.py index 2fed2f761f..1feb1d1498 100644 --- a/testsuite/MDAnalysisTests/coordinates/test_imd.py +++ b/testsuite/MDAnalysisTests/coordinates/test_imd.py @@ -469,8 +469,12 @@ def test_iterate_twice_fi_all_raises_error(self, reader): pass def test_index_stream_raises_error(self, reader): - with pytest.raises(TypeError, match="Streamed trajectories must be"): - reader[0] + reader[0] + with pytest.raises( + ValueError, + match="Streamed trajectories must specify current frame value", + ): + reader[1] def test_iterate_backwards_raises_error(self, reader): with pytest.raises(ValueError, match="Cannot go backwards"): @@ -507,6 +511,335 @@ def test_step_property(self, reader): sliced_reader_step5 = reader[::5] assert sliced_reader_step5.step == 5 + def test_iterate_current_frame_raises_error(self, reader): + with pytest.raises( + ValueError, + match="must have single current frame value", + ): + for ts in reader[[1]]: + pass + ts = reader[[0]] + with pytest.raises(StopIteration): + next(ts) + with pytest.raises( + RuntimeError, + match="Current frame iterator does not support indexing", + ): + ts[0] + + def test_iterate_current_frame(self, reader): + cts = reader.ts + # test iterator length + assert len(reader[[reader.frame]]) == 1 + # test list iterator + for ts in reader[[reader.frame]]: + assert ts == cts + assert ts.frame == reader.frame + # test np.ndarray iterator + reader[np.array([reader.frame])] + # test same timestep + assert reader[reader.frame] == cts + assert reader[reader.frame] == reader[reader.frame] + # should be able to iterate all 5 frames in reader + # due to server.send_frames(1, 5) in reader setup + for i in range(5): + ts = reader[i] + if i < 4: + reader.next() + else: + with pytest.raises(StopIteration): + reader.next() + + def test_iterate_current_frame_no_transformations(self, reader): + reader.add_transformations( + translate([1, 1, 1]), translate([0, 0, 0.33]) + ) + p1 = reader[reader.frame].positions.copy() + p2 = reader[reader.frame].positions + assert_allclose(p1, p2) + + def test_iterate_continuity_1(self, reader): + step = -1 + for ts in reader: + step += 1 + assert ts.data["step"] == step + if step == 4: + break + + def test_iterate_continuity_2(self, reader): + ts = reader[0] + assert ts.data["step"] == 0 + reader.next() + ts = reader[1] + assert ts.data["step"] == 1 + step = 1 + for ts in reader: + step += 1 + assert ts.data["step"] == step + if step == 4: + break + + +@pytest.mark.skipif(not HAS_IMDCLIENT, reason="IMDClient not installed") +class TestAnalysisClasses: + """ + Tests for AnalysisBase-based classes + + Following classes currently do not work: + align.AlignTraj + align.AverageStructure + diffusionmap.DiffusionMap + gnm.GNMAnalysis + gnm.closeContactGNMAnalysis + pca.PCA + polymer.PersistenceLength + rms.RMSD + hydrogenbonds.HydrogenBondAutoCorrel + + """ + + @pytest.fixture + def create_imd_universe(self): + server = None + + def _create_imd_universe( + topo, + traj, + velocities=True, + forces=True, + box=True, + frames=5, + ): + nonlocal server + u = mda.Universe(topo, traj) + server = InThreadIMDServer(u.trajectory) + info = create_default_imdsinfo_v3() + info.velocities = velocities + info.forces = forces + info.box = box + server.set_imdsessioninfo(info) + server.handshake_sequence("localhost", first_frame=True) + u_imd = mda.Universe( + topo, + f"imd://localhost:{server.port}", + n_atoms=u.trajectory.n_atoms, + ) + server.send_frames(1, frames) + return u_imd + + yield _create_imd_universe + server.cleanup() + + @pytest.fixture + def u1(self, create_imd_universe): + from MDAnalysisTests.datafiles import TPR, XTC + + return create_imd_universe( + topo=TPR, + traj=XTC, + velocities=False, + forces=False, + ) + + @pytest.fixture + def u2(self, create_imd_universe): + from MDAnalysisTests.datafiles import PSF_TRICLINIC, DCD_TRICLINIC + + return create_imd_universe( + topo=PSF_TRICLINIC, + traj=DCD_TRICLINIC, + velocities=False, + forces=False, + ) + + @pytest.fixture + def u3(self, create_imd_universe): + from MDAnalysisTests.datafiles import RNA_PSF, RNA_PDB + + return create_imd_universe( + topo=RNA_PSF, + traj=RNA_PDB, + velocities=False, + forces=False, + box=False, + frames=1, + ) + + @pytest.fixture + def u4(self, create_imd_universe): + from MDAnalysisTests.datafiles import waterPSF, waterDCD + + return create_imd_universe( + topo=waterPSF, + traj=waterDCD, + velocities=False, + forces=False, + ) + + def test_atomicdistances(self, u1): + from MDAnalysis.analysis.atomicdistances import AtomicDistances + + ag1 = u1.atoms[10:20] + ag2 = u1.atoms[70:80] + ad = AtomicDistances(ag1, ag2) + for i in range(3): + ad.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_bat(self, u1): + from MDAnalysis.analysis.bat import BAT + + selected_residues = u1.select_atoms("resid 1:10") + bat = BAT(selected_residues) + for i in range(3): + bat.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_contacts(self, u1): + from MDAnalysis.analysis.contacts import Contacts + + sel_basic = "(resname ARG LYS) and (name NH* NZ)" + sel_acidic = "(resname ASP GLU) and (name OE* OD*)" + acidic = u1.select_atoms(sel_acidic) + basic = u1.select_atoms(sel_basic) + ca1 = Contacts( + u1, + select=(sel_acidic, sel_basic), + refgroup=(acidic, basic), + radius=6.0, + ) + for i in range(3): + ca1.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_density(self, u1): + from MDAnalysis.analysis.density import DensityAnalysis + + ow = u1.select_atoms("protein") + da = DensityAnalysis(ow, delta=1.0) + for i in range(3): + da.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_dielectric(self, u2): + from MDAnalysis.analysis.dielectric import DielectricConstant + + diel = DielectricConstant(u2.atoms) + for i in range(3): + diel.run(frames=[u2.trajectory.frame]) + u2.trajectory.next() + + def test_diffusionmap(self, u1): + import MDAnalysis.analysis.diffusionmap as diffusionmap + + dm = diffusionmap.DistanceMatrix(u1, select="backbone") + for i in range(3): + dm.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_dihedrals(self, u1): + from MDAnalysis.analysis.dihedrals import Dihedral, Janin, Ramachandran + + ags = [res.phi_selection() for res in u1.residues[4:9]] + d = Dihedral(ags) + rama = Ramachandran(u1.select_atoms("protein")) + janin = Janin(u1.select_atoms("protein")) + for i in range(3): + d.run(frames=[u1.trajectory.frame]) + rama.run(frames=[u1.trajectory.frame]) + janin.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_helix_analysis(self, u1): + from MDAnalysis.analysis import helix_analysis as hel + + ha = hel.HELANAL( + u1.select_atoms("name CA"), + select="resnum 1-9", + flatten_single_helix=True, + ) + for i in range(3): + ha.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_lineardensity(self, u1): + from MDAnalysis.analysis.lineardensity import LinearDensity + + ld = LinearDensity(u1.select_atoms("all"), binsize=5) + for i in range(3): + ld.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_msd(self, u1): + from MDAnalysis.analysis import msd + + m = msd.EinsteinMSD(u1, "all", non_linear=True) + for i in range(3): + m.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_nucleicacids(self, u3): + from MDAnalysis.analysis import nucleicacids + from MDAnalysis.core.groups import ResidueGroup + + strand = u3.select_atoms("resid 1-100") + strand1 = ResidueGroup([strand.residues[0], strand.residues[21]]) + strand2 = ResidueGroup([strand.residues[2], strand.residues[22]]) + wcd = nucleicacids.WatsonCrickDist(strand1, strand2) + wcd.run(frames=[u3.trajectory.frame]) + strand1 = ResidueGroup([strand.residues[2], strand.residues[19]]) + strand2 = ResidueGroup([strand.residues[16], strand.residues[4]]) + mi = nucleicacids.MinorPairDist(strand1, strand2) + mi.run(frames=[u3.trajectory.frame]) + strand1 = ResidueGroup([strand.residues[1], strand.residues[4]]) + strand2 = ResidueGroup([strand.residues[11], strand.residues[8]]) + ma = nucleicacids.MajorPairDist(strand1, strand2) + ma.run(frames=[u3.trajectory.frame]) + + def test_rdf(self, u1): + from MDAnalysis.analysis import rdf + + ag1 = u1.select_atoms("resid 1:10") + ag2 = u1.select_atoms("resid 30:40") + r = rdf.InterRDF(ag1, ag2) + r_s = rdf.InterRDF_s(u1, [[ag1, ag1], [ag2, ag2]]) + for i in range(3): + r.run(frames=[u1.trajectory.frame]) + r_s.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_rmsf(self, u1): + from MDAnalysis.analysis.rms import RMSF + + rmsf = RMSF(u1.select_atoms("name CA")) + for i in range(3): + rmsf.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_dssp(self, u1): + from MDAnalysis.analysis.dssp import DSSP + + dssp = DSSP(u1) + for i in range(3): + dssp.run(frames=[u1.trajectory.frame]) + u1.trajectory.next() + + def test_hba(self, u4): + from MDAnalysis.analysis.hydrogenbonds import hbond_analysis + + hba = hbond_analysis.HydrogenBondAnalysis(universe=u4) + for i in range(3): + hba.run(frames=[u4.trajectory.frame]) + u4.trajectory.next() + + def test_wba(self, u4): + from MDAnalysis.analysis.hydrogenbonds import WaterBridgeAnalysis + + wba = WaterBridgeAnalysis(u4, "resid 1:10", "resid 10:20") + for i in range(3): + wba.run(frames=[u4.trajectory.frame]) + u4.trajectory.next() + @pytest.mark.skipif(not HAS_IMDCLIENT, reason="IMDClient not installed") def test_n_atoms_not_specified(universe, imdsinfo):