115 lines
3.5 KiB
JavaScript
115 lines
3.5 KiB
JavaScript
import { randn, shuffle, choice } from "./random.js";
|
|
import Utils from "./utils.js";
|
|
|
|
export class StaticCombandit {
|
|
constructor({
|
|
mu = [35, 15, -5, -25],
|
|
sigma = [15, 15, 15, 15]
|
|
}) {
|
|
if (mu.length !== sigma.length) {
|
|
throw new Error('mu and sigma must have the same length');
|
|
} else {
|
|
const orders = Utils.argsort(mu);
|
|
this.mu = orders.map(i => mu[i]);
|
|
this.sigma = orders.map(i => sigma[i]);
|
|
}
|
|
}
|
|
|
|
reward(state, action) {
|
|
// Action suppposed to be a Boolean array,
|
|
// Where action[i] = true means arm i is pulled,
|
|
// And action[i] = false means arm i is not pulled.
|
|
let reward = 0;
|
|
action.forEach((a, i) => {
|
|
if (a) {
|
|
reward += randn(state.mu[i], state.sigma[i]);
|
|
}
|
|
});
|
|
return reward;
|
|
}
|
|
|
|
transition(state, action) {
|
|
return structuredClone(state);
|
|
}
|
|
|
|
step(state, action) {
|
|
return { reward: this.reward(state, action), nextState: this.transition(state, action) };
|
|
}
|
|
|
|
getInitialState() {
|
|
const orders = shuffle(Utils.range(this.nArms));
|
|
return {
|
|
mu: this.mu.map((_, i) => this.mu[orders[i]]),
|
|
sigma: this.sigma.map((_, i) => this.sigma[orders[i]])
|
|
}
|
|
}
|
|
|
|
get nArms() {
|
|
return this.mu.length;
|
|
}
|
|
}
|
|
|
|
export class ExpCombandit extends StaticCombandit {
|
|
constructor({
|
|
mu = [35, 15, -5, -25],
|
|
sigma = [15, 15, 15, 15],
|
|
transitProb = 0.05,
|
|
maxDCReward = 50,
|
|
minDCReward = 10,
|
|
stepDCReward = 10,
|
|
recoverRunLength = 3
|
|
}) {
|
|
super({ mu, sigma });
|
|
if (mu.length !== 4) {
|
|
throw new Error('mu and sigma must have length 4');
|
|
}
|
|
this.swapTarget = [[this.mu[0], this.mu[3]],
|
|
[this.mu[1], this.mu[2]]];
|
|
|
|
this.transitProb = transitProb;
|
|
this.maxDCReward = maxDCReward;
|
|
this.minDCReward = minDCReward;
|
|
this.stepDCReward = stepDCReward;
|
|
this.recoverRunLength = recoverRunLength;
|
|
}
|
|
|
|
reward(state, action) {
|
|
return action.some(a => a) ? super.reward(state, action) : state.dcReward;
|
|
}
|
|
|
|
transition(state, action) {
|
|
const nextState = structuredClone(state);
|
|
|
|
// Deal with default choice
|
|
if (action.some(a => a)) {
|
|
nextState.runLength++;
|
|
if (nextState.runLength >= this.recoverRunLength) {
|
|
nextState.dcReward = Math.min(nextState.dcReward + this.stepDCReward, this.maxDCReward);
|
|
nextState.runLength = 0;
|
|
}
|
|
} else {
|
|
nextState.runLength = 0;
|
|
nextState.dcReward = Math.max(nextState.dcReward - this.stepDCReward, this.minDCReward);
|
|
}
|
|
|
|
// Deal with distribution change
|
|
if (Math.random() <= this.transitProb) {
|
|
const swapTarget = choice(this.swapTarget);
|
|
const swapIdx = [nextState.mu.indexOf(swapTarget[0]), nextState.mu.indexOf(swapTarget[1])];
|
|
nextState.mu[swapIdx[0]] = swapTarget[1];
|
|
nextState.mu[swapIdx[1]] = swapTarget[0];
|
|
nextState.sigma[swapIdx[0]] = this.sigma[swapIdx[1]];
|
|
nextState.sigma[swapIdx[1]] = this.sigma[swapIdx[0]];
|
|
}
|
|
|
|
return nextState;
|
|
}
|
|
|
|
getInitialState() {
|
|
return {
|
|
...super.getInitialState(),
|
|
dcReward: this.maxDCReward,
|
|
runLength: 0
|
|
};
|
|
}
|
|
} |