diff --git a/patch2self.py b/patch2self.py index cf69d29..210d807 100644 --- a/patch2self.py +++ b/patch2self.py @@ -8,7 +8,7 @@ def patch2self(data, patch_radius=[1, 1, 1], mask=None): # If mask is not specified, use the whole volume mask = np.ones_like(data, dtype=bool)[..., 0] - def _extract_2d_patches(arr, patch_radius=[1, 1, 1]): + def _extract_3d_patches(arr, patch_radius=[1, 1, 1]): if isinstance(patch_radius, int): patch_radius = np.ones(3, dtype=int) * patch_radius @@ -18,7 +18,6 @@ def _extract_2d_patches(arr, patch_radius=[1, 1, 1]): patch_radius = np.asarray(patch_radius).astype(int) patch_size = 2 * patch_radius + 1 - dim = arr.shape[-1] all_patches = [] @@ -39,7 +38,8 @@ def _extract_2d_patches(arr, patch_radius=[1, 1, 1]): return np.array(all_patches).T - train = _extract_2d_patches(data, patch_radius=patch_radius) + train = _extract_3d_patches(np.pad(data, ((1, 1), (1, 1), (1, 1))) + , patch_radius=patch_radius) print(train.shape) print('Patch Extraction Done...') @@ -59,33 +59,18 @@ def _extract_2d_patches(arr, patch_radius=[1, 1, 1]): print('Training for resolution: ', patch_radius) cur_X = np.reshape(np.concatenate((X1, X2), axis=0), - (train.shape[0]-1, - train.shape[1]*train.shape[2])) + (-1, train.shape[2])) + + Y = train[f, train.shape[1]//2, :] - Y = np.reshape(train[f, :, :], (1, - train.shape[1]*train.shape[2])) + model.fit(cur_X.T[::100], Y.T[::100]) - model.fit(cur_X.T, Y.T) del cur_X, Y, X1, X2 print(' -> Trained to Denoise Volume: ', f) - - theta = np.zeros((data.shape[0], data.shape[1], data.shape[2], 1)) - for i in range(f, f+1): - for k in range(0, theta.shape[2]): - for j in range(0, theta.shape[1]): - for l in range(0, theta.shape[0]): - if not mask[l, j, k]: - continue - X = np.reshape(data[l, j, k, :], (data.shape[3], 1)) - X1 = X[:i, :] - X2 = X[i+1:, :] - cur_x = np.reshape(np.concatenate((X1, X2), axis=0), - (1, data.shape[3]-1)) - - cur_y = model.predict(cur_x) - theta[l, j, k, i-f] = cur_y - del model - denoised_array[..., f] = np.squeeze(theta) + + denoised_array[..., f] = model.predict(cur_X.T).reshape( + data.shape[0], data.shape[1], data.shape[2]) + print('Denoising Volume ', f, ' Complete...') denoised_array[mask == 0] = 0