https://docs.pytorch.org/docs/stable/generated/torch.sum.html
torch.sum — PyTorch 2.9 documentation
torch.sum torch.sum(input, *, dtype=None) → Tensor Returns the sum of all elements in the input tensor. Parameters input (Tensor) – the input tensor. Keyword Arguments dtype (torch.dtype, optional) – the desired data type of returned tensor. If speci
docs.pytorch.org
torch.sum() 의 반환은 tensor로 반환됨
torch.sum().item()은 해당 scalar로 반환됨
import torch
x = torch.tensor([1,2,3,4])
print(torch.sum(x))
print(type(torch.sum(x)) # <class 'torch.Tensor'>
print('----------------------')
print(torch.sum(x).item())
print(type(torch.sum(x).item())) # <class 'int'>
