신경망 모델의 매개 변수를 업데이트 / 변경하려고 시도한 후 업데이트 된 신경망의 전달 패스를 계산 그래프에 넣었습니다 (얼마나 많은 변경 / 업데이트가 있더라도).
나는이 아이디어를 시도했지만 그것을 할 때마다 pytorch는 업데이트 된 텐서 (모델 내부)를 리프로 설정하여 그라디언트를 받고 싶은 네트워크의 그라디언트 흐름을 죽입니다. 리프 노드는 내가 원하는 방식으로 계산 그래프의 일부가 아니기 때문에 그라디언트의 흐름을 죽입니다 (정말로 리프가 아니기 때문에).
여러 가지를 시도했지만 아무것도 작동하지 않는 것 같습니다. 나는 그라디언트를 갖고 싶은 네트워크의 그라디언트를 인쇄하는 자체 포함 된 더미 코드를 만들었습니다.
import torch
import torch.nn as nn
import copy
from collections import OrderedDict
# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2
criterion = nn.CrossEntropyLoss()
#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))
hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
nb_updates = 2
for i in range(nb_updates):
print(f'i = {i}')
new_params = copy.deepcopy( loss_net.state_dict() )
## w^<t> := f(w^<t-1>,delta^<t-1>)
for (name, w) in loss_net.named_parameters():
print(f'name = {name}')
print(w.size())
hidden = updater_net(hidden).view(1)
print(hidden.size())
#delta = ((hidden**2)*w/2)
delta = w + hidden
wt = w + delta
print(wt.size())
new_params[name] = wt
#del loss_net.fc0.weight
#setattr(loss_net.fc0, 'weight', nn.Parameter( wt ))
#setattr(loss_net.fc0, 'weight', wt)
#loss_net.fc0.weight = wt
#loss_net.fc0.weight = nn.Parameter( wt )
##
loss_net.load_state_dict(new_params)
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}')
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')
누구 든지이 작업을 수행하는 방법을 알고 있다면 ping을주십시오 ... 업데이트 작업은 계산 그래프에 임의의 횟수로 있어야하기 때문에 업데이트 횟수를 2로 설정했습니다 ... 그래야합니다. 2.
관련성이 높은 게시물 :
- 그래서 : pytorch 모델의 매개 변수는 어떻게 잎이 아니고 계산 그래프에 있습니까?
- 파이 토치 포럼 : https://discuss.pytorch.org/t/how-does-one-have-the-parameters-of-a-model-not-be-leafs/70076
교차 게시 :
backward
? 즉retain_graph=True
및 / 또는create_graph=True
?