"""
SN Getter Functions for Parameter Extraction.
Native Python implementations for extracting parameters from
network structures including arrival rates, throughputs, and
product-form chain parameters.
Key functions:
sn_get_arvr_from_tput: Compute arrival rates from throughputs
sn_get_node_arvr_from_tput: Compute node arrival rates from throughputs
sn_get_node_tput_from_tput: Compute node throughputs from station throughputs
sn_get_product_form_chain_params: Extract chain-aggregated parameters
References:
Original MATLAB: matlab/src/api/sn/sn_get_*.m
"""
import numpy as np
from typing import Tuple, Optional
from dataclasses import dataclass
from .network_struct import NetworkStruct, NodeType
[docs]
@dataclass
class ChainParams:
"""Chain-aggregated product-form parameters."""
lambda_vec: np.ndarray # Chain arrival rates
D: np.ndarray # Chain service demands at queuing stations
N: np.ndarray # Chain populations
Z: np.ndarray # Chain think times
mu: np.ndarray # Load-dependent service capacity scaling
S: np.ndarray # Number of servers at queuing stations
V: np.ndarray # Chain visit ratios
[docs]
def sn_get_arvr_from_tput(sn: NetworkStruct, TN: np.ndarray,
TH: Optional[np.ndarray] = None) -> np.ndarray:
"""
Compute average arrival rates at stations from throughputs.
Calculates the average arrival rate at each station in steady-state
from the station throughputs and routing matrix.
Args:
sn: Network structure
TN: Average throughputs at stations (M x R)
TH: Throughput handles (optional)
Returns:
AN: Average arrival rates at stations (M x R)
References:
Original MATLAB: matlab/src/api/sn/sn_get_arvr_from_tput.m
"""
M = sn.nstations
R = sn.nclasses
if TN is None or len(TN) == 0:
return np.array([])
TN = np.atleast_2d(np.asarray(TN, dtype=float))
AN = np.zeros((M, R))
# Build mapping from stateful nodes to their position in rt matrix
stateful_nodes = np.where(sn.isstateful)[0]
n_stateful = len(stateful_nodes)
# Build throughput vector for all stateful nodes (stations have TN, others need computation)
TN_stateful = np.zeros((n_stateful, R))
for sf, ind in enumerate(stateful_nodes):
ist = sn.nodeToStation[ind]
# Check if this is a Cache node - needs special handling regardless of station status
# In some implementations (like Python), Cache may be treated as a station
if sn.nodetype[ind] == NodeType.CACHE:
# For Cache nodes, compute hit/miss class throughputs
# from the reference station throughput and hit/miss probabilities
if not hasattr(sn, 'nodeparam') or sn.nodeparam is None:
# Fall back to station throughput if available
if ist >= 0 and ist < TN.shape[0]:
TN_stateful[sf, :] = TN[ist, :]
continue
# nodeparam may be a dict or list
if isinstance(sn.nodeparam, dict):
nodeparam = sn.nodeparam.get(ind, None)
else:
nodeparam = sn.nodeparam[ind] if ind < len(sn.nodeparam) else None
if nodeparam is None:
# Fall back to station throughput if available
if ist >= 0 and ist < TN.shape[0]:
TN_stateful[sf, :] = TN[ist, :]
continue
hitclass = getattr(nodeparam, 'hitclass', None)
missclass = getattr(nodeparam, 'missclass', None)
if hitclass is None or missclass is None:
# Fall back to station throughput if available
if ist >= 0 and ist < TN.shape[0]:
TN_stateful[sf, :] = TN[ist, :]
continue
# Get actual hit/miss probabilities if available
actualhitprob = getattr(nodeparam, 'actualhitprob', None)
actualmissprob = getattr(nodeparam, 'actualmissprob', None)
if actualhitprob is None or actualmissprob is None:
# Actual probabilities not yet computed - fall back to station throughput
if ist >= 0 and ist < TN.shape[0]:
TN_stateful[sf, :] = TN[ist, :]
continue
hitclass = np.atleast_1d(hitclass)
missclass = np.atleast_1d(missclass)
actualhitprob = np.atleast_1d(actualhitprob)
actualmissprob = np.atleast_1d(actualmissprob)
# Find the chain and reference station for this cache
if hasattr(sn, 'nchains') and hasattr(sn, 'inchain') and hasattr(sn, 'refstat'):
for c in range(sn.nchains):
inchain = sn.inchain[c] if sn.inchain is not None else []
refstat = int(sn.refstat[c]) if sn.refstat is not None else 0
if refstat < TN.shape[0]:
totalTput = np.sum(TN[refstat, list(inchain)]) if len(inchain) > 0 else 0
# Set throughput for hit/miss classes
# Note: hitclass/missclass contain 0-indexed class indices
# Values > 0 are valid class indices; 0 or negative means no class
for origClass in range(len(hitclass)):
hc = int(hitclass[origClass])
mc = int(missclass[origClass])
if hc > 0 and hc < R:
TN_stateful[sf, hc] = totalTput * actualhitprob[origClass]
if mc > 0 and mc < R:
TN_stateful[sf, mc] = totalTput * actualmissprob[origClass]
elif ist >= 0:
# This stateful node is a station - use station throughput
TN_stateful[sf, :] = TN[ist, :]
# Compute arrival rates using stateful node throughputs and rt matrix
for ist in range(M):
ind_ist = sn.stationToNode[ist]
if sn.nodetype[ind_ist] == NodeType.SOURCE:
AN[ist, :] = 0
else:
# Find position of this station in stateful node list
sf_ist_arr = np.where(stateful_nodes == ind_ist)[0]
if len(sf_ist_arr) == 0:
continue
sf_ist = sf_ist_arr[0]
for sf_jst in range(n_stateful):
for k in range(R):
for r in range(R):
from_idx = sf_jst * R + r
to_idx = sf_ist * R + k
if from_idx < sn.rt.shape[0] and to_idx < sn.rt.shape[1]:
AN[ist, k] += TN_stateful[sf_jst, r] * sn.rt[from_idx, to_idx]
# Fork-join special handling: if network has fork-join nodes, use node-level arrival rates
# This matches MATLAB sn_get_arvr_from_tput.m lines 117-122
if hasattr(sn, 'fj') and sn.fj is not None and np.any(sn.fj):
ANn = sn_get_node_arvr_from_tput(sn, TN, TH, AN)
for ist in range(M):
ind = sn.stationToNode[ist]
if ind >= 0 and ind < ANn.shape[0]:
AN[ist, :] = ANn[ind, :]
return AN
[docs]
def sn_get_node_arvr_from_tput(sn: NetworkStruct, TN: np.ndarray,
TH: Optional[np.ndarray] = None,
AN: Optional[np.ndarray] = None) -> np.ndarray:
"""
Compute node arrival rates from station throughputs.
This function handles:
- Station nodes: Uses station arrival rates directly
- Cache nodes: Only requesting classes arrive (not hit/miss classes)
- Non-station nodes (ClassSwitch, Sink): Uses nodevisits-based computation
Args:
sn: Network structure
TN: Station throughputs (M x R)
TH: Throughput handles (optional)
AN: Station arrival rates (optional, computed if not provided)
Returns:
ANn: Node arrival rates (I x R)
References:
Original MATLAB: matlab/src/api/sn/sn_get_node_arvr_from_tput.m
"""
I = sn.nnodes
M = sn.nstations
C = sn.nchains
R = sn.nclasses
if AN is None:
AN = sn_get_arvr_from_tput(sn, TN, TH)
ANn = np.zeros((I, R))
# Check if we have valid throughput data
if TH is None or TN is None or TN.size == 0:
return ANn
# First, copy station arrival rates to station nodes
for ist in range(M):
ind = sn.stationToNode[ist]
if ind >= 0 and ind < I:
ANn[ind, :] = AN[ist, :]
# Process non-station nodes
for ind in range(I):
if sn.nodetype is None or ind >= len(sn.nodetype):
continue
node_type = sn.nodetype[ind]
# Skip Source nodes
if node_type == NodeType.SOURCE:
continue
for c in range(C):
if c not in sn.inchain:
continue
inchain = sn.inchain[c].flatten().astype(int)
refstat_idx = int(sn.refstat[c]) if c < len(sn.refstat) else 0
for r in inchain:
if r >= R:
continue
if node_type == NodeType.CACHE:
# For cache nodes, only the requesting class arrives
# Hit/miss classes don't arrive - they leave
hitclass = []
missclass = []
# Get hit/miss classes from nodeparam
if sn.nodeparam is not None and ind in sn.nodeparam:
node_param = sn.nodeparam[ind]
if hasattr(node_param, 'hitclass'):
hitclass = np.atleast_1d(node_param.hitclass).flatten()
if hasattr(node_param, 'missclass'):
missclass = np.atleast_1d(node_param.missclass).flatten()
# Check if this class is a hit or miss class
is_hit_or_miss = (r in hitclass) or (r in missclass)
if not is_hit_or_miss:
# This is a requesting class
# Arrival rate equals the throughput of this class (from any station)
for ist in range(M):
if r < TN.shape[1]:
val = TN[ist, r]
if val > 0 and not np.isnan(val):
ANn[ind, r] = val
break
# Hit/miss classes have 0 arrival rate at cache (they only depart)
elif node_type == NodeType.CLASSSWITCH:
# For ClassSwitch nodes that follow Cache nodes
# Hit/miss classes arrive from the Cache
# Check if there's a Cache node that connects to this ClassSwitch
for cache_ind in range(I):
if cache_ind >= len(sn.nodetype):
continue
if sn.nodetype[cache_ind] != NodeType.CACHE:
continue
if sn.nodeparam is None or cache_ind not in sn.nodeparam:
continue
cache_param = sn.nodeparam[cache_ind]
hitclass = np.atleast_1d(cache_param.hitclass).flatten() if hasattr(cache_param, 'hitclass') else []
missclass = np.atleast_1d(cache_param.missclass).flatten() if hasattr(cache_param, 'missclass') else []
actual_hit_prob = np.atleast_1d(cache_param.actualhitprob).flatten() if hasattr(cache_param, 'actualhitprob') and cache_param.actualhitprob is not None else None
actual_miss_prob = np.atleast_1d(cache_param.actualmissprob).flatten() if hasattr(cache_param, 'actualmissprob') and cache_param.actualmissprob is not None else None
# Check if r is a hit or miss class for this cache
for orig_class in range(len(hitclass)):
if actual_hit_prob is None or orig_class >= len(actual_hit_prob):
continue
# Get throughput of requesting class
req_tput = 0.0
for ist in range(M):
if orig_class < TN.shape[1]:
val = TN[ist, orig_class]
if val > 0 and not np.isnan(val):
req_tput = val
break
if req_tput > 0:
if hitclass[orig_class] == r:
# r is a hit class - arrival rate at ClassSwitch = hit throughput
ANn[ind, r] = req_tput * actual_hit_prob[orig_class]
elif orig_class < len(missclass) and missclass[orig_class] == r:
# r is a miss class - arrival rate at ClassSwitch = miss throughput
miss_prob = actual_miss_prob[orig_class] if actual_miss_prob is not None and orig_class < len(actual_miss_prob) else (1 - actual_hit_prob[orig_class])
ANn[ind, r] = req_tput * miss_prob
else:
# For other non-station nodes (like Sink)
# Check if this node has a station mapping
node_to_station = sn.nodeToStation[ind] if ind < len(sn.nodeToStation) else -1
if node_to_station < 0 or np.isnan(node_to_station):
# This is a non-station node
if c in sn.nodevisits and c in sn.visits:
nodevisits_c = sn.nodevisits[c]
visits_c = sn.visits[c]
if ind < nodevisits_c.shape[0] and r < nodevisits_c.shape[1]:
nodevisit_val = nodevisits_c[ind, r]
# Get stateful index for refstat
stateful_refstat = sn.stationToStateful[refstat_idx] if refstat_idx < len(sn.stationToStateful) else refstat_idx
# Sum visits at refstat for all classes in chain
visit_sum = 0.0
for s in inchain:
if stateful_refstat < visits_c.shape[0] and s < visits_c.shape[1]:
visit_sum += visits_c[stateful_refstat, s]
# Sum throughput at refstat for all classes in chain
tput_sum = 0.0
for s in inchain:
if refstat_idx < TN.shape[0] and s < TN.shape[1]:
tput_sum += TN[refstat_idx, s]
if visit_sum > 0:
ANn[ind, r] = (nodevisit_val / visit_sum) * tput_sum
# Replace NaN with 0
ANn = np.nan_to_num(ANn, nan=0.0)
return ANn
[docs]
def sn_get_node_tput_from_tput(sn: NetworkStruct, TN: np.ndarray,
TH: Optional[np.ndarray] = None,
ANn: Optional[np.ndarray] = None) -> np.ndarray:
"""
Compute node throughputs from station throughputs.
This function handles:
- Station nodes: Uses station throughputs directly
- Cache nodes: Uses actual hit/miss probabilities if available
- Non-station nodes: Uses routing matrix (rtnodes) for computation
Args:
sn: Network structure
TN: Station throughputs (M x R)
TH: Throughput handles (optional)
ANn: Node arrival rates (optional, computed if not provided)
Returns:
TNn: Node throughputs (I x R)
References:
Original MATLAB: matlab/src/api/sn/sn_get_node_tput_from_tput.m
"""
I = sn.nnodes
M = sn.nstations
C = sn.nchains
R = sn.nclasses
TN = np.atleast_2d(np.asarray(TN, dtype=float))
if ANn is None:
ANn = sn_get_node_arvr_from_tput(sn, TN, TH)
TNn = np.zeros((I, R))
# Check if we have valid throughput data
if TH is None or TN is None or TN.size == 0:
return TNn
# First pass: Process Cache nodes with actual hit/miss probabilities
for ind in range(I):
if sn.nodetype is None or ind >= len(sn.nodetype):
continue
node_type = sn.nodetype[ind]
if node_type == NodeType.CACHE:
# Get hit/miss class indices from nodeparam
hitclass = np.array([])
missclass = np.array([])
actual_hit_prob = None
actual_miss_prob = None
if sn.nodeparam is not None and ind in sn.nodeparam:
node_param = sn.nodeparam[ind]
if hasattr(node_param, 'hitclass'):
hitclass = np.atleast_1d(node_param.hitclass).flatten()
if hasattr(node_param, 'missclass'):
missclass = np.atleast_1d(node_param.missclass).flatten()
if hasattr(node_param, 'actualhitprob') and node_param.actualhitprob is not None:
actual_hit_prob = np.atleast_1d(node_param.actualhitprob).flatten()
if hasattr(node_param, 'actualmissprob') and node_param.actualmissprob is not None:
actual_miss_prob = np.atleast_1d(node_param.actualmissprob).flatten()
# For Cache nodes with class switching, each class may be in a different chain
# Find the total throughput from the chain with the requesting class (not hit/miss classes)
# The requesting class is the one that has a non-negative hit/miss class mapping
for orig_class in range(len(hitclass)):
h = hitclass[orig_class]
m = missclass[orig_class]
if h >= 0 and m >= 0:
# orig_class is a requesting class with hit/miss classes
# Find the throughput for this requesting class
# Look at refstat (typically station 0)
refstat_idx = 0
if orig_class < C and orig_class in sn.inchain:
inchain_orig = sn.inchain[orig_class].flatten().astype(int)
if orig_class < len(sn.refstat):
refstat_idx = int(sn.refstat[orig_class])
# Get throughput of the requesting class at refstat
req_tput = 0.0
if refstat_idx < TN.shape[0] and orig_class < TN.shape[1]:
val = TN[refstat_idx, orig_class]
if not np.isnan(val):
req_tput = val
# If no throughput at refstat, try to find it from any station
if req_tput == 0:
for ist in range(TN.shape[0]):
if orig_class < TN.shape[1]:
val = TN[ist, orig_class]
if val > 0 and not np.isnan(val):
req_tput = val
break
if req_tput > 0 and actual_hit_prob is not None and orig_class < len(actual_hit_prob):
# Set hit class throughput at cache
if h < R:
TNn[ind, h] = req_tput * actual_hit_prob[orig_class]
# Set miss class throughput at cache
if m < R:
miss_prob = actual_miss_prob[orig_class] if actual_miss_prob is not None and orig_class < len(actual_miss_prob) else (1 - actual_hit_prob[orig_class])
TNn[ind, m] = req_tput * miss_prob
# Second pass: Copy station throughputs directly to station nodes
# Skip Cache nodes - their throughputs were computed in the first pass
for ist in range(M):
ind = sn.stationToNode[ist]
if ind >= 0 and ind < I:
# Don't overwrite Cache node throughputs - they have special handling
if sn.nodetype is not None and ind < len(sn.nodetype) and sn.nodetype[ind] == NodeType.CACHE:
continue
TNn[ind, :] = TN[ist, :]
# Third pass: Compute throughputs for non-station nodes using routing matrix
for ind in range(I):
if sn.nodetype is None or ind >= len(sn.nodetype):
continue
node_type = sn.nodetype[ind]
# Skip Source, Sink, Join nodes and station nodes
if node_type in [NodeType.SOURCE, NodeType.SINK, NodeType.JOIN]:
continue
# Check if this is a station node
node_to_station = sn.nodeToStation[ind] if ind < len(sn.nodeToStation) else -1
if node_to_station >= 0 and not np.isnan(node_to_station):
continue # Already handled above
for c in range(C):
if c not in sn.inchain:
continue
inchain = sn.inchain[c].flatten().astype(int)
for r in inchain:
if r >= R:
continue
# Check if there are any visits for this class at any stateful node
# (skip this check for ClassSwitch nodes which need to process hit/miss classes)
if node_type != NodeType.CLASSSWITCH:
if c not in sn.visits:
continue
visits_c = sn.visits[c]
any_stateful = np.any(visits_c[:, r] > 0) if r < visits_c.shape[1] else False
if not any_stateful:
continue
# For Cache nodes, compute throughput from arrival rate and routing
# BUT skip if actual hit/miss probs were already used in the first pass
if node_type == NodeType.CACHE:
# Check if first pass already handled this cache
has_actual_probs = False
if sn.nodeparam is not None and ind in sn.nodeparam:
node_param = sn.nodeparam[ind]
if hasattr(node_param, 'actualhitprob') and node_param.actualhitprob is not None:
actual_probs = np.atleast_1d(node_param.actualhitprob).flatten()
has_actual_probs = np.any(actual_probs > 0)
if has_actual_probs:
continue # Skip - already computed with actual probs in first pass
for s in inchain:
if s >= R:
continue
for jnd in range(I):
if ind != jnd:
# Use rtnodes for routing probability
if sn.rtnodes is not None and sn.rtnodes.size > 0:
from_idx = ind * R + r
to_idx = jnd * R + s
if from_idx < sn.rtnodes.shape[0] and to_idx < sn.rtnodes.shape[1]:
TNn[ind, s] += ANn[ind, r] * sn.rtnodes[from_idx, to_idx]
elif node_type == NodeType.CLASSSWITCH:
# For ClassSwitch nodes that follow Cache nodes
# The output class (JobClass) gets throughput = sum of hit + miss arrival rates
# Check if there's a Cache node that connects to this ClassSwitch
for cache_ind in range(I):
if cache_ind >= len(sn.nodetype):
continue
if sn.nodetype[cache_ind] != NodeType.CACHE:
continue
if sn.nodeparam is None or cache_ind not in sn.nodeparam:
continue
cache_param = sn.nodeparam[cache_ind]
hitclass = np.atleast_1d(cache_param.hitclass).flatten() if hasattr(cache_param, 'hitclass') else []
missclass = np.atleast_1d(cache_param.missclass).flatten() if hasattr(cache_param, 'missclass') else []
# Find the requesting class that has r as its hit or miss class
for orig_class in range(len(hitclass)):
if orig_class < len(hitclass) and hitclass[orig_class] == r:
# r is a hit class - output class is orig_class
# Throughput of output class = arrival rate of hit + miss classes
# which equals the throughput of the requesting class
req_tput = 0.0
for ist in range(M):
if orig_class < TN.shape[1]:
val = TN[ist, orig_class]
if val > 0 and not np.isnan(val):
req_tput = val
break
if req_tput > 0 and orig_class < R:
TNn[ind, orig_class] = req_tput
elif orig_class < len(missclass) and missclass[orig_class] == r:
# r is a miss class - output class is orig_class
req_tput = 0.0
for ist in range(M):
if orig_class < TN.shape[1]:
val = TN[ist, orig_class]
if val > 0 and not np.isnan(val):
req_tput = val
break
if req_tput > 0 and orig_class < R:
TNn[ind, orig_class] = req_tput
else:
# For other non-station nodes (Router, etc.)
for s in inchain:
if s >= R:
continue
for jnd in range(I):
# Use rtnodes for routing probability
if sn.rtnodes is not None and sn.rtnodes.size > 0:
from_idx = ind * R + r
to_idx = jnd * R + s
if from_idx < sn.rtnodes.shape[0] and to_idx < sn.rtnodes.shape[1]:
TNn[ind, s] += ANn[ind, r] * sn.rtnodes[from_idx, to_idx]
# Handle Join nodes
for ind in range(I):
if sn.nodetype is None or ind >= len(sn.nodetype):
continue
node_type = sn.nodetype[ind]
if node_type == NodeType.JOIN:
for c in range(C):
if c not in sn.inchain:
continue
inchain = sn.inchain[c].flatten().astype(int)
for r in inchain:
if r >= R:
continue
for s in inchain:
if s >= R:
continue
for jnd in range(I):
if sn.rtnodes is not None and sn.rtnodes.size > 0:
from_idx = ind * R + r
to_idx = jnd * R + s
if from_idx < sn.rtnodes.shape[0] and to_idx < sn.rtnodes.shape[1]:
TNn[ind, s] += ANn[ind, r] * sn.rtnodes[from_idx, to_idx]
# Replace NaN with 0
TNn = np.nan_to_num(TNn, nan=0.0)
return TNn
[docs]
def sn_set_routing_prob(sn: NetworkStruct, from_stateful: int, from_class: int,
to_stateful: int, to_class: int, prob: float,
auto_refresh: bool = False) -> NetworkStruct:
"""
Set a routing probability between two stateful node-class pairs.
Updates a single entry in the rt matrix.
Args:
sn: Network structure
from_stateful: Source stateful node index (0-based)
from_class: Source class index (0-based)
to_stateful: Destination stateful node index (0-based)
to_class: Destination class index (0-based)
prob: Routing probability [0, 1]
auto_refresh: If True, refresh visit ratios (default False)
Returns:
Modified network structure
References:
Original MATLAB: matlab/src/api/sn/sn_set_routing_prob.m
"""
K = sn.nclasses
# Calculate indices in rt matrix
from_idx = from_stateful * K + from_class
to_idx = to_stateful * K + to_class
# Update rt matrix
sn.rt[from_idx, to_idx] = prob
# Auto-refresh visit ratios if requested
if auto_refresh:
from .transforms import sn_refresh_visits
sn_refresh_visits(sn)
return sn
__all__ = [
'ChainParams',
'sn_get_arvr_from_tput',
'sn_get_node_arvr_from_tput',
'sn_get_node_tput_from_tput',
'sn_get_product_form_chain_params',
'sn_set_routing_prob',
]