import { randn, shuffle, choice } from "./random.js"; import Utils from "./utils.js"; export class StaticCombandit { constructor({ mu = [35, 15, -5, -25], sigma = [15, 15, 15, 15] }) { if (mu.length !== sigma.length) { throw new Error('mu and sigma must have the same length'); } else { const orders = Utils.argsort(mu); this.mu = orders.map(i => mu[i]); this.sigma = orders.map(i => sigma[i]); } } reward(state, action) { // Action suppposed to be a Boolean array, // Where action[i] = true means arm i is pulled, // And action[i] = false means arm i is not pulled. let reward = 0; action.forEach((a, i) => { if (a) { reward += randn(state.mu[i], state.sigma[i]); } }); return reward; } transition(state, action) { return structuredClone(state); } step(state, action) { return { reward: this.reward(state, action), nextState: this.transition(state, action) }; } getInitialState() { const orders = shuffle(Utils.range(this.nArms)); return { mu: this.mu.map((_, i) => this.mu[orders[i]]), sigma: this.sigma.map((_, i) => this.sigma[orders[i]]) } } get nArms() { return this.mu.length; } } export class ExpCombandit extends StaticCombandit { constructor({ mu = [35, 15, -5, -25], sigma = [15, 15, 15, 15], transitProb = 0.05, maxDCReward = 50, minDCReward = 10, stepDCReward = 10, recoverRunLength = 3 }) { super({ mu, sigma }); if (mu.length !== 4) { throw new Error('mu and sigma must have length 4'); } this.swapTarget = [[this.mu[0], this.mu[3]], [this.mu[1], this.mu[2]]]; this.transitProb = transitProb; this.maxDCReward = maxDCReward; this.minDCReward = minDCReward; this.stepDCReward = stepDCReward; this.recoverRunLength = recoverRunLength; } reward(state, action) { return action.some(a => a) ? super.reward(state, action) : state.dcReward; } transition(state, action) { const nextState = structuredClone(state); // Deal with default choice if (action.some(a => a)) { nextState.runLength++; if (nextState.runLength >= this.recoverRunLength) { nextState.dcReward = Math.min(nextState.dcReward + this.stepDCReward, this.maxDCReward); nextState.runLength = 0; } } else { nextState.runLength = 0; nextState.dcReward = Math.max(nextState.dcReward - this.stepDCReward, this.minDCReward); } // Deal with distribution change if (Math.random() <= this.transitProb) { const swapTarget = choice(this.swapTarget); const swapIdx = [nextState.mu.indexOf(swapTarget[0]), nextState.mu.indexOf(swapTarget[1])]; nextState.mu[swapIdx[0]] = swapTarget[1]; nextState.mu[swapIdx[1]] = swapTarget[0]; nextState.sigma[swapIdx[0]] = this.sigma[swapIdx[1]]; nextState.sigma[swapIdx[1]] = this.sigma[swapIdx[0]]; } return nextState; } getInitialState() { return { ...super.getInitialState(), dcReward: this.maxDCReward, runLength: 0 }; } }