Python 使用pytorch进行句子分类的多类(使用nn.LSTM)

Python 使用pytorch进行句子分类的多类(使用nn.LSTM),python,machine-learning,neural-network,pytorch,torch,Python,Machine Learning,Neural Network,Pytorch,Torch,我有这个网络,我从教程中得到的,我想把句子作为输入(已经完成了),结果是一个单线张量 从教程中,这句话“约翰的狗喜欢吃东西”,得到一个1列张量返回: tensor([[-3.0462, -4.0106, -0.6096], [-4.8205, -0.0286, -3.9045], [-3.7876, -4.1355, -0.0394], [-0.0185, -4.7874, -4.6013]]) …和类列表: tag_list[ “name”, “verb”, “noun”] 每一行都有一个

我有这个网络,我从教程中得到的,我想把句子作为输入(已经完成了),结果是一个单线张量

从教程中,这句话“约翰的狗喜欢吃东西”,得到一个1列张量返回:

tensor([[-3.0462, -4.0106, -0.6096],
[-4.8205, -0.0286, -3.9045],
[-3.7876, -4.1355, -0.0394],
[-0.0185, -4.7874, -4.6013]])
…和类列表:

tag_list[ “name”, “verb”, “noun”]
每一行都有一个标记与单词关联的可能性。(第一个单词有[-3.0462,-4.0106,-0.6096]向量,其中最后一个元素对应于最大评分标记“名词”)

本教程的数据集如下所示:

training_data = [
    ("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
    ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])
]
我希望我的格式是这样的:

training_data = [
    ("Hello world".split(), ["ONE"]),
    ("I am dog".split(), ["TWO"]),
    ("It's Britney glitch".split(), ["THREE"])
]
参数定义为:

class LSTMTagger(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

    def forward(self, sentence):
        embeds      = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        tag_space   = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_scores  = F.log_softmax(tag_space, dim=1)
        return tag_scores
到目前为止,输入和输出的大小不匹配,我得到: ValueError:预期输入批次大小(2)与目标批次大小(1)匹配

由于大小不匹配,Criteria函数不接受输入。看起来:

loss        = criterion(tag_scores, targets)
我已经读到最后一层可以定义为nn.Linear,以压缩输出,但我似乎无法得到任何结果。尝试其他损失函数


我如何更改它,使模型能够像原始教程那样对句子而不是每个单词进行分类?

我通过简单地获取最后一个单词的隐藏状态,解决了这个问题

tag_space   = self.hidden2tag(lstm_out[-1])