Files
FSI.BT.IR.Tools/Kalk/Kalk.Core/Modules/Vectors/KalkMatrix.cs
Maier Stephan SI b684704bf8 Sicherung
2023-01-20 16:09:00 +01:00

1201 lines
45 KiB
C#

using System;
using System.Collections;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using MathNet.Numerics.LinearAlgebra;
using Scriban;
using Scriban.Functions;
using Scriban.Helpers;
using Scriban.Parsing;
using Scriban.Runtime;
using Scriban.Syntax;
namespace Kalk.Core
{
[ScriptTypeName("matrix")]
public abstract class KalkMatrix : KalkValue, IKalkDisplayable
{
protected KalkMatrix(int row, int column)
{
if (row <= 0) throw new ArgumentOutOfRangeException(nameof(row));
if (column <= 0) throw new ArgumentOutOfRangeException(nameof(column));
RowCount = row;
ColumnCount = column;
}
public int RowCount { get; }
public int ColumnCount { get; }
protected abstract KalkMatrix Transpose();
protected abstract KalkMatrix Identity();
protected abstract KalkVector Diagonal();
public abstract KalkVector GetColumn(int index);
public abstract KalkVector GetRow(int index);
protected abstract object GenericDeterminant();
protected abstract KalkMatrix GenericInverse();
public abstract Array Values { get; }
public bool IsSquare => RowCount == ColumnCount;
protected abstract KalkVector GenericMultiplyLeft(KalkVector x);
protected abstract KalkVector GenericMultiplyRight(KalkVector y);
protected abstract KalkMatrix GenericMultiply(KalkMatrix y);
protected static void AssertSquareMatrix(KalkMatrix m)
{
if (m == null) throw new ArgumentNullException(nameof(m));
if (!m.IsSquare) throw new ArgumentException($"Matrix must be square nxn instead of {m.TypeName}", nameof(m));
}
public static KalkMatrix Transpose(KalkMatrix m)
{
if (m == null) throw new ArgumentNullException(nameof(m));
return m.Transpose();
}
public static KalkMatrix Identity(KalkMatrix m)
{
AssertSquareMatrix(m);
return m.Identity();
}
public static KalkVector Diagonal(KalkMatrix x)
{
AssertSquareMatrix(x);
return x.Diagonal();
}
public static object Determinant(KalkMatrix m)
{
AssertSquareMatrix(m);
return m.GenericDeterminant();
}
public static KalkMatrix Inverse(KalkMatrix m)
{
AssertSquareMatrix(m);
return m.GenericInverse();
}
public static object Multiply(KalkVector x, KalkMatrix y)
{
if (x == null) throw new ArgumentNullException(nameof(x));
if (y == null) throw new ArgumentNullException(nameof(y));
CheckElementType(x, y);
if (x.Length != y.RowCount)
throw new ArgumentException($"Invalid size between the vector type length {x.Length} and the matrix row count {y.RowCount}. They Must be equal.", nameof(x));
return y.GenericMultiplyLeft(x);
}
public static object Multiply(KalkMatrix x, KalkVector y)
{
if (x == null) throw new ArgumentNullException(nameof(x));
if (y == null) throw new ArgumentNullException(nameof(y));
CheckElementType(x, y);
if (x.ColumnCount != y.Length)
throw new ArgumentException($"Invalid size between the vector type length {y.Length} and the matrix column count {x.ColumnCount}. They Must be equal.", nameof(x));
return x.GenericMultiplyRight(y);
}
public static object Multiply(KalkMatrix x, KalkMatrix y)
{
if (x == null) throw new ArgumentNullException(nameof(x));
if (y == null) throw new ArgumentNullException(nameof(y));
CheckElementType(x, y);
if (x.ColumnCount != y.RowCount)
throw new ArgumentException($"Invalid size not between the matrix x {x.TypeName} with a column count {x.ColumnCount} and the matrix y {y.TypeName} with a row count of {y.RowCount}. They Must be equal.", nameof(x));
return x.GenericMultiply(y);
}
private static void CheckElementType(KalkValue x, KalkValue y)
{
if (x == null) throw new ArgumentNullException(nameof(x));
if (y == null) throw new ArgumentNullException(nameof(y));
if (x.ElementType != y.ElementType)
{
throw new ArgumentException($"Unsupported type for matrix multiplication. The combination of {x.GetType().ScriptPrettyName()} * {y.GetType().ScriptPrettyName()} is not supported.", nameof(x));
}
}
public abstract void Display(KalkEngine engine, KalkDisplayMode mode);
}
public class KalkMatrix<T> : KalkMatrix, IFormattable, IList, IScriptCustomType where T: unmanaged
{
private readonly T[] _values;
public KalkMatrix(int row, int column) : base(row, column)
{
_values = new T[row * column];
VerifyElementType();
}
private KalkMatrix(int row, int column, T[] values) : base(row, column)
{
if (values == null) throw new ArgumentNullException(nameof(values));
if (values.Length != row * column) throw new ArgumentException($"Matrix values must have {row*column} elements instead of {values.Length}");
_values = values;
VerifyElementType();
}
public KalkMatrix(KalkMatrix<T> values) : base(values.RowCount, values.ColumnCount)
{
_values = (T[])values._values.Clone();
VerifyElementType();
}
private static void VerifyElementType()
{
if (typeof(T) == typeof(KalkBool) || typeof(T) == typeof(KalkHalf) || typeof(T) == typeof(int) || typeof(T) == typeof(float) || typeof(T) == typeof(double))
{
return;
}
throw new InvalidOperationException($"The type {typeof(T)} is not supported for matrices. Only bool, int, float, half and double are supported.");
}
public override Type ElementType => typeof(T);
public override Array Values => _values;
public T this[int index]
{
get => _values[index];
set => _values[index] = value;
}
public override KalkVector GetRow(int index)
{
if ((uint)index >= (uint)RowCount) throw new ArgumentOutOfRangeException(nameof(index), $"The row index {index} is out of range [0, {RowCount-1}]");
var row = new KalkVector<T>(ColumnCount);
for (int i = 0; i < ColumnCount; i++)
{
row[i] = _values[ColumnCount * index + i];
}
return row;
}
public void SetRow(int index, KalkVector<T> value)
{
if ((uint)index >= (uint)RowCount) throw new ArgumentOutOfRangeException(nameof(index));
if (value == null) throw new ArgumentNullException(nameof(value));
if (value.Length != ColumnCount) throw new ArgumentOutOfRangeException(nameof(value), $"Invalid vector size. The vector has a length of {value.Length} while the row of this matrix is expecting {ColumnCount} elements.");
for (int i = 0; i < ColumnCount; i++)
{
_values[ColumnCount * index + i] = value[i];
}
}
public override KalkVector GetColumn(int index)
{
if ((uint)index >= (uint)ColumnCount) throw new ArgumentOutOfRangeException(nameof(index), $"The column index {index} is out of range [0, {ColumnCount-1}]");
var column = new KalkVector<T>(RowCount);
for (int i = 0; i < RowCount; i++)
{
column[i] = _values[ColumnCount * i + index];
}
return column;
}
public void SetColumn(int index, KalkVector<T> value)
{
if ((uint)index >= (uint)RowCount) throw new ArgumentOutOfRangeException(nameof(index));
if (value == null) throw new ArgumentNullException(nameof(value));
if (value.Length != RowCount) throw new ArgumentOutOfRangeException(nameof(value), $"Invalid vector size. The vector has a length of {value.Length} while the column of this matrix is expecting {RowCount} elements.");
for (int i = 0; i < RowCount; i++)
{
_values[ColumnCount * i + index] = value[i];
}
}
public KalkMatrix<T> Clone()
{
return new KalkMatrix<T>(this);
}
public override IScriptObject Clone(bool deep)
{
return Clone();
}
public override string TypeName
{
get
{
if (RowCount > 4 || ColumnCount > 4)
{
return $"matrix({ElementTypeName}, {RowCount.ToString(CultureInfo.InvariantCulture)}, {ColumnCount.ToString(CultureInfo.InvariantCulture)})";
}
else
{
return $"{ElementTypeName}{RowCount.ToString(CultureInfo.InvariantCulture)}x{ColumnCount.ToString(CultureInfo.InvariantCulture)}";
}
}
}
private string ElementTypeName => typeof(T) == typeof(KalkBool) ? "bool" : typeof(T).ScriptPrettyName();
public object Transform(Func<T, T> apply)
{
var newValue = Clone();
for (int i = 0; i < _values.Length; i++)
{
var result = apply(_values[i]);
newValue._values[i] = (T)result;
}
return newValue;
}
public object TransformToBool(Func<T, KalkBool> apply)
{
var newValue = new KalkMatrix<KalkBool>(RowCount, ColumnCount);
for (int i = 0; i < _values.Length; i++)
{
var result = apply(_values[i]);
newValue._values[i] = result;
}
return newValue;
}
public override bool CanTransform(Type transformType)
{
return typeof(T) == typeof(int) && (typeof(long) == transformType || typeof(int) == transformType) ||
(typeof(T) == typeof(float) && (typeof(double) == transformType || typeof(float) == transformType) ||
typeof(T) == typeof(double) && typeof(double) == transformType) ||
(typeof(T) == typeof(KalkHalf) && (typeof(double) == transformType || typeof(float) == transformType || typeof(KalkHalf) == transformType));
}
public override Span<byte> AsSpan() => MemoryMarshal.AsBytes(new Span<T>(_values));
public override bool Visit(TemplateContext context, SourceSpan span, Func<object, bool> visit)
{
for (int i = 0; i < _values.Length; i++)
{
if (!visit(_values[i]))
{
return false;
}
}
return true;
}
public override object Transform(TemplateContext context, SourceSpan span, Func<object, object> apply, Type destType)
{
if (destType == typeof(KalkBool))
{
return TransformToBool(value =>
{
var result = apply(value);
return context.ToObject<KalkBool>(span, result);
});
}
return this.Transform(value =>
{
var result = apply(value);
return context.ToObject<T>(span, result);
});
}
public void CopyTo(Array array, int index)
{
((ICollection) _values).CopyTo(array, index);
}
public override int Count => RowCount;
public bool IsSynchronized => ((ICollection) _values).IsSynchronized;
public object SyncRoot => ((ICollection) _values).SyncRoot;
public override IEnumerable<string> GetMembers()
{
if (_values == null) yield break;
if (RowCount > 8 || ColumnCount > 8) yield break;
for(int y = 0; y < RowCount; y++)
{
for (int x = 0; x < ColumnCount; x++)
{
yield return string.Format(CultureInfo.InvariantCulture, "_m{0}{1}", y, x);
}
}
}
public override bool Contains(string member)
{
// At least _m00 or _00 so 3 characters
if (member.Length < 3) return false;
var emptySpan = new SourceSpan();
foreach (var cell in ForEachMemberPart(emptySpan, member, false))
{
// We say that 0 or 1 are members as well
if (cell.Item1 < 0 || cell.Item2 < 0) return false;
}
return true;
}
public int Add(object value) => throw new NotSupportedException("This type does not support this operation");
public void Clear()
{
Array.Clear(_values, 0, _values.Length);
}
public bool Contains(object value) => throw new NotSupportedException("This type does not support this operation");
public int IndexOf(object value) => throw new NotSupportedException("This type does not support this operation");
public void Insert(int index, object value) => throw new NotSupportedException("This type does not support this operation");
public void Remove(object value) => throw new NotSupportedException("This type does not support this operation");
public void RemoveAt(int index) => throw new NotSupportedException("This type does not support this operation");
public bool IsFixedSize => ((IList) _values).IsFixedSize;
public override bool IsReadOnly
{
get => false;
set => throw new InvalidOperationException("Cannot set this instance readonly");
}
object IList.this[int index]
{
get => GetRow(index);
set
{
var tValue = value as KalkVector<T>;
if (value == null)
{
tValue = new KalkVector<T>(Count);
}
else
{
if (tValue == null)
{
if (value is KalkColor color)
{
if (color is KalkColorRgb && (ColumnCount < 3 || ColumnCount > 4))
{
throw new ArgumentException($"Cannot set a rgb color for a column count {ColumnCount}. Expecting 3 or 4 columns.");
}
if (color is KalkColorRgba && (ColumnCount != 4))
{
throw new ArgumentException($"Cannot set a rgba color for a column count {ColumnCount}. Expecting 4 columns.");
}
var colorVector = color.GetFloatVector(ColumnCount);
tValue = new KalkVector<T>(ColumnCount);
for (int i = 0; i < ColumnCount; i++)
{
tValue[i] = (T) Convert.ChangeType(colorVector[i], typeof(T));
}
}
else if (value is KalkVector vector)
{
if (vector.Length != ColumnCount)
{
throw new ArgumentException($"Invalid vector length {vector.Length}. Expecting {ColumnCount} columns.");
}
tValue = new KalkVector<T>(ColumnCount);
for (int i = 0; i < ColumnCount; i++)
{
tValue[i] = (T) Convert.ChangeType(vector.GetComponent(i), typeof(T));
}
}
else if (value.GetType().IsNumber())
{
tValue = new KalkVector<T>(ColumnCount);
for (int i = 0; i < ColumnCount; i++)
{
tValue[i] = (T) Convert.ChangeType(value, typeof(T));
}
}
else
{
throw new ArgumentException($"Invalid type for setting a row {value.GetType().ScriptPrettyName()}.");
}
}
}
SetRow(index, tValue);
}
}
public override bool TryGetValue(TemplateContext context, SourceSpan span, string member, out object result)
{
result = null;
if (_values == null) return false;
if (member.Length < 1) return false;
List<T> list = null;
foreach (var cell in ForEachMemberPart(span, member, true))
{
var value = _values[cell.Item1 * ColumnCount + cell.Item2];
if (result == null)
{
result = value;
}
else
{
if (list == null)
{
list = new List<T>() { (T)result };
}
list.Add(value);
}
}
if (list != null)
{
result = new KalkVector<T>((IReadOnlyList<T>)list);
}
return true;
}
private struct CharIterator
{
public CharIterator(string text)
{
Text = text;
Index = 0;
}
public readonly string Text;
public int Index;
public bool HasNext => Index < Text.Length;
public char NextChar()
{
if (Index >= Text.Length) return (char)0;
return Text[Index++];
}
}
private IEnumerable<(int, int)> ForEachMemberPart(SourceSpan span, string member, bool throwIfInvalid)
{
var it = new CharIterator(member);
// -1 None
// 0 Zero based
// 1 One based
int kindOfMember = -1;
string error = null;
while (it.HasNext)
{
var c = it.NextChar();
if (c == '_')
{
c = it.NextChar();
bool oneBased = true;
if (c == 'm')
{
oneBased = false;
c = it.NextChar();
}
if (kindOfMember < 0)
{
kindOfMember = oneBased ? 1 : 0;
}
else
{
var newKindOfMember = oneBased ? 1 : 0;
if (newKindOfMember != kindOfMember)
{
if (throwIfInvalid)
{
throw new ScriptRuntimeException(span, $"Invalid mix of 1 based _m11 and zero based _00");
}
else
{
yield return (-1, -1);
yield break;
}
}
}
// parse _00
if (CharHelper.TryParseDigit(c, out var rowIndex))
{
if (oneBased) rowIndex--;
c = it.NextChar();
if (CharHelper.TryParseDigit(c, out var columnIndex))
{
if (oneBased) columnIndex--;
if (rowIndex >= RowCount)
{
if (throwIfInvalid)
{
error = $"Invalid row index {rowIndex} >= {RowCount} for member {member}.";
}
}
else if (columnIndex >= ColumnCount)
{
if (throwIfInvalid)
{
error = $"Invalid column index {columnIndex} >= {ColumnCount} for member {member}.";
}
}
else
{
yield return (rowIndex, columnIndex);
continue;
}
}
}
}
if (error == null && throwIfInvalid)
{
error = c == 0 ? $"Invalid end of member text found when parsing `{member}`." : $"Invalid matrix character `{StringFunctions.Escape(c.ToString())}` found.";
}
if (throwIfInvalid)
{
throw new ScriptRuntimeException(span, error);
}
yield return (-1, -1);
yield break;
}
}
public override bool CanWrite(string member)
{
return true;
}
protected override KalkVector GenericMultiplyLeft(KalkVector x)
{
return MultiplyLeft((KalkVector<T>)x);
}
protected override KalkVector GenericMultiplyRight(KalkVector y)
{
return MultiplyRight((KalkVector<T>)y);
}
protected override KalkMatrix GenericMultiply(KalkMatrix y)
{
return Multiply((KalkMatrix<T>)y);
}
protected override object GenericDeterminant()
{
return Determinant();
}
protected override KalkMatrix GenericInverse()
{
return Inverse();
}
protected KalkVector<T> MultiplyLeft(KalkVector<T> x)
{
if (typeof(T) == typeof(float))
{
var matrix = Matrix<float>.Build.Dense(RowCount, ColumnCount, (float[])(object)_values);
var vector = (Vector<float>)x.AsMathNetVector();
var result = new KalkVector<float>(matrix.LeftMultiply(vector));
return Unsafe.As<KalkVector<float>, KalkVector<T>>(ref result);
}
if (typeof(T) == typeof(KalkHalf))
{
var matrix = Matrix<float>.Build.Dense(RowCount, ColumnCount, KalkHalf.ToFloatValues((KalkHalf[])(object)_values));
var vector = (Vector<float>)x.AsMathNetVector();
var floatResult = matrix.LeftMultiply(vector);
var result = new KalkVector<KalkHalf>(floatResult.Count);
for (int i = 0; i < result.Length; i++)
{
result[i] = (KalkHalf)floatResult[i];
}
return Unsafe.As<KalkVector<KalkHalf>, KalkVector<T>>(ref result);
}
if (typeof(T) == typeof(double))
{
var matrix = Matrix<double>.Build.Dense(RowCount, ColumnCount, (double[])(object)_values);
var vector = (Vector<double>)x.AsMathNetVector();
var result = new KalkVector<double>(matrix.LeftMultiply(vector));
return Unsafe.As<KalkVector<double>, KalkVector<T>>(ref result);
}
throw new ArgumentException($"The type {ElementType.ScriptPrettyName()} is not supported for this mul operation", nameof(x));
}
protected KalkVector<T> MultiplyRight(KalkVector<T> y)
{
if (typeof(T) == typeof(float))
{
var matrix = Matrix<float>.Build.Dense(RowCount, ColumnCount, (float[])(object)_values);
var vector = (Vector<float>)y.AsMathNetVector();
var result = new KalkVector<float>(matrix.Multiply(vector));
return Unsafe.As<KalkVector<float>, KalkVector<T>>(ref result);
}
if (typeof(T) == typeof(KalkHalf))
{
var matrix = Matrix<float>.Build.Dense(RowCount, ColumnCount, KalkHalf.ToFloatValues((KalkHalf[])(object)_values));
var vector = (Vector<float>)y.AsMathNetVector();
var floatResult = matrix.Multiply(vector);
var result = new KalkVector<KalkHalf>(floatResult.Count);
for (int i = 0; i < result.Length; i++)
{
result[i] = (KalkHalf)floatResult[i];
}
return Unsafe.As<KalkVector<KalkHalf>, KalkVector<T>>(ref result);
}
if (typeof(T) == typeof(double))
{
var matrix = Matrix<double>.Build.Dense(RowCount, ColumnCount, (double[])(object)_values);
var vector = (Vector<double>)y.AsMathNetVector();
var result = new KalkVector<double>(matrix.Multiply(vector));
return Unsafe.As<KalkVector<double>, KalkVector<T>>(ref result);
}
throw new ArgumentException($"The type {ElementType.ScriptPrettyName()} is not supported for this mul operation", nameof(y));
}
protected KalkMatrix Multiply(KalkMatrix<T> y)
{
if (typeof(T) == typeof(float))
{
var mx = Matrix<float>.Build.Dense(RowCount, ColumnCount, (float[])(object)_values);
var my = Matrix<float>.Build.Dense(y.RowCount, y.ColumnCount, (float[])(object)y._values);
var result = mx.Multiply(my);
var kalkResult = new KalkMatrix<float>(result.RowCount, result.ColumnCount, result.Storage.AsColumnMajorArray());
return kalkResult;
}
if (typeof(T) == typeof(KalkHalf))
{
var mx = Matrix<float>.Build.Dense(RowCount, ColumnCount, KalkHalf.ToFloatValues((KalkHalf[])(object)_values));
var my = Matrix<float>.Build.Dense(y.RowCount, y.ColumnCount, KalkHalf.ToFloatValues((KalkHalf[])(object)y._values));
var result = mx.Multiply(my);
var kalkResult = new KalkMatrix<KalkHalf>(result.RowCount, result.ColumnCount, KalkHalf.ToHalfValues(result.Storage.AsColumnMajorArray()));
return kalkResult;
}
if (typeof(T) == typeof(double))
{
var mx = Matrix<double>.Build.Dense(RowCount, ColumnCount, (double[])(object)_values);
var my = Matrix<double>.Build.Dense(y.RowCount, y.ColumnCount, (double[])(object)y._values);
var result = mx.Multiply(my);
var kalkResult = new KalkMatrix<double>(result.RowCount, result.ColumnCount, result.Storage.AsColumnMajorArray());
return kalkResult;
}
throw new ArgumentException($"The type {ElementType.ScriptPrettyName()} is not supported for this mul operation", nameof(y));
}
protected T Determinant()
{
if (typeof(T) == typeof(float))
{
var matrix = Matrix<float>.Build.Dense(RowCount, ColumnCount, (float[])(object)_values);
var value = matrix.Determinant();
return Unsafe.As<float, T>(ref value);
}
else if (typeof(T) == typeof(KalkHalf))
{
var matrix = Matrix<float>.Build.Dense(RowCount, ColumnCount, KalkHalf.ToFloatValues((KalkHalf[])(object)_values));
var value = (KalkHalf)matrix.Determinant();
return Unsafe.As<KalkHalf, T>(ref value);
}
else if (typeof(T) == typeof(double))
{
var matrix = Matrix<double>.Build.Dense(RowCount, ColumnCount, (double[])(object)_values);
var value = matrix.Determinant();
return Unsafe.As<double, T>(ref value);
}
else
{
throw new InvalidOperationException($"Determinant can only be calculated for float or double matrices but not for {TypeName}.");
}
}
protected KalkMatrix<T> Inverse()
{
if (typeof(T) == typeof(float))
{
var matrix = Matrix<float>.Build.Dense(RowCount, ColumnCount, (float[])(object)_values);
matrix = matrix.Inverse();
var result = new KalkMatrix<float>(matrix.RowCount, matrix.ColumnCount, matrix.Storage.AsColumnMajorArray());
return Unsafe.As<KalkMatrix<float>, KalkMatrix<T>>(ref result);
}
else if (typeof(T) == typeof(KalkHalf))
{
// Use float for calculation with half
var values = KalkHalf.ToFloatValues((KalkHalf[]) (object) _values);
var matrix = Matrix<float>.Build.Dense(RowCount, ColumnCount, values);
matrix = matrix.Inverse();
var result = new KalkMatrix<KalkHalf>(matrix.RowCount, matrix.ColumnCount, KalkHalf.ToHalfValues(matrix.Storage.AsColumnMajorArray()));
return Unsafe.As<KalkMatrix<KalkHalf>, KalkMatrix<T>>(ref result);
}
else if (typeof(T) == typeof(double))
{
var matrix = Matrix<double>.Build.Dense(RowCount, ColumnCount, (double[])(object)_values);
matrix = matrix.Inverse();
var result = new KalkMatrix<double>(matrix.RowCount, matrix.ColumnCount, matrix.Storage.AsColumnMajorArray());
return Unsafe.As<KalkMatrix<double>, KalkMatrix<T>>(ref result);
}
else
{
throw new InvalidOperationException("Determinant can only be calculated for float or double or half matrices.");
}
}
protected override KalkMatrix Transpose()
{
var transpose = new KalkMatrix<T>(ColumnCount, RowCount);
for(int y =0; y < RowCount; y++)
{
for(int x = 0; x < ColumnCount; x++)
{
transpose[RowCount * x + y] = this[ColumnCount * y + x];
}
}
return transpose;
}
protected override KalkMatrix Identity()
{
if (RowCount != ColumnCount) throw new InvalidOperationException($"Matrix must be square nxn instead of {TypeName}");
var transpose = new KalkMatrix<T>(ColumnCount, RowCount);
if (typeof(T) == typeof(KalkBool))
{
var m = (KalkMatrix<KalkBool>) (KalkMatrix) transpose;
int count = RowCount;
for (int i = 0; i < count; i++)
m[ColumnCount * i + i] = true;
}
else if (typeof(T) == typeof(int))
{
var m = (KalkMatrix<int>)(KalkMatrix)transpose;
int count = RowCount;
for (int i = 0; i < count; i++)
m[ColumnCount * i + i] = 1;
}
else if (typeof(T) == typeof(float))
{
var m = (KalkMatrix<float>)(KalkMatrix)transpose;
int count = RowCount;
for (int i = 0; i < count; i++)
m[ColumnCount * i + i] = 1.0f;
}
else if (typeof(T) == typeof(KalkHalf))
{
var m = (KalkMatrix<KalkHalf>)(KalkMatrix)transpose;
int count = RowCount;
for (int i = 0; i < count; i++)
m[ColumnCount * i + i] = (KalkHalf)1.0f;
}
else if (typeof(T) == typeof(double))
{
var m = (KalkMatrix<double>)(KalkMatrix)transpose;
int count = RowCount;
for (int i = 0; i < count; i++)
m[ColumnCount * i + i] = 1.0;
}
return transpose;
}
protected override KalkVector Diagonal()
{
if (RowCount != ColumnCount) throw new InvalidOperationException($"Matrix must be square nxn instead of {TypeName}");
var m = this;
int count = RowCount;
var result = new KalkVector<T>(RowCount);
for (int i = 0; i < count; i++)
result[i] = m[ColumnCount * i + i];
return result;
}
public override bool TrySetValue(TemplateContext context, SourceSpan span, string member, object value, bool readOnly)
{
// TODO: could cache this list
var list = new List<int>();
foreach (var cell in ForEachMemberPart(span, member, true))
{
int nextIndex = cell.Item1 * ColumnCount + cell.Item2;
// Otherwise check that we are not trying to override an existing element.
if (list.Contains(nextIndex))
{
throw new ScriptRuntimeException(span, $"The cell({cell.Item1}, {cell.Item2}) was already set and cannot be set more than once by a same set.");
}
list.Add(nextIndex);
}
// Handle the case where the value to set is a vector or an array
if (value is IList inputList)
{
if (inputList.Count != list.Count) throw new ScriptRuntimeException(span, $"Invalid number of components. Expecting {list.Count} but the value {value} has {inputList.Count} components.");
if (value is KalkVector<T> vector)
{
int i = 0;
foreach (var index in list)
{
this[index] = vector[i];
i++;
}
}
else
{
int i = 0;
foreach (var index in list)
{
this[index] = context.ToObject<T>(span, inputList[i]); ;
i++;
}
}
}
else
{
// otherwise it is a single value that we dispatch
var tValue = context.ToObject<T>(span, value);
foreach (var index in list)
{
this[index] = tValue;
}
}
return true;
}
public override bool Remove(string member)
{
throw new InvalidOperationException($"Cannot remove member {member} for {this.GetType()}");
}
public override void SetReadOnly(string member, bool readOnly)
{
throw new InvalidOperationException($"Cannot set readonly member {member} for {this.GetType()}");
}
public bool TryEvaluate(TemplateContext context, SourceSpan span, ScriptBinaryOperator op, SourceSpan leftSpan, object leftValue, SourceSpan rightSpan, object rightValue, out object result)
{
if (leftValue is KalkExpression leftExpression)
{
return leftExpression.TryEvaluate(context, span, op, leftSpan, leftValue, rightSpan, rightValue, out result);
}
if (rightValue is KalkExpression rightExpression)
{
return rightExpression.TryEvaluate(context, span, op, leftSpan, leftValue, rightSpan, rightValue, out result);
}
result = null;
var leftMatrix = leftValue as KalkMatrix<T>;
var rightMatrix = rightValue as KalkMatrix<T>;
if (leftMatrix == null && rightMatrix == null) return false;
if (leftMatrix != null && rightMatrix != null && (leftMatrix.RowCount != rightMatrix.RowCount || leftMatrix.ColumnCount != rightMatrix.ColumnCount))
{
return false;
}
if (leftMatrix == null)
{
leftMatrix = new KalkMatrix<T>(rightMatrix.RowCount, rightMatrix.ColumnCount);
var leftComponentValue = context.ToObject<T>(span, leftValue);
for (int i = 0; i < leftMatrix._values.Length; i++)
{
leftMatrix[i] = leftComponentValue;
}
}
if (rightMatrix == null)
{
rightMatrix = new KalkMatrix<T>(leftMatrix.RowCount, leftMatrix.ColumnCount);
var rightComponentValue = context.ToObject<T>(span, rightValue);
for (int i = 0; i < rightMatrix._values.Length; i++)
{
rightMatrix[i] = rightComponentValue;
}
}
switch (op)
{
case ScriptBinaryOperator.CompareEqual:
case ScriptBinaryOperator.CompareNotEqual:
case ScriptBinaryOperator.CompareLessOrEqual:
case ScriptBinaryOperator.CompareGreaterOrEqual:
case ScriptBinaryOperator.CompareLess:
case ScriptBinaryOperator.CompareGreater:
var vbool = new KalkMatrix<KalkBool>(leftMatrix.RowCount, leftMatrix.ColumnCount);
for (int i = 0; i < vbool._values.Length; i++)
{
vbool[i] = (bool)ScriptBinaryExpression.Evaluate(context, span, op, leftMatrix._values[i], rightMatrix._values[i]);
}
result = vbool;
return true;
case ScriptBinaryOperator.Add:
case ScriptBinaryOperator.Subtract:
case ScriptBinaryOperator.Multiply:
case ScriptBinaryOperator.Divide:
case ScriptBinaryOperator.DivideRound:
case ScriptBinaryOperator.Modulus:
case ScriptBinaryOperator.ShiftLeft:
case ScriptBinaryOperator.ShiftRight:
case ScriptBinaryOperator.Power:
{
var opResult = new KalkMatrix<T>(leftMatrix.RowCount, leftMatrix.ColumnCount);
for (int i = 0; i < opResult._values.Length; i++)
{
opResult[i] = context.ToObject<T>(span, ScriptBinaryExpression.Evaluate(context, span, op, leftMatrix._values[i], rightMatrix._values[i]));
}
result = opResult;
return true;
}
}
return false;
}
public bool TryConvertTo(TemplateContext context, SourceSpan span, Type type, out object value)
{
value = null;
return false;
}
public bool Equals(KalkMatrix<T> other)
{
if (other == null) return false;
if (RowCount != other.RowCount || ColumnCount != other.ColumnCount) return false;
var values = _values;
var otherValues = other._values;
for (int i = 0; i < values.Length; i++)
{
if (values[i].Equals(otherValues[i])) return false;
}
return true;
}
public override bool Equals(object obj)
{
if (ReferenceEquals(null, obj)) return false;
if (ReferenceEquals(this, obj)) return true;
if (obj.GetType() != this.GetType()) return false;
return Equals((KalkMatrix<T>) obj);
}
public override int GetHashCode()
{
var values = _values;
int hashCode = (RowCount * 397) ^ ColumnCount;
for (int i = 0; i < values.Length; i++)
{
hashCode = (hashCode * 397) ^ values[i].GetHashCode();
}
return hashCode;
}
public IEnumerator GetEnumerator()
{
for (int i = 0; i < RowCount; i++)
{
yield return GetRow(i);
}
}
public bool TryEvaluate(TemplateContext context, SourceSpan span, ScriptUnaryOperator op, object rightValue, out object result)
{
throw new NotImplementedException();
}
public static bool operator ==(KalkMatrix<T> left, KalkMatrix<T> right)
{
return Equals(left, right);
}
public static bool operator !=(KalkMatrix<T> left, KalkMatrix<T> right)
{
return !Equals(left, right);
}
public string ToString(string format, IFormatProvider formatProvider)
{
var engine = formatProvider as KalkEngine;
var builder = new StringBuilder();
if (RowCount > 4 || ColumnCount > 4)
{
builder.Append("matrix");
builder.Append("(");
builder.Append(engine != null ? engine.GetTypeName(_values[0]) : typeof(T).ScriptPrettyName());
builder.Append(", ");
builder.Append(RowCount.ToString(CultureInfo.InvariantCulture));
builder.Append(", ");
builder.Append(ColumnCount.ToString(CultureInfo.InvariantCulture));
builder.Append(", ");
}
else
{
builder.Append(engine != null ? engine.GetTypeName(_values[0]) : typeof(T).ScriptPrettyName());
builder.Append(RowCount.ToString(CultureInfo.InvariantCulture)).Append('x').Append(ColumnCount.ToString(CultureInfo.InvariantCulture));
builder.Append('(');
}
for (int i = 0; i < _values.Length; i++)
{
if (i > 0) builder.Append(", ");
var valueToFormat = _values[i];
builder.Append(engine != null ? engine.ObjectToString(valueToFormat) : valueToFormat is IFormattable formattable ? formattable.ToString(null, formatProvider) : valueToFormat.ToString());
}
builder.Append(')');
return builder.ToString();
}
private static bool StartsWithNegative(ref T value)
{
if (typeof(T) == typeof(float))
{
return Unsafe.As<T, float>(ref value) < 0.0f;
}
else if (typeof(T) == typeof(double))
{
return Unsafe.As<T, double>(ref value) < 0.0;
}
else if (typeof(T) == typeof(int))
{
return Unsafe.As<T, int>(ref value) < 0;
}
return false;
}
public override void Display(KalkEngine engine, KalkDisplayMode mode)
{
// Don't display anything beyond 4x4
if (RowCount > 4 || ColumnCount > 4) return;
var columnCountDisplay = ColumnCount + 2;
var rowCountDisplay = RowCount + 1;
var cells = new string[rowCountDisplay * columnCountDisplay];
var cellWidth = new int[columnCountDisplay];
var columnHasNeg = new bool[ColumnCount];
// Setup a minimum width
for (int i = 1; i < cellWidth.Length - 1; i++)
{
cellWidth[i] = 10;
}
for (int x = 0; x < ColumnCount; x++)
{
for (int y = 0; y < RowCount; y++)
{
var cellValue = _values[y * ColumnCount + x];
if (StartsWithNegative(ref cellValue))
{
columnHasNeg[x] = true;
break;
}
}
}
var rowTypeName = $"{engine.GetTypeName(_values[0])}{ColumnCount.ToString(CultureInfo.InvariantCulture)}";
for (int y = 0; y < RowCount; y++)
{
cells[(y + 1) * columnCountDisplay + ColumnCount + 1] = $") # {y}";
var col0 = $"{rowTypeName}(";
cells[(y + 1) * columnCountDisplay + 0] = col0;
for (int x = 0; x < ColumnCount; x++)
{
var cellValue = _values[y * ColumnCount + x];
var text = engine.ObjectToString(cellValue);
// Put a space in place of the -, for positive number
if (columnHasNeg[x] && !StartsWithNegative(ref cellValue))
{
text = $" {text}";
}
cells[(y + 1) * columnCountDisplay + x + 1] = text;
if (text.Length > cellWidth[x + 1])
{
cellWidth[x + 1] = text.Length;
}
}
}
cells[0] = "# col";
cellWidth[0] = rowTypeName.Length + 1; // typeName(
cells[ColumnCount + 1] = " / row";
for (int x = 0; x < ColumnCount; x++)
{
cells[1 + x] = columnHasNeg[x] ? $" {x}" : $"{x}";
}
var builder = new StringBuilder();
for (int y = 0; y < RowCount + 1; y++)
{
builder.Append(" ");
WriteRow(y * columnCountDisplay, columnCountDisplay, y == 0 ? " " : ", ");
engine.WriteHighlightLine(builder.ToString());
builder.Clear();
}
void WriteRow(int index, int column, string sep)
{
for (int i = 0; i < column; i++)
{
var width = cellWidth[i];
var text = cells[index + i].PadRight(width);
builder.Append(text);
if (i >= 1 && i + 2 < column)
{
builder.Append(sep);
}
}
}
}
}
}