safety_net/
graph.rs

1/*!
2
3  Graph utils for the `graph` module.
4
5*/
6
7use crate::circuit::{Instantiable, Net};
8use crate::error::Error;
9#[cfg(feature = "graph")]
10use crate::netlist::Connection;
11use crate::netlist::{InputPort, NetRef, Netlist};
12#[cfg(feature = "graph")]
13use petgraph::graph::DiGraph;
14use std::cmp::Reverse;
15use std::collections::hash_map::Entry;
16use std::collections::{BinaryHeap, HashMap, HashSet};
17
18/// A common trait of analyses than can be performed on a netlist.
19/// An analysis becomes stale when the netlist is modified.
20pub trait Analysis<'a, I: Instantiable>
21where
22    Self: Sized + 'a,
23{
24    /// Construct the analysis to the current state of the netlist.
25    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error>;
26}
27
28/// A table that maps nets to the circuit nodes they drive
29pub struct FanOutTable<'a, I: Instantiable> {
30    /// A reference to the underlying netlist
31    _netlist: &'a Netlist<I>,
32    /// Maps a net to the list of nodes it drives
33    net_fan_out: HashMap<Net, Vec<NetRef<I>>>,
34    /// Maps a node to the list of nodes it drives
35    node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>>,
36    /// Contains nets which are outputs
37    is_an_output: HashSet<Net>,
38}
39
40impl<I> FanOutTable<'_, I>
41where
42    I: Instantiable,
43{
44    /// Returns an iterator to the circuit nodes that use `net`.
45    pub fn get_net_users(&self, net: &Net) -> impl Iterator<Item = NetRef<I>> {
46        self.net_fan_out
47            .get(net)
48            .into_iter()
49            .flat_map(|users| users.iter().cloned())
50    }
51
52    /// Returns an iterator to the circuit nodes that use `node`.
53    pub fn get_node_users(&self, node: &NetRef<I>) -> impl Iterator<Item = NetRef<I>> {
54        self.node_fan_out
55            .get(node)
56            .into_iter()
57            .flat_map(|users| users.iter().cloned())
58    }
59
60    /// Returns `true` if the net has any used by any cells in the circuit
61    /// This does incude nets that are only used as outputs.
62    pub fn net_has_uses(&self, net: &Net) -> bool {
63        (self.net_fan_out.contains_key(net) && !self.net_fan_out.get(net).unwrap().is_empty())
64            || self.is_an_output.contains(net)
65    }
66}
67
68impl<'a, I> Analysis<'a, I> for FanOutTable<'a, I>
69where
70    I: Instantiable,
71{
72    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
73        let mut net_fan_out: HashMap<Net, Vec<NetRef<I>>> = HashMap::new();
74        let mut node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>> = HashMap::new();
75        let mut is_an_output: HashSet<Net> = HashSet::new();
76
77        // We can only build the fanout table if netlist is mostly intact
78        if let Err(e) = netlist.verify() {
79            match e {
80                Error::NoOutputs => (),
81                _ => return Err(e),
82            }
83        }
84
85        for c in netlist.connections() {
86            if let Entry::Vacant(e) = net_fan_out.entry(c.net()) {
87                e.insert(vec![c.target().unwrap()]);
88            } else {
89                net_fan_out
90                    .get_mut(&c.net())
91                    .unwrap()
92                    .push(c.target().unwrap());
93            }
94
95            if let Entry::Vacant(e) = node_fan_out.entry(c.src().unwrap()) {
96                e.insert(vec![c.target().unwrap()]);
97            } else {
98                node_fan_out
99                    .get_mut(&c.src().unwrap())
100                    .unwrap()
101                    .push(c.target().unwrap());
102            }
103        }
104
105        for (o, n) in netlist.outputs() {
106            is_an_output.insert(o.as_net().clone());
107            is_an_output.insert(n);
108        }
109
110        Ok(FanOutTable {
111            _netlist: netlist,
112            net_fan_out,
113            node_fan_out,
114            is_an_output,
115        })
116    }
117}
118
119/// A simple example to analyze the logic levels of a netlist.
120/// This analysis checks for cycles, but it doesn't check for registers.
121/// Result of combinational depth analysis for a single net.
122#[derive(Debug, Copy, Clone, PartialEq, Eq)]
123pub enum CombDepthResult {
124    /// Signal has no driver
125    Undefined,
126    /// Signal is along a cycle
127    CombCycle,
128    /// Integer logic level
129    Depth(usize),
130}
131
132/// Computes the combinational depth of each net in a netlist.
133///
134/// Each net is classified as having a defined depth, being undefined,
135/// or participating in a combinational cycle.
136pub struct CombDepthInfo<'a, I: Instantiable> {
137    _netlist: &'a Netlist<I>,
138    /// The total distance from a sequential element
139    results: HashMap<NetRef<I>, CombDepthResult>,
140    /// The critical predecessor to the node
141    critical_par: HashMap<NetRef<I>, InputPort<I>>,
142    /// Critical endpoints to build paths from
143    critical_ends: BinaryHeap<(Reverse<usize>, NetRef<I>)>,
144    /// Max will be `None` if the entire circuit is part of a combinational cycle or has undriven elements
145    max_depth: Option<usize>,
146}
147
148impl<I> CombDepthInfo<'_, I>
149where
150    I: Instantiable,
151{
152    /// Max number of critical endpoints to keep in the heap.
153    const SIZE_HEAP: usize = 10;
154
155    /// Returns the logic level of a node in the circuit.
156    pub fn get_comb_depth(&self, node: &NetRef<I>) -> Option<CombDepthResult> {
157        self.results.get(node).copied()
158    }
159
160    /// Returns the critical input port
161    pub fn get_crit_input(&self, node: &NetRef<I>) -> Option<&InputPort<I>> {
162        self.critical_par.get(node)
163    }
164
165    /// Returns the most critical endpoints in the circuit
166    pub fn get_critical_points(&self) -> impl IntoIterator<Item = &NetRef<I>> {
167        let mut v = self.critical_ends.iter().collect::<Vec<_>>();
168        v.sort_by_key(|(d, _)| *d);
169        v.into_iter().map(|(_, n)| n)
170    }
171
172    /// Builds the most critical path
173    pub fn build_critical_path(&self) -> Option<Vec<NetRef<I>>> {
174        let mut path = Vec::new();
175        let mut current = self.get_critical_points().into_iter().next()?.clone();
176        while let Some(crit) = self.critical_par.get(&current) {
177            path.push(current.clone());
178            current = self
179                ._netlist
180                .get_driver(current, crit.get_input_num())
181                .unwrap();
182        }
183        path.push(current);
184        Some(path)
185    }
186
187    /// Returns the maximum logic level of the circuit.
188    pub fn get_max_depth(&self) -> Option<usize> {
189        self.max_depth
190    }
191}
192
193impl<'a, I> Analysis<'a, I> for CombDepthInfo<'a, I>
194where
195    I: Instantiable,
196{
197    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
198        let mut results: HashMap<NetRef<I>, CombDepthResult> = HashMap::new();
199        let mut critical_par: HashMap<NetRef<I>, InputPort<I>> = HashMap::new();
200        let mut critical_ends: BinaryHeap<(_, NetRef<I>)> = BinaryHeap::new();
201        let mut visiting: HashSet<NetRef<I>> = HashSet::new();
202        let mut max_depth: Option<usize> = None;
203
204        fn compute<I: Instantiable>(
205            node: NetRef<I>,
206            netlist: &Netlist<I>,
207            results: &mut HashMap<NetRef<I>, CombDepthResult>,
208            critical_par: &mut HashMap<NetRef<I>, InputPort<I>>,
209            critical_ends: &mut BinaryHeap<(Reverse<usize>, NetRef<I>)>,
210            visiting: &mut HashSet<NetRef<I>>,
211        ) -> CombDepthResult {
212            // Memoized result
213            if let Some(&r) = results.get(&node) {
214                return r;
215            }
216
217            // Cycle detection
218            if visiting.contains(&node) {
219                for n in visiting.iter() {
220                    results.insert(n.clone(), CombDepthResult::CombCycle);
221                }
222                return CombDepthResult::CombCycle;
223            }
224
225            // Input nodes and reg have depth 0
226            if node.is_an_input() || node.get_instance_type().is_some_and(|inst| inst.is_seq()) {
227                let r = CombDepthResult::Depth(0);
228                results.insert(node.clone(), r);
229                return r;
230            }
231
232            visiting.insert(node.clone());
233
234            let mut max_depth = 0;
235            let mut crit: Option<InputPort<I>> = None;
236            let mut is_undefined = false;
237
238            for i in 0..node.get_num_input_ports() {
239                let driver = match netlist.get_driver(node.clone(), i) {
240                    Some(d) => d,
241                    None => {
242                        is_undefined = true;
243                        continue;
244                    }
245                };
246
247                if let Some(inst) = driver.get_instance_type()
248                    && inst.is_seq()
249                {
250                    continue;
251                }
252
253                match compute(
254                    driver,
255                    netlist,
256                    results,
257                    critical_par,
258                    critical_ends,
259                    visiting,
260                ) {
261                    CombDepthResult::Depth(d) => {
262                        if d > max_depth {
263                            max_depth = d;
264                            crit = Some(node.get_input(i));
265                        }
266                    }
267                    CombDepthResult::Undefined => {
268                        is_undefined = true;
269                    }
270                    CombDepthResult::CombCycle => {
271                        let r = CombDepthResult::CombCycle;
272                        results.insert(node.clone(), r);
273                        visiting.remove(&node);
274                        return r;
275                    }
276                }
277            }
278
279            visiting.remove(&node);
280            let r = if is_undefined {
281                CombDepthResult::Undefined
282            } else {
283                if let Some(crit) = crit {
284                    critical_par.insert(node.clone(), crit);
285                }
286                let d = max_depth + 1;
287                critical_ends.push((Reverse(d), node.clone()));
288                if critical_ends.len() > CombDepthInfo::<I>::SIZE_HEAP {
289                    critical_ends.pop();
290                }
291                CombDepthResult::Depth(d)
292            };
293            results.insert(node.clone(), r);
294            r
295        }
296
297        for (driven, _) in netlist.outputs() {
298            let node = driven.unwrap();
299            let r = compute(
300                node,
301                netlist,
302                &mut results,
303                &mut critical_par,
304                &mut critical_ends,
305                &mut visiting,
306            );
307
308            if let CombDepthResult::Depth(d) = r {
309                max_depth = Some(max_depth.map_or(d, |m| m.max(d)));
310            }
311        }
312
313        for node in netlist.matches(|inst| inst.is_seq()) {
314            compute(
315                node.clone(),
316                netlist,
317                &mut results,
318                &mut critical_par,
319                &mut critical_ends,
320                &mut visiting,
321            );
322            for i in 0..node.get_num_input_ports() {
323                if let Some(driver) = netlist.get_driver(node.clone(), i) {
324                    if driver.get_instance_type().is_some_and(|inst| inst.is_seq()) {
325                        continue;
326                    }
327
328                    let r = compute(
329                        driver,
330                        netlist,
331                        &mut results,
332                        &mut critical_par,
333                        &mut critical_ends,
334                        &mut visiting,
335                    );
336                    if let CombDepthResult::Depth(d) = r {
337                        max_depth = Some(max_depth.map_or(d, |m| m.max(d)));
338                    }
339                }
340            }
341        }
342
343        Ok(CombDepthInfo {
344            _netlist: netlist,
345            results,
346            critical_par,
347            critical_ends,
348            max_depth,
349        })
350    }
351}
352
353/// An enum to provide pseudo-nodes for any misc user-programmable behavior.
354#[cfg(feature = "graph")]
355#[derive(Debug, Clone)]
356pub enum Node<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
357    /// A 'real' circuit node
358    NetRef(NetRef<I>),
359    /// Any other user-programmable node
360    Pseudo(T),
361}
362
363#[cfg(feature = "graph")]
364impl<I, T> std::fmt::Display for Node<I, T>
365where
366    I: Instantiable,
367    T: Clone + std::fmt::Debug + std::fmt::Display,
368{
369    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370        match self {
371            Node::NetRef(nr) => nr.fmt(f),
372            Node::Pseudo(t) => std::fmt::Display::fmt(t, f),
373        }
374    }
375}
376
377/// An enum to provide pseudo-edges for any misc user-programmable behavior.
378#[cfg(feature = "graph")]
379#[derive(Debug, Clone)]
380pub enum Edge<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
381    /// A 'real' circuit connection
382    Connection(Connection<I>),
383    /// Any other user-programmable node
384    Pseudo(T),
385}
386
387#[cfg(feature = "graph")]
388impl<I, T> std::fmt::Display for Edge<I, T>
389where
390    I: Instantiable,
391    T: Clone + std::fmt::Debug + std::fmt::Display,
392{
393    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        match self {
395            Edge::Connection(c) => c.fmt(f),
396            Edge::Pseudo(t) => std::fmt::Display::fmt(t, f),
397        }
398    }
399}
400
401/// Returns a petgraph representation of the netlist as a directed multi-graph with type [DiGraph<Object, NetLabel>].
402#[cfg(feature = "graph")]
403pub struct MultiDiGraph<'a, I: Instantiable> {
404    _netlist: &'a Netlist<I>,
405    graph: DiGraph<Node<I, String>, Edge<I, Net>>,
406}
407
408#[cfg(feature = "graph")]
409impl<I> MultiDiGraph<'_, I>
410where
411    I: Instantiable,
412{
413    /// Return a reference to the graph constructed by this analysis
414    pub fn get_graph(&self) -> &DiGraph<Node<I, String>, Edge<I, Net>> {
415        &self.graph
416    }
417
418    /// Iterates through a [greedy feedback arc set](https://doi.org/10.1016/0020-0190(93)90079-O) for the graph.
419    pub fn greedy_feedback_arcs(&self) -> impl Iterator<Item = Connection<I>> {
420        petgraph::algo::feedback_arc_set::greedy_feedback_arc_set(&self.graph)
421            .map(|e| match e.weight() {
422                Edge::Connection(c) => c,
423                _ => unreachable!("Outputs should be sinks"),
424            })
425            .cloned()
426    }
427
428    /// Returns all the circuit nodes sorted into their strongly connected components.
429    pub fn sccs(&self) -> Vec<Vec<NetRef<I>>> {
430        let mut res = Vec::new();
431        for scc in petgraph::algo::tarjan_scc(&self.graph) {
432            let c: Vec<NetRef<I>> = scc
433                .into_iter()
434                .filter_map(|i| match &self.graph[i] {
435                    Node::NetRef(nr) => Some(nr.clone()),
436                    _ => None,
437                })
438                .collect();
439            if !c.is_empty() {
440                res.push(c);
441            }
442        }
443        res
444    }
445}
446
447#[cfg(feature = "graph")]
448impl<'a, I> Analysis<'a, I> for MultiDiGraph<'a, I>
449where
450    I: Instantiable,
451{
452    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
453        // If we verify, we can hash by name
454        netlist.verify()?;
455        let mut mapping = HashMap::new();
456        let mut graph = DiGraph::new();
457
458        for obj in netlist.objects() {
459            let id = graph.add_node(Node::NetRef(obj.clone()));
460            mapping.insert(obj.to_string(), id);
461        }
462
463        for connection in netlist.connections() {
464            let source = connection.src().unwrap().get_obj().to_string();
465            let target = connection.target().unwrap().get_obj().to_string();
466            let s_id = mapping[&source];
467            let t_id = mapping[&target];
468            graph.add_edge(s_id, t_id, Edge::Connection(connection));
469        }
470
471        // Finally, add the output connections
472        for (o, n) in netlist.outputs() {
473            let s_id = mapping[&o.clone().unwrap().get_obj().to_string()];
474            let t_id = graph.add_node(Node::Pseudo(format!("Output({n})")));
475            graph.add_edge(s_id, t_id, Edge::Pseudo(o.as_net().clone()));
476        }
477
478        Ok(Self {
479            _netlist: netlist,
480            graph,
481        })
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use crate::{format_id, netlist::*};
489
490    fn full_adder() -> Gate {
491        Gate::new_logical_multi(
492            "FA".into(),
493            vec!["CIN".into(), "A".into(), "B".into()],
494            vec!["S".into(), "COUT".into()],
495        )
496    }
497
498    fn ripple_adder() -> GateNetlist {
499        let netlist = Netlist::new("ripple_adder".to_string());
500        let bitwidth = 4;
501
502        // Add the the inputs
503        let a = netlist.insert_input_escaped_logic_bus("a".to_string(), bitwidth);
504        let b = netlist.insert_input_escaped_logic_bus("b".to_string(), bitwidth);
505        let mut carry: DrivenNet<Gate> = netlist.insert_input("cin".into());
506
507        for (i, (a, b)) in a.into_iter().zip(b.into_iter()).enumerate() {
508            // Instantiate a full adder for each bit
509            let fa = netlist
510                .insert_gate(full_adder(), format_id!("fa_{i}"), &[carry, a, b])
511                .unwrap();
512
513            // Expose the sum
514            fa.expose_net(&fa.get_net(0)).unwrap();
515
516            carry = fa.find_output(&"COUT".into()).unwrap();
517
518            if i == bitwidth - 1 {
519                // Last full adder, expose the carry out
520                fa.get_output(1).expose_with_name("cout".into()).unwrap();
521            }
522        }
523
524        netlist.reclaim().unwrap()
525    }
526
527    #[test]
528    fn fanout_table() {
529        let netlist = ripple_adder();
530        let analysis = FanOutTable::build(&netlist);
531        assert!(analysis.is_ok());
532        let analysis = analysis.unwrap();
533        assert!(netlist.verify().is_ok());
534
535        for item in netlist.objects().filter(|o| !o.is_an_input()) {
536            // Sum bit has no users (it is a direct output)
537            assert!(
538                analysis
539                    .get_net_users(&item.find_output(&"S".into()).unwrap().as_net())
540                    .next()
541                    .is_none(),
542                "Sum bit should not have users"
543            );
544
545            assert!(
546                item.get_instance_name().is_some(),
547                "Item should have a name. Filtered inputs"
548            );
549
550            let net = item.find_output(&"COUT".into()).unwrap().as_net().clone();
551            let mut cout_users = analysis.get_net_users(&net);
552            if item.get_instance_name().unwrap().to_string() != "fa_3" {
553                assert!(cout_users.next().is_some(), "Carry bit should have users");
554            }
555
556            assert!(
557                cout_users.next().is_none(),
558                "Carry bit should have 1 or 0 user"
559            );
560        }
561    }
562}