Eu estava examinando este exemplo de um modelo de linguagem LSTM no github (link) . O que ele faz em geral é bastante claro para mim. Mas ainda estou lutando para entender o que a chamada contiguous()
faz, o que ocorre várias vezes no código.
Por exemplo, na linha 74/75 da entrada de código e as sequências de destino do LSTM são criadas. Os dados (armazenados em ids
) são bidimensionais, onde a primeira dimensão é o tamanho do lote.
for i in range(0, ids.size(1) - seq_length, seq_length):
# Get batch inputs and targets
inputs = Variable(ids[:, i:i+seq_length])
targets = Variable(ids[:, (i+1):(i+1)+seq_length].contiguous())
Portanto, como um exemplo simples, ao usar os tamanhos de lote 1 e seq_length
10 inputs
e se targets
parecer com isto:
inputs Variable containing:
0 1 2 3 4 5 6 7 8 9
[torch.LongTensor of size 1x10]
targets Variable containing:
1 2 3 4 5 6 7 8 9 10
[torch.LongTensor of size 1x10]
Então, em geral, minha pergunta é: o que é contiguous()
e por que preciso disso?
Além disso, não entendo por que o método é chamado para a sequência de destino e não a sequência de entrada, pois ambas as variáveis são compostas dos mesmos dados.
Como pode targets
ser não contíguo e inputs
ainda assim ser contíguo?
EDIT:
Tentei deixar de ligar contiguous()
, mas isso leva a uma mensagem de erro ao calcular a perda.
RuntimeError: invalid argument 1: input is not contiguous at .../src/torch/lib/TH/generic/THTensor.c:231
Então, obviamente, chamar contiguous()
neste exemplo é necessário.
(Para manter isso legível, evitei postar o código completo aqui, ele pode ser encontrado usando o link do GitHub acima.)
Desde já, obrigado!
tldr; to the point summary
com um resumo conciso e direto ao ponto.Respostas:
Existem poucas operações no Tensor em PyTorch que não mudam realmente o conteúdo do tensor, mas apenas como converter índices em tensor para localização de byte. Essas operações incluem:
Por exemplo: quando você chama
transpose()
, o PyTorch não gera um novo tensor com o novo layout, ele apenas modifica as metainformações no objeto Tensor para que o deslocamento e o passo sejam para a nova forma. O tensor transposto e o tensor original estão de fato compartilhando a memória!x = torch.randn(3,2) y = torch.transpose(x, 0, 1) x[0, 0] = 42 print(y[0,0]) # prints 42
É aqui que entra o conceito de contíguo . Acima
x
é contíguo, masy
não porque seu layout de memória seja diferente de um tensor de mesmo formato feito do zero. Observe que a palavra "contíguo" é um pouco enganosa porque não é que o conteúdo do tensor esteja espalhado em torno de blocos desconectados de memória. Aqui, os bytes ainda são alocados em um bloco de memória, mas a ordem dos elementos é diferente!Quando você chama
contiguous()
, ele realmente faz uma cópia do tensor, de forma que a ordem dos elementos seria a mesma como se o tensor da mesma forma fosse criado do zero.Normalmente você não precisa se preocupar com isso. Se o PyTorch espera um tensor contíguo, mas se não, você obterá
RuntimeError: input is not contiguous
e apenas adicionará uma chamada acontiguous()
.fonte
contiguous()
sozinho?permute
, que também pode retornar tensores não "contíguos".Da [documentação do pytorch] [1]:
Onde
contiguous
aqui significa não apenas contíguo na memória, mas também na mesma ordem na memória que a ordem dos índices: por exemplo, fazer uma transposição não muda os dados na memória, simplesmente muda o mapa de índices para ponteiros de memória, se você então aplicá-contiguous()
lo mudará os dados na memória de forma que o mapa dos índices para a localização da memória seja o canônico. [1]: http://pytorch.org/docs/master/tensors.htmlfonte
tensor.contiguous () criará uma cópia do tensor, e o elemento na cópia será armazenado na memória de forma contígua. A função contiguous () é normalmente necessária quando primeiro transpomos () um tensor e depois o remodelamos (visualizamos). Primeiro, vamos criar um tensor contíguo:
O stride () return (3,1) significa que: ao percorrer a primeira dimensão a cada passo (linha por linha), precisamos mover 3 passos na memória. Ao mover ao longo da segunda dimensão (coluna por coluna), precisamos mover 1 passo na memória. Isso indica que os elementos no tensor são armazenados de forma contígua.
Agora tentamos aplicar as funções de vir ao tensor:
Ok, podemos descobrir que transpose (), narrow () e fatiamento de tensor, e expand () farão com que o tensor gerado não seja contíguo. Curiosamente, repeat () e view () não o tornam descontínuo. Portanto, agora a pergunta é: o que acontece se eu usar um tensor descontíguo?
A resposta é que a função view () não pode ser aplicada a um tensor descontíguo. Provavelmente, isso ocorre porque view () requer que o tensor seja armazenado de forma contígua para que possa fazer uma remodelagem rápida na memória. por exemplo:
obteremos o erro:
Para resolver isso, basta adicionar contiguous () a um tensor descontíguo, para criar uma cópia contígua e, em seguida, aplicar view ()
fonte
Como na resposta anterior, contigous () aloca blocos de memória contíguos , será útil quando estivermos passando tensor para código de back-end c ou c ++ onde tensores são passados como ponteiros
fonte
As respostas aceitas foram tão boas, e tentei enganar o
transpose()
efeito da função. Criei as duas funções que podem verificar osamestorage()
e ocontiguous
.def samestorage(x,y): if x.storage().data_ptr()==y.storage().data_ptr(): print("same storage") else: print("different storage") def contiguous(y): if True==y.is_contiguous(): print("contiguous") else: print("non contiguous")
Eu verifiquei e obtive este resultado como uma tabela:
Você pode revisar o código do verificador abaixo, mas vamos dar um exemplo quando o tensor não é contíguo . Não podemos simplesmente chamar
view()
esse tensor, precisaríamosreshape()
disso ou também poderíamos chamar.contiguous().view()
.Além disso, há métodos que criam tensores contíguos e não contíguos no final. Existem métodos que podem operar em um mesmo armazenamento , e alguns métodos
flip()
que criarão um novo armazenamento (leia-se: clonar o tensor) antes do retorno.O código do verificador:
fonte
Pelo que entendi, esta é uma resposta mais resumida:
Na minha opinião, a palavra contíguo é um termo confuso / enganoso, pois em contextos normais significa quando a memória não está espalhada em blocos desconectados (ou seja, seu "contíguo / conectado / contínuo").
Algumas operações podem precisar dessa propriedade contígua por algum motivo (mais provavelmente, eficiência em gpu, etc.).
Observe que
.view
é outra operação que pode causar esse problema. Veja o seguinte código que corrigi simplesmente chamando contíguo (em vez do típico problema de transposição que o causa, aqui está um exemplo que causa quando um RNN não está satisfeito com sua entrada):Erro que costumava obter:
Fontes / recursos:
fonte