본문 바로가기

카테고리 없음

torch.sum() 과 torch.sum().item() 의 차이를 알아보자

https://docs.pytorch.org/docs/stable/generated/torch.sum.html

 

torch.sum — PyTorch 2.9 documentation

torch.sum torch.sum(input, *, dtype=None) → Tensor Returns the sum of all elements in the input tensor. Parameters input (Tensor) – the input tensor. Keyword Arguments dtype (torch.dtype, optional) – the desired data type of returned tensor. If speci

docs.pytorch.org

 

torch.sum() 의 반환은 tensor로 반환됨 

torch.sum().item()은 해당 scalar로 반환됨

 

import torch

x = torch.tensor([1,2,3,4])

print(torch.sum(x)) 
print(type(torch.sum(x)) # <class 'torch.Tensor'>
print('----------------------')

print(torch.sum(x).item())
print(type(torch.sum(x).item())) # <class 'int'>