1 module gbaid.gba.io;
2 
3 import std.meta : Alias;
4 import std.format : format;
5 import std.container.array : Array;
6 
7 import gbaid.util;
8 
9 public enum uint IO_REGISTERS_SIZE = 1 * BYTES_PER_KIB;
10 
11 public struct IoRegisters {
12     private Array!Register[IO_REGISTERS_SIZE / int.sizeof] registerSets;
13 
14     public auto mapAddress(T)(uint address, T valuePtr, int mask, int shift,
15             bool readable = true, bool writable = true) {
16         checkAddress(address);
17         // Get the integer aligned address
18         address >>>= 2;
19         // Verify that the new mapping doesn't overlap any existing one
20         foreach (register; registerSets[address]) {
21             if ((mask << shift) & (register.mask << register.shift)) {
22                 throw new Exception("Overlapping masks");
23             }
24         }
25         // Add the register to the address
26         registerSets[address] ~= Register(valuePtr, mask, shift, readable, writable);
27         // Return an anonymous builder for adding listeners to the mapping
28         struct Builder {
29             Register* register;
30 
31             Builder readMonitor(ReadMonitor monitor) {
32                 register.onRead = monitor;
33                 return this;
34             }
35 
36             Builder preWriteMonitor(PreWriteMonitor monitor) {
37                 register.onPreWrite = monitor;
38                 return this;
39             }
40 
41             Builder postWriteMonitor(PostWriteMonitor monitor) {
42                 register.onPostWrite = monitor;
43                 return this;
44             }
45         }
46         return Builder(&registerSets[address][$ - 1]);
47     }
48 
49     public void unmapAddress(uint address, int mask, int shift) {
50         checkAddress(address);
51         // Get the integer aligned address
52         address >>>= 2;
53         // Remove any mappings that are fully covered by the mask
54         auto removeMask = mask << shift;
55         auto registerSet = registerSets[address];
56         for (size_t i = 0; i < registerSet.length; i++) {
57             auto registerMask = registerSet[i].mask << registerSet[i].shift;
58             auto maskIntersection = removeMask & registerMask;
59             if (maskIntersection == 0) {
60                 continue;
61             }
62             if (maskIntersection != registerMask) {
63                 throw new Exception(format("Partial mask match when unmapping: %08x < %08x",
64                         maskIntersection, registerMask));
65             }
66             registerSet.linearRemove(registerSet[i .. i + 1]);
67             // Update the index for the shift
68             i -= 1;
69         }
70         // Writeback the updated register set
71         registerSets[address] = registerSet;
72     }
73 
74     private void checkAddress(uint address) {
75         if ((address & 0b11) != 0) {
76             throw new Exception(format("Address %08x is not 4 byte aligned", address));
77         }
78         if (address >= IO_REGISTERS_SIZE) {
79             throw new Exception(format("Address out of bounds: %08x >= %08x", address, IO_REGISTERS_SIZE));
80         }
81     }
82 
83     private alias lsb(T) = Alias!(((1 << IntSizeLog2!T) - 1) ^ 0b11);
84     private alias bits(T) = Alias!(cast(int) ((1L << T.sizeof * 8) - 1));
85 
86     public alias getUnMonitored(T) = get!(T, false);
87 
88     public T get(T, bool monitored = true)(uint address) if (IsInt8to32Type!T) {
89         auto shift = (address & lsb!T) << 3;
90         auto mask = bits!T << shift;
91         auto registers = registerSets[address >>> 2];
92         int readValue = 0;
93         foreach (register; registers) {
94             if (!register.readable) {
95                 continue;
96             }
97             auto modifiedMask = (mask >>> register.shift) & register.mask;
98             if (modifiedMask == 0) {
99                 continue;
100             }
101             auto value = register.value & modifiedMask;
102             static if (monitored) {
103                 if (register.onRead !is null) {
104                     register.onRead(modifiedMask, value);
105                 }
106             }
107             readValue |= (value & modifiedMask) << register.shift;
108         }
109         static if (is(T == uint) || is(T == int)) {
110             return readValue;
111         } else {
112             return cast(T) (readValue >>> shift);
113         }
114     }
115 
116     public alias setUnMonitored(T) = set!(T, false);
117 
118     public void set(T, bool monitored = true)(uint address, T value) if (IsInt8to32Type!T) {
119         auto shift = (address & lsb!T) << 3;
120         auto mask = bits!T << shift;
121         static if (is(T == uint) || is(T == int)) {
122             int intValue = value;
123         } else {
124             int intValue = value.ucast() << shift;
125         }
126         auto registers = registerSets[address >>> 2];
127         foreach (register; registers) {
128             if (!register.writable) {
129                 continue;
130             }
131             auto modifiedMask = (mask >>> register.shift) & register.mask;
132             if (modifiedMask == 0) {
133                 continue;
134             }
135             auto newValue = (intValue >>> register.shift) & modifiedMask;
136             static if (monitored) {
137                 if (register.onPreWrite is null || register.onPreWrite(modifiedMask, newValue)) {
138                     auto oldValue = register.value & modifiedMask;
139                     newValue &= modifiedMask;
140                     register.value = newValue | register.value & ~modifiedMask;
141                     if (register.onPostWrite !is null) {
142                         register.onPostWrite(modifiedMask, oldValue, newValue);
143                     }
144                 }
145             } else {
146                 register.value = newValue | register.value & ~modifiedMask;
147             }
148         }
149     }
150 }
151 
152 private alias ReadMonitor = void delegate(int, ref int);
153 private alias PreWriteMonitor = bool delegate(int, ref int);
154 private alias PostWriteMonitor = void delegate(int, int, int);
155 
156 private union ValuePtr {
157     bool* valueBool;
158     byte* valueByte;
159     ubyte* valueUbyte;
160     short* valueShort;
161     int* valueInt;
162 }
163 
164 private enum ValueSize {
165     NULL, BOOL, BYTE, UBYTE, SHORT, INT
166 }
167 
168 private struct Register {
169     private ReadMonitor onRead = null;
170     private PreWriteMonitor onPreWrite = null;
171     private PostWriteMonitor onPostWrite = null;
172     private ValuePtr valuePtr;
173     private int valueSize;
174     private int mask;
175     private int shift;
176     private bool readable;
177     private bool writable;
178 
179     private this(T)(T valuePtr, int mask, int shift, bool readable, bool writable) {
180         this.mask = mask;
181         this.shift = shift;
182         this.readable = readable;
183         this.writable = writable;
184 
185         static if (is(T == typeof(null))) {
186             valueSize = ValueSize.NULL;
187         } else static if (is(T == bool*)) {
188             this.valuePtr.valueBool = valuePtr;
189             valueSize = ValueSize.BOOL;
190         } else static if (is(T == byte*)) {
191             this.valuePtr.valueByte = valuePtr;
192             valueSize = ValueSize.BYTE;
193         } else static if (is(T == ubyte*)) {
194             this.valuePtr.valueUbyte = valuePtr;
195             valueSize = ValueSize.UBYTE;
196         } else static if (is(T == short*)) {
197             this.valuePtr.valueShort = valuePtr;
198             valueSize = ValueSize.SHORT;
199         } else static if (is(T == int*)) {
200             this.valuePtr.valueInt = valuePtr;
201             valueSize = ValueSize.INT;
202         } else {
203             static assert (0);
204         }
205     }
206 
207     @property
208     private int value() {
209         final switch (valueSize) with (ValueSize) {
210             case NULL:
211                 return 0;
212             case BOOL:
213                 return *valuePtr.valueBool & 0b1;
214             case BYTE:
215                 return *valuePtr.valueByte & 0xFF;
216             case UBYTE:
217                 return *valuePtr.valueUbyte & 0xFF;
218             case SHORT:
219                 return *valuePtr.valueShort & 0xFFFF;
220             case INT:
221                 return *valuePtr.valueInt;
222         }
223     }
224 
225     @property
226     private void value(int value) {
227         final switch (valueSize) with (ValueSize) {
228             case NULL:
229                 break;
230             case BOOL:
231                 *valuePtr.valueBool = cast(bool) value;
232                 break;
233             case BYTE:
234                 *valuePtr.valueByte = cast(byte) value;
235                 break;
236             case UBYTE:
237                 *valuePtr.valueUbyte = cast(ubyte) value;
238                 break;
239             case SHORT:
240                 *valuePtr.valueShort = cast(short) value;
241                 break;
242             case INT:
243                 *valuePtr.valueInt = value;
244                 break;
245         }
246     }
247 }
248 
249 unittest {
250     class TestMonitor {
251         int expectedMask;
252         int expectedValue;
253         int expectedOldValue;
254         int expectedNewValue;
255 
256         void expected(int mask, int value) {
257             expectedMask = mask;
258             expectedValue = value;
259         }
260 
261         void expected(int mask, int preWriteValue, int oldValue, int newValue) {
262             expected(mask, preWriteValue);
263             expectedOldValue = oldValue;
264             expectedNewValue = newValue;
265         }
266 
267         void onRead(int mask, ref int value) {
268             assert (expectedMask == mask);
269             assert (expectedValue == value);
270         }
271 
272         bool onPreWrite(int mask, ref int newValue) {
273             assert (expectedMask == mask);
274             assert (expectedValue == newValue);
275             return true;
276         }
277 
278         void onPostWrite(int mask, int oldValue, int newValue) {
279             assert (expectedMask == mask);
280             assert (expectedOldValue == oldValue);
281             assert (expectedNewValue == newValue);
282         }
283     }
284 
285     auto io = IoRegisters();
286     auto monitor1 = new TestMonitor();
287     auto monitor2 = new TestMonitor();
288     auto monitor3 = new TestMonitor();
289     auto monitor4 = new TestMonitor();
290 
291     bool data1 = false;
292     byte data2 = 0;
293     short data3 = 0;
294     int data4 = 0;
295 
296     io.mapAddress(0x10, &data1, 0b1, 9)
297             .readMonitor(&monitor1.onRead)
298             .preWriteMonitor(&monitor1.onPreWrite)
299             .postWriteMonitor(&monitor1.onPostWrite);
300     io.mapAddress(0x10, &data2, 0xDF, 10)
301             .readMonitor(&monitor2.onRead)
302             .preWriteMonitor(&monitor2.onPreWrite)
303             .postWriteMonitor(&monitor2.onPostWrite);
304     io.mapAddress(0x14, &data3, 0xFCCF, 16)
305             .readMonitor(&monitor3.onRead)
306             .preWriteMonitor(&monitor3.onPreWrite)
307             .postWriteMonitor(&monitor3.onPostWrite);
308     io.mapAddress(0x18, &data4, 0xFFFFFF, 0)
309             .readMonitor(&monitor4.onRead)
310             .preWriteMonitor(&monitor4.onPreWrite)
311             .postWriteMonitor(&monitor4.onPostWrite);
312 
313     monitor1.expected(0b1, 0b1, 0b0, 0b1);
314     monitor2.expected(0xDF, 0b1, 0b0, 0b1);
315     io.set!int(0x10, 0x700);
316     assert (io.get!int(0x10) == 0x600);
317     assert (data1);
318     assert (data2 == cast(byte) 0b1);
319 
320     monitor3.expected(0xFCCF, 0xFCCF, 0, 0xFCCF);
321     io.set!int(0x14, 0xFFFFFFFF);
322     assert (io.get!int(0x14) == 0xFCCF0000);
323     assert (data3 == cast(short) 0xFCCF);
324 
325     monitor4.expected(0x00FFFFFF, 0x00356789, 0, 0x00356789);
326     io.set!int(0x18, 0x12356789);
327     assert (io.get!int(0x18) == 0x00356789);
328     assert (data4 == 0x00356789);
329 
330     monitor4.expected(0x00FF0000, 0x00CD0000, 0x00350000, 0x00CD0000);
331     io.set!short(0x1A, cast(short) 0xABCD);
332     monitor4.expected(0x00FFFFFF, 0x00CD6789);
333     assert (io.get!int(0x18) == 0x00CD6789);
334     assert (data4 == 0x00CD6789);
335 
336     monitor4.expected(0x00FF0000, 0x00CD0000);
337     assert (io.get!short(0x1A) == 0x00CD);
338 
339     assert (io.registerSets[0x10 >>> 2].length == 2);
340     io.unmapAddress(0x10, 0x3FF, 0);
341     assert (io.registerSets[0x10 >>> 2].length == 1);
342 }