001    /*
002     * JastAddJ is covered by the modified BSD License.
003     * You should have received a copy of the
004     * modified BSD license with this compiler.
005     * 
006     * Copyright (c) 2011, Jesper Öqvist <jesper.oqvist@cs.lth.se>
007     * All rights reserved.
008     */
009    
010    /**
011     * <p>This aspect adds the Project Coin/JSR 334 Strings in Switch language
012     * change to the JastAddJ backend.
013     *
014     * <p>The following features were modified:
015     * <ul>
016     * <li>code generation for switch statement</li>
017     * </ul>
018     */
019    aspect StringsInSwitch {
020    
021        syn boolean SwitchStmt.isSwitchWithString() =
022                getExpr().type().isString();
023    
024        // inherit equation for typeString
025        inh TypeDecl SwitchStmt.typeString();
026    
027        /**
028         * We two extra locals for switch for switch with string!
029         */
030        eq SwitchStmt.getChild().localNum() =
031                //isSwitchWithString()
032                localNum() + typeInt().variableSize() + typeString().variableSize();
033                //: localNum();
034    
035        /**
036         * Local index for the first switch variable.
037         */
038        syn int SwitchStmt.localNumA() =
039                localNum();
040    
041        /**
042         * Local index for the second switch variable.
043         */
044        syn int SwitchStmt.localNumB() =
045                localNum() + typeInt().variableSize();
046    
047        /**
048         * Group multiple case labels as one.
049         */
050        class CaseGroup {
051                int lbl;
052                int hashCode;
053                java.util.List<CaseLbl>   lbls = new LinkedList<CaseLbl>();
054    
055                public CaseGroup(SwitchStmt ss, int hash) {
056                        lbl = ss.hostType().constantPool().newLabel();
057                        hashCode = hash;
058                }
059    
060                public void addCase(CaseLbl lbl) {
061                        lbls.add(lbl);
062                }
063        }
064    
065        /**
066         * Handles code generation for individual case labels.
067         */
068        class CaseLbl {
069                int lbl;
070                int serial;
071                String value;
072                java.util.List<Stmt> stmts = new ArrayList<Stmt>();
073    
074                CaseLbl(int lbl) {
075                        this.lbl = lbl;
076                }
077    
078                CaseLbl(ConstCase cc, CodeGeneration gen) {
079                        lbl = cc.label(gen);
080                        value = cc.getValue().constant().stringValue();
081                }
082    
083                void addStmt(Stmt stmt) {
084                        stmts.add(stmt);
085                }
086    
087                /**
088                 * Code generation for case label.
089                 */
090                void createBCode(CodeGeneration gen) {
091                        for (Stmt stmt : stmts) {
092                                stmt.createBCode(gen);
093                        }
094                }
095        }
096    
097        /**
098         * Utility method to compute offsets between labels.
099         */
100       syn int SwitchStmt.labelOffset(CodeGeneration gen, int lbl1, int lbl2) =
101               gen.addressOf(lbl1) - gen.addressOf(lbl2);
102    
103       /**
104        * Two switch statements are generated.
105        * The first switch will switch on the hash code of the switch expression.
106        * The first switch statement computes a value for a variable that selects
107        * a case in the second switch statement.
108        *
109        */
110       refine AutoBoxingCodegen
111               public void SwitchStmt.createBCode(CodeGeneration gen) {
112               if (getExpr().type().isString()) {
113                       // add line number for start of statement
114                       super.createBCode(gen);
115    
116                       // Enumerate case labels with same hash value
117                       TreeMap< Integer, CaseGroup > groups =
118                               new TreeMap< Integer, CaseGroup >();
119                       java.util.List<CaseLbl> labels = new LinkedList<CaseLbl>();
120    
121                       CaseLbl defaultLbl = null;
122                       CaseLbl caseLbl = null;
123                       int serial = 1;
124                       for (Stmt stmt : getBlock().getStmts()) {
125                               if (stmt instanceof ConstCase) {
126                                       ConstCase cc = (ConstCase) stmt;
127                                       caseLbl = new CaseLbl(cc, gen);
128                                       caseLbl.serial = serial++;
129                                       labels.add(caseLbl);
130                                       int key = caseLbl.value.hashCode();
131                                       if (groups.containsKey(key)) {
132                                               groups.get(key).addCase(caseLbl);
133                                       } else {
134                                               CaseGroup group = new CaseGroup(this, key);
135                                               group.addCase(caseLbl);
136                                               groups.put(key, group);
137                                       }
138                               } else if (stmt instanceof DefaultCase) {
139                                       defaultLbl =
140                                               new CaseLbl(hostType().constantPool().newLabel());
141                                       caseLbl = defaultLbl;
142                               } else if (caseLbl != null) {
143                                       caseLbl.addStmt(stmt);
144                               }
145                       }
146    
147                       int index_a = localNumA();
148                       genFirstSwitch(gen, groups, index_a);
149                       genSecondSwitch(gen, labels, index_a, defaultLbl);
150    
151               } else {
152                       refined(gen);
153               }
154       }
155    
156       private void SwitchStmt.genFirstSwitch(
157                       CodeGeneration gen,
158                       TreeMap<Integer, CaseGroup> groups,
159                       int index_a) {
160               int cond_label = hostType().constantPool().newLabel();
161               int switch_label = hostType().constantPool().newLabel();
162               int end_label1 = hostType().constantPool().newLabel();
163               int index_b = localNumB();
164    
165               gen.emitGoto(cond_label);
166    
167               // Code generation for switch body
168               for (CaseGroup group : groups.values()) {
169                       gen.addLabel(group.lbl);
170    
171                       // Possible hash miss. Check for equality.
172                       Iterator<CaseLbl> iter = group.lbls.iterator();
173                       while (iter.hasNext()) {
174                               CaseLbl lbl = iter.next();
175                               int thenLbl;
176                               if (iter.hasNext())
177                                       thenLbl = hostType().constantPool().newLabel();
178                               else
179                                       // last conditional branches to end label
180                                       thenLbl = end_label1;
181    
182                               typeString().emitLoadLocal(gen, index_b);
183                               StringLiteral.push(gen, lbl.value);
184                               equalsMethod().emitInvokeMethod(gen,
185                                               lookupType("java.lang", "Object"));
186                               gen.emitCompare(Bytecode.IFEQ, thenLbl);
187                               IntegerLiteral.push(gen, lbl.serial);
188                               typeInt().emitStoreLocal(gen, index_a);
189                               gen.emitGoto(end_label1);
190    
191                               if (iter.hasNext())
192                                       gen.addLabel(thenLbl);
193                       }
194               }
195    
196               gen.addLabel(cond_label);
197    
198               // Initialize switch variable for second switch
199               IntegerLiteral.push(gen, 0);
200               typeInt().emitStoreLocal(gen, index_a);
201    
202               // Store the value of the switch expr so that it is only evaluated once!
203               getExpr().createBCode(gen);
204    
205               // Push the hash code for the switch instruction
206               if (getExpr().isConstant()) {
207                       typeString().emitStoreLocal(gen, index_b);
208    
209                       int hashCode = getExpr().constant().stringValue().hashCode();
210                       IntegerLiteral.push(gen, hashCode);
211               } else {
212                       typeString().emitDup(gen);
213                       typeString().emitStoreLocal(gen, index_b);
214                       hashCodeMethod().emitInvokeMethod(gen,
215                                       lookupType("java.lang", "Object"));
216               }
217    
218               // Emit switch instruction
219               gen.addLabel(switch_label);
220               long low = groups.isEmpty() ? 0 : groups.firstKey();
221               long high = groups.isEmpty() ? 0 : groups.lastKey();
222    
223               long tableSwitchSize = 8L + (high - low + 1L) * 4L;
224               long lookupSwitchSize = 4L + groups.size() * 8L;
225    
226               // Select the switch type which produces the smallest switch instr.
227               if (tableSwitchSize < lookupSwitchSize) {
228                       gen.emit(Bytecode.TABLESWITCH);
229                       int pad = emitPad(gen);
230                       int defaultOffset = 1 + pad + 4 + 4 + 4 +
231                               4 * (int)(high - low + 1);
232                       gen.add4(defaultOffset);
233                       gen.add4((int)low);
234                       gen.add4((int)high);
235                       for(long i = low; i <= high; i++) {
236                               if (groups.containsKey((int)i)) {
237                                       CaseGroup group = groups.get((int)i);
238                                       int offset = labelOffset(gen, group.lbl, switch_label);
239                                       gen.add4(offset);
240                               } else {
241                                       gen.add4(defaultOffset);
242                               }
243                       }
244               } else {
245                       gen.emit(Bytecode.LOOKUPSWITCH);
246                       int pad = emitPad(gen);
247                       int defaultOffset = 1 + pad + 4 + 4 + 8 * groups.size();
248                       gen.add4(defaultOffset);
249                       gen.add4(groups.size());
250                       for (CaseGroup group : groups.values()) {
251                               gen.add4(group.hashCode);
252                               int offset = labelOffset(gen, group.lbl, switch_label);
253                               gen.add4(offset);
254                       }
255               }
256               gen.addLabel(end_label1);
257       }
258    
259       private void SwitchStmt.genSecondSwitch(
260                       CodeGeneration gen,
261                       java.util.List<CaseLbl> labels,
262                       int index_a,
263                       CaseLbl defaultLbl) {
264               int cond_label = hostType().constantPool().newLabel();
265               int switch_label = hostType().constantPool().newLabel();
266               int default_label = hostType().constantPool().newLabel();
267    
268               gen.emitGoto(cond_label);
269    
270               // Code generation for case labels
271    
272               for (CaseLbl lbl : labels) {
273                       gen.addLabel(lbl.lbl);
274                       lbl.createBCode(gen);
275               }
276    
277               gen.addLabel(default_label);
278               if (defaultLbl != null) {
279                       defaultLbl.createBCode(gen);
280    
281               }
282               if (canCompleteNormally())
283                       gen.emitGoto(end_label());
284    
285               gen.addLabel(cond_label);
286    
287               // push the switch variable
288               typeInt().emitLoadLocal(gen, index_a);
289    
290               // Emit switch instruction
291               gen.addLabel(switch_label);
292               gen.emit(Bytecode.TABLESWITCH);
293               long high = labels.size();
294               long tableSwitchSize = 8L + (high + 1L) * 4L;
295               int pad = emitPad(gen);
296               int defaultOffset = 1 + pad + 4 + 4 + 4 +
297                               4 * (int)(high + 1);
298               gen.add4(defaultOffset);
299               gen.add4(0);
300               gen.add4((int)high);
301    
302               int offset = labelOffset(gen, default_label, switch_label);
303               gen.add4(offset);
304               for (CaseLbl lbl : labels) {
305                       offset = labelOffset(gen, lbl.lbl, switch_label);
306                       gen.add4(offset);
307               }
308    
309               if (canCompleteNormally())
310                       gen.addLabel(end_label());
311       }
312    
313       /**
314        * Generate invocation of method
315        * {@code java.lang.Object.hashCode()}.
316        */
317       private MethodDecl SwitchStmt.hashCodeMethod() {
318               TypeDecl objectType = lookupType("java.lang", "Object");
319               if (objectType == null)
320                       throw new Error("Could not find java.lang.Object");
321               for (MethodDecl method :
322                               (Collection<MethodDecl>) objectType.memberMethods("hashCode")) {
323                       if (method.getNumParameter() == 0)
324                               return method;
325               }
326               throw new Error("Could not find java.lang.Object.hashCode()");
327       }
328    
329       /**
330        * Generate invocation of method
331        * {@code java.lang.Object.equals(java.lang.Object)}.
332        */
333       private MethodDecl SwitchStmt.equalsMethod() {
334               TypeDecl objectType = lookupType("java.lang", "Object");
335               if (objectType == null)
336                       throw new Error("Could not find java.lang.Object");
337               for (MethodDecl method :
338                               (Collection<MethodDecl>) objectType.memberMethods("equals")) {
339                       if (method.getNumParameter() == 1 &&
340                                       method.getParameter(0).getTypeAccess().type() == objectType)
341                               return method;
342               }
343               throw new Error("Could not find java.lang.Object.equals()");
344       }
345    }