Added matrices
This commit is contained in:
parent
d838756d0e
commit
a930045006
1 changed files with 112 additions and 0 deletions
112
domains/kaar.ml
Normal file
112
domains/kaar.ml
Normal file
|
@ -0,0 +1,112 @@
|
|||
let swap arr i j =
|
||||
let tmp = arr.(i) in
|
||||
arr.(i) <- arr.(j);
|
||||
arr.(j) <- tmp
|
||||
|
||||
module Matrix : sig
|
||||
type t
|
||||
|
||||
val init : int -> int -> (int -> int -> Q.t) -> t
|
||||
val copy : t -> t
|
||||
val zero : int -> int -> t
|
||||
val neg : t -> t
|
||||
val add : t -> t -> t
|
||||
val sub : t -> t -> t
|
||||
val mul : t -> t -> t
|
||||
val transpose : t -> t
|
||||
val gauss : t * t -> t * t
|
||||
val print : Format.formatter -> t -> unit
|
||||
end = struct
|
||||
type t = { height : int; width : int; data : Q.t array array }
|
||||
|
||||
exception Incorrect_matrix_size
|
||||
|
||||
let init n m f =
|
||||
{
|
||||
height = n;
|
||||
width = m;
|
||||
data = Array.init n (fun i -> Array.init m (fun j -> f i j));
|
||||
}
|
||||
|
||||
let copy m = init m.height m.width (fun i j -> m.data.(i).(j))
|
||||
let zero n m = init n m (fun _ _ -> Q.zero)
|
||||
let neg mat = init mat.width mat.height (fun i j -> Q.neg mat.data.(i).(j))
|
||||
|
||||
let add mat1 mat2 =
|
||||
if mat1.width <> mat2.width || mat1.height <> mat2.height then
|
||||
raise Incorrect_matrix_size
|
||||
else
|
||||
init mat1.height mat1.width (fun i j ->
|
||||
Q.add mat1.data.(i).(j) mat2.data.(i).(j))
|
||||
|
||||
let sub mat1 mat2 = add mat1 (neg mat2)
|
||||
|
||||
let mul mat1 mat2 =
|
||||
if mat1.width <> mat2.height then raise Incorrect_matrix_size
|
||||
else
|
||||
init mat1.height mat2.width (fun i j ->
|
||||
List.fold_left
|
||||
(fun sum k -> Q.add sum (Q.mul mat1.data.(i).(k) mat2.data.(k).(j)))
|
||||
Q.zero
|
||||
(List.init mat1.width (fun x -> x)))
|
||||
|
||||
let transpose mat = init mat.width mat.height (fun i j -> mat.data.(j).(i))
|
||||
|
||||
let gauss (m, c) =
|
||||
let m' = copy m in
|
||||
let c' = copy c in
|
||||
let pivot line column =
|
||||
let rec search_pivot l =
|
||||
if l >= m'.height then None
|
||||
else if not (Q.equal m'.data.(l).(column) Q.zero) then Some l
|
||||
else search_pivot (l + 1)
|
||||
in
|
||||
let sub_line l sl coef =
|
||||
for i = 0 to m'.width - 1 do
|
||||
m'.data.(l).(i) <- Q.sub m'.data.(l).(i) (Q.mul coef m'.data.(sl).(i))
|
||||
done;
|
||||
for i = 0 to c'.width - 1 do
|
||||
c'.data.(l).(i) <- Q.sub c'.data.(l).(i) (Q.mul coef c'.data.(sl).(i))
|
||||
done
|
||||
in
|
||||
match search_pivot line with
|
||||
| Some pline ->
|
||||
swap m'.data line pline;
|
||||
swap c'.data line pline;
|
||||
let d = m'.data.(line).(column) in
|
||||
for i = 0 to m'.width - 1 do
|
||||
m'.data.(line).(i) <- Q.div m'.data.(line).(i) d
|
||||
done;
|
||||
for i = 0 to c'.width - 1 do
|
||||
c'.data.(line).(i) <- Q.div c'.data.(line).(i) d
|
||||
done;
|
||||
if line <> m'.height - 1 then begin
|
||||
for l = line + 1 to m'.height - 1 do
|
||||
sub_line l line m'.data.(l).(column)
|
||||
done;
|
||||
end;
|
||||
line + 1
|
||||
| None -> line
|
||||
in
|
||||
|
||||
if m'.height <> c'.height then raise Incorrect_matrix_size
|
||||
else
|
||||
let line = ref 0 in
|
||||
for i = 0 to m'.width - 1 do
|
||||
line := pivot !line i
|
||||
done;
|
||||
(m', c')
|
||||
|
||||
let print fmt mat =
|
||||
Format.pp_print_string fmt "[";
|
||||
Array.iter
|
||||
(fun arr ->
|
||||
Format.pp_print_string fmt "\n";
|
||||
Array.iter
|
||||
(fun elt ->
|
||||
Format.pp_print_string fmt " ";
|
||||
Q.pp_print fmt elt)
|
||||
arr)
|
||||
mat.data;
|
||||
Format.pp_print_string fmt "\n]"
|
||||
end
|
Loading…
Reference in a new issue