import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
num_samples = 2000
num_features = 300
num_classes = 10
X = np.random.randn(num_samples, num_features)
y = np.random.randint(0, num_classes, num_samples)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
class TextDataset(Dataset):
def __init__(self, features, labels):
self.features = features
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return torch.tensor(self.features[idx], dtype=torch.float), torch.tensor(self.labels[idx], dtype=torch.long)
train_dataset = TextDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataset = TextDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
class TopKGating(nn.Module):
def __init__(self, input_dim, num_experts, top_k=2):
super(TopKGating, self).__init__()
self.gate = nn.Linear(input_dim, num_experts)
self.top_k = top_k
def forward(self, x):
gating_scores = self.gate(x)
top_k_values, top_k_indices = torch.topk(F.softmax(gating_scores, dim=1), self.top_k)
return top_k_indices, top_k_values
class Expert(nn.Module):
def __init__(self, input_dim, output_dim):
super(Expert, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, output_dim)
)
def forward(self, x):
return self.net(x)
class MoE(nn.Module):
def __init__(self, input_dim, num_classes, num_experts, top_k=2):
super(MoE, self).__init__()
self.num_experts = num_experts
self.num_classes = num_classes
self.gating = TopKGating(input_dim, num_experts, top_k)
self.experts = nn.ModuleList([Expert(input_dim, num_classes) for _ in range(num_experts)])
def forward(self, x):
batch_size = x.size(0)
indices, gates = self.gating(x)
expert_outputs = torch.zeros(batch_size, indices.size(1), self.num_classes).to(x.device)
for i in range(batch_size):
for j in range(indices.size(1)):
expert_idx = indices[i, j].item()
expert_outputs[i, j, :] = self.experts[expert_idx](x[i].unsqueeze(0))
gates = gates.unsqueeze(-1).expand(-1, -1, self.num_classes)
output = (gates * expert_outputs).sum(1)
return output, gates.sum(0)
def moe_loss(output, target, gating_weights, lambda_balance=0.1):
standard_loss = F.cross_entropy(output, target)
balance_loss = torch.std(gating_weights)