Skip to main content

safety_net/
graph.rs

1/*!
2
3  Graph utils for the `graph` module.
4
5*/
6
7use crate::circuit::{Instantiable, Net};
8use crate::error::Error;
9#[cfg(feature = "graph")]
10use crate::netlist::Connection;
11use crate::netlist::{DrivenNet, InputPort, NetRef, Netlist};
12#[cfg(feature = "graph")]
13use petgraph::graph::DiGraph;
14use std::cmp::Reverse;
15use std::collections::{BinaryHeap, HashMap, HashSet};
16
17/// A common trait of analyses than can be performed on a netlist.
18/// An analysis becomes stale when the netlist is modified.
19pub trait Analysis<'a, I: Instantiable>
20where
21    Self: Sized + 'a,
22{
23    /// Construct the analysis to the current state of the netlist.
24    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error>;
25}
26
27/// A table that maps nets to the circuit nodes they drive
28pub struct FanOutTable<'a, I: Instantiable> {
29    /// A reference to the underlying netlist
30    _netlist: &'a Netlist<I>,
31    /// Maps a net to the list of nodes it drives
32    net_fan_out: HashMap<Net, Vec<NetRef<I>>>,
33    /// Maps a driven net to a list of inputs it drives
34    dnet_fan_out: HashMap<DrivenNet<I>, Vec<InputPort<I>>>,
35    /// Maps a node to the list of nodes it drives
36    node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>>,
37    /// The number of references held by all the  data structures
38    ref_count: HashMap<NetRef<I>, usize>,
39    /// Contains nets which are outputs
40    is_an_output: HashSet<Net>,
41}
42
43impl<I> FanOutTable<'_, I>
44where
45    I: Instantiable,
46{
47    /// Returns an iterator to the circuit nodes that use `net`.
48    pub fn get_net_users(&self, net: &Net) -> impl Iterator<Item = NetRef<I>> {
49        self.net_fan_out
50            .get(net)
51            .into_iter()
52            .flat_map(|users| users.iter().cloned())
53    }
54
55    /// Returns an iterator to the circuit nodes that use `node`.
56    pub fn get_node_users(&self, node: &NetRef<I>) -> impl Iterator<Item = NetRef<I>> {
57        self.node_fan_out
58            .get(node)
59            .into_iter()
60            .flat_map(|users| users.iter().cloned())
61    }
62
63    /// Returns an iterator to the uses of `net`.
64    pub fn get_users(&self, net: &DrivenNet<I>) -> impl Iterator<Item = InputPort<I>> {
65        self.dnet_fan_out
66            .get(net)
67            .into_iter()
68            .flat_map(|users| users.iter().cloned())
69    }
70
71    /// Get the number of reference held by this table
72    pub fn get_ref_count(&self, node: &NetRef<I>) -> usize {
73        self.ref_count.get(node).copied().unwrap_or(0)
74    }
75
76    /// Returns `true` if the net has any used by any cells in the circuit
77    /// This does incude nets that are only used as outputs.
78    pub fn net_has_uses(&self, net: &Net) -> bool {
79        (self.net_fan_out.contains_key(net) && !self.net_fan_out[net].is_empty())
80            || self.is_an_output.contains(net)
81    }
82
83    /// Returns `true` if the net has any uses  in the circuit
84    pub fn has_uses(&self, net: &DrivenNet<I>) -> bool {
85        net.is_top_level_output()
86            || (self.dnet_fan_out.contains_key(net) && !self.dnet_fan_out[net].is_empty())
87    }
88}
89
90impl<'a, I> Analysis<'a, I> for FanOutTable<'a, I>
91where
92    I: Instantiable,
93{
94    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
95        let mut net_fan_out: HashMap<Net, Vec<NetRef<I>>> = HashMap::new();
96        let mut dnet_fan_out: HashMap<DrivenNet<I>, Vec<InputPort<I>>> = HashMap::new();
97        let mut node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>> = HashMap::new();
98        let mut is_an_output: HashSet<Net> = HashSet::new();
99        let mut ref_count: HashMap<NetRef<I>, usize> = HashMap::new();
100
101        // We can only build the fanout table if netlist is mostly intact
102        if let Err(e) = netlist.verify() {
103            match e {
104                Error::NoOutputs => (),
105                _ => return Err(e),
106            }
107        }
108
109        for c in netlist.connections() {
110            let e = net_fan_out.entry(c.net()).or_default();
111            e.push(c.target().unwrap());
112
113            let e = dnet_fan_out.entry(c.src()).or_default();
114            e.push(c.target());
115
116            let e = node_fan_out.entry(c.src().unwrap()).or_default();
117            e.push(c.target().unwrap());
118        }
119
120        for (o, n) in netlist.outputs() {
121            is_an_output.insert(o.as_net().clone());
122            is_an_output.insert(n);
123        }
124
125        for v in net_fan_out.values() {
126            for nr in v {
127                *ref_count.entry(nr.clone()).or_insert(1) += 1;
128            }
129        }
130
131        for (k, v) in &dnet_fan_out {
132            for nr in v {
133                *ref_count.entry(nr.clone().unwrap()).or_insert(1) += 1;
134            }
135            *ref_count.entry(k.clone().unwrap()).or_insert(1) += 1;
136        }
137
138        for (k, v) in &node_fan_out {
139            for nr in v {
140                *ref_count.entry(nr.clone()).or_insert(1) += 1;
141            }
142            *ref_count.entry(k.clone()).or_insert(1) += 1;
143        }
144
145        Ok(FanOutTable {
146            _netlist: netlist,
147            net_fan_out,
148            dnet_fan_out,
149            node_fan_out,
150            is_an_output,
151            ref_count,
152        })
153    }
154}
155
156/// A simple example to analyze the logic levels of a netlist.
157/// This analysis checks for cycles, but it doesn't check for registers.
158/// Result of combinational depth analysis for a single net.
159#[derive(Debug, Copy, Clone, PartialEq, Eq)]
160pub enum CombDepthResult {
161    /// Signal has no driver
162    Undefined,
163    /// Signal is along a cycle
164    CombCycle,
165    /// Integer logic level
166    Depth(usize),
167}
168
169/// Computes the combinational depth of each net in a netlist.
170///
171/// Each net is classified as having a defined depth, being undefined,
172/// or participating in a combinational cycle.
173pub struct CombDepthInfo<'a, I: Instantiable> {
174    _netlist: &'a Netlist<I>,
175    /// The total distance from a sequential element
176    results: HashMap<NetRef<I>, CombDepthResult>,
177    /// The critical predecessor to the node
178    critical_par: HashMap<NetRef<I>, InputPort<I>>,
179    /// Critical endpoints to build paths from
180    critical_ends: BinaryHeap<(Reverse<usize>, NetRef<I>)>,
181    /// Max will be `None` if the entire circuit is part of a combinational cycle or has undriven elements
182    max_depth: Option<usize>,
183}
184
185impl<I> CombDepthInfo<'_, I>
186where
187    I: Instantiable,
188{
189    /// Max number of critical endpoints to keep in the heap.
190    const SIZE_HEAP: usize = 10;
191
192    /// Returns the logic level of a node in the circuit.
193    pub fn get_comb_depth(&self, node: &NetRef<I>) -> Option<CombDepthResult> {
194        self.results.get(node).copied()
195    }
196
197    /// Returns the critical input port
198    pub fn get_crit_input(&self, node: &NetRef<I>) -> Option<&InputPort<I>> {
199        self.critical_par.get(node)
200    }
201
202    /// Returns the most critical endpoints in the circuit
203    pub fn get_critical_points(&self) -> impl IntoIterator<Item = DrivenNet<I>> {
204        let mut v = self.critical_ends.iter().collect::<Vec<_>>();
205        v.sort_by_key(|(d, _)| *d);
206        v.into_iter().flat_map(|(_, n)| n.outputs())
207    }
208
209    /// Builds the most critical path
210    pub fn build_critical_path(&self) -> Option<Vec<DrivenNet<I>>> {
211        let mut path = Vec::new();
212        let mut current = self.get_critical_points().into_iter().next()?;
213        while let Some(crit) = self.critical_par.get(&current.clone().unwrap()) {
214            path.push(current.clone());
215            current = self
216                ._netlist
217                .get_driver(current.unwrap(), crit.get_input_num())
218                .unwrap();
219        }
220        path.push(current);
221        Some(path)
222    }
223
224    /// Returns the maximum logic level of the circuit.
225    pub fn get_max_depth(&self) -> Option<usize> {
226        self.max_depth
227    }
228}
229
230impl<'a, I> Analysis<'a, I> for CombDepthInfo<'a, I>
231where
232    I: Instantiable,
233{
234    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
235        let mut results: HashMap<NetRef<I>, CombDepthResult> = HashMap::new();
236        let mut critical_par: HashMap<NetRef<I>, InputPort<I>> = HashMap::new();
237        let mut critical_ends: BinaryHeap<(_, NetRef<I>)> = BinaryHeap::new();
238        let mut visiting: HashSet<NetRef<I>> = HashSet::new();
239        let mut max_depth: Option<usize> = None;
240
241        fn compute<I: Instantiable>(
242            node: NetRef<I>,
243            netlist: &Netlist<I>,
244            results: &mut HashMap<NetRef<I>, CombDepthResult>,
245            critical_par: &mut HashMap<NetRef<I>, InputPort<I>>,
246            visiting: &mut HashSet<NetRef<I>>,
247        ) -> CombDepthResult {
248            // Memoized result
249            if let Some(&r) = results.get(&node) {
250                return r;
251            }
252
253            // Cycle detection
254            if visiting.contains(&node) {
255                for n in visiting.iter() {
256                    results.insert(n.clone(), CombDepthResult::CombCycle);
257                }
258                return CombDepthResult::CombCycle;
259            }
260
261            // Input nodes and reg have depth 0
262            if node.is_an_input() || node.get_instance_type().is_some_and(|inst| inst.is_seq()) {
263                let r = CombDepthResult::Depth(0);
264                results.insert(node.clone(), r);
265                return r;
266            }
267
268            visiting.insert(node.clone());
269
270            let mut max_depth = 0;
271            let mut crit: Option<InputPort<I>> = None;
272            let mut is_undefined = false;
273
274            for i in 0..node.get_num_input_ports() {
275                let driver = match netlist.get_driver(node.clone(), i) {
276                    Some(d) => d.unwrap(),
277                    None => {
278                        is_undefined = true;
279                        continue;
280                    }
281                };
282
283                if let Some(inst) = driver.get_instance_type()
284                    && inst.is_seq()
285                {
286                    continue;
287                }
288
289                match compute(driver, netlist, results, critical_par, visiting) {
290                    CombDepthResult::Depth(d) => {
291                        if d > max_depth {
292                            max_depth = d;
293                            crit = Some(node.get_input(i));
294                        }
295                    }
296                    CombDepthResult::Undefined => {
297                        is_undefined = true;
298                    }
299                    CombDepthResult::CombCycle => {
300                        let r = CombDepthResult::CombCycle;
301                        results.insert(node.clone(), r);
302                        visiting.remove(&node);
303                        return r;
304                    }
305                }
306            }
307
308            visiting.remove(&node);
309            let r = if is_undefined {
310                CombDepthResult::Undefined
311            } else {
312                if let Some(crit) = crit {
313                    critical_par.insert(node.clone(), crit);
314                }
315                let d = max_depth + 1;
316                CombDepthResult::Depth(d)
317            };
318            results.insert(node.clone(), r);
319            r
320        }
321
322        for (driven, _) in netlist.outputs() {
323            let node = driven.unwrap();
324            let r = compute(
325                node.clone(),
326                netlist,
327                &mut results,
328                &mut critical_par,
329                &mut visiting,
330            );
331
332            if let CombDepthResult::Depth(d) = r {
333                critical_ends.push((Reverse(d), node));
334                if critical_ends.len() > CombDepthInfo::<I>::SIZE_HEAP {
335                    critical_ends.pop();
336                }
337                max_depth = Some(max_depth.map_or(d, |m| m.max(d)));
338            }
339        }
340
341        for node in netlist.matches(|inst| inst.is_seq()) {
342            compute(
343                node.clone(),
344                netlist,
345                &mut results,
346                &mut critical_par,
347                &mut visiting,
348            );
349            for i in 0..node.get_num_input_ports() {
350                if let Some(driver) = netlist.get_driver(node.clone(), i) {
351                    if driver.get_instance_type().is_some_and(|inst| inst.is_seq()) {
352                        continue;
353                    }
354
355                    let r = compute(
356                        driver.clone().unwrap(),
357                        netlist,
358                        &mut results,
359                        &mut critical_par,
360                        &mut visiting,
361                    );
362                    if let CombDepthResult::Depth(d) = r {
363                        critical_ends.push((Reverse(d), driver.unwrap()));
364                        if critical_ends.len() > CombDepthInfo::<I>::SIZE_HEAP {
365                            critical_ends.pop();
366                        }
367                        max_depth = Some(max_depth.map_or(d, |m| m.max(d)));
368                    }
369                }
370            }
371        }
372
373        Ok(CombDepthInfo {
374            _netlist: netlist,
375            results,
376            critical_par,
377            critical_ends,
378            max_depth,
379        })
380    }
381}
382
383/// An enum to provide pseudo-nodes for any misc user-programmable behavior.
384#[cfg(feature = "graph")]
385#[derive(Debug, Clone)]
386pub enum Node<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
387    /// A 'real' circuit node
388    NetRef(NetRef<I>),
389    /// Any other user-programmable node
390    Pseudo(T),
391}
392
393#[cfg(feature = "graph")]
394impl<I, T> std::fmt::Display for Node<I, T>
395where
396    I: Instantiable,
397    T: Clone + std::fmt::Debug + std::fmt::Display,
398{
399    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400        match self {
401            Node::NetRef(nr) => nr.fmt(f),
402            Node::Pseudo(t) => std::fmt::Display::fmt(t, f),
403        }
404    }
405}
406
407/// An enum to provide pseudo-edges for any misc user-programmable behavior.
408#[cfg(feature = "graph")]
409#[derive(Debug, Clone)]
410pub enum Edge<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
411    /// A 'real' circuit connection
412    Connection(Connection<I>),
413    /// Any other user-programmable node
414    Pseudo(T),
415}
416
417#[cfg(feature = "graph")]
418impl<I, T> std::fmt::Display for Edge<I, T>
419where
420    I: Instantiable,
421    T: Clone + std::fmt::Debug + std::fmt::Display,
422{
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        match self {
425            Edge::Connection(c) => c.fmt(f),
426            Edge::Pseudo(t) => std::fmt::Display::fmt(t, f),
427        }
428    }
429}
430
431/// Returns a petgraph representation of the netlist as a directed multi-graph with type [DiGraph<Object, NetLabel>].
432#[cfg(feature = "graph")]
433pub struct MultiDiGraph<'a, I: Instantiable> {
434    _netlist: &'a Netlist<I>,
435    graph: DiGraph<Node<I, String>, Edge<I, Net>>,
436}
437
438#[cfg(feature = "graph")]
439impl<I> MultiDiGraph<'_, I>
440where
441    I: Instantiable,
442{
443    /// Return a reference to the graph constructed by this analysis
444    pub fn get_graph(&self) -> &DiGraph<Node<I, String>, Edge<I, Net>> {
445        &self.graph
446    }
447
448    /// Iterates through a [greedy feedback arc set](https://doi.org/10.1016/0020-0190(93)90079-O) for the graph.
449    pub fn greedy_feedback_arcs(&self) -> impl Iterator<Item = Connection<I>> {
450        petgraph::algo::feedback_arc_set::greedy_feedback_arc_set(&self.graph)
451            .map(|e| match e.weight() {
452                Edge::Connection(c) => c,
453                _ => unreachable!("Outputs should be sinks"),
454            })
455            .cloned()
456    }
457
458    /// Returns all the circuit nodes sorted into their strongly connected components.
459    pub fn sccs(&self) -> Vec<Vec<NetRef<I>>> {
460        let mut res = Vec::new();
461        for scc in petgraph::algo::tarjan_scc(&self.graph) {
462            let c: Vec<NetRef<I>> = scc
463                .into_iter()
464                .filter_map(|i| match &self.graph[i] {
465                    Node::NetRef(nr) => Some(nr.clone()),
466                    _ => None,
467                })
468                .collect();
469            if !c.is_empty() {
470                res.push(c);
471            }
472        }
473        res
474    }
475}
476
477#[cfg(feature = "graph")]
478impl<'a, I> Analysis<'a, I> for MultiDiGraph<'a, I>
479where
480    I: Instantiable,
481{
482    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
483        // If we verify, we can hash by name
484        netlist.verify()?;
485        let mut mapping = HashMap::new();
486        let mut graph = DiGraph::new();
487
488        for obj in netlist.objects() {
489            let id = graph.add_node(Node::NetRef(obj.clone()));
490            mapping.insert(obj.to_string(), id);
491        }
492
493        for connection in netlist.connections() {
494            let source = connection.src().unwrap().get_obj().to_string();
495            let target = connection.target().unwrap().get_obj().to_string();
496            let s_id = mapping[&source];
497            let t_id = mapping[&target];
498            graph.add_edge(s_id, t_id, Edge::Connection(connection));
499        }
500
501        // Finally, add the output connections
502        for (o, n) in netlist.outputs() {
503            let s_id = mapping[&o.clone().unwrap().get_obj().to_string()];
504            let t_id = graph.add_node(Node::Pseudo(format!("Output({n})")));
505            graph.add_edge(s_id, t_id, Edge::Pseudo(o.as_net().clone()));
506        }
507
508        Ok(Self {
509            _netlist: netlist,
510            graph,
511        })
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use crate::{format_id, netlist::*};
519
520    fn full_adder() -> Gate {
521        Gate::new_logical_multi(
522            "FA".into(),
523            vec!["CIN".into(), "A".into(), "B".into()],
524            vec!["S".into(), "COUT".into()],
525        )
526    }
527
528    fn ripple_adder() -> GateNetlist {
529        let netlist = Netlist::new("ripple_adder".to_string());
530        let bitwidth = 4;
531
532        // Add the the inputs
533        let a = netlist.insert_input_escaped_logic_bus("a".to_string(), bitwidth);
534        let b = netlist.insert_input_escaped_logic_bus("b".to_string(), bitwidth);
535        let mut carry: DrivenNet<Gate> = netlist.insert_input("cin".into());
536
537        for (i, (a, b)) in a.into_iter().zip(b).enumerate() {
538            // Instantiate a full adder for each bit
539            let fa = netlist
540                .insert_gate(full_adder(), format_id!("fa_{i}"), &[carry, a, b])
541                .unwrap();
542
543            // Expose the sum
544            fa.expose_net(&fa.get_net(0)).unwrap();
545
546            carry = fa.find_output(&"COUT".into()).unwrap();
547
548            if i == bitwidth - 1 {
549                // Last full adder, expose the carry out
550                fa.get_output(1).expose_with_name("cout".into()).unwrap();
551            }
552        }
553
554        netlist.reclaim().unwrap()
555    }
556
557    #[test]
558    fn fanout_table() {
559        let netlist = ripple_adder();
560        let analysis = FanOutTable::build(&netlist);
561        assert!(analysis.is_ok());
562        let analysis = analysis.unwrap();
563        assert!(netlist.verify().is_ok());
564
565        for item in netlist.objects().filter(|o| !o.is_an_input()) {
566            // Sum bit has no users (it is a direct output)
567            assert!(
568                analysis
569                    .get_net_users(&item.find_output(&"S".into()).unwrap().as_net())
570                    .next()
571                    .is_none(),
572                "Sum bit should not have users"
573            );
574
575            assert!(
576                item.get_instance_name().is_some(),
577                "Item should have a name. Filtered inputs"
578            );
579
580            let net = item.find_output(&"COUT".into()).unwrap().as_net().clone();
581            let mut cout_users = analysis.get_net_users(&net);
582            if item.get_instance_name().unwrap().to_string() != "fa_3" {
583                assert!(cout_users.next().is_some(), "Carry bit should have users");
584            }
585
586            assert!(
587                cout_users.next().is_none(),
588                "Carry bit should have 1 or 0 user"
589            );
590        }
591    }
592}