Added matrices

This commit is contained in:
soyouzpanda 2024-06-08 22:40:20 +02:00
parent d838756d0e
commit a930045006

112
domains/kaar.ml Normal file
View file

@ -0,0 +1,112 @@
let swap arr i j =
let tmp = arr.(i) in
arr.(i) <- arr.(j);
arr.(j) <- tmp
module Matrix : sig
type t
val init : int -> int -> (int -> int -> Q.t) -> t
val copy : t -> t
val zero : int -> int -> t
val neg : t -> t
val add : t -> t -> t
val sub : t -> t -> t
val mul : t -> t -> t
val transpose : t -> t
val gauss : t * t -> t * t
val print : Format.formatter -> t -> unit
end = struct
type t = { height : int; width : int; data : Q.t array array }
exception Incorrect_matrix_size
let init n m f =
{
height = n;
width = m;
data = Array.init n (fun i -> Array.init m (fun j -> f i j));
}
let copy m = init m.height m.width (fun i j -> m.data.(i).(j))
let zero n m = init n m (fun _ _ -> Q.zero)
let neg mat = init mat.width mat.height (fun i j -> Q.neg mat.data.(i).(j))
let add mat1 mat2 =
if mat1.width <> mat2.width || mat1.height <> mat2.height then
raise Incorrect_matrix_size
else
init mat1.height mat1.width (fun i j ->
Q.add mat1.data.(i).(j) mat2.data.(i).(j))
let sub mat1 mat2 = add mat1 (neg mat2)
let mul mat1 mat2 =
if mat1.width <> mat2.height then raise Incorrect_matrix_size
else
init mat1.height mat2.width (fun i j ->
List.fold_left
(fun sum k -> Q.add sum (Q.mul mat1.data.(i).(k) mat2.data.(k).(j)))
Q.zero
(List.init mat1.width (fun x -> x)))
let transpose mat = init mat.width mat.height (fun i j -> mat.data.(j).(i))
let gauss (m, c) =
let m' = copy m in
let c' = copy c in
let pivot line column =
let rec search_pivot l =
if l >= m'.height then None
else if not (Q.equal m'.data.(l).(column) Q.zero) then Some l
else search_pivot (l + 1)
in
let sub_line l sl coef =
for i = 0 to m'.width - 1 do
m'.data.(l).(i) <- Q.sub m'.data.(l).(i) (Q.mul coef m'.data.(sl).(i))
done;
for i = 0 to c'.width - 1 do
c'.data.(l).(i) <- Q.sub c'.data.(l).(i) (Q.mul coef c'.data.(sl).(i))
done
in
match search_pivot line with
| Some pline ->
swap m'.data line pline;
swap c'.data line pline;
let d = m'.data.(line).(column) in
for i = 0 to m'.width - 1 do
m'.data.(line).(i) <- Q.div m'.data.(line).(i) d
done;
for i = 0 to c'.width - 1 do
c'.data.(line).(i) <- Q.div c'.data.(line).(i) d
done;
if line <> m'.height - 1 then begin
for l = line + 1 to m'.height - 1 do
sub_line l line m'.data.(l).(column)
done;
end;
line + 1
| None -> line
in
if m'.height <> c'.height then raise Incorrect_matrix_size
else
let line = ref 0 in
for i = 0 to m'.width - 1 do
line := pivot !line i
done;
(m', c')
let print fmt mat =
Format.pp_print_string fmt "[";
Array.iter
(fun arr ->
Format.pp_print_string fmt "\n";
Array.iter
(fun elt ->
Format.pp_print_string fmt " ";
Q.pp_print fmt elt)
arr)
mat.data;
Format.pp_print_string fmt "\n]"
end