如下代碼可讓你更加熟悉seq2seq模型機制網絡
""" test """ import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.autograd import Variable # 建立字典 seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']] char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz'] num_dict = {n:i for i,n in enumerate(char_arr)} # 網絡參數 n_step = 5 n_hidden = 128 n_class = len(num_dict) batch_size = len(seq_data) # 準備數據 def make_batch(seq_data): input_batch, output_batch, target_batch =[], [], [] for seq in seq_data: for i in range(2): seq[i] = seq[i] + 'P' * (n_step-len(seq[i])) input = [num_dict[n] for n in seq[0]] ouput = [num_dict[n] for n in ('S'+ seq[1])] target = [num_dict[n] for n in (seq[1]) + 'E'] input_batch.append(np.eye(n_class)[input]) output_batch.append(np.eye(n_class)[ouput]) target_batch.append(target) return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch)) input_batch, output_batch, target_batch = make_batch(seq_data) # 建立網絡 class Seq2Seq(nn.Module): """ 要點: 1.該網絡包含一個encoder和一個decoder,使用的RNN的結構相同,最後使用全鏈接接預測結果 2.RNN網絡結構要熟知 3.seq2seq的精髓:encoder層生成的參數做爲decoder層的輸入 """ def __init__(self): super().__init__() # 此處的input_size是每個節點可接納的狀態,hidden_size是隱藏節點的維度 self.enc = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5) self.dec = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5) self.fc = nn.Linear(n_hidden, n_class) def forward(self, enc_input, enc_hidden, dec_input): # RNN要求輸入:(seq_len, batch_size, n_class),這裏須要轉置一下 enc_input = enc_input.transpose(0,1) dec_input = dec_input.transpose(0,1) _, enc_states = self.enc(enc_input, enc_hidden) outputs, _ = self.dec(dec_input, enc_states) pred = self.fc(outputs) return pred # training model = Seq2Seq() loss_fun = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) for epoch in range(5000): hidden = Variable(torch.zeros(1, batch_size, n_hidden)) optimizer.zero_grad() pred = model(input_batch, hidden, output_batch) pred = pred.transpose(0, 1) loss = 0 for i in range(len(seq_data)): temp = pred[i] tar = target_batch[i] loss += loss_fun(pred[i], target_batch[i]) if (epoch + 1) % 1000 == 0: print('Epoch: %d Cost: %f' % (epoch + 1, loss)) loss.backward() optimizer.step() # 測試 def translate(word): input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]]) # hidden 形狀 (1, 1, n_class) hidden = Variable(torch.zeros(1, 1, n_hidden)) # output 形狀(6,1, n_class) output = model(input_batch, hidden, output_batch) predict = output.data.max(2, keepdim=True)[1] decoded = [char_arr[i] for i in predict] end = decoded.index('E') translated = ''.join(decoded[:end]) return translated.replace('P', '') print('girl ->', translate('girl'))
參考:https://blog.csdn.net/weixin_43632501/article/details/98525673app