Init
This commit is contained in:
commit
0fd1073341
115
combandit.js
Normal file
115
combandit.js
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
import { randn, shuffle, choice } from "./random.js";
|
||||||
|
import Utils from "./utils.js";
|
||||||
|
|
||||||
|
export class StaticCombandit {
|
||||||
|
constructor({
|
||||||
|
mu = [35, 15, -5, -25],
|
||||||
|
sigma = [15, 15, 15, 15]
|
||||||
|
}) {
|
||||||
|
if (mu.length !== sigma.length) {
|
||||||
|
throw new Error('mu and sigma must have the same length');
|
||||||
|
} else {
|
||||||
|
const orders = Utils.argsort(mu);
|
||||||
|
this.mu = orders.map(i => mu[i]);
|
||||||
|
this.sigma = orders.map(i => sigma[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reward(state, action) {
|
||||||
|
// Action suppposed to be a Boolean array,
|
||||||
|
// Where action[i] = true means arm i is pulled,
|
||||||
|
// And action[i] = false means arm i is not pulled.
|
||||||
|
let reward = 0;
|
||||||
|
action.forEach((a, i) => {
|
||||||
|
if (a) {
|
||||||
|
reward += randn(state.mu[i], state.sigma[i]);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return reward;
|
||||||
|
}
|
||||||
|
|
||||||
|
transition(state, action) {
|
||||||
|
return structuredClone(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
step(state, action) {
|
||||||
|
return { reward: this.reward(state, action), nextState: this.transition(state, action) };
|
||||||
|
}
|
||||||
|
|
||||||
|
getInitialState() {
|
||||||
|
const orders = shuffle(Utils.range(this.nArms));
|
||||||
|
return {
|
||||||
|
mu: this.mu.map((_, i) => this.mu[orders[i]]),
|
||||||
|
sigma: this.sigma.map((_, i) => this.sigma[orders[i]])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
get nArms() {
|
||||||
|
return this.mu.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ExpCombandit extends StaticCombandit {
|
||||||
|
constructor({
|
||||||
|
mu = [35, 15, -5, -25],
|
||||||
|
sigma = [15, 15, 15, 15],
|
||||||
|
transitProb = 0.05,
|
||||||
|
maxDCReward = 50,
|
||||||
|
minDCReward = 10,
|
||||||
|
stepDCReward = 10,
|
||||||
|
recoverRunLength = 3
|
||||||
|
}) {
|
||||||
|
super({ mu, sigma });
|
||||||
|
if (mu.length !== 4) {
|
||||||
|
throw new Error('mu and sigma must have length 4');
|
||||||
|
}
|
||||||
|
this.swapTarget = [[this.mu[0], this.mu[3]],
|
||||||
|
[this.mu[1], this.mu[2]]];
|
||||||
|
|
||||||
|
this.transitProb = transitProb;
|
||||||
|
this.maxDCReward = maxDCReward;
|
||||||
|
this.minDCReward = minDCReward;
|
||||||
|
this.stepDCReward = stepDCReward;
|
||||||
|
this.recoverRunLength = recoverRunLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
reward(state, action) {
|
||||||
|
return action.some(a => a) ? super.reward(state, action) : state.dcReward;
|
||||||
|
}
|
||||||
|
|
||||||
|
transition(state, action) {
|
||||||
|
const nextState = structuredClone(state);
|
||||||
|
|
||||||
|
// Deal with default choice
|
||||||
|
if (action.some(a => a)) {
|
||||||
|
nextState.runLength++;
|
||||||
|
if (nextState.runLength >= this.recoverRunLength) {
|
||||||
|
nextState.dcReward = Math.min(nextState.dcReward + this.stepDCReward, this.maxDCReward);
|
||||||
|
nextState.runLength = 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nextState.runLength = 0;
|
||||||
|
nextState.dcReward = Math.max(nextState.dcReward - this.stepDCReward, this.minDCReward);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deal with distribution change
|
||||||
|
if (Math.random() <= this.transitProb) {
|
||||||
|
const swapTarget = choice(this.swapTarget);
|
||||||
|
const swapIdx = [nextState.mu.indexOf(swapTarget[0]), nextState.mu.indexOf(swapTarget[1])];
|
||||||
|
nextState.mu[swapIdx[0]] = swapTarget[1];
|
||||||
|
nextState.mu[swapIdx[1]] = swapTarget[0];
|
||||||
|
nextState.sigma[swapIdx[0]] = this.sigma[swapIdx[1]];
|
||||||
|
nextState.sigma[swapIdx[1]] = this.sigma[swapIdx[0]];
|
||||||
|
}
|
||||||
|
|
||||||
|
return nextState;
|
||||||
|
}
|
||||||
|
|
||||||
|
getInitialState() {
|
||||||
|
return {
|
||||||
|
...super.getInitialState(),
|
||||||
|
dcReward: this.maxDCReward,
|
||||||
|
runLength: 0
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
133
index.html
Normal file
133
index.html
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<title>Home</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
padding: 0;
|
||||||
|
width: 500px;
|
||||||
|
margin: 50px auto;
|
||||||
|
background-color: #f0f0f0;
|
||||||
|
}
|
||||||
|
h1, h2 {
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.wrapper {
|
||||||
|
margin: 0 auto;
|
||||||
|
width: 320px;
|
||||||
|
height: 320px;
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(2, 1fr);
|
||||||
|
grid-gap: 10px;
|
||||||
|
grid-auto-rows: minmax(100px, auto);
|
||||||
|
}
|
||||||
|
.gridcell, .gridcell-highlighted {
|
||||||
|
background-color: #fafafa;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
}
|
||||||
|
.gridcell-selected {
|
||||||
|
background-color: #d4380d88;
|
||||||
|
}
|
||||||
|
.gridcell-unselected {
|
||||||
|
background-color: #44444422;
|
||||||
|
}
|
||||||
|
.indicator {
|
||||||
|
margin-left: 50%;
|
||||||
|
margin-top: 50%;
|
||||||
|
transform: translate(-50%, -50%);
|
||||||
|
width: 15%;
|
||||||
|
height: 15%;
|
||||||
|
}
|
||||||
|
.gridcell-highlighted .indicator {
|
||||||
|
border: 2px solid #d4380d;
|
||||||
|
border-radius: 50%;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
<script type="module">
|
||||||
|
import { shuffle, randint } from "./random.js";
|
||||||
|
import { ExpCombandit } from "./combandit.js";
|
||||||
|
import Utils from "./utils.js";
|
||||||
|
|
||||||
|
const combandit = new ExpCombandit({});
|
||||||
|
let combanditState = combandit.getInitialState();
|
||||||
|
|
||||||
|
const handler = {
|
||||||
|
getCell(idx) {
|
||||||
|
return document.querySelector(`#gridcell-${idx}`);
|
||||||
|
},
|
||||||
|
|
||||||
|
getCellList() {
|
||||||
|
const query = '.gridcell, .gridcell-highlighted, .gridcell-selected, .gridcell-unselected'
|
||||||
|
return Array.from(document.querySelectorAll(query));
|
||||||
|
},
|
||||||
|
|
||||||
|
init() {
|
||||||
|
this.orders = shuffle(Utils.range(4));
|
||||||
|
this.currentIdx = 0;
|
||||||
|
this.action = Array(4).fill(false);
|
||||||
|
|
||||||
|
document.querySelector('#prompt').innerHTML = ' ';
|
||||||
|
this.getCellList().forEach(cell => cell.className = 'gridcell');
|
||||||
|
this.highlightCurrent();
|
||||||
|
|
||||||
|
this.completed = false;
|
||||||
|
},
|
||||||
|
|
||||||
|
highlightCurrent() {
|
||||||
|
const { orders, currentIdx } = this;
|
||||||
|
const cell = this.getCell(orders[currentIdx]);
|
||||||
|
cell.className = 'gridcell-highlighted';
|
||||||
|
},
|
||||||
|
|
||||||
|
onChoice(choice) {
|
||||||
|
const { orders, currentIdx, completed } = this;
|
||||||
|
const cell = this.getCell(orders[currentIdx]);
|
||||||
|
|
||||||
|
this.action[orders[currentIdx]] = choice;
|
||||||
|
cell.className = choice ? 'gridcell-selected' : 'gridcell-unselected';
|
||||||
|
if (currentIdx < orders.length - 1) {
|
||||||
|
this.currentIdx++;
|
||||||
|
const nextCell = this.getCell(orders[this.currentIdx]);
|
||||||
|
nextCell.className = 'gridcell-highlighted';
|
||||||
|
} else if (currentIdx === orders.length - 1) {
|
||||||
|
let reward;
|
||||||
|
this.completed = true;
|
||||||
|
({ reward, nextState: combanditState } = combandit.step(combanditState, this.action));
|
||||||
|
|
||||||
|
const rewardText = reward.toFixed(0).replace('-', '−');
|
||||||
|
document.querySelector('#prompt').innerHTML = `You got ${rewardText} points!`;
|
||||||
|
setTimeout(() => this.init(), 2000);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
const keysHeld = new Set();
|
||||||
|
document.addEventListener('DOMContentLoaded', () => handler.init());
|
||||||
|
document.addEventListener('keydown', function(event) {
|
||||||
|
const keys = ['f', 'j', 'F', 'J'];
|
||||||
|
if (keys.includes(event.key) && !keysHeld.has(event.key) && !handler.completed) {
|
||||||
|
keysHeld.add(event.key);
|
||||||
|
handler.onChoice(event.key.toLowerCase() === 'f');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
document.addEventListener('keyup', function(event) {
|
||||||
|
if (event.key.match(/[a-z]/i)) {
|
||||||
|
keysHeld.delete(event.key);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Press <span style="color: #d4380d;">F</span> or <span style="color: #444444;">J</span></h1>
|
||||||
|
<h2 id="prompt"> </h2>
|
||||||
|
|
||||||
|
<div class="wrapper" id="grid">
|
||||||
|
<div class="gridcell" id="gridcell-0"><div class="indicator"></div></div>
|
||||||
|
<div class="gridcell" id="gridcell-1"><div class="indicator"></div></div>
|
||||||
|
<div class="gridcell" id="gridcell-2"><div class="indicator"></div></div>
|
||||||
|
<div class="gridcell" id="gridcell-3"><div class="indicator"></div></div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
48
random.js
Normal file
48
random.js
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
export function random(min, max) {
|
||||||
|
// Return a random number in the range [min, max)
|
||||||
|
return Math.random() * (max - min) + min;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function randint(min, max, high = false) {
|
||||||
|
// Return a random integer in the range [min, max) or [min, max]
|
||||||
|
return Math.floor(Math.random() * (max - min + high)) + min;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function randn(mu = 0, sigma = 1, n = 1) {
|
||||||
|
const list = [];
|
||||||
|
function generatePair() {
|
||||||
|
let u = 0, v = 0;
|
||||||
|
while (u === 0) u = Math.random();
|
||||||
|
while (v === 0) v = Math.random();
|
||||||
|
list.push(mu + sigma * Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v));
|
||||||
|
list.push(mu + sigma * Math.sqrt(-2.0 * Math.log(u)) * Math.sin(2.0 * Math.PI * v));
|
||||||
|
}
|
||||||
|
|
||||||
|
while (list.length < n) {
|
||||||
|
generatePair();
|
||||||
|
}
|
||||||
|
return n === 1 ? list[0] : list.slice(0, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function choice(list) {
|
||||||
|
return list[randint(0, list.length)];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function shuffle(list) {
|
||||||
|
list = list.slice();
|
||||||
|
for (let i = 0; i < list.length; i++) {
|
||||||
|
const j = randint(i, list.length);
|
||||||
|
[list[i], list[j]] = [list[j], list[i]];
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const Random = {
|
||||||
|
random,
|
||||||
|
randint,
|
||||||
|
randn,
|
||||||
|
choice,
|
||||||
|
shuffle
|
||||||
|
};
|
||||||
|
|
||||||
|
export default Random;
|
39
utils.js
Normal file
39
utils.js
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
export function range(...args) {
|
||||||
|
let start, end, step;
|
||||||
|
if (args.length === 1) {
|
||||||
|
start = 0;
|
||||||
|
end = args[0];
|
||||||
|
step = 1;
|
||||||
|
} else if (args.length === 2) {
|
||||||
|
[start, end] = args;
|
||||||
|
step = 1;
|
||||||
|
} else {
|
||||||
|
[start, end, step] = args;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a list of numbers in the range [start, end)
|
||||||
|
const list = [];
|
||||||
|
for (let i = start; i < end; i += step) {
|
||||||
|
list.push(i);
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function argsort(list) {
|
||||||
|
return list.map((_, i) => i).sort((a, b) => list[a] - list[b]);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function offerFile(filename, content, type = 'text/plain') {
|
||||||
|
const a = document.createElement('a');
|
||||||
|
a.href = URL.createObjectURL(new Blob([content], { type: type }));
|
||||||
|
a.download = filename;
|
||||||
|
a.click();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Utils = {
|
||||||
|
range,
|
||||||
|
offerFile,
|
||||||
|
argsort
|
||||||
|
};
|
||||||
|
|
||||||
|
export default Utils;
|
Loading…
Reference in New Issue
Block a user