This commit is contained in:
HoshinoKoji 2024-07-28 23:36:17 +08:00
commit 0fd1073341
4 changed files with 335 additions and 0 deletions

115
combandit.js Normal file
View File

@ -0,0 +1,115 @@
import { randn, shuffle, choice } from "./random.js";
import Utils from "./utils.js";
export class StaticCombandit {
constructor({
mu = [35, 15, -5, -25],
sigma = [15, 15, 15, 15]
}) {
if (mu.length !== sigma.length) {
throw new Error('mu and sigma must have the same length');
} else {
const orders = Utils.argsort(mu);
this.mu = orders.map(i => mu[i]);
this.sigma = orders.map(i => sigma[i]);
}
}
reward(state, action) {
// Action suppposed to be a Boolean array,
// Where action[i] = true means arm i is pulled,
// And action[i] = false means arm i is not pulled.
let reward = 0;
action.forEach((a, i) => {
if (a) {
reward += randn(state.mu[i], state.sigma[i]);
}
});
return reward;
}
transition(state, action) {
return structuredClone(state);
}
step(state, action) {
return { reward: this.reward(state, action), nextState: this.transition(state, action) };
}
getInitialState() {
const orders = shuffle(Utils.range(this.nArms));
return {
mu: this.mu.map((_, i) => this.mu[orders[i]]),
sigma: this.sigma.map((_, i) => this.sigma[orders[i]])
}
}
get nArms() {
return this.mu.length;
}
}
export class ExpCombandit extends StaticCombandit {
constructor({
mu = [35, 15, -5, -25],
sigma = [15, 15, 15, 15],
transitProb = 0.05,
maxDCReward = 50,
minDCReward = 10,
stepDCReward = 10,
recoverRunLength = 3
}) {
super({ mu, sigma });
if (mu.length !== 4) {
throw new Error('mu and sigma must have length 4');
}
this.swapTarget = [[this.mu[0], this.mu[3]],
[this.mu[1], this.mu[2]]];
this.transitProb = transitProb;
this.maxDCReward = maxDCReward;
this.minDCReward = minDCReward;
this.stepDCReward = stepDCReward;
this.recoverRunLength = recoverRunLength;
}
reward(state, action) {
return action.some(a => a) ? super.reward(state, action) : state.dcReward;
}
transition(state, action) {
const nextState = structuredClone(state);
// Deal with default choice
if (action.some(a => a)) {
nextState.runLength++;
if (nextState.runLength >= this.recoverRunLength) {
nextState.dcReward = Math.min(nextState.dcReward + this.stepDCReward, this.maxDCReward);
nextState.runLength = 0;
}
} else {
nextState.runLength = 0;
nextState.dcReward = Math.max(nextState.dcReward - this.stepDCReward, this.minDCReward);
}
// Deal with distribution change
if (Math.random() <= this.transitProb) {
const swapTarget = choice(this.swapTarget);
const swapIdx = [nextState.mu.indexOf(swapTarget[0]), nextState.mu.indexOf(swapTarget[1])];
nextState.mu[swapIdx[0]] = swapTarget[1];
nextState.mu[swapIdx[1]] = swapTarget[0];
nextState.sigma[swapIdx[0]] = this.sigma[swapIdx[1]];
nextState.sigma[swapIdx[1]] = this.sigma[swapIdx[0]];
}
return nextState;
}
getInitialState() {
return {
...super.getInitialState(),
dcReward: this.maxDCReward,
runLength: 0
};
}
}

133
index.html Normal file
View File

