AStat/iterator/iterable.ml
2024-06-09 18:11:53 +02:00

118 lines
4.3 KiB
OCaml

open Cfg
open Domain
module type ITERABLE = sig
type t (*type of a node abst*)
val bottom : t
val init : int -> t
val do_compute : arc -> t (*source*) -> (arc -> unit) -> (func -> t -> t) -> t (*to accumulate*)
val accumulate : arc -> t (*source*) -> t (*old dst*) -> t*bool (*dst*)
val print : Format.formatter -> t -> unit
end
module SimpleIterable (D : DOMAIN) : ITERABLE = struct
type t = D.t
let bottom = D.bottom
let init x = D.init x
let print fmt abst = D.print fmt abst
let do_compute a src cb_fail cb_fun = match a.arc_inst with
| CFG_skip _ -> src
| CFG_assign (v, iexpr) -> D.assign src v iexpr
| CFG_guard bexpr -> D.guard src bexpr
| CFG_assert (bexpr, _) -> (let s = D.guard src (CFG_bool_unary (AST_NOT, bexpr)) in
if D.is_bottom s then (
(*Format.printf "State %a is disjoint with %a@ " D.print src Cfg_printer.print_bool_expr (rm_negations (CFG_bool_unary (AST_NOT, bexpr)));*)
src) else (
(*Format.printf "Failure of assert on %a@ " D.print s;*)
cb_fail a;
(D.guard src bexpr)))
| CFG_call f -> cb_fun f src
let accumulate a dst_old dst_toacc =
if D.subset dst_toacc dst_old then (dst_old, false) else (
let accfun = if a.arc_dst.widen_target then D.widen else D.join in
(*let str = if a.arc_dst.widen_target then "widen" else "join" in
Format.printf "@[<h 0>[%i -> %i] Got node %i state %a %s %a " a.arc_src.node_id a.arc_dst.node_id a.arc_dst.node_id
D.print dst_old
str
D.print dst_toacc;*)
let r = accfun dst_old dst_toacc in
(*Format.printf "= %a@]@ " D.print r;*) r,true)
end
module SCR = struct (*Sparse Conditional Record*)
type t = (int*bool) list (*list must be sorted!!*)
let compare v1 v2 = match v1, v2 with
| [], [] -> 0
| _, [] -> 1
| [], _ -> -1
| (i,_)::_, (i',_)::_ when i < i' -> 1 (* v1 > v2 *)
| (i,_)::_, (i',_)::_ when i > i' -> -1
| (_,w)::q, (_,w')::q' -> if (compare w w') <> 0 then compare w w' else
compare q q'
end
module SCRMap = Map.Make(SCR)
module DisjunctiveIterable (D : DOMAIN) : ITERABLE = struct
(* invariant : dans tout parcours de l'arbre, le int est croissant *)
type t = D.t SCRMap.t
let bottom = SCRMap.empty
let init x = SCRMap.singleton [] (D.init x)
let print fmt abst =
Format.fprintf fmt "@[<h 0>[";
SCRMap.iter (fun k x -> Format.fprintf fmt "(";
List.iter (fun (i, b) -> Format.fprintf fmt "%d:%b," i b) k;
Format.fprintf fmt ") %a" D.print x) abst;
Format.fprintf fmt "]@]"
let do_compute a src cb_fail cb_fun = match a.arc_inst with
| CFG_skip _ -> src
| CFG_assign (v, iexpr) -> SCRMap.map (fun d -> D.assign d v iexpr) src
| CFG_guard bexpr -> SCRMap.map (fun d -> D.guard d bexpr) src
| CFG_assert (bexpr, _) -> let b = SCRMap.fold (fun _ d acc -> match D.is_bottom (D.guard d (CFG_bool_unary(AST_NOT, bexpr))) with
| true -> acc
| false -> Some d) src None in
(match b with
| None -> src
| Some _(*d*) -> (*Format.printf "Failure of assert : cannot rule out state %a@ " D.print d;*)
cb_fail a; SCRMap.map (fun d -> D.guard d bexpr) src)
| CFG_call f -> cb_fun f src
let rec tag_key a key = match key with
| [] -> [a.arc_src.node_id, a.arc_parity]
| (ci, x)::q when ci < a.arc_src.node_id -> (ci, x)::(tag_key a q)
| (ci, _)::_ when ci > a.arc_src.node_id -> (a.arc_src.node_id, a.arc_parity)::key
| (ci, _)::q -> (ci, a.arc_parity)::q
let accumulate a dst_old dst_toacc =
(*Format.printf "[%i -> %i] accumulating...@ " a.arc_src.node_id a.arc_dst.node_id;*)
let tounion = if a.arc_src.branch_node then
let ml = SCRMap.to_list dst_toacc in
let modlist = (List.map (fun (key,d) -> (tag_key a key, d)) ml) in
SCRMap.of_list modlist
else dst_toacc in
let acctor = if a.arc_dst.widen_target then D.widen else D.join in
let flag = ref false in
let ns = SCRMap.merge (fun _ d d' -> match d,d' with
| None, None -> None
| Some d, None -> Some d (*just preserving old state*)
| None, Some d -> (flag := true; (*Format.printf "Unreached branch took !@ ";*) Some d)
| Some d, Some d' -> (if D.subset d' d then (Some d) else
(flag := true; Some (acctor d d')))) dst_old tounion in
ns, !flag
(*
If we are on the arc of a conditional, change the keys of dst_old accordingly.
Then, union the maps (with the appropriate accfun) !
*)
end