Estou confuso sobre o método view()
no seguinte trecho de código.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
Minha confusão está em relação à seguinte linha.
x = x.view(-1, 16*5*5)
O que a tensor.view()
função faz? Vi seu uso em muitos lugares, mas não consigo entender como ele interpreta seus parâmetros.
O que acontece se eu der valores negativos como parâmetros para a view()
função? Por exemplo, o que acontece se eu ligar tensor_variable.view(1, 1, -1)
?
Alguém pode explicar o princípio principal da view()
função com alguns exemplos?
reshape
no PyTorch ?!Vamos fazer alguns exemplos, do mais simples ao mais difícil.
O
view
método retorna um tensor com os mesmos dados que oself
tensor (o que significa que o tensor retornado tem o mesmo número de elementos), mas com uma forma diferente. Por exemplo:Supondo que esse
-1
não seja um dos parâmetros, quando você os multiplica, o resultado deve ser igual ao número de elementos no tensor. Se você fizer:,a.view(3, 3)
ele aumentará umRuntimeError
formato porque (3 x 3) é inválido para entrada com 16 elementos. Em outras palavras: 3 x 3 não é igual a 16, mas 9.Você pode usar
-1
como um dos parâmetros que você passa para a função, mas apenas uma vez. Tudo o que acontece é que o método fará as contas para você sobre como preencher essa dimensão. Por exemplo,a.view(2, -1, 4)
é equivalente aa.view(2, 2, 4)
. [16 / (2 x 4) = 2]Observe que o tensor retornado compartilha os mesmos dados . Se você fizer uma alteração na "visualização", está alterando os dados do tensor original:
Agora, para um caso de uso mais complexo. A documentação diz que cada nova dimensão de visualização deve ser um subespaço de uma dimensão original ou abranger apenas d, d + 1, ..., d + k que atendam à seguinte condição de contiguidade que, para todos os i = 0,. .., k - 1, passada [i] = passada [i + 1] x tamanho [i + 1] . Caso contrário,
contiguous()
precisa ser chamado antes que o tensor possa ser visualizado. Por exemplo:Observe que para
a_t
, passo [0]! = Passo [1] x tamanho [1] desde 24! = 2 x 3fonte
torch.Tensor.view()
Simplificando,
torch.Tensor.view()
inspirado emnumpy.ndarray.reshape()
ounumpy.reshape()
, cria uma nova visualização do tensor, desde que a nova forma seja compatível com a forma do tensor original.Vamos entender isso em detalhes usando um exemplo concreto.
Com esse tensor
t
de forma(18,)
, novas vistas podem ser criadas apenas para as seguintes formas:(1, 18)
ou equivalentemente(1, -1)
ou ou equivalentemente ou ou equivalentemente ou ou equivalentemente ou ou equivalentemente ou ou equivalentemente ou(-1, 18)
(2, 9)
(2, -1)
(-1, 9)
(3, 6)
(3, -1)
(-1, 6)
(6, 3)
(6, -1)
(-1, 3)
(9, 2)
(9, -1)
(-1, 2)
(18, 1)
(18, -1)
(-1, 1)
Como já podemos observar pelas tuplas de forma acima, a multiplicação dos elementos da tupla de forma (por exemplo
2*9
,3*6
etc.) deve sempre ser igual ao número total de elementos no tensor original (18
no nosso exemplo).Outra coisa a observar é que usamos um
-1
em um dos lugares em cada uma das tuplas de forma. Usando a-1
, estamos sendo preguiçosos ao fazer o cálculo e delegar a tarefa ao PyTorch para fazer o cálculo desse valor para a forma quando ela cria a nova exibição . Uma coisa importante a ser observada é que só podemos usar uma única-1
na tupla de forma. Os valores restantes devem ser explicitamente fornecidos por nós. Outro PyTorch irá reclamar, lançando umRuntimeError
:Portanto, com todas as formas mencionadas acima, o PyTorch sempre retornará uma nova visualização do tensor original
t
. Isso basicamente significa que apenas altera as informações de passada do tensor para cada uma das novas visualizações solicitadas.Abaixo estão alguns exemplos que ilustram como as passadas dos tensores são alteradas a cada nova vista .
Agora, veremos os avanços para as novas visualizações :
Então essa é a mágica da
view()
função. Ele apenas altera os passos do tensor (original) para cada uma das novas vistas , desde que a forma da nova vista seja compatível com a forma original.Outra coisa interessante pode observar a partir dos tuplos Strides é que o valor do elemento no 0 ª posição é igual ao valor do elemento no 1 st posição da tupla forma.
Isto é porque:
o passo
(6, 1)
diz que para ir de um elemento para o próximo elemento ao longo da 0ª dimensão, temos que pular ou dar 6 passos. (ou seja, para ir de0
para6
, alguém tem que tomar 6 passos.) Mas, para ir de um elemento para o próximo elemento no 1 st dimensão, só precisamos de apenas um passo (por exemplo, para ir a partir2
de3
).Assim, as informações de passada estão no centro de como os elementos são acessados da memória para realizar o cálculo.
torch.reshape ()
Essa função retornaria uma vista e é exatamente a mesma que usar
torch.Tensor.view()
, desde que a nova forma seja compatível com a forma do tensor original. Caso contrário, ele retornará uma cópia.No entanto, as notas de
torch.reshape()
adverte que:fonte
Eu descobri que
x.view(-1, 16 * 5 * 5)
é equivalente ax.flatten(1)
, onde o parâmetro 1 indica que o processo de nivelamento começa na 1ª dimensão (não nivelando a dimensão 'amostra') Como você pode ver, o último uso é semanticamente mais claro e fácil de usar, então eu preferirflatten()
.fonte
Você pode ler
-1
como número dinâmico de parâmetros ou "qualquer coisa". Por causa de que não pode haver apenas um parâmetro-1
emview()
.Se você perguntar,
x.view(-1,1)
isso produzirá a forma do tensor,[anything, 1]
dependendo do número de elementos emx
. Por exemplo:Saída:
fonte
weights.reshape(a, b)
retornará um novo tensor com os mesmos dados que pesos com tamanho (a, b), pois ele copia os dados para outra parte da memória.weights.resize_(a, b)
retorna o mesmo tensor com uma forma diferente. No entanto, se a nova forma resultar em menos elementos que o tensor original, alguns elementos serão removidos do tensor (mas não da memória). Se a nova forma resultar em mais elementos que o tensor original, novos elementos serão não inicializados na memória.weights.view(a, b)
retornará um novo tensor com os mesmos dados que pesos com tamanho (a, b)fonte
Gostei muito dos exemplos de @Jadiel de Armas.
Gostaria de adicionar um pequeno insight sobre como os elementos são ordenados para .view (...)
fonte
Vamos tentar entender a visualização pelos seguintes exemplos:
-1 como um valor de argumento é uma maneira fácil de calcular o valor de say x, desde que conheçamos os valores de y, z ou o inverso no caso de 3d e para 2d novamente uma maneira fácil de calcular o valor de say x, desde que conhecer valores de y ou vice-versa ..
fonte