1万回のQuaternion * Vector3の処理速度を上げてみる

Quaternion * Vector3をMultiplyPoint3x4()やBurst Direct Callで計算する記事です。

UnityでQuaternion * Vector3すると、Vector3をQuaternionで回転させた座標が得られます。
通常、Unityでこの計算をする時は以下のコードになります。

// X軸に90度回転させる四元数
Quaternion testQuaternion = Quaternion.Euler(90, 0, 0);

// ベクトル
Vector3 testVector3 = new Vector3(10, 10, 10);

// X軸に90度回転した後のベクトル
Vector3 rotatedVector = testQuaternion * testVector3;

クォータニオンなら数学関数を使わずにベクトルを求められるので、処理速度が速いです。
しかし、それでも1フレームの中で何万回も実行するような作りをするとパフォーマンス悪化の原因になります。

同じベクトルを上記とは異なる方法で求めることもできます。

// X軸に90度回転させる四元数
Quaternion testQuaternion = Quaternion.Euler(90, 0, 0);

// ベクトル
Vector3 testVector3 = new Vector3(10, 10, 10);

// 変換行列を作成
Matrix4x4 testMatrix4x4 = Matrix4x4.Rotate(testQuaternion);

// X軸に90度回転した後のベクトル
Vector3 rotatedVector = testMatrix4x4.MultiplyPoint3x4(testVector3);

通常であれば前者の方が処理速度は速いです。
ですが、以下の条件に当てはまる場合であれば、Matrix4x4構造体のMultiplyPoint3x4メソッドで処理速度を向上させることができます。

  • 同じQuaternionを使用してベクトルを計算している
  • 計算回数が7500回以上

2023/06/12 バーストコンパイルに関して追記。

目次

ベンチマーク

ベンチマーク用のコードは以下です。

Quaternion * Vector3

int n = 10000; // 計算回数

Quaternion testQuaternion = new Quaternion(1, 0, 0, 1);
Vector3 testVector3 = new Vector3(1, 2, 3);

Vector3 vt1 = Vector3.zero;
var sw1 = new System.Diagnostics.Stopwatch();
sw1.Start();

for (int i = 0; i < n; i++)
{
    vt1 = testQuaternion * testVector3;
}

sw1.Stop();
UnityEngine.Debug.Log(vt1);
UnityEngine.Debug.Log(sw1.Elapsed);

vt1をログに出力しているのはコンパイラによる最適化防止です。

MultiplyPoint3x4(Vector3)

int n = 10000; // 計算回数

Quaternion testQuaternion = new Quaternion(1, 0, 0, 1);
Vector3 testVector3 = new Vector3(1, 2, 3);

Vector3 vt2 = Vector3.zero;
var sw2 = new System.Diagnostics.Stopwatch();
sw2.Start();

Matrix4x4 testMatrix4x4 = Matrix4x4.Rotate(testQuaternion);
for (int i = 0; i < n; i++)
{
    vt2 = testMatrix4x4.MultiplyPoint3x4(testVector3);
}

sw2.Stop();
UnityEngine.Debug.Log(vt2);
UnityEngine.Debug.Log(sw2.Elapsed);

vt2については同上。
Matrix4x4構造体を生成する処理も含めないと不公平なので計測に含めています。

ベンチマーク結果

計算回数Quaternion * Vector3MultiplyPoint3x4(Vector3)
n=10000.0359ms0.1265ms
n=75000.2377ms0.2374ms
n=100000.3125ms0.2777ms
n=100000031.0942ms16.7268ms
PC (Core i7-8700Bでの計測。参考程度にどうぞ)

n=1000000など実際に計算することはないでしょうが、参考用として載せました。
n=7500辺りからMultiplyPoint3x4が処理速度で上回ります。

ベンチマーク用のコードでは同じベクトルを計算させましたが、実際には異なるベクトルを求めるでしょう。
MultiplyPoint3x4()に計算させたいVector3を与えればそのクォータニオンで計算したVector3が返ってくるので心配はありません。

計算を7500回以上実行しない場合は逆にパフォーマンスが落ちるところが難点ですね。

なんで速くなるの?

Quaternion * Vector3は乗算演算子をオーバーロードして実装されており、Unityのソースコードから実際の処理を見れます。

