WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

center_nodes #5

@ywj66hh

Description

@ywj66hh

作者您好,我想要调用代码中的_local_pattern函数,center_nodes这个参数的取值应该是什么呢
def _local_pattern(self, center_nodes, r=0.1, r_resolution=100, phi_resolution=360):
assert self._model_name in ['CLCRN','CLCSTN'], 'the model does not provide the kernel visualization'
with torch.no_grad():
center_nodes = torch.from_numpy(np.array(center_nodes)).float().to(self._device)
N = center_nodes.shape[0]
angle_ratio = 1 / phi_resolution
rs = np.linspace(0, r, r_resolution)
phis = np.linspace(-np.pi, np.pi, phi_resolution)
xs = torch.from_numpy(rs[:, None] * np.cos(phis)[None, :]).float().to(self._device).flatten() # r_res * phi_res
ys = torch.from_numpy(rs[:, None] * np.sin(phis)[None, :]).float().to(self._device).flatten() # r_res * phi_res
vs = torch.stack([xs, ys], dim=-1)[None, :, :].repeat(N, 1, 1)

        kernel = self.model.get_kernel()
        local_pattern = kernel.kernel_prattern(center_nodes, vs, angle_ratio)
    return local_pattern, center_nodes, rs, phis

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions