pytorch seq2seq模型示例

如下代碼可讓你更加熟悉seq2seq模型機制網絡

"""
    test
"""
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

# 建立字典
seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]
char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
num_dict = {n:i for i,n in enumerate(char_arr)}

# 網絡參數
n_step = 5
n_hidden = 128
n_class = len(num_dict)
batch_size = len(seq_data)

# 準備數據
def make_batch(seq_data):
    input_batch, output_batch, target_batch =[], [], []

    for seq in seq_data:
        for i in range(2):
            seq[i] = seq[i] + 'P' * (n_step-len(seq[i]))
        input = [num_dict[n] for n in seq[0]]
        ouput = [num_dict[n] for n in ('S'+ seq[1])]
        target = [num_dict[n] for n in (seq[1]) + 'E']

        input_batch.append(np.eye(n_class)[input])
        output_batch.append(np.eye(n_class)[ouput])
        target_batch.append(target)

    return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch))

input_batch, output_batch, target_batch = make_batch(seq_data)


# 建立網絡
class Seq2Seq(nn.Module):
    """
    要點:
    1.該網絡包含一個encoder和一個decoder,使用的RNN的結構相同,最後使用全鏈接接預測結果
    2.RNN網絡結構要熟知
    3.seq2seq的精髓:encoder層生成的參數做爲decoder層的輸入
    """
    def __init__(self):
        super().__init__()
        # 此處的input_size是每個節點可接納的狀態,hidden_size是隱藏節點的維度
        self.enc = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        self.dec = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        self.fc = nn.Linear(n_hidden, n_class)

    def forward(self, enc_input, enc_hidden, dec_input):
        # RNN要求輸入:(seq_len, batch_size, n_class),這裏須要轉置一下
        enc_input = enc_input.transpose(0,1)
        dec_input = dec_input.transpose(0,1)
        _, enc_states = self.enc(enc_input, enc_hidden)
        outputs, _ = self.dec(dec_input, enc_states)
        pred = self.fc(outputs)

        return pred


# training
model = Seq2Seq()
loss_fun = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(5000):
    hidden = Variable(torch.zeros(1, batch_size, n_hidden))

    optimizer.zero_grad()
    pred = model(input_batch, hidden, output_batch)
    pred = pred.transpose(0, 1)
    loss = 0
    for i in range(len(seq_data)):
        temp = pred[i]
        tar = target_batch[i]
        loss +=  loss_fun(pred[i], target_batch[i])
    if (epoch + 1) % 1000 == 0:
        print('Epoch: %d   Cost: %f' % (epoch + 1, loss))
    loss.backward()
    optimizer.step()


# 測試
def translate(word):
    input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]])
    # hidden 形狀 (1, 1, n_class)
    hidden = Variable(torch.zeros(1, 1, n_hidden))
    # output 形狀(6,1, n_class)
    output = model(input_batch, hidden, output_batch)
    predict = output.data.max(2, keepdim=True)[1]
    decoded = [char_arr[i] for i in predict]
    end = decoded.index('E')
    translated = ''.join(decoded[:end])

    return translated.replace('P', '')

print('girl ->', translate('girl'))

參考:https://blog.csdn.net/weixin_43632501/article/details/98525673app

相關文章
相關標籤/搜索