// Rotates the point /point/ with /rotation/.
public static Vector3 operator*(Quaternion rotation, Vector3 point)
{
    float x = rotation.x * 2F;
    float y = rotation.y * 2F;
    float z = rotation.z * 2F;
    float xx = rotation.x * x;
    float yy = rotation.y * y;
    float zz = rotation.z * z;
    float xy = rotation.x * y;
    float xz = rotation.x * z;
    float yz = rotation.y * z;
    float wx = rotation.w * x;
    float wy = rotation.w * y;
    float wz = rotation.w * z;

    Vector3 res;
    res.x = (1F - (yy + zz)) * point.x + (xy - wz) * point.y + (xz + wy) * point.z;
    res.y = (xy + wz) * point.x + (1F - (xx + zz)) * point.y + (yz - wx) * point.z;
    res.z = (xz - wy) * point.x + (yz + wx) * point.y + (1F - (xx + yy)) * point.z;
    return res;
}

n=10000なら上記の処理が10000回動くわけですね。

Matrix4x4.Rotate()もほぼ同じ処理を行っていますが、大事なのは最初の1回だけ行うということです。
以降は以下の処理でベクトルを計算します。

// Transforms a position by this matrix, without a perspective divide. (fast)
public Vector3 MultiplyPoint3x4(Vector3 point)
{
    Vector3 res;
    res.x = this.m00 * point.x + this.m01 * point.y + this.m02 * point.z + this.m03;
    res.y = this.m10 * point.x + this.m11 * point.y + this.m12 * point.z + this.m13;
    res.z = this.m20 * point.x + this.m21 * point.y + this.m22 * point.z + this.m23;
    return res;
}

こちらの方が明らかに速く計算できることが分かります。
3~4回呼び出すだけでパフォーマンスが逆転しそうなものですが、構造体を作るコストの関係などで、計算回数が7500回とかなり多くなければパフォーマンスが逆転しないようです。
(コスト重くない……?)

MultiplyPoint()じゃなくていいの?

MultiplyPoint3x4()の他にMultiplyPoint()もあります。
Unityのドキュメントを見ると「MultiplyPoint()は速度が遅いが射影変換を扱うことができる」と書いてあります。

Matrix4x4.Rotate()を使って変換行列を作る場合はMultiplyPoint3x4()でOKです。

コードから紐解いてみましょう。
まずはMatrix4x4.Rotateの処理を見てみます。

public static Matrix4x4 Rotate(Quaternion q)
{
    // Precalculate coordinate products
    float x = q.x * 2.0F;
    float y = q.y * 2.0F;
    float z = q.z * 2.0F;
    float xx = q.x * x;
    float yy = q.y * y;
    float zz = q.z * z;
    float xy = q.x * y;
    float xz = q.x * z;
    float yz = q.y * z;
    float wx = q.w * x;
    float wy = q.w * y;
    float wz = q.w * z;

    // Calculate 3x3 matrix from orthonormal basis
    Matrix4x4 m;
    m.m00 = 1.0f - (yy + zz); m.m10 = xy + wz; m.m20 = xz - wy; m.m30 = 0.0F;
    m.m01 = xy - wz; m.m11 = 1.0f - (xx + zz); m.m21 = yz + wx; m.m31 = 0.0F;
    m.m02 = xz + wy; m.m12 = yz - wx; m.m22 = 1.0f - (xx + yy); m.m32 = 0.0F;
    m.m03 = 0.0F; m.m13 = 0.0F; m.m23 = 0.0F; m.m33 = 1.0F;
    return m;
}

注目すべきは19-22行目のそれぞれの最後の代入です。
次にMatrix4x4.MultiplyPointの処理を見てみます。

// Transforms a position by this matrix, with a perspective divide. (generic)
public Vector3 MultiplyPoint(Vector3 point)
{
    Vector3 res;
    float w;
    res.x = this.m00 * point.x + this.m01 * point.y + this.m02 * point.z + this.m03;
    res.y = this.m10 * point.x + this.m11 * point.y + this.m12 * point.z + this.m13;
    res.z = this.m20 * point.x + this.m21 * point.y + this.m22 * point.z + this.m23;
    w = this.m30 * point.x + this.m31 * point.y + this.m32 * point.z + this.m33;

    w = 1F / w;
    res.x *= w;
    res.y *= w;
    res.z *= w;
    return res;
}

this.m30とthis.m31とthis.m32は0、this.m33は1です。なので9行目の変数wの結果は必ず1になります。
ということは、11行目は1÷1なので変数wの結果は必ず1になります。

残りはwを掛ける計算なので、以降の計算は行わなくても戻り値が同じになることが分かります。
この計算を省いたのがMultiplyPoint3x4()です。

