nssi-fmri-demo/combandit.js
2024-07-28 23:36:17 +08:00

115 lines
3.5 KiB
JavaScript

import { randn, shuffle, choice } from "./random.js";
import Utils from "./utils.js";
export class StaticCombandit {
constructor({
mu = [35, 15, -5, -25],
sigma = [15, 15, 15, 15]
}) {
if (mu.length !== sigma.length) {
throw new Error('mu and sigma must have the same length');
} else {
const orders = Utils.argsort(mu);
this.mu = orders.map(i => mu[i]);
this.sigma = orders.map(i => sigma[i]);
}
}
reward(state, action) {
// Action suppposed to be a Boolean array,
// Where action[i] = true means arm i is pulled,
// And action[i] = false means arm i is not pulled.
let reward = 0;
action.forEach((a, i) => {
if (a) {
reward += randn(state.mu[i], state.sigma[i]);
}
});
return reward;
}
transition(state, action) {
return structuredClone(state);
}
step(state, action) {
return { reward: this.reward(state, action), nextState: this.transition(state, action) };
}
getInitialState() {
const orders = shuffle(Utils.range(this.nArms));
return {
mu: this.mu.map((_, i) => this.mu[orders[i]]),
sigma: this.sigma.map((_, i) => this.sigma[orders[i]])
}
}
get nArms() {
return this.mu.length;
}
}
export class ExpCombandit extends StaticCombandit {
constructor({
mu = [35, 15, -5, -25],
sigma = [15, 15, 15, 15],
transitProb = 0.05,
maxDCReward = 50,
minDCReward = 10,
stepDCReward = 10,
recoverRunLength = 3
}) {
super({ mu, sigma });
if (mu.length !== 4) {
throw new Error('mu and sigma must have length 4');
}
this.swapTarget = [[this.mu[0], this.mu[3]],
[this.mu[1], this.mu[2]]];
this.transitProb = transitProb;
this.maxDCReward = maxDCReward;
this.minDCReward = minDCReward;
this.stepDCReward = stepDCReward;
this.recoverRunLength = recoverRunLength;
}
reward(state, action) {
return action.some(a => a) ? super.reward(state, action) : state.dcReward;
}
transition(state, action) {
const nextState = structuredClone(state);
// Deal with default choice
if (action.some(a => a)) {
nextState.runLength++;
if (nextState.runLength >= this.recoverRunLength) {
nextState.dcReward = Math.min(nextState.dcReward + this.stepDCReward, this.maxDCReward);
nextState.runLength = 0;
}
} else {
nextState.runLength = 0;
nextState.dcReward = Math.max(nextState.dcReward - this.stepDCReward, this.minDCReward);
}
// Deal with distribution change
if (Math.random() <= this.transitProb) {
const swapTarget = choice(this.swapTarget);
const swapIdx = [nextState.mu.indexOf(swapTarget[0]), nextState.mu.indexOf(swapTarget[1])];
nextState.mu[swapIdx[0]] = swapTarget[1];
nextState.mu[swapIdx[1]] = swapTarget[0];
nextState.sigma[swapIdx[0]] = this.sigma[swapIdx[1]];
nextState.sigma[swapIdx[1]] = this.sigma[swapIdx[0]];
}
return nextState;
}
getInitialState() {
return {
...super.getInitialState(),
dcReward: this.maxDCReward,
runLength: 0
};
}
}