import torch
import torch.nn as nn
import torch.nn.functional as F
class EFC(nn.Module):
def __init__(self,
c1, c2
):
super().__init__()
self.conv1 = nn.Conv2d(c1, c2, kernel_size=1, stride=1)
self.conv2 = nn.Conv2d(c2, c2, kernel_size=1, stride=1)
self.conv4 = nn.Conv2d(c2, c2, kernel_size=1, stride=1)
self.bn = nn.BatchNorm2d(c2)
self.sigomid = nn.Sigmoid()
self.group_num = 16
self.eps = 1e-10
self.gamma = nn.Parameter(torch.randn(c2, 1, 1))
self.beta = nn.Parameter(torch.zeros(c2, 1, 1))
self.gate_genator = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(c2, c2, 1, 1),
nn.ReLU(True),
nn.Softmax(dim=1),
)
self.dwconv = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1, groups=c2)
self.conv3 = nn.Conv2d(c2, c2, kernel_size=1, stride=1)
self.Apt = nn.AdaptiveAvgPool2d(1)
self.one = c2
self.two = c2
self.conv4_gobal = nn.Conv2d(c2, 1, kernel_size=1, stride=1)
for group_id in range(0, 4):
self.interact = nn.Conv2d(c2 // 4, c2 // 4, 1, 1, )
def forward(self, x):
x1, x2 = x
global_conv1 = self.conv1(x1)
bn_x = self.bn(global_conv1)
weight_1 = self.sigomid(bn_x)
global_conv2 = self.conv2(x2)
bn_x2 = self.bn(global_conv2)
weight_2 = self.sigomid(bn_x2)
X_GOBAL = global_conv1 + global_conv2
x_conv4 = self.conv4_gobal(X_GOBAL)
X_4_sigmoid = self.sigomid(x_conv4)
X_ = X_4_sigmoid * X_GOBAL
X_ = X_.chunk(4, dim=1)
out = []
for group_id in range(0, 4):
out_1 = self.interact(X_[group_id])
N, C, H, W = out_1.size()
x_1_map = out_1.reshape(N, 1, -1)
mean_1 = x_1_map.mean(dim=2, keepdim=True)
x_1_av = x_1_map / mean_1
x_2_2 = F.softmax(x_1_av, dim=1)
x1 = x_2_2.reshape(N, C, H, W)
x1 = X_[group_id] * x1
out.append(x1)
out = torch.cat([out[0], out[1], out[2], out[3]], dim=1)
N, C, H, W = out.size()
x_add_1 = out.reshape(N, self.group_num, -1)
N, C, H, W = X_GOBAL.size()
x_shape_1 = X_GOBAL.reshape(N, self.group_num, -1)
mean_1 = x_shape_1.mean(dim=2, keepdim=True)
std_1 = x_shape_1.std(dim=2, keepdim=True)
x_guiyi = (x_add_1 - mean_1) / (std_1 + self.eps)
x_guiyi_1 = x_guiyi.reshape(N, C, H, W)
x_gui = (x_guiyi_1 * self.gamma + self.beta)
weight_x3 = self.Apt(X_GOBAL)
reweights = self.sigomid(weight_x3)
x_up_1 = reweights >= weight_1
x_low_1 = reweights < weight_1
x_up_2 = reweights >= weight_2
x_low_2 = reweights < weight_2
x_up = x_up_1 * X_GOBAL + x_up_2 * X_GOBAL
x_low