主頁 > 知識庫 > PyTorch平方根報錯的處理方案

PyTorch平方根報錯的處理方案

熱門標簽:陜西金融外呼系統(tǒng) 哈爾濱ai外呼系統(tǒng)定制 騰訊外呼線路 激戰(zhàn)2地圖標注 廣告地圖標注app 白銀外呼系統(tǒng) 海南400電話如何申請 公司電話機器人 唐山智能外呼系統(tǒng)一般多少錢

問題描述

初步使用PyTorch進行平方根計算,通過range()創(chuàng)建一個張量,然后對其求平方根。

a = torch.tensor(list(range(9)))
b = torch.sqrt(a)

報出以下錯誤:

RuntimeError: sqrt_vml_cpu not implemented for 'Long'

原因

Long類型的數(shù)據(jù)不支持log對數(shù)運算, 為什么Tensor是Long類型? 因為創(chuàng)建List數(shù)組時默認使用的是int, 所以從List轉(zhuǎn)成torch.Tensor后, 數(shù)據(jù)類型變成了Long。

print(a.dtype)

torch.int64

解決方法

提前將數(shù)據(jù)類型指定為浮點型, 重新執(zhí)行:

b = torch.sqrt(a.to(torch.double))
print(b)

tensor([0.0000, 1.0000, 1.4142, 1.7321, 2.0000, 2.2361, 2.4495, 2.6458, 2.8284], dtype=torch.float64)

補充:pytorch10 pytorch常見運算詳解

矩陣與標量

這個是矩陣(張量)每一個元素與標量進行操作。

import torch
a = torch.tensor([1,2])
print(a+1)
>>> tensor([2, 3])

哈達瑪積

這個就是兩個相同尺寸的張量相乘,然后對應(yīng)元素的相乘就是這個哈達瑪積,也成為element wise。

a = torch.tensor([1,2])
b = torch.tensor([2,3])
print(a*b)
print(torch.mul(a,b))
>>> tensor([2, 6])
>>> tensor([2, 6])

這個torch.mul()和*是等價的。

當然,除法也是類似的:

a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.])
print(a/b)
print(torch.div(a/b))
>>> tensor([0.5000, 0.6667])
>>> tensor([0.5000, 0.6667])

我們可以發(fā)現(xiàn)的torch.div()其實就是/, 類似的:torch.add就是+,torch.sub()就是-,不過符號的運算更簡單常用。

矩陣乘法

如果我們想實現(xiàn)線性代數(shù)中的矩陣相乘怎么辦呢?

這樣的操作有三個寫法:

torch.mm()

torch.matmul()

@,這個需要記憶,不然遇到這個可能會挺蒙蔽的

a = torch.tensor([[1.],[2.]])
b = torch.tensor([2.,3.]).view(1,2)
print(torch.mm(a, b))
print(torch.matmul(a, b))
print(a @ b)

這是對二維矩陣而言的,假如參與運算的是一個多維張量,那么只有torch.matmul()可以使用。等等,多維張量怎么進行矩陣的乘法?在多維張量中,參與矩陣運算的其實只有后兩個維度,前面的維度其實就像是索引一樣,舉個例子:

a = torch.rand((1,2,64,32))
b = torch.rand((1,2,32,64))
print(torch.matmul(a, b).shape)
>>> torch.Size([1, 2, 64, 64])

a = torch.rand((3,2,64,32))
b = torch.rand((1,2,32,64))
print(torch.matmul(a, b).shape)
>>> torch.Size([3, 2, 64, 64])

這樣也是可以相乘的,因為這里涉及一個自動傳播Broadcasting機制,這個在后面會講,這里就知道,如果這種情況下,會把b的第一維度復(fù)制3次 ,然后變成和a一樣的尺寸,進行矩陣相乘。

冪與開方

print('冪運算')
a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.])
c1 = a ** b
c2 = torch.pow(a, b)
print(c1,c2)
>>> tensor([1., 8.]) tensor([1., 8.])

和上面一樣,不多說了。開方運算可以用torch.sqrt(),當然也可以用a**(0.5)。

對數(shù)運算

在上學(xué)的時候,我們知道ln是以e為底的,但是在pytorch中,并不是這樣。

pytorch中l(wèi)og是以e自然數(shù)為底數(shù)的,然后log2和log10才是以2和10為底數(shù)的運算。

import numpy as np
print('對數(shù)運算')
a = torch.tensor([2,10,np.e])
print(torch.log(a))
print(torch.log2(a))
print(torch.log10(a))
>>> tensor([0.6931, 2.3026, 1.0000])
>>> tensor([1.0000, 3.3219, 1.4427])
>>> tensor([0.3010, 1.0000, 0.4343]) 

近似值運算

.ceil() 向上取整

.floor()向下取整

.trunc()取整數(shù)

.frac()取小數(shù)

.round()四舍五入

.ceil() 向上取整.floor()向下取整.trunc()取整數(shù).frac()取小數(shù).round()四舍五入

a = torch.tensor(1.2345)
print(a.ceil())
>>>tensor(2.)
print(a.floor())
>>> tensor(1.)
print(a.trunc())
>>> tensor(1.)
print(a.frac())
>>> tensor(0.2345)
print(a.round())
>>> tensor(1.)

剪裁運算

這個是讓一個數(shù),限制在你自己設(shè)置的一個范圍內(nèi)[min,max],小于min的話就被設(shè)置為min,大于max的話就被設(shè)置為max。這個操作在一些對抗生成網(wǎng)絡(luò)中,好像是WGAN-GP,通過強行限制模型的參數(shù)的值。

a = torch.rand(5)
print(a)
print(a.clamp(0.3,0.7))

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • 解決pytorch 數(shù)據(jù)類型報錯的問題
  • Pytorch Tensor基本數(shù)學(xué)運算詳解
  • pytorch masked_fill報錯的解決

標簽:惠州 黔西 常德 上海 四川 益陽 黑龍江 鷹潭

巨人網(wǎng)絡(luò)通訊聲明:本文標題《PyTorch平方根報錯的處理方案》,本文關(guān)鍵詞  PyTorch,平方根,報,錯的,處理,;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問題,煩請?zhí)峁┫嚓P(guān)信息告之我們,我們將及時溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無關(guān)。
  • 相關(guān)文章
  • 下面列出與本文章《PyTorch平方根報錯的處理方案》相關(guān)的同類信息!
  • 本頁收集關(guān)于PyTorch平方根報錯的處理方案的相關(guān)信息資訊供網(wǎng)民參考!
  • 推薦文章