// Transforms a position by this matrix, without a perspective divide. (fast)
public Vector3 MultiplyPoint3x4(Vector3 point)
{
    Vector3 res;
    res.x = this.m00 * point.x + this.m01 * point.y + this.m02 * point.z + this.m03;
    res.y = this.m10 * point.x + this.m11 * point.y + this.m12 * point.z + this.m13;
    res.z = this.m20 * point.x + this.m21 * point.y + this.m22 * point.z + this.m23;
    return res;
}

つまりMultiplyPoint()を使おうがMultiplyPoint3x4()を使おうが結果は変わらない、ということです。
(あくまで今回のケースでは、ですが)

また、MultiplyPoint()に対してMultiplyPoint3x4()は2倍近く高速です。
乗算・除算の回数がほぼ半減することからも分かります。
ドキュメントでも言及されていますし、わざわざMultiplyPoint3x4()メソッドが用意されているということはこういうニーズに目を向けているのでしょうね。

ちなみにMultiplyPoint3x4()ではなくMultiplyPoint()を使った場合、n=10000の条件で計測すると0.4032msでした。
0.3125msを下回っていないので1万回同じQuaternionで計算しても元よりパフォーマンスが悪化する始末です。
素直にMultiplyPoint3x4()を使いましょう。

バーストコンパイル時の速度

Burst 1.5以降、「Burst Direct Call」という仕組みが導入されました。
Burst CompilerはC# Job Systemからのみ使用できましたが、特定の条件下に限り、通常のUnity C#コードにも使用できるというものです。

条件には「使える型に制限がある」「staticメソッドじゃないとNG」などがあります。
Burst InspectorにはBurst Compileされたメソッドが一覧されるので、それを確認しながらコーディングしましょう。
詳細は公式マニュアルを参照。

使い方はかんたん。メソッドの上に[BurstCompile]と書くだけです。
Burst Direct Callを使ったメソッドはBurst Compileされます。

爆速と噂されているBurstですが、果たして前述の結果が灰塵と化す夢のような機能なのか?

ということでBurst Direct Callを使って計測してみました。
計測環境は下記の通り。

  • Unity 2021.3.11f1
  • Burst 1.8.4 (Safety Checks: Off)
  • Mathematics 1.2.6
  • Unityエディタ上で実行。CPUはCore i7-8700B
計算回数math.mul(quaternion, float3)math.mul(float4x4, float4).xyzMultiplyPoint3x4(Vector3)
n=10000.0036ms0.0035ms0.0028ms
n=75000.0053ms0.0051ms0.0046ms
n=100000.0055ms0.0053ms0.0051ms
n=10000000.0269ms0.0263ms0.0262ms
3パターンの処理で計測

3パターンともBurst Direct Callです。
非Burstと比べるとBurst Direct Callの圧勝。10~1000倍くらい速い。いや速すぎんだろ。

念のため補足しますが、Burst内のコードではStopwatchが使えないので関数呼び出しなどのコストも含まれていて、前述の非Burst版コードよりも不利な条件で計測していることになります。
つよい。

ベンチマーク用のコードも置いておきます。

ベンチマーク用のコード

BurstパッケージとMathematicsパッケージが必須なので注意。

  • Testメソッド:math.mul(quaternion, float3)
  • Test2メソッド:math.mul(float4x4, float4).xyz
  • Test3メソッド:MultiplyPoint3x4(Vector3)
using Unity.Burst;
using Unity.Mathematics;
using UnityEngine;

[BurstCompile]
public class Test : MonoBehaviour
{
    void Start()
    {
        // 初回呼び出し
        float3 pvt1;
        float3 pvt2;
        Vector3 pvt3;
        Test(out pvt1);
        Test2(out pvt2);
        Test3(out pvt3);

        // math.mul(quaternion, float3)
        float3 vt1;
        var sw1 = new System.Diagnostics.Stopwatch();
        sw1.Start();

        Test(out vt1);

        sw1.Stop();

        UnityEngine.Debug.Log(vt1);
        UnityEngine.Debug.Log(sw1.Elapsed);

        // math.mul(float4x4, float4).xyz
        float3 vt2;
        var sw2 = new System.Diagnostics.Stopwatch();
        sw2.Start();

        Test2(out vt2);

        sw2.Stop();

        UnityEngine.Debug.Log(vt2);
        UnityEngine.Debug.Log(sw2.Elapsed);

        // MultiplyPoint3x4(Vector3)
        Vector3 vt3;
        var sw3 = new System.Diagnostics.Stopwatch();
        sw3.Start();

        Test3(out vt3);

        sw3.Stop();

        UnityEngine.Debug.Log(vt3);
        UnityEngine.Debug.Log(sw3.Elapsed);
    }

