Pytorch 减少torch.nn.lstm中每层的节点数

Pytorch 减少torch.nn.lstm中每层的节点数,pytorch,lstm,Pytorch,Lstm,是否有一种简单的方法可以将每层中的节点数减少一倍?我在文档页面上没有看到这个选项,也许有一个类似的功能可以使用,而不是手动定义每个层 self.lstm = nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2, ) # lstm 据

是否有一种简单的方法可以将每层中的节点数减少一倍?我在文档页面上没有看到这个选项,也许有一个类似的功能可以使用,而不是手动定义每个层

    self.lstm = nn.LSTM(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        batch_first=True,
        dropout=0.2,
    )  # lstm

据我所知不是这样,但从头开始写很简单:

def _constant_scale(initial: int, factor: int) -> int:
   return initial//factor

class StackedLSTM(Module):
   def __init__(self, input_size: int, hidden_sizes: list[int], *args, **kwargs):
       super(StackedLSTM, self).__init__()
       self.layers = ModuleList([LSTM(input_size=xs, hidden_size=hs, *args, **kwargs) for xs, hs in zip([input_size] + hidden_sizes, hidden_sizes)])

   def forward(self, x: Tensor, hc: Optional[tuple[Tensor, Tensor]] = None) -> Tensor:
       for layer in self.layers:
           x, _ = layer(x, hc)
           hc = None
       return x

hidden_sizes = [_constant_scale(300, 2**i) for i in range(3)]
sltm = StackedLSTM(100, hidden_sizes)
x = torch.rand(10, 32, 100)
h = torch.rand(1, 32, 300)
c = torch.rand(1, 32, 300)
out = sltm(x, (h, c))
print(out.shape) 
# torch.Size([10, 32, 75])