我尝试使用以下代码从四元数计算欧拉角:
def compute_euler_angles_from_quaternion(quaternions, sequence='xyz'):
"""
Convert a batch of quaternions to Euler angles.
Args:
quaternions: Tensor of shape (batch_size, 4). Batch of quaternions.
sequence: String specifying the rotation sequence. Default is 'xyz'.
Returns:
euler_angles: Tensor of shape (batch_size, 3). Batch of Euler angles.
"""
batch_size = quaternions.shape[0]
q = quaternions.detach().cpu().numpy() # Convert to NumPy array
rotations = Rotation.from_quat(q)
euler_angles = rotations.as_euler(sequence, degrees=False)
euler_angles = torch.tensor(euler_angles, device=quaternions.device)
euler_angles = euler_angles.view(batch_size, 3)
return euler_angles
字符串
但我有个问题
RuntimeError: shape '[4, 3]' is invalid for input of size 3
型
我尝试使用以下代码修复它,因为我找到了一些解决方案here:
def compute_euler_angles_from_quaternion(quaternions, sequence='xyz'):
batch_size = quaternions.shape[0] # Ensure that batch_size is correctly calculated
print("Batch size:", batch_size)
q = quaternions.detach().cpu().numpy() # Convert to NumPy array
rotations = Rotation.from_quat(q)
euler_angles = rotations.as_euler(sequence, degrees=False)
print("Shape of euler_angles:", euler_angles.shape)
if batch_size == 1:
euler_angles = euler_angles.reshape(1, -1) # Reshape single quaternion to (1, 3)
euler_angles = torch.tensor(euler_angles, device=quaternions.device)
assert batch_size == euler_angles.shape[0], "Mismatch in batch_size and euler_angles shape"
euler_angles = euler_angles.view(batch_size, 3) # Reshape to [batch_size, 3]
return euler_angles
型
当我输入一些批量大小时,例如1或2或16,我遇到了这个问题:
Batch size: 4
Shape of euler_angles: (3,)
Traceback (most recent call last):
File "test_quat.py", line 147, in <module>
euler = utils.compute_euler_angles_from_quaternion(
File "/home/redhwan/2/HPE/quat/utils.py", line 318, in compute_euler_angles_from_quaternion
assert batch_size == euler_angles.shape[0], "Mismatch in batch_size and euler_angles shape"
AssertionError: Mismatch in batch_size and euler_angles shape
型
1条答案
按热度按时间t9aqgxwy1#
由于代码是为期望一个批处理维度而编写的,因此请确保您有一个批处理维度,或者如果您永远不会有批处理维度,请删除
batch_size = quaternions.shape[0]
和euler_angles = euler_angles.view(batch_size, 3)
行。如果要保留批维度,可以执行以下操作
字符串
您也可以在将
quaternions
传递给函数之前执行unsqueeze
,这样就不需要检查和unsqueeze
。