defnet(X): X = X.view((-1, num_inputs)) H = relu(torch.matmul(X, W1) + b1) return torch.matmul(H, W2) + b2
defsgd(params, lr, batch_size): for param in params: param.data -= lr * param.grad / batch_size # 更新时用.data 以免操作计入计算图
loss = torch.nn.CrossEntropyLoss()
defevaluate_accuracy(data_iter, net): acc_sum, n = 0.0, 0 for X, y in data_iter: acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0] return acc_sum / n
deftrain_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None): for epoch in range(num_epochs): train_l_sum, train_acc_sum, n = 0.0, 0.0, 0 for X, y in train_iter: y_hat = net(X) ## 前向得到网络输出 l = loss(y_hat, y).sum() ## 计算loss
# 梯度清零 if optimizer isnotNone: optimizer.zero_grad() elif params isnotNoneand params[0].grad isnotNone: for param in params: param.grad.data.zero_()