From 0fd1073341b9e896057e60c34a714e7606d46f6d Mon Sep 17 00:00:00 2001 From: HoshinoKoji Date: Sun, 28 Jul 2024 23:36:17 +0800 Subject: [PATCH] Init --- combandit.js | 115 ++++++++++++++++++++++++++++++++++++++++++++ index.html | 133 +++++++++++++++++++++++++++++++++++++++++++++++++++ random.js | 48 +++++++++++++++++++ utils.js | 39 +++++++++++++++ 4 files changed, 335 insertions(+) create mode 100644 combandit.js create mode 100644 index.html create mode 100644 random.js create mode 100644 utils.js diff --git a/combandit.js b/combandit.js new file mode 100644 index 0000000..f8d6f92 --- /dev/null +++ b/combandit.js @@ -0,0 +1,115 @@ +import { randn, shuffle, choice } from "./random.js"; +import Utils from "./utils.js"; + +export class StaticCombandit { + constructor({ + mu = [35, 15, -5, -25], + sigma = [15, 15, 15, 15] + }) { + if (mu.length !== sigma.length) { + throw new Error('mu and sigma must have the same length'); + } else { + const orders = Utils.argsort(mu); + this.mu = orders.map(i => mu[i]); + this.sigma = orders.map(i => sigma[i]); + } + } + + reward(state, action) { + // Action suppposed to be a Boolean array, + // Where action[i] = true means arm i is pulled, + // And action[i] = false means arm i is not pulled. + let reward = 0; + action.forEach((a, i) => { + if (a) { + reward += randn(state.mu[i], state.sigma[i]); + } + }); + return reward; + } + + transition(state, action) { + return structuredClone(state); + } + + step(state, action) { + return { reward: this.reward(state, action), nextState: this.transition(state, action) }; + } + + getInitialState() { + const orders = shuffle(Utils.range(this.nArms)); + return { + mu: this.mu.map((_, i) => this.mu[orders[i]]), + sigma: this.sigma.map((_, i) => this.sigma[orders[i]]) + } + } + + get nArms() { + return this.mu.length; + } +} + +export class ExpCombandit extends StaticCombandit { + constructor({ + mu = [35, 15, -5, -25], + sigma = [15, 15, 15, 15], + transitProb = 0.05, + maxDCReward = 50, + minDCReward = 10, + stepDCReward = 10, + recoverRunLength = 3 + }) { + super({ mu, sigma }); + if (mu.length !== 4) { + throw new Error('mu and sigma must have length 4'); + } + this.swapTarget = [[this.mu[0], this.mu[3]], + [this.mu[1], this.mu[2]]]; + + this.transitProb = transitProb; + this.maxDCReward = maxDCReward; + this.minDCReward = minDCReward; + this.stepDCReward = stepDCReward; + this.recoverRunLength = recoverRunLength; + } + + reward(state, action) { + return action.some(a => a) ? super.reward(state, action) : state.dcReward; + } + + transition(state, action) { + const nextState = structuredClone(state); + + // Deal with default choice + if (action.some(a => a)) { + nextState.runLength++; + if (nextState.runLength >= this.recoverRunLength) { + nextState.dcReward = Math.min(nextState.dcReward + this.stepDCReward, this.maxDCReward); + nextState.runLength = 0; + } + } else { + nextState.runLength = 0; + nextState.dcReward = Math.max(nextState.dcReward - this.stepDCReward, this.minDCReward); + } + + // Deal with distribution change + if (Math.random() <= this.transitProb) { + const swapTarget = choice(this.swapTarget); + const swapIdx = [nextState.mu.indexOf(swapTarget[0]), nextState.mu.indexOf(swapTarget[1])]; + nextState.mu[swapIdx[0]] = swapTarget[1]; + nextState.mu[swapIdx[1]] = swapTarget[0]; + nextState.sigma[swapIdx[0]] = this.sigma[swapIdx[1]]; + nextState.sigma[swapIdx[1]] = this.sigma[swapIdx[0]]; + } + + return nextState; + } + + getInitialState() { + return { + ...super.getInitialState(), + dcReward: this.maxDCReward, + runLength: 0 + }; + } +} \ No newline at end of file diff --git a/index.html b/index.html new file mode 100644 index 0000000..b6c2aa6 --- /dev/null +++ b/index.html @@ -0,0 +1,133 @@ + + + + + Home + + + + +

Press F or J

+

 

+ +
+
+
+
+
+
+ + \ No newline at end of file diff --git a/random.js b/random.js new file mode 100644 index 0000000..5e92391 --- /dev/null +++ b/random.js @@ -0,0 +1,48 @@ +export function random(min, max) { + // Return a random number in the range [min, max) + return Math.random() * (max - min) + min; +} + +export function randint(min, max, high = false) { + // Return a random integer in the range [min, max) or [min, max] + return Math.floor(Math.random() * (max - min + high)) + min; +} + +export function randn(mu = 0, sigma = 1, n = 1) { + const list = []; + function generatePair() { + let u = 0, v = 0; + while (u === 0) u = Math.random(); + while (v === 0) v = Math.random(); + list.push(mu + sigma * Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v)); + list.push(mu + sigma * Math.sqrt(-2.0 * Math.log(u)) * Math.sin(2.0 * Math.PI * v)); + } + + while (list.length < n) { + generatePair(); + } + return n === 1 ? list[0] : list.slice(0, n); +} + +export function choice(list) { + return list[randint(0, list.length)]; +} + +export function shuffle(list) { + list = list.slice(); + for (let i = 0; i < list.length; i++) { + const j = randint(i, list.length); + [list[i], list[j]] = [list[j], list[i]]; + } + return list; +} + +export const Random = { + random, + randint, + randn, + choice, + shuffle +}; + +export default Random; \ No newline at end of file diff --git a/utils.js b/utils.js new file mode 100644 index 0000000..83f5b91 --- /dev/null +++ b/utils.js @@ -0,0 +1,39 @@ +export function range(...args) { + let start, end, step; + if (args.length === 1) { + start = 0; + end = args[0]; + step = 1; + } else if (args.length === 2) { + [start, end] = args; + step = 1; + } else { + [start, end, step] = args; + } + + // Return a list of numbers in the range [start, end) + const list = []; + for (let i = start; i < end; i += step) { + list.push(i); + } + return list; +} + +export function argsort(list) { + return list.map((_, i) => i).sort((a, b) => list[a] - list[b]); +} + +export function offerFile(filename, content, type = 'text/plain') { + const a = document.createElement('a'); + a.href = URL.createObjectURL(new Blob([content], { type: type })); + a.download = filename; + a.click(); +} + +const Utils = { + range, + offerFile, + argsort +}; + +export default Utils; \ No newline at end of file