Init
This commit is contained in:
commit
0fd1073341
115
combandit.js
Normal file
115
combandit.js
Normal file
@ -0,0 +1,115 @@
|
||||
import { randn, shuffle, choice } from "./random.js";
|
||||
import Utils from "./utils.js";
|
||||
|
||||
export class StaticCombandit {
|
||||
constructor({
|
||||
mu = [35, 15, -5, -25],
|
||||
sigma = [15, 15, 15, 15]
|
||||
}) {
|
||||
if (mu.length !== sigma.length) {
|
||||
throw new Error('mu and sigma must have the same length');
|
||||
} else {
|
||||
const orders = Utils.argsort(mu);
|
||||
this.mu = orders.map(i => mu[i]);
|
||||
this.sigma = orders.map(i => sigma[i]);
|
||||
}
|
||||
}
|
||||
|
||||
reward(state, action) {
|
||||
// Action suppposed to be a Boolean array,
|
||||
// Where action[i] = true means arm i is pulled,
|
||||
// And action[i] = false means arm i is not pulled.
|
||||
let reward = 0;
|
||||
action.forEach((a, i) => {
|
||||
if (a) {
|
||||
reward += randn(state.mu[i], state.sigma[i]);
|
||||
}
|
||||
});
|
||||
return reward;
|
||||
}
|
||||
|
||||
transition(state, action) {
|
||||
return structuredClone(state);
|
||||
}
|
||||
|
||||
step(state, action) {
|
||||
return { reward: this.reward(state, action), nextState: this.transition(state, action) };
|
||||
}
|
||||
|
||||
getInitialState() {
|
||||
const orders = shuffle(Utils.range(this.nArms));
|
||||
return {
|
||||
mu: this.mu.map((_, i) => this.mu[orders[i]]),
|
||||
sigma: this.sigma.map((_, i) => this.sigma[orders[i]])
|
||||
}
|
||||
}
|
||||
|
||||
get nArms() {
|
||||
return this.mu.length;
|
||||
}
|
||||
}
|
||||
|
||||
export class ExpCombandit extends StaticCombandit {
|
||||
constructor({
|
||||
mu = [35, 15, -5, -25],
|
||||
sigma = [15, 15, 15, 15],
|
||||
transitProb = 0.05,
|
||||
maxDCReward = 50,
|
||||
minDCReward = 10,
|
||||
stepDCReward = 10,
|
||||
recoverRunLength = 3
|
||||
}) {
|
||||
super({ mu, sigma });
|
||||
if (mu.length !== 4) {
|
||||
throw new Error('mu and sigma must have length 4');
|
||||
}
|
||||
this.swapTarget = [[this.mu[0], this.mu[3]],
|
||||
[this.mu[1], this.mu[2]]];
|
||||
|
||||
this.transitProb = transitProb;
|
||||
this.maxDCReward = maxDCReward;
|
||||
this.minDCReward = minDCReward;
|
||||
this.stepDCReward = stepDCReward;
|
||||
this.recoverRunLength = recoverRunLength;
|
||||
}
|
||||
|
||||
reward(state, action) {
|
||||
return action.some(a => a) ? super.reward(state, action) : state.dcReward;
|
||||
}
|
||||
|
||||
transition(state, action) {
|
||||
const nextState = structuredClone(state);
|
||||
|
||||
// Deal with default choice
|
||||
if (action.some(a => a)) {
|
||||
nextState.runLength++;
|
||||
if (nextState.runLength >= this.recoverRunLength) {
|
||||
nextState.dcReward = Math.min(nextState.dcReward + this.stepDCReward, this.maxDCReward);
|
||||
nextState.runLength = 0;
|
||||
}
|
||||
} else {
|
||||
nextState.runLength = 0;
|
||||
nextState.dcReward = Math.max(nextState.dcReward - this.stepDCReward, this.minDCReward);
|
||||
}
|
||||
|
||||
// Deal with distribution change
|
||||
if (Math.random() <= this.transitProb) {
|
||||
const swapTarget = choice(this.swapTarget);
|
||||
const swapIdx = [nextState.mu.indexOf(swapTarget[0]), nextState.mu.indexOf(swapTarget[1])];
|
||||
nextState.mu[swapIdx[0]] = swapTarget[1];
|
||||
nextState.mu[swapIdx[1]] = swapTarget[0];
|
||||
nextState.sigma[swapIdx[0]] = this.sigma[swapIdx[1]];
|
||||
nextState.sigma[swapIdx[1]] = this.sigma[swapIdx[0]];
|
||||
}
|
||||
|
||||
return nextState;
|
||||
}
|
||||
|
||||
getInitialState() {
|
||||
return {
|
||||
...super.getInitialState(),
|
||||
dcReward: this.maxDCReward,
|
||||
runLength: 0
|
||||
};
|
||||
}
|
||||
}
|
133
index.html
Normal file
133
index.html
Normal file
@ -0,0 +1,133 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Home</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
padding: 0;
|
||||
width: 500px;
|
||||
margin: 50px auto;
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
h1, h2 {
|
||||
text-align: center;
|
||||
}
|
||||
.wrapper {
|
||||
margin: 0 auto;
|
||||
width: 320px;
|
||||
height: 320px;
|
||||
display: grid;
|
||||
grid-template-columns: repeat(2, 1fr);
|
||||
grid-gap: 10px;
|
||||
grid-auto-rows: minmax(100px, auto);
|
||||
}
|
||||
.gridcell, .gridcell-highlighted {
|
||||
background-color: #fafafa;
|
||||
border: 1px solid #ddd;
|
||||
}
|
||||
.gridcell-selected {
|
||||
background-color: #d4380d88;
|
||||
}
|
||||
.gridcell-unselected {
|
||||
background-color: #44444422;
|
||||
}
|
||||
.indicator {
|
||||
margin-left: 50%;
|
||||
margin-top: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
width: 15%;
|
||||
height: 15%;
|
||||
}
|
||||
.gridcell-highlighted .indicator {
|
||||
border: 2px solid #d4380d;
|
||||
border-radius: 50%;
|
||||
}
|
||||
</style>
|
||||
<script type="module">
|
||||
import { shuffle, randint } from "./random.js";
|
||||
import { ExpCombandit } from "./combandit.js";
|
||||
import Utils from "./utils.js";
|
||||
|
||||
const combandit = new ExpCombandit({});
|
||||
let combanditState = combandit.getInitialState();
|
||||
|
||||
const handler = {
|
||||
getCell(idx) {
|
||||
return document.querySelector(`#gridcell-${idx}`);
|
||||
},
|
||||
|
||||
getCellList() {
|
||||
const query = '.gridcell, .gridcell-highlighted, .gridcell-selected, .gridcell-unselected'
|
||||
return Array.from(document.querySelectorAll(query));
|
||||
},
|
||||
|
||||
init() {
|
||||
this.orders = shuffle(Utils.range(4));
|
||||
this.currentIdx = 0;
|
||||
this.action = Array(4).fill(false);
|
||||
|
||||
document.querySelector('#prompt').innerHTML = ' ';
|
||||
this.getCellList().forEach(cell => cell.className = 'gridcell');
|
||||
this.highlightCurrent();
|
||||
|
||||
this.completed = false;
|
||||
},
|
||||
|
||||
highlightCurrent() {
|
||||
const { orders, currentIdx } = this;
|
||||
const cell = this.getCell(orders[currentIdx]);
|
||||
cell.className = 'gridcell-highlighted';
|
||||
},
|
||||
|
||||
onChoice(choice) {
|
||||
const { orders, currentIdx, completed } = this;
|
||||
const cell = this.getCell(orders[currentIdx]);
|
||||
|
||||
this.action[orders[currentIdx]] = choice;
|
||||
cell.className = choice ? 'gridcell-selected' : 'gridcell-unselected';
|
||||
if (currentIdx < orders.length - 1) {
|
||||
this.currentIdx++;
|
||||
const nextCell = this.getCell(orders[this.currentIdx]);
|
||||
nextCell.className = 'gridcell-highlighted';
|
||||
} else if (currentIdx === orders.length - 1) {
|
||||
let reward;
|
||||
this.completed = true;
|
||||
({ reward, nextState: combanditState } = combandit.step(combanditState, this.action));
|
||||
|
||||
const rewardText = reward.toFixed(0).replace('-', '−');
|
||||
document.querySelector('#prompt').innerHTML = `You got ${rewardText} points!`;
|
||||
setTimeout(() => this.init(), 2000);
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const keysHeld = new Set();
|
||||
document.addEventListener('DOMContentLoaded', () => handler.init());
|
||||
document.addEventListener('keydown', function(event) {
|
||||
const keys = ['f', 'j', 'F', 'J'];
|
||||
if (keys.includes(event.key) && !keysHeld.has(event.key) && !handler.completed) {
|
||||
keysHeld.add(event.key);
|
||||
handler.onChoice(event.key.toLowerCase() === 'f');
|
||||
}
|
||||
});
|
||||
document.addEventListener('keyup', function(event) {
|
||||
if (event.key.match(/[a-z]/i)) {
|
||||
keysHeld.delete(event.key);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Press <span style="color: #d4380d;">F</span> or <span style="color: #444444;">J</span></h1>
|
||||
<h2 id="prompt"> </h2>
|
||||
|
||||
<div class="wrapper" id="grid">
|
||||
<div class="gridcell" id="gridcell-0"><div class="indicator"></div></div>
|
||||
<div class="gridcell" id="gridcell-1"><div class="indicator"></div></div>
|
||||
<div class="gridcell" id="gridcell-2"><div class="indicator"></div></div>
|
||||
<div class="gridcell" id="gridcell-3"><div class="indicator"></div></div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
48
random.js
Normal file
48
random.js
Normal file
@ -0,0 +1,48 @@
|
||||
export function random(min, max) {
|
||||
// Return a random number in the range [min, max)
|
||||
return Math.random() * (max - min) + min;
|
||||
}
|
||||
|
||||
export function randint(min, max, high = false) {
|
||||
// Return a random integer in the range [min, max) or [min, max]
|
||||
return Math.floor(Math.random() * (max - min + high)) + min;
|
||||
}
|
||||
|
||||
export function randn(mu = 0, sigma = 1, n = 1) {
|
||||
const list = [];
|
||||
function generatePair() {
|
||||
let u = 0, v = 0;
|
||||
while (u === 0) u = Math.random();
|
||||
while (v === 0) v = Math.random();
|
||||
list.push(mu + sigma * Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v));
|
||||
list.push(mu + sigma * Math.sqrt(-2.0 * Math.log(u)) * Math.sin(2.0 * Math.PI * v));
|
||||
}
|
||||
|
||||
while (list.length < n) {
|
||||
generatePair();
|
||||
}
|
||||
return n === 1 ? list[0] : list.slice(0, n);
|
||||
}
|
||||
|
||||
export function choice(list) {
|
||||
return list[randint(0, list.length)];
|
||||
}
|
||||
|
||||
export function shuffle(list) {
|
||||
list = list.slice();
|
||||
for (let i = 0; i < list.length; i++) {
|
||||
const j = randint(i, list.length);
|
||||
[list[i], list[j]] = [list[j], list[i]];
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
export const Random = {
|
||||
random,
|
||||
randint,
|
||||
randn,
|
||||
choice,
|
||||
shuffle
|
||||
};
|
||||
|
||||
export default Random;
|
39
utils.js
Normal file
39
utils.js
Normal file
@ -0,0 +1,39 @@
|
||||
export function range(...args) {
|
||||
let start, end, step;
|
||||
if (args.length === 1) {
|
||||
start = 0;
|
||||
end = args[0];
|
||||
step = 1;
|
||||
} else if (args.length === 2) {
|
||||
[start, end] = args;
|
||||
step = 1;
|
||||
} else {
|
||||
[start, end, step] = args;
|
||||
}
|
||||
|
||||
// Return a list of numbers in the range [start, end)
|
||||
const list = [];
|
||||
for (let i = start; i < end; i += step) {
|
||||
list.push(i);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
export function argsort(list) {
|
||||
return list.map((_, i) => i).sort((a, b) => list[a] - list[b]);
|
||||
}
|
||||
|
||||
export function offerFile(filename, content, type = 'text/plain') {
|
||||
const a = document.createElement('a');
|
||||
a.href = URL.createObjectURL(new Blob([content], { type: type }));
|
||||
a.download = filename;
|
||||
a.click();
|
||||
}
|
||||
|
||||
const Utils = {
|
||||
range,
|
||||
offerFile,
|
||||
argsort
|
||||
};
|
||||
|
||||
export default Utils;
|
Loading…
Reference in New Issue
Block a user