    [BurstCompile]
    static void Test(out float3 result)
    {
        int n = 1000000; // 計算回数

        quaternion testQuaternion = new quaternion(1, 0, 0, 1);
        float3 testVector3 = new float3(1, 2, 3);

        float3 vt = float3.zero;

        for (int i = 0; i < n; i++)
        {
            vt = math.mul(testQuaternion, testVector3);
        }

        result = vt;
    }

    [BurstCompile]
    static void Test2(out float3 result)
    {
        int n = 1000000; // 計算回数

        quaternion testQuaternion = new quaternion(1, 0, 0, 1);
        float3 testVector3 = new float3(1, 2, 3);

        float3 vt = float3.zero;

        float4x4 testMatrix4x4 = new float4x4(testQuaternion, float3.zero);
        float4 testFloat4 = new float4(testVector3, 1);
        for (int i = 0; i < n; i++)
        {
            vt = math.mul(testMatrix4x4, testFloat4).xyz;
        }

        result = vt;
    }

    [BurstCompile]
    static void Test3(out Vector3 result)
    {
        int n = 1000000; // 計算回数

        Quaternion testQuaternion = new Quaternion(1, 0, 0, 1);
        Vector3 testVector3 = new Vector3(1, 2, 3);

        Vector3 vt = Vector3.zero;

        Matrix4x4 testMatrix4x4 = Matrix4x4.Rotate(testQuaternion);
        for (int i = 0; i < n; i++)
        {
            vt = testMatrix4x4.MultiplyPoint3x4(testVector3);
        }

        result = vt;
    }
}

ちなみに10~16行目の行は、ベンチマーク上ではとても重要です。

        // 初回呼び出し
        float3 pvt1;
        float3 pvt2;
        Vector3 pvt3;
        Test(out pvt1);
        Test2(out pvt2);
        Test3(out pvt3);

Burst Direct Callは初回呼び出しに時間がかかります。
この初回呼び出しをコメントアウトすると、だいぶ結果が変わります。

計算回数math.mul(quaternion, float3)math.mul(float4x4, float4).xyzMultiplyPoint3x4(Vector3)
n=10000.7189ms0.6722ms0.3733ms
n=75000.7345ms0.6780ms0.3825ms
n=100000.7425ms0.6814ms0.3734ms
n=10000000.9598ms0.9015ms0.6119ms
初回呼び出しのコストが重いことが分かる

そういうことなので、最初の1フレーム目はカクつくかもしれませんね。
Burstしたいようなコードは毎フレーム実行するコードだと思いますので、特に問題にはならないでしょう。

この初回呼び出し時のコストは元々もっと高かったようで、Burst 1.7から改善の手が加わっています。

Burst 1.7 では、直接呼び出しのパフォーマンスも大幅に改善されました。直接呼び出しは Burst 1.5 で追加された機能で、マネージド C# コードが BurstCompiler.CompileFunctionPointer を介さずに Burst コンパイルされたメソッドを直接呼び出すことを許可するものです。ドメインリロード時には、直接呼び出しのメソッドを配線するための初期化作業が必要ですが、Burst 1.7 では、この初期化作業を最大 33 倍高速化することに成功しました。

Burst 1.7 でゲームの品質を向上させる | Unity Blog

※「直接呼び出し」はBurst Direct Callのことです。うーん、この直訳感

Mathematicsくん、きみ遅くない?

float3、quaternion、float4x4はMathematicsパッケージの構造体です。
それぞれVector3、Quaternion、Matrix4x4に相当し、Burst時にSIMD命令による最適化が有効になります。

なので3つ目のTest3メソッドが一番速くなることは通常ないのですが、今回の計測では最も優秀でした。
Test3メソッドはMathematicsとの比較のために用意してみただけなんですけどね。
この結果は意外です。

考察ですが、現時点でfloat4x4にはMultiplyPoint3x4に相当するメソッドがなく、回りくどい方法で計算しているので演算量が多くなってTest3>Test2になっており、Quaternion * Vector3よりもMultiplyPoint3x4()の方が速いのでTest3>Testになっているのだと思います。
「Quaternion * Vector3」のMathematics版が「math.mul(quaternion, float3)」に相当するのでBurst時は後者の方が速く演算できるものの、今回はMultiplyPoint3x4()で単純に演算量を減らした方が良かった、ということですね。

奥が深いです。
まあここまで切り詰めてパフォーマンスチューニングをしてもゲームが完成しなければ意味がない得られる効果は少ないケースがほとんどなので、Burstの恩恵を受けた時点で十分でしょう。
とはいえこの結果、どこかに需要はあるんじゃないかなぁと思いました。

  • URLをコピーしました!
目次