import numpy as np
from dipy.utils.optpkg import optional_package
from scilpy.ml.utils import IMPORT_ERROR_MSG
torch, have_torch, _ = optional_package('torch', trip_msg=IMPORT_ERROR_MSG)
# From https://github.com/tatp22/multidim-positional-encoding
[docs]
def get_emb(sin_inp):
"""
Gets a base embedding for one dimension with sin and cos intertwined
"""
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
return torch.flatten(emb, -2, -1)
class PositionalEncoding3D(torch.nn.Module):
def __init__(self, channels, dtype_override=None):
"""
:param channels: The last dimension of the tensor you want to apply pos emb to.
:param dtype_override: If set, overrides the dtype of the output embedding.
"""
super(PositionalEncoding3D, self).__init__()
self.org_channels = channels
channels = int(np.ceil(channels / 6) * 2)
if channels % 2:
channels += 1
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("cached_penc", None, persistent=False)
self.dtype_override = dtype_override
self.channels = channels
def forward(self, tensor):
"""
:param tensor: A 5d tensor of size (batch_size, x, y, z, ch)
:return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)
"""
if len(tensor.shape) != 5:
raise RuntimeError("The input tensor has to be 5d!")
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
return self.cached_penc
self.cached_penc = None
batch_size, x, y, z, orig_ch = tensor.shape
pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
pos_y = torch.arange(y, device=tensor.device, dtype=self.inv_freq.dtype)
pos_z = torch.arange(z, device=tensor.device, dtype=self.inv_freq.dtype)
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq)
emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1)
emb_y = get_emb(sin_inp_y).unsqueeze(1)
emb_z = get_emb(sin_inp_z)
emb = torch.zeros(
(x, y, z, self.channels * 3),
device=tensor.device,
dtype=(
self.dtype_override if self.dtype_override is not None else tensor.dtype
),
)
emb[:, :, :, : self.channels] = emb_x
emb[:, :, :, self.channels : 2 * self.channels] = emb_y
emb[:, :, :, 2 * self.channels :] = emb_z
self.cached_penc = emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1)
return self.cached_penc
class PositionalEncodingPermute3D(torch.nn.Module):
def __init__(self, channels, dtype_override=None):
"""
Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch)
"""
super(PositionalEncodingPermute3D, self).__init__()
self.penc = PositionalEncoding3D(channels, dtype_override)
def forward(self, tensor):
tensor = tensor.permute(0, 2, 3, 4, 1)
enc = self.penc(tensor)
return enc.permute(0, 4, 1, 2, 3)
@property
def org_channels(self):
return self.penc.org_channels