@ -0,0 +1,133 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Home</title>
<style>
body {
font-family: Arial, sans-serif;
padding: 0;
width: 500px;
margin: 50px auto;
background-color: #f0f0f0;
}
h1, h2 {
text-align: center;
}
.wrapper {
margin: 0 auto;
width: 320px;
height: 320px;
display: grid;
grid-template-columns: repeat(2, 1fr);
grid-gap: 10px;
grid-auto-rows: minmax(100px, auto);
}
.gridcell, .gridcell-highlighted {
background-color: #fafafa;
border: 1px solid #ddd;
}
.gridcell-selected {
background-color: #d4380d88;
}
.gridcell-unselected {
background-color: #44444422;
}
.indicator {
margin-left: 50%;
margin-top: 50%;
transform: translate(-50%, -50%);
width: 15%;
height: 15%;
}
.gridcell-highlighted .indicator {
border: 2px solid #d4380d;
border-radius: 50%;
}
</style>
<script type="module">
import { shuffle, randint } from "./random.js";
import { ExpCombandit } from "./combandit.js";
import Utils from "./utils.js";
const combandit = new ExpCombandit({});
let combanditState = combandit.getInitialState();
const handler = {
getCell(idx) {
return document.querySelector(`#gridcell-${idx}`);
},
getCellList() {
const query = '.gridcell, .gridcell-highlighted, .gridcell-selected, .gridcell-unselected'
return Array.from(document.querySelectorAll(query));
},
init() {
this.orders = shuffle(Utils.range(4));
this.currentIdx = 0;
this.action = Array(4).fill(false);
document.querySelector('#prompt').innerHTML = '&nbsp;';
this.getCellList().forEach(cell => cell.className = 'gridcell');
this.highlightCurrent();
this.completed = false;
},
highlightCurrent() {
const { orders, currentIdx } = this;
const cell = this.getCell(orders[currentIdx]);
cell.className = 'gridcell-highlighted';
},
onChoice(choice) {
const { orders, currentIdx, completed } = this;
const cell = this.getCell(orders[currentIdx]);
this.action[orders[currentIdx]] = choice;
cell.className = choice ? 'gridcell-selected' : 'gridcell-unselected';
if (currentIdx < orders.length - 1) {
this.currentIdx++;
const nextCell = this.getCell(orders[this.currentIdx]);
nextCell.className = 'gridcell-highlighted';
} else if (currentIdx === orders.length - 1) {
let reward;
this.completed = true;
({ reward, nextState: combanditState } = combandit.step(combanditState, this.action));
const rewardText = reward.toFixed(0).replace('-', '&minus;');
document.querySelector('#prompt').innerHTML = `You got ${rewardText} points!`;
setTimeout(() => this.init(), 2000);
}
},
}
const keysHeld = new Set();
document.addEventListener('DOMContentLoaded', () => handler.init());
document.addEventListener('keydown', function(event) {
const keys = ['f', 'j', 'F', 'J'];
if (keys.includes(event.key) && !keysHeld.has(event.key) && !handler.completed) {
keysHeld.add(event.key);
handler.onChoice(event.key.toLowerCase() === 'f');
}
});
document.addEventListener('keyup', function(event) {
if (event.key.match(/[a-z]/i)) {
keysHeld.delete(event.key);
}
});
</script>
</head>
<body>
<h1>Press <span style="color: #d4380d;">F</span> or <span style="color: #444444;">J</span></h1>
<h2 id="prompt">&nbsp;</h2>
<div class="wrapper" id="grid">
<div class="gridcell" id="gridcell-0"><div class="indicator"></div></div>
<div class="gridcell" id="gridcell-1"><div class="indicator"></div></div>
<div class="gridcell" id="gridcell-2"><div class="indicator"></div></div>
<div class="gridcell" id="gridcell-3"><div class="indicator"></div></div>
</div>
</body>
</html>

48
random.js Normal file
View File

@ -0,0 +1,48 @@
export function random(min, max) {
// Return a random number in the range [min, max)
return Math.random() * (max - min) + min;
}
export function randint(min, max, high = false) {
// Return a random integer in the range [min, max) or [min, max]
return Math.floor(Math.random() * (max - min + high)) + min;
}
export function randn(mu = 0, sigma = 1, n = 1) {
const list = [];
function generatePair() {
let u = 0, v = 0;
while (u === 0) u = Math.random();
while (v === 0) v = Math.random();
list.push(mu + sigma * Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v));
list.push(mu + sigma * Math.sqrt(-2.0 * Math.log(u)) * Math.sin(2.0 * Math.PI * v));
}
while (list.length < n) {
generatePair();
}
return n === 1 ? list[0] : list.slice(0, n);
}
export function choice(list) {
return list[randint(0, list.length)];
}
export function shuffle(list) {
list = list.slice();
for (let i = 0; i < list.length; i++) {
const j = randint(i, list.length);
[list[i], list[j]] = [list[j], list[i]];
}
return list;
}
export const Random = {
random,
randint,
randn,
choice,
shuffle
};
export default Random;

39
utils.js Normal file
View File

@ -0,0 +1,39 @@
export function range(...args) {
let start, end, step;
if (args.length === 1) {
start = 0;
end = args[0];
step = 1;
} else if (args.length === 2) {
[start, end] = args;
step = 1;
} else {
[start, end, step] = args;
}
// Return a list of numbers in the range [start, end)
const list = [];
for (let i = start; i < end; i += step) {
list.push(i);
}
return list;
}
export function argsort(list) {
return list.map((_, i) => i).sort((a, b) => list[a] - list[b]);
}
export function offerFile(filename, content, type = 'text/plain') {
const a = document.createElement('a');
a.href = URL.createObjectURL(new Blob([content], { type: type }));
a.download = filename;
a.click();
}
const Utils = {
range,
offerFile,
argsort
};
export default Utils;