Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ch 08 练习 8.5.5 代码似乎有遗漏? #99

Open
YueZhengMeng opened this issue Jun 25, 2024 · 1 comment
Open

ch 08 练习 8.5.5 代码似乎有遗漏? #99

YueZhengMeng opened this issue Jun 25, 2024 · 1 comment

Comments

@YueZhengMeng
Copy link
Contributor

解答:
先对输入和标签进行设备(device)变换和形状(reshape)变换,再进行前向计算和反向传播,将隐状态的分离操作放在更新之前,避免了更新中对隐状态进行计算,这样无需对隐状态进行修改,即可实现了不会从计算图中分离隐状态。

但是给出的解答代码里并没有与分离梯度相关的detach_()函数

@YueZhengMeng
Copy link
Contributor Author

YueZhengMeng commented Jun 25, 2024

# 修改后
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):
    """训练网络一个迭代周期(定义见第8章)"""
    state, timer = None, d2l.Timer()
    metric = d2l.Accumulator(2)  # 训练损失之和,词元数量
    for X, Y in train_iter:
        X, Y = X.to(device), Y.to(device)
        y = Y.T.reshape(-1)
        state = net.begin_state(batch_size=X.shape[0], device=device)
        y_hat, state = net(X, state)
        l = loss(y_hat, y.long()).mean()
        if isinstance(updater, torch.optim.Optimizer):
            updater.zero_grad()
            l.backward(retain_graph=True)
            d2l.grad_clipping(net, 1)
            updater.step()
        else:
            l.backward(retain_graph=True)
            d2l.grad_clipping(net, 1)
            # 因为已经调用了mean函数
            updater(batch_size=1)
        metric.add(l * y.numel(), y.numel())
    return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()

def train_ch8(net, train_iter, vocab, lr, num_epochs, device,
              use_random_iter=False):
    """训练模型(定义见第8章)"""
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',
                            legend=['train'], xlim=[10, num_epochs])
    # 初始化
    if isinstance(net, nn.Module):
        updater = torch.optim.SGD(net.parameters(), lr)
    else:
        updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)
    predict = lambda prefix: d2l.predict_ch8(prefix, 50, net, vocab, device)
    # 训练和预测
    for epoch in range(num_epochs):
        ppl, speed = train_epoch_ch8(
            net, train_iter, loss, updater, device, use_random_iter)
        if (epoch + 1) % 10 == 0:
            print(predict('time traveller'))
            animator.add(epoch + 1, [ppl])
    print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')
    print(predict('time traveller'))
    print(predict('traveller'))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant