1 module gbaid.gba.register;
2 
3 import gbaid.util;
4 
5 public enum Set {
6     ARM = 0,
7     THUMB = 1
8 }
9 
10 public enum Mode {
11     USER = 16,
12     FIQ = 17,
13     IRQ = 18,
14     SUPERVISOR = 19,
15     ABORT = 23,
16     UNDEFINED = 27,
17     SYSTEM = 31
18 }
19 
20 public enum Register {
21     R0 = 0,
22     R1 = 1,
23     R2 = 2,
24     R3 = 3,
25     R4 = 4,
26     R5 = 5,
27     R6 = 6,
28     R7 = 7,
29     R8 = 8,
30     R9 = 9,
31     R10 = 10,
32     R11 = 11,
33     R12 = 12,
34     SP = 13,
35     LR = 14,
36     PC = 15,
37 }
38 
39 public enum CPSRFlag {
40     N = 31,
41     Z = 30,
42     C = 29,
43     V = 28,
44     I = 7,
45     F = 6,
46     T = 5,
47 }
48 
49 public struct Registers {
50     private immutable size_t[REGISTER_LOOKUP_LENGTH] registerIndices = createRegisterLookupTable();
51     private int[REGISTER_COUNT] registers;
52     private int cpsrRegister;
53     private int[1 << MODE_BITS] spsrRegisters;
54     private bool modifiedPC = false;
55 
56     @property public Mode mode() {
57         return cast(Mode) (cpsrRegister & 0x1F);
58     }
59 
60     @property public Set instructionSet() {
61         return cast(Set) cpsrRegister.getBit(CPSRFlag.T);
62     }
63 
64     public int get(int register) {
65         return get(mode, register);
66     }
67 
68     public int get(Mode mode, int register) {
69         return registers[registerIndices[(mode & 0xF) << REGISTER_BITS | register]];
70     }
71 
72     public int getPC() {
73         return registers[Register.PC];
74     }
75 
76     public int getCPSR() {
77         return cpsrRegister;
78     }
79 
80     public int getSPSR() {
81         return getSPSR(mode);
82     }
83 
84     public int getSPSR(Mode mode) {
85         if (mode == Mode.SYSTEM || mode == Mode.USER) {
86             throw new Exception("The SPSR register does not exist in the system and user modes");
87         }
88         return spsrRegisters[mode & 0xF];
89     }
90 
91     public void set(int register, int value) {
92         set(mode, register, value);
93     }
94 
95     public void set(Mode mode, int register, int value) {
96         registers[registerIndices[(mode & 0xF) << REGISTER_BITS | register]] = value;
97         if (register == Register.PC) {
98             modifiedPC = true;
99         }
100     }
101 
102     public void setPC(int value) {
103         registers[Register.PC] = value;
104         modifiedPC = true;
105     }
106 
107     public void setCPSR(int value) {
108         cpsrRegister = value;
109     }
110 
111     public void setSPSR(int value) {
112         setSPSR(mode, value);
113     }
114 
115     public void setSPSR(Mode mode, int value) {
116         if (mode == Mode.SYSTEM || mode == Mode.USER) {
117             throw new Exception("The SPSR register does not exist in the system and user modes");
118         }
119         spsrRegisters[mode & 0xF] = value;
120     }
121 
122     public int getFlag(CPSRFlag flag) {
123         return cpsrRegister.getBit(flag);
124     }
125 
126     public void setFlag(CPSRFlag flag, int b) {
127         cpsrRegister.setBit(flag, b);
128     }
129 
130     public void setApsrFlags(int n, int z) {
131         auto newFlags = (n << 3) & 0b1000 | (z << 2) & 0b0100;
132         cpsrRegister = cpsrRegister & 0x3FFFFFFF | (newFlags << 28);
133     }
134 
135     public void setApsrFlags(int n, int z, int c) {
136         auto newFlags = (n << 3) & 0b1000 | (z << 2) & 0b0100 | (c << 1) & 0b0010;
137         cpsrRegister = cpsrRegister & 0x1FFFFFFF | (newFlags << 28);
138     }
139 
140     public void setApsrFlags(int n, int z, int c, int v) {
141         auto newFlags = (n << 3) & 0b1000 | (z << 2) & 0b0100 | (c << 1) & 0b0010 | v & 0b0001;
142         cpsrRegister = cpsrRegister & 0x0FFFFFFF | (newFlags << 28);
143     }
144 
145     public void setApsrFlagsPacked(int nzcv) {
146         cpsrRegister = cpsrRegister & 0x0FFFFFFF | (nzcv << 28);
147     }
148 
149     public void setMode(Mode mode) {
150         cpsrRegister.setBits(0, 4, mode);
151     }
152 
153     public void incrementPC() {
154         final switch (instructionSet) {
155             case Set.ARM:
156                 registers[Register.PC] = (registers[Register.PC] & ~3) + 4;
157                 break;
158             case Set.THUMB:
159                 registers[Register.PC] = (registers[Register.PC] & ~1) + 2;
160                 break;
161         }
162     }
163 
164     public int getExecutedPC() {
165         final switch (instructionSet) {
166             case Set.ARM:
167                 return registers[Register.PC] - 8;
168             case Set.THUMB:
169                 return registers[Register.PC] - 4;
170         }
171     }
172 
173     public bool wasPCModified() {
174         auto value = modifiedPC;
175         modifiedPC = false;
176         return value;
177     }
178 
179     public int applyShift(bool registerShift)(int shiftType, ubyte shift, int op, out int carry) {
180         final switch (shiftType) {
181             // LSL
182             case 0:
183                 static if (registerShift) {
184                     if (shift == 0) {
185                         carry = getFlag(CPSRFlag.C);
186                         return op;
187                     } else if (shift < 32) {
188                         carry = op.getBit(32 - shift);
189                         return op << shift;
190                     } else if (shift == 32) {
191                         carry = op & 0b1;
192                         return 0;
193                     } else {
194                         carry = 0;
195                         return 0;
196                     }
197                 } else {
198                     if (shift == 0) {
199                         carry = getFlag(CPSRFlag.C);
200                         return op;
201                     } else {
202                         carry = op.getBit(32 - shift);
203                         return op << shift;
204                     }
205                 }
206             // LSR
207             case 1:
208                 static if (registerShift) {
209                     if (shift == 0) {
210                         carry = getFlag(CPSRFlag.C);
211                         return op;
212                     } else if (shift < 32) {
213                         carry = op.getBit(shift - 1);
214                         return op >>> shift;
215                     } else if (shift == 32) {
216                         carry = op.getBit(31);
217                         return 0;
218                     } else {
219                         carry = 0;
220                         return 0;
221                     }
222                 } else {
223                     if (shift == 0) {
224                         carry = op.getBit(31);
225                         return 0;
226                     } else {
227                         carry = op.getBit(shift - 1);
228                         return op >>> shift;
229                     }
230                 }
231             // ASR
232             case 2:
233                 static if (registerShift) {
234                     if (shift == 0) {
235                         carry = getFlag(CPSRFlag.C);
236                         return op;
237                     } else if (shift < 32) {
238                         carry = op.getBit(shift - 1);
239                         return op >> shift;
240                     } else {
241                         carry = op.getBit(31);
242                         return carry ? 0xFFFFFFFF : 0;
243                     }
244                 } else {
245                     if (shift == 0) {
246                         carry = op.getBit(31);
247                         return carry ? 0xFFFFFFFF : 0;
248                     } else {
249                         carry = op.getBit(shift - 1);
250                         return op >> shift;
251                     }
252                 }
253             // ROR
254             case 3:
255                 static if (registerShift) {
256                     if (shift == 0) {
257                         carry = getFlag(CPSRFlag.C);
258                         return op;
259                     } else if (shift & 0b11111) {
260                         shift &= 0b11111;
261                         carry = op.getBit(shift - 1);
262                         return op.rotateRight(shift);
263                     } else {
264                         carry = op.getBit(31);
265                         return op;
266                     }
267                 } else {
268                     if (shift == 0) {
269                         // RRX
270                         carry = op & 0b1;
271                         return getFlag(CPSRFlag.C) << 31 | op >>> 1;
272                     } else {
273                         carry = op.getBit(shift - 1);
274                         return op.rotateRight(shift);
275                     }
276                 }
277         }
278     }
279 
280     public bool checkCondition(int condition) {
281         final switch (condition) {
282             case 0x0:
283                 // EQ
284                 return cpsrRegister.checkBit(CPSRFlag.Z);
285             case 0x1:
286                 // NE
287                 return !cpsrRegister.checkBit(CPSRFlag.Z);
288             case 0x2:
289                 // CS/HS
290                 return cpsrRegister.checkBit(CPSRFlag.C);
291             case 0x3:
292                 // CC/LO
293                 return !cpsrRegister.checkBit(CPSRFlag.C);
294             case 0x4:
295                 // MI
296                 return cpsrRegister.checkBit(CPSRFlag.N);
297             case 0x5:
298                 // PL
299                 return !cpsrRegister.checkBit(CPSRFlag.N);
300             case 0x6:
301                 // VS
302                 return cpsrRegister.checkBit(CPSRFlag.V);
303             case 0x7:
304                 // VC
305                 return !cpsrRegister.checkBit(CPSRFlag.V);
306             case 0x8:
307                 // HI
308                 return cpsrRegister.checkBit(CPSRFlag.C) && !cpsrRegister.checkBit(CPSRFlag.Z);
309             case 0x9:
310                 // LS
311                 return !cpsrRegister.checkBit(CPSRFlag.C) || cpsrRegister.checkBit(CPSRFlag.Z);
312             case 0xA:
313                 // GE
314                 return cpsrRegister.checkBit(CPSRFlag.N) == cpsrRegister.checkBit(CPSRFlag.V);
315             case 0xB:
316                 // LT
317                 return cpsrRegister.checkBit(CPSRFlag.N) != cpsrRegister.checkBit(CPSRFlag.V);
318             case 0xC:
319                 // GT
320                 return !cpsrRegister.checkBit(CPSRFlag.Z)
321                         && cpsrRegister.checkBit(CPSRFlag.N) == cpsrRegister.checkBit(CPSRFlag.V);
322             case 0xD:
323                 // LE
324                 return cpsrRegister.checkBit(CPSRFlag.Z)
325                     || cpsrRegister.checkBit(CPSRFlag.N) != cpsrRegister.checkBit(CPSRFlag.V);
326             case 0xE:
327                 // AL
328                 return true;
329             case 0xF:
330                 // NV
331                 return false;
332         }
333     }
334 
335     debug (outputInstructions) {
336         import std.stdio : writeln, writef, writefln;
337 
338         private enum size_t CPU_LOG_SIZE = 32;
339         private CpuState[CPU_LOG_SIZE] cpuLog;
340         private size_t logSize = 0;
341         private size_t index = 0;
342 
343         public void logInstruction(int code, string mnemonic) {
344             logInstruction(getExecutedPC(), code, mnemonic);
345         }
346 
347         public void logInstruction(int address, int code, string mnemonic) {
348             if (instructionSet == Set.THUMB) {
349                 code &= 0xFFFF;
350             }
351             cpuLog[index].mode = mode;
352             cpuLog[index].address = address;
353             cpuLog[index].code = code;
354             cpuLog[index].mnemonic = mnemonic;
355             cpuLog[index].set = instructionSet;
356             foreach (i; 0 .. 16) {
357                 cpuLog[index].registers[i] = get(i);
358             }
359             cpuLog[index].cpsrRegister = cpsrRegister;
360             if (mode != Mode.SYSTEM && mode != Mode.USER) {
361                 cpuLog[index].spsrRegister = getSPSR();
362             }
363             index = (index + 1) % CPU_LOG_SIZE;
364             if (logSize < CPU_LOG_SIZE) {
365                 logSize++;
366             }
367         }
368 
369         public void dumpInstructions() {
370             dumpInstructions(logSize);
371         }
372 
373         public void dumpInstructions(size_t amount) {
374             amount = amount > logSize ? logSize : amount;
375             auto start = (logSize < CPU_LOG_SIZE ? 0 : index) + logSize - amount;
376             if (amount > 1) {
377                 writefln("Dumping last %s instructions executed:", amount);
378             }
379             foreach (i; 0 .. amount) {
380                 cpuLog[(i + start) % CPU_LOG_SIZE].dump();
381             }
382         }
383 
384         private static struct CpuState {
385             private Mode mode;
386             private int address;
387             private int code;
388             private string mnemonic;
389             private Set set;
390             private int[16] registers;
391             private int cpsrRegister;
392             private int spsrRegister;
393 
394             private void dump() {
395                 writefln("%s", mode);
396                 // Dump register values
397                 foreach (i; 0 .. 4) {
398                     writef("%-4s", cast(Register) (i * 4));
399                     foreach (j; 0 .. 4) {
400                         writef(" %08X", registers[i * 4 + j]);
401                     }
402                     writeln();
403                 }
404                 writef("CPSR %08X", cpsrRegister);
405                 if (mode != Mode.SYSTEM && mode != Mode.USER) {
406                     writef(", SPSR %08X", spsrRegister);
407                 }
408                 writeln();
409                 // Dump instruction
410                 final switch (set) {
411                     case Set.ARM:
412                         writefln("%08X: %08X %s", address, code, mnemonic);
413                         break;
414                     case Set.THUMB:
415                         writefln("%08X: %04X     %s", address, code, mnemonic);
416                         break;
417                 }
418                 writeln();
419             }
420         }
421     }
422 }
423 
424 private enum REGISTER_COUNT = 31;
425 private enum REGISTER_BITS = 4;
426 private enum MODE_BITS = 4;
427 private enum REGISTER_LOOKUP_LENGTH = 1 << (MODE_BITS + REGISTER_BITS);
428 
429 private size_t[] createRegisterLookupTable() {
430     size_t[] table;
431     table.length = REGISTER_LOOKUP_LENGTH;
432     // For all modes: R0 - R15 = 0 - 15
433     void setIndex(int mode, int register, size_t i) {
434         table[(mode & 0xF) << REGISTER_BITS | register] = i;
435     }
436     size_t i = void;
437     foreach (mode; 0 .. 1 << MODE_BITS) {
438         i = 0;
439         foreach (register; 0 .. 1 << REGISTER_BITS) {
440             setIndex(mode, register, i++);
441         }
442     }
443     // Except: R8_fiq - R14_fiq
444     setIndex(Mode.FIQ, 8, i++);
445     setIndex(Mode.FIQ, 9, i++);
446     setIndex(Mode.FIQ, 10, i++);
447     setIndex(Mode.FIQ, 11, i++);
448     setIndex(Mode.FIQ, 12, i++);
449     setIndex(Mode.FIQ, 13, i++);
450     setIndex(Mode.FIQ, 14, i++);
451     // Except: R13_svc - R14_svc
452     setIndex(Mode.SUPERVISOR, 13, i++);
453     setIndex(Mode.SUPERVISOR, 14, i++);
454     // Except: R13_abt - R14_abt
455     setIndex(Mode.ABORT, 13, i++);
456     setIndex(Mode.ABORT, 14, i++);
457     // Except: R13_irq - R14_irq
458     setIndex(Mode.IRQ, 13, i++);
459     setIndex(Mode.IRQ, 14, i++);
460     // Except: R13_und - R14_und
461     setIndex(Mode.UNDEFINED, 13, i++);
462     setIndex(Mode.UNDEFINED, 14, i++);
463     return table;
464 }