"""
河图洛书智能体 v1.0 (HeTu LuoShu Agent)
乾-坤八卦复阵:五环闭环系统
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import copy
import os
from datetime import datetime
# ================== 安全边界 ==================
class SafetyBoundary:
PARAM_BOUNDS = {
'sheng_strength': (0.05, 0.95),
'ke_strength': (0.01, 0.80),
'yang_weights': (0.05, 0.95),
'yin_weights': (0.05, 0.95),
'center_state': (-2.0, 2.0),
'coupling_coef': (0.10, 0.90),
}
COUPLING_SPARSITY_TARGET = 0.3
COUPLING_SPARSITY_WEIGHT = 0.01
STRUCTURE = {
'num_kun_units': 8,
'num_heads': 8,
'num_memory_slots': 8,
}
# ================== 1. 洛书核心(修复版)==================
class LuoShuCore(nn.Module):
def __init__(self):
super().__init__()
self.state = nn.Parameter(torch.randn(9) * 0.1)
ke = torch.zeros(9, 9)
ke[8,7] = 1.0; ke[7,4] = 0.8; ke[4,5] = 0.9
ke[5,6] = 0.7; ke[6,8] = 0.6
self.register_buffer('ke_matrix', ke)
def forward(self, feedback):
influence = torch.matmul(feedback, self.ke_matrix)
new_state = self.state + 0.05 * torch.tanh(influence)
# 修复:使用 data 属性赋值,不能直接赋值 Parameter
self.state.data = new_state
dispatch = torch.sigmoid(self.state)
return dispatch, self.state
# ================== 2. 坤八卦 ==================
class KunBaGua(nn.Module):
def __init__(self, feat_dim=32):
super().__init__()
self.cores = nn.ModuleList([LuoShuCore() for _ in range(8)])
self.sheng_strength = nn.Parameter(torch.ones(8) * 0.5)
self.ke_strength = nn.Parameter(torch.ones(8) * 0.3)
def apply_safety_guard(self):
with torch.no_grad():
self.sheng_strength.data = torch.clamp(self.sheng_strength.data, 0.05, 0.95)
self.ke_strength.data = torch.clamp(self.ke_strength.data, 0.01, 0.80)
def forward(self, feedbacks, global_dispatch=None):
local_dispatches = []
for i, core in enumerate(self.cores):
fb = feedbacks if feedbacks is not None else torch.zeros(9, device=core.state.device)
dispatch, _ = core(fb)
dispatch = dispatch * self.sheng_strength
local_dispatches.append(dispatch)
if global_dispatch is not None:
for i in range(8):
ke_idx = (i + 2) % 8
ke_factor = 1 - self.ke_strength * local_dispatches[ke_idx].mean()
local_dispatches = local_dispatches * ke_factor
local_dispatches = local_dispatches * (0.5 + global_dispatch.mean())
return local_dispatches
# ================== 3. 乾总调度器 ==================
class QianZong(nn.Module):
def __init__(self):
super().__init__()
self.core = LuoShuCore()
self.in_proj = nn.Linear(8, 9)
def forward(self, kun_losses):
feedback = self.in_proj(kun_losses.unsqueeze(0)).squeeze(0)
dispatch, _ = self.core(feedback)
return dispatch
# ================== 4. 河图三界 ==================
class XiaJie(nn.Module):
def __init__(self, in_ch=1, out_ch=32):
super().__init__()
self.group_ch = out_ch // 8
self.conv_sheng = nn.ModuleList([
nn.Conv2d(in_ch, self.group_ch, 3, padding=1) for _ in range(8)
])
self.conv_cheng = nn.ModuleList([
nn.Conv2d(in_ch, self.group_ch, 5, padding=2) for _ in range(8)
])
self.fusion = nn.Conv2d(out_ch * 2, out_ch, 1)
self.act = nn.ReLU()
def forward(self, x, local_dispatches):
sheng_parts = []
cheng_parts = []
for i in range(8):
scale = 0.5 + local_dispatches.mean()
sheng = self.act(self.conv_sheng(x)) * scale
cheng = self.act(self.conv_cheng(x)) * scale
sheng_parts.append(sheng)
cheng_parts.append(cheng)
sheng_out = torch.cat(sheng_parts, dim=1)
cheng_out = torch.cat(cheng_parts, dim=1)
combined = torch.cat([sheng_out, cheng_out], dim=1)
out = self.act(self.fusion(combined))
return out
class ZhongJie(nn.Module):
def __init__(self, dim=32, num_heads=8):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
self.min_dist = nn.Parameter(torch.ones(num_heads) * 2)
def forward(self, x, local_dispatches):
B, C, H, W = x.shape
N = H * W
x_flat = x.flatten(2).permute(0, 2, 1)
qkv = self.qkv(x_flat).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn_out = []
for h in range(self.num_heads):
min_d = int(self.min_dist[h].item())
with torch.no_grad():
coords_y, coords_x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
coords = torch.stack([coords_x.flatten(), coords_y.flatten()], dim=1).float().to(x.device)
dist = torch.cdist(coords, coords)
mask = (dist < min_d).bool()
attn = torch.matmul(q[:, h], k[:, h].transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = attn.masked_fill(mask.unsqueeze(0), -float('inf'))
attn = F.softmax(attn, dim=-1)
out_h = torch.matmul(attn, v[:, h])
attn_out.append(out_h)
out = torch.stack(attn_out, dim=2).reshape(B, N, C)
out = self.proj(out)
for h in range(self.num_heads):
scale = 0.5 + local_dispatches[h % 8].mean()
out[:, :, h * self.head_dim h + 1) * self.head_dim] *= scale
out = out.permute(0, 2, 1).view(B, C, H, W)
return x + out
class ShangJie(nn.Module):
def __init__(self, in_ch=32, num_classes=10):
super().__init__()
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(in_ch, num_classes)
def forward(self, x):
return self.fc(self.gap(x).flatten(1))
# ================== 5. 记忆库 ==================
class HTLSMemory(nn.Module):
def __init__(self, input_dim=32, hidden_dim=64, num_slots=8, reset_mode='epoch'):
super().__init__()
self.num_slots = num_slots
self.hidden_dim = hidden_dim
self.reset_mode = reset_mode
self.rnns = nn.ModuleList([nn.GRUCell(input_dim, hidden_dim) for _ in range(num_slots)])
self.proj = nn.Linear(hidden_dim * num_slots, input_dim)
self.states = [None] * num_slots
def reset_states(self):
self.states = [None] * self.num_slots
def forward(self, x, local_dispatches):
batch_size = x.size(0)
device = x.device
slot_outputs = []
for i in range(self.num_slots):
if self.states is None or self.states.size(0) != batch_size:
self.states = torch.zeros(batch_size, self.hidden_dim, device=device)
scale = 0.5 + local_dispatches.mean()
x_mod = x * scale
new_state = self.rnns(x_mod, self.states)
self.states = new_state.detach()
slot_outputs.append(new_state)
out = torch.cat(slot_outputs, dim=1)
out = self.proj(out)
return out
# ================== 6. 完整智能体 ==================
class HeTuLuoShuAgent(nn.Module):
def __init__(self, num_classes=10, memory_reset_mode='epoch'):
super().__init__()
self.qian = QianZong()
self.kun = KunBaGua(feat_dim=32)
self.xia = XiaJie(in_ch=1, out_ch=32)
self.zhong = ZhongJie(dim=32, num_heads=8)
self.memory = HTLSMemory(input_dim=32, hidden_dim=64, num_slots=8, reset_mode=memory_reset_mode)
self.shang = ShangJie(in_ch=32, num_classes=num_classes)
self.coupling = nn.Parameter(torch.ones(8) * 0.5)
def apply_safety_guard(self):
with torch.no_grad():
self.coupling.data = torch.clamp(self.coupling.data, 0.10, 0.90)
self.kun.apply_safety_guard()
def reset_memory(self):
self.memory.reset_states()
def forward(self, x):
B = x.shape[0]
device = x.device
with torch.no_grad():
dummy_feedbacks = [None] * 8
dummy_local = self.kun(dummy_feedbacks, None)
dummy_x = self.xia(x, dummy_local)
dummy_aux = self.shang(dummy_x)
dummy_probs = F.softmax(dummy_aux, dim=-1)
entropy = -(dummy_probs * torch.log(dummy_probs + 1e-8)).sum(dim=-1)
kun_losses = entropy[:8] if B >= 8 else torch.cat([entropy, torch.zeros(8 - B, device=device)])
global_dispatch = self.qian(kun_losses)
feedbacks = [None] * 8
local_dispatches = self.kun(feedbacks, global_dispatch)
for i in range(8):
local_dispatches = local_dispatches * self.coupling
x = self.xia(x, local_dispatches)
x = self.zhong(x, local_dispatches)
x_pool = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)
mem_out = self.memory(x_pool, local_dispatches)
mem_out = mem_out.view(-1, 32, 1, 1)
x = x + mem_out * 0.3
logits = self.shang(x)
return logits, {
'global_dispatch': global_dispatch,
'local_dispatches': local_dispatches,
'coupling': self.coupling,
'kun_losses': kun_losses,
'entropy': entropy.mean()
}
# ================== 7. 训练器 ==================
class Trainer:
def __init__(self, model, device, lr=0.001):
self.model = model
self.device = device
self.criterion = nn.CrossEntropyLoss()
self.base_lr = 0.0001
self.max_lr = lr
self.warmup_epochs = 3
self.current_epoch = 0
self.optimizer = optim.AdamW(model.parameters(), lr=self.base_lr, weight_decay=1e-4)
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=50)
self.snapshot = None
def update_learning_rate(self):
if self.current_epoch < self.warmup_epochs:
lr = self.base_lr + (self.max_lr - self.base_lr) * (self.current_epoch / self.warmup_epochs)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
else:
self.scheduler.step()
def take_snapshot(self):
self.snapshot = copy.deepcopy(self.model.state_dict())
def rollback(self):
if self.snapshot:
self.model.load_state_dict(self.snapshot)
print("⚠️ 已回滚")
def train_epoch(self, train_loader, epoch):
self.model.train()
self.model.reset_memory()
self.current_epoch = epoch
self.update_learning_rate()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
logits, aux = self.model(data)
loss = self.criterion(logits, target)
dispatch_std = torch.std(torch.stack([d.mean() for d in aux['local_dispatches']]))
diversity_loss = 0.05 * max(0, 1.0 - dispatch_std.item())
coupling_std = torch.std(self.model.coupling)
sparsity_loss = SafetyBoundary.COUPLING_SPARSITY_WEIGHT * max(
0, SafetyBoundary.COUPLING_SPARSITY_TARGET - coupling_std.item()
)
total_loss_val = loss + diversity_loss + sparsity_loss
total_loss_val.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.model.apply_safety_guard()
total_loss += loss.item()
pred = logits.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
if batch_idx % 100 == 0:
acc = 100. * correct / total if total > 0 else 0
current_lr = self.optimizer.param_groups[0]['lr']
print(f"Epoch {epoch} [{batch_idx}/{len(train_loader)}] "
f"Loss: {loss.item():.4f} | Acc: {acc:.2f}% | LR: {current_lr:.5f}")
avg_loss = total_loss / len(train_loader)
accuracy = 100. * correct / total
return avg_loss, accuracy
def evaluate(self, test_loader):
self.model.eval()
self.model.reset_memory()
correct = 0
total = 0
test_loss = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(self.device), target.to(self.device)
logits, _ = self.model(data)
loss = self.criterion(logits, target)
test_loss += loss.item()
pred = logits.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
accuracy = 100. * correct / total
avg_loss = test_loss / len(test_loader)
return avg_loss, accuracy
# ================== 8. 主训练 ==================
def main(epochs=30, use_tensorboard=False):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("=" * 70)
print("河图洛书智能体 v1.0")
print("乾-坤八卦复阵:五环闭环系统")
print("=" * 70)
transform_train = transforms.Compose([
transforms.RandomAffine(degrees=5, translate=(0.05, 0.05)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform_train)
test_dataset = datasets.MNIST('.', train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=0)
model = HeTuLuoShuAgent(num_classes=10, memory_reset_mode='epoch').to(device)
trainer = Trainer(model, device, lr=0.001)
total_params = sum(p.numel() for p in model.parameters())
print(f"\n模型参数量: {total_params:,}")
print(f"设备: {device}")
best_acc = 0
print("\n开始训练...")
print("-" * 70)
for epoch in range(epochs):
train_loss, train_acc = trainer.train_epoch(train_loader, epoch)
test_loss, test_acc = trainer.evaluate(test_loader)
print(f"\nEpoch {epoch+1}/{epochs} 完成")
print(f" 训练 - Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%")
print(f" 测试 - Loss: {test_loss:.4f} | Acc: {test_acc:.2f}%")
with torch.no_grad():
coupling = model.coupling.cpu().numpy()
print(f" 耦合系数: [{coupling[0]:.2f}, {coupling[1]:.2f}, {coupling[2]:.2f}, {coupling[3]:.2f}, "
f"{coupling[4]:.2f}, {coupling[5]:.2f}, {coupling[6]:.2f}, {coupling[7]:.2f}]")
print(f" 耦合标准差: {coupling.std():.4f}")
if test_acc > best_acc:
best_acc = test_acc
torch.save(model.state_dict(), 'heluo_v1.0_best.pth')
print(f" ★ 新最佳模型!准确率: {best_acc:.2f}%")
print("-" * 70)
print(f"\n{'='*70}")
print(f"训练完成!最佳测试准确率: {best_acc:.2f}%")
print("=" * 70)
return model
if __name__ == '__main__':
main(epochs=30, use_tensorboard=False)
|