"""
SN utility functions.
Native Python implementations of network structure utility functions
for printing, debugging, and data conversion.
Port from:
- matlab/src/api/sn/sn_print.m
- matlab/src/api/sn/sn_print_routing_matrix.m
- matlab/src/api/sn/sn_refresh_process_fields.m
- matlab/src/api/sn/sn_rtnodes_to_rtorig.m
"""
import numpy as np
from typing import Optional, Dict, Any, Tuple, List
import sys
from .network_struct import NetworkStruct, NodeType, RoutingStrategy
[docs]
def sn_print(sn: NetworkStruct, file=None) -> None:
"""
Print comprehensive information about a NetworkStruct object.
This function displays all fields, matrices, lists, and maps in a formatted
manner useful for debugging and inspection of network structures.
Args:
sn: Network structure to inspect
file: Output file (default: sys.stdout)
References:
MATLAB: matlab/src/api/sn/sn_print.m
"""
if file is None:
file = sys.stdout
def p(s):
print(s, file=file)
def format_matrix(m, name=''):
"""Format a matrix for compact display."""
if m is None:
return 'null'
if isinstance(m, np.ndarray):
if m.size == 0:
return '[]'
if m.ndim == 1:
items = [str(int(x)) if x == int(x) and not np.isinf(x) else str(x) for x in m]
return '[' + ' '.join(items) + ']'
else:
rows = []
for i in range(m.shape[0]):
items = [str(int(x)) if x == int(x) and not np.isinf(x) else str(x) for x in m[i, :]]
rows.append(' '.join(items))
return '[' + '; '.join(rows) + ']'
return str(m)
# Basic integer fields
p(f'nstations: {sn.nstations}')
p(f'nstateful: {sn.nstateful}')
p(f'nnodes: {sn.nnodes}')
p(f'nclasses: {sn.nclasses}')
p(f'nclosedjobs: {sn.nclosedjobs}')
p(f'nchains: {sn.nchains}')
# Matrix fields
p(f'refstat: {format_matrix(sn.refstat)}')
p(f'njobs: {format_matrix(sn.njobs)}')
p(f'nservers: {format_matrix(sn.nservers)}')
if sn.connmatrix is not None:
p(f'connmatrix: {format_matrix(sn.connmatrix)}')
p(f'scv: {format_matrix(sn.scv)}')
# Mapping arrays
p(f'nodeToStateful: {format_matrix(sn.nodeToStateful)}')
p(f'nodeToStation: {format_matrix(sn.nodeToStation)}')
p(f'stationToNode: {format_matrix(sn.stationToNode)}')
p(f'stationToStateful: {format_matrix(sn.stationToStateful)}')
p(f'statefulToStation: {format_matrix(sn.statefulToStation)}')
p(f'statefulToNode: {format_matrix(sn.statefulToNode)}')
# Rate fields
p(f'rates: {format_matrix(sn.rates)}')
if sn.classprio is not None:
p(f'classprio: {format_matrix(sn.classprio)}')
if sn.phases is not None:
p(f'phases: {format_matrix(sn.phases)}')
# Node type list
if sn.nodetype is not None:
nodetype_names = [NodeType.toText(nt) for nt in sn.nodetype]
p(f'nodetype: [{", ".join(nodetype_names)}]')
# Class names
if sn.classnames is not None:
if isinstance(sn.classnames, list):
classnames_str = ', '.join(f'"{n}"' for n in sn.classnames)
p(f'classnames: [{classnames_str}]')
else:
p(f'classnames: ["{sn.classnames}"]')
# Node names
if sn.nodenames is not None:
if isinstance(sn.nodenames, list):
nodenames_str = ', '.join(f'"{n}"' for n in sn.nodenames)
p(f'nodenames: [{nodenames_str}]')
# Routing tables
if sn.rt is not None:
p(f'rt: {format_matrix(sn.rt)}')
if sn.rtnodes is not None:
p(f'rtnodes: {format_matrix(sn.rtnodes)}')
# Visit ratios
if sn.visits:
p('visits: {')
for chain_id, visits in sn.visits.items():
p(f' {chain_id}: {format_matrix(visits)}')
p('}')
# Node visits
if sn.nodevisits:
p('nodevisits: {')
for chain_id, nodevisits in sn.nodevisits.items():
p(f' {chain_id}: {format_matrix(nodevisits)}')
p('}')
# Chain info
if sn.inchain:
p('inchain: {')
for chain_id, classes in sn.inchain.items():
if isinstance(classes, np.ndarray):
classes_str = ', '.join(str(int(c)) for c in classes.flatten())
elif isinstance(classes, list):
classes_str = ', '.join(str(c) for c in classes)
else:
classes_str = str(classes)
p(f' {chain_id}: [{classes_str}]')
p('}')
[docs]
def sn_print_routing_matrix(
sn: NetworkStruct,
onlyclass: Optional[Any] = None,
file=None
) -> None:
"""
Print the routing matrix of the network.
This function displays the routing probabilities between nodes and classes
in a human-readable format.
Args:
sn: Network structure
onlyclass: Optional filter for a specific class (object with 'name' attribute)
file: Output file (default: sys.stdout)
References:
MATLAB: matlab/src/api/sn/sn_print_routing_matrix.m
"""
if file is None:
file = sys.stdout
node_names = sn.nodenames if sn.nodenames else [f'Node{i}' for i in range(sn.nnodes)]
classnames = sn.classnames if sn.classnames else [f'Class{i}' for i in range(sn.nclasses)]
rtnodes = sn.rtnodes
nnodes = sn.nnodes
nclasses = sn.nclasses
if rtnodes is None:
print("No routing matrix available.", file=file)
return
for i in range(nnodes):
for r in range(nclasses):
for j in range(nnodes):
for s in range(nclasses):
rt_idx_src = i * nclasses + r
rt_idx_dst = j * nclasses + s
if rt_idx_src < rtnodes.shape[0] and rt_idx_dst < rtnodes.shape[1]:
prob = rtnodes[rt_idx_src, rt_idx_dst]
if prob > 0:
# Skip sinks
if sn.nodetype is not None and i < len(sn.nodetype) and sn.nodetype[i] == NodeType.SINK:
continue
# Check for Cache (state-dependent)
if sn.nodetype is not None and i < len(sn.nodetype) and sn.nodetype[i] == NodeType.CACHE:
pr_str = 'state-dependent'
elif sn.routing is not None and i < sn.routing.shape[0] and r < sn.routing.shape[1]:
if sn.routing[i, r] == RoutingStrategy.DISABLED:
continue
pr_str = f'{prob:.6f}'
else:
pr_str = f'{prob:.6f}'
# Apply class filter if specified
if onlyclass is None:
print(f'{node_names[i]} [{classnames[r]}] => {node_names[j]} [{classnames[s]}] : Pr={pr_str}', file=file)
else:
class_name = getattr(onlyclass, 'name', str(onlyclass))
if classnames[r].lower() == class_name.lower() or classnames[s].lower() == class_name.lower():
print(f'{node_names[i]} [{classnames[r]}] => {node_names[j]} [{classnames[s]}] : Pr={pr_str}', file=file)
[docs]
def sn_refresh_process_fields(
sn: NetworkStruct,
station_idx: int,
class_idx: int
) -> NetworkStruct:
"""
Refresh process fields based on rate and SCV values.
Updates mu, phi, proc, pie, phases based on current rate and SCV values.
- SCV = 1.0: Exponential (1 phase)
- SCV < 1.0: Erlang approximation
- SCV > 1.0: Hyperexponential(2) approximation
Args:
sn: Network structure (modified in place)
station_idx: Station index (0-based)
class_idx: Class index (0-based)
Returns:
Modified network structure
References:
MATLAB: matlab/src/api/sn/sn_refresh_process_fields.m
"""
if sn.rates is None or station_idx >= sn.rates.shape[0] or class_idx >= sn.rates.shape[1]:
return sn
rate = sn.rates[station_idx, class_idx]
scv = sn.scv[station_idx, class_idx] if sn.scv is not None else 1.0
# Skip if rate is invalid
if np.isnan(rate) or rate <= 0 or np.isinf(rate):
return sn
mean = 1.0 / rate
# Determine number of phases and create MAP representation
if np.isnan(scv) or abs(scv - 1.0) < 1e-10:
# Exponential
n_phases = 1
D0 = np.array([[-rate]])
D1 = np.array([[rate]])
elif scv < 1.0:
# Erlang: k = ceil(1/scv)
k = max(1, int(np.ceil(1.0 / scv)))
n_phases = k
# Erlang-k phase rate
phase_rate = k * rate
D0 = np.zeros((k, k))
for i in range(k - 1):
D0[i, i] = -phase_rate
D0[i, i + 1] = phase_rate
D0[k - 1, k - 1] = -phase_rate
D1 = np.zeros((k, k))
D1[k - 1, 0] = phase_rate
else:
# Hyperexponential (scv > 1)
n_phases = 2
# Fit H2 to match mean and scv
cv2 = scv
if cv2 <= 1:
cv2 = 1.01
# Two-moment matching for H2
# Using balanced means approach
p = 0.5 * (1 + np.sqrt((cv2 - 1) / (cv2 + 1)))
p = min(max(p, 0.01), 0.99)
mu1 = 2 * p / mean
mu2 = 2 * (1 - p) / mean
D0 = np.array([[-mu1, 0], [0, -mu2]])
D1 = np.array([[mu1, 0], [0, mu2]])
# Update phases
if sn.phases is None:
sn.phases = np.ones((sn.nstations, sn.nclasses))
if station_idx < sn.phases.shape[0] and class_idx < sn.phases.shape[1]:
sn.phases[station_idx, class_idx] = n_phases
# Update phasessz
if sn.phasessz is None:
sn.phasessz = np.ones((sn.nstations, sn.nclasses))
if station_idx < sn.phasessz.shape[0] and class_idx < sn.phasessz.shape[1]:
sn.phasessz[station_idx, class_idx] = max(n_phases, 1)
# Recompute phaseshift for this station
if sn.phaseshift is None:
sn.phaseshift = np.zeros((sn.nstations, sn.nclasses + 1))
cum_sum = 0
sn.phaseshift[station_idx, 0] = 0
for c in range(sn.nclasses):
if sn.phasessz is not None:
cum_sum += sn.phasessz[station_idx, c]
if c + 1 < sn.phaseshift.shape[1]:
sn.phaseshift[station_idx, c + 1] = cum_sum
# Update mu (rates from -diag(D0))
if sn.mu is None:
sn.mu = {}
mu_vec = -np.diag(D0)
sn.mu[(station_idx, class_idx)] = mu_vec
# Update phi (completion probabilities)
if sn.phi is None:
sn.phi = {}
phi_vec = np.zeros(n_phases)
for i in range(n_phases):
d1_row_sum = np.sum(D1[i, :])
d0_diag = -D0[i, i]
if d0_diag != 0:
phi_vec[i] = d1_row_sum / d0_diag
sn.phi[(station_idx, class_idx)] = phi_vec
# Update pie (initial phase distribution)
if sn.pie is None:
sn.pie = {}
pie_vec = np.zeros(n_phases)
pie_vec[0] = 1.0 # Start in first phase
sn.pie[(station_idx, class_idx)] = pie_vec
# Update proc (MAP representation)
if sn.proc is None:
sn.proc = {}
sn.proc[(station_idx, class_idx)] = [D0, D1]
return sn
[docs]
def sn_rtnodes_to_rtorig(sn: NetworkStruct) -> Tuple[Dict, np.ndarray]:
"""
Convert node routing matrix to the original routing matrix format.
This function converts the node-level routing matrix to the original
routing matrix format, excluding class-switching nodes.
Args:
sn: Network structure
Returns:
Tuple of (rtorigcell, rtorig) where:
rtorigcell: Dictionary representation {(r,s): ndarray}
rtorig: Sparse/dense matrix representation
References:
MATLAB: matlab/src/api/sn/sn_rtnodes_to_rtorig.m
"""
K = sn.nclasses
rtnodes = sn.rtnodes
if rtnodes is None:
return {}, np.array([])
# Find where class-switching nodes start
csshift = sn.nnodes
if sn.nodenames is not None:
for ind in range(sn.nnodes):
if ind < len(sn.nodenames) and sn.nodenames[ind].startswith('CS_'):
csshift = ind
break
# Build column indices to keep (exclude CS nodes)
col_to_keep = []
for ind in range(csshift):
for k in range(K):
col_to_keep.append(ind * K + k)
if len(col_to_keep) == 0:
return {}, np.array([])
# Perform stochastic complementation
rtorig = _dtmc_stochcomp(rtnodes, col_to_keep)
# Replace NaNs with 0
rtorig = np.nan_to_num(rtorig, nan=0.0)
# Build cell representation
rtorigcell: Dict[Tuple[int, int], np.ndarray] = {}
for r in range(K):
for s in range(K):
rtorigcell[(r, s)] = np.zeros((csshift, csshift))
for ind in range(csshift):
if sn.nodetype is not None and ind < len(sn.nodetype) and sn.nodetype[ind] != NodeType.SINK:
for jnd in range(csshift):
for r in range(K):
for s in range(K):
src_idx = ind * K + r
dst_idx = jnd * K + s
if src_idx < rtorig.shape[0] and dst_idx < rtorig.shape[1]:
rtorigcell[(r, s)][ind, jnd] = rtorig[src_idx, dst_idx]
return rtorigcell, rtorig
def _dtmc_stochcomp(P: np.ndarray, keep_states: List[int]) -> np.ndarray:
"""
Perform stochastic complementation on a transition matrix.
Removes transient states by computing the stochastic complement.
Args:
P: Transition probability matrix
keep_states: Indices of states to keep
Returns:
Reduced transition probability matrix
"""
n = P.shape[0]
keep_set = set(keep_states)
remove_states = [i for i in range(n) if i not in keep_set]
if len(remove_states) == 0:
# Nothing to remove
return P[np.ix_(keep_states, keep_states)]
if len(keep_states) == 0:
return np.array([])
# Partition the matrix
# P = [[Q_AA, Q_AB], [Q_BA, Q_BB]]
# where A = keep, B = remove
Q_AA = P[np.ix_(keep_states, keep_states)]
Q_AB = P[np.ix_(keep_states, remove_states)]
Q_BA = P[np.ix_(remove_states, keep_states)]
Q_BB = P[np.ix_(remove_states, remove_states)]
# Stochastic complement: P_A = Q_AA + Q_AB * (I - Q_BB)^(-1) * Q_BA
I_BB = np.eye(len(remove_states))
try:
inv_term = np.linalg.inv(I_BB - Q_BB)
P_A = Q_AA + Q_AB @ inv_term @ Q_BA
except np.linalg.LinAlgError:
# If inversion fails, just use the direct submatrix
P_A = Q_AA
return P_A