LINE Solver
MATLAB API documentation
Loading...
Searching...
No Matches
predictEnsemble.m
1function [pred, scores] = predictEnsemble(learners, X)
2% This function is a custom predict function for ensembles
3% of decision trees trained using either umrbBagging or
4% smoteBagging. "learners" is a cell arry of MATLAB decision
5% tree objects, and X is an N x M array, where N is the number
6% of test instances, and M is the number of features used to
7% train the learners.
8
9 if ~iscolumn(learners)
10 learners = learners';
11 end
12 preCombinedScores = cell2mat(cellfun(@testPredict, learners, 'UniformOutput', false));
13 numLearners = length(learners);
14 numRows = size(X, 1);
15 scores = zeros(numRows, size(preCombinedScores, 2));
16 pred = zeros(numRows, 1);
17 for i = 1 : numRows
18 scores(i, :) = mean(preCombinedScores(i:numRows:i+(numLearners-1)*numRows, :));
19 highestIdx = find(scores(i, :) == max(scores(i, :)));
20 pred(i) = highestIdx(randi(length(highestIdx))); % Break ties using random choice
21 end
22
23 function scores = testPredict(learner)
24 [~, scores] = learner.predict(X);
25 end
26end