model = model.cuda()
model.cuda()
model = create_a_model()
tensor = torch.zeros([2,3,10,10])
model.cuda()
tensor.cuda()
model(tensor) # 会报错
tensor = tensor.cuda()
model(tensor) # 正常运行
total_loss += loss.item()
# torch.device object used throughout this script
device = torch.device("cuda" if use_cuda else "cpu")
model = MyRNN().to(device)
# train
total_loss= 0
for input, target in train_loader:
input, target = input.to(device), target.to(device)
hidden = input.new_zeros(*h_shape) # has the same device & dtype as `input`
... # get loss and optimize
total_loss += loss.item()
# test
with torch.no_grad(): # operations inside don't track history
for input, targetin test_loader:
...
Returns a new Tensor, detached from the current graph.
The result will never require gradient.
input_B = output_A.detach()
以CrossEntropyLoss为例:
CrossEntropyLoss(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='elementwise_mean')
若 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss,即batch中每个元素对应的loss.
若 reduce = True,那么 loss 返回的是标量:
如果 size_average = True,返回 loss.mean().
如果 size_average = False,返回 loss.sum().
weight : 输入一个1D的权值向量,为各个类别的loss加权,如下公式所示:
ignore_index : 选择要忽视的目标值,使其对输入梯度不作贡献。如果 size_average = True,那么只计算不被忽视的目标的loss的均值。
reduction : 可选的参数有:‘none’ | ‘elementwise_mean’ | ‘sum’, 正如参数的字面意思,不解释。
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
load_state_dict(torch.load(weight_path), strict=False)
import numpy as np
# 判断输入数据是否存在nan
if np.any(np.isnan(input.cpu().numpy())):
print('Input data has NaN!')
# 判断损失是否为nan
if np.isnan(loss.item()):
print('Loss value is NaN!')
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
来源 | CSDN博客
免责声明:本文系网络转载,版权归原作者所有。本文所用视频、图片、文字如涉及作品版权问题,请第一时间告知,我们将根据您提供的证明材料确认版权并按国家标准支付稿酬或立即删除内容!本文内容为原作者观点,并不代表本公众号赞同其观点和对其真实性负责。