Commit 6e2aa414 authored by PoroCYon's avatar PoroCYon Committed by PoroCYon
Browse files

more optimizations (64-bit only for now)

parent 518a15d2
...@@ -27,6 +27,8 @@ ASFLAGS += -f elf64 ...@@ -27,6 +27,8 @@ ASFLAGS += -f elf64
endif endif
LDFLAGS_=$(LDFLAGS) -T $(LDDIR)/link.ld --oformat=binary LDFLAGS_=$(LDFLAGS) -T $(LDDIR)/link.ld --oformat=binary
SMOLFLAGS ?= #--libsep
CFLAGS += -m$(BITS) $(shell pkg-config --cflags sdl2) CFLAGS += -m$(BITS) $(shell pkg-config --cflags sdl2)
CXXFLAGS += -m$(BITS) $(shell pkg-config --cflags sdl2) CXXFLAGS += -m$(BITS) $(shell pkg-config --cflags sdl2)
...@@ -37,7 +39,7 @@ ASFLAGS += -DUSE_INTERP -DALIGN_STACK ...@@ -37,7 +39,7 @@ ASFLAGS += -DUSE_INTERP -DALIGN_STACK
NASM ?= nasm NASM ?= nasm
PYTHON3 ?= python3 PYTHON3 ?= python3
all: $(BINDIR)/hello-crt $(BINDIR)/sdl-crt $(BINDIR)/flag-crt all: $(BINDIR)/hello-crt $(BINDIR)/sdl-crt $(BINDIR)/flag-crt $(BINDIR)/hello-_start
LIBS += $(filter-out -pthread,$(shell pkg-config --libs sdl2)) -lX11 #-lGL LIBS += $(filter-out -pthread,$(shell pkg-config --libs sdl2)) -lX11 #-lGL
...@@ -58,7 +60,7 @@ $(OBJDIR)/%.start.o: $(OBJDIR)/%.o $(OBJDIR)/crt1.o ...@@ -58,7 +60,7 @@ $(OBJDIR)/%.start.o: $(OBJDIR)/%.o $(OBJDIR)/crt1.o
$(LD) $(LDFLAGS) -r -o "$@" $^ $(LD) $(LDFLAGS) -r -o "$@" $^
$(OBJDIR)/symbols.%.asm: $(OBJDIR)/%.o $(OBJDIR)/symbols.%.asm: $(OBJDIR)/%.o
$(PYTHON3) ./smol.py $(LIBS) "$<" "$@" $(PYTHON3) ./smol.py $(SMOLFLAGS) $(LIBS) "$<" "$@"
$(OBJDIR)/stub.%.o: $(OBJDIR)/symbols.%.asm $(SRCDIR)/header32.asm \ $(OBJDIR)/stub.%.o: $(OBJDIR)/symbols.%.asm $(SRCDIR)/header32.asm \
$(SRCDIR)/loader32.asm $(SRCDIR)/loader32.asm
......
...@@ -9,9 +9,12 @@ SECTIONS { ...@@ -9,9 +9,12 @@ SECTIONS {
_smol_text_start = .; _smol_text_start = .;
_smol_text_off = _smol_text_start - _smol_origin; _smol_text_off = _smol_text_start - _smol_origin;
.text : { .text : {
KEEP(*(.rodata.dynamic))
KEEP(*(.rodata.interp .rodata.neededlibs))
*(.rdata .rdata.* .rodata .rodata.*)
KEEP(*(.text.startup.smol)) KEEP(*(.text.startup.smol))
KEEP(*(.text.startup._start)) KEEP(*(.text.startup._start))
*(.text .text.* .rdata .rdata.* .rodata .rodata.*) *(.text .text.*)
} }
_smol_text_end = .; _smol_text_end = .;
_smol_text_size = _smol_text_end - _smol_text_start; _smol_text_size = _smol_text_end - _smol_text_start;
......
...@@ -29,6 +29,10 @@ def main(): ...@@ -29,6 +29,10 @@ def main():
help="which scanelf binary to use") help="which scanelf binary to use")
parser.add_argument('--readelf', default=shutil.which('readelf'), \ parser.add_argument('--readelf', default=shutil.which('readelf'), \
help="which readelf binary to use") help="which readelf binary to use")
parser.add_argument('--libsep', default=False, action='store_true', \
help="Separete import symbols per library, instead of looking at every library when resolving a symbol.")
parser.add_argument('input', nargs='+', help="input object file") parser.add_argument('input', nargs='+', help="input object file")
parser.add_argument('output', type=argparse.FileType('w'), \ parser.add_argument('output', type=argparse.FileType('w'), \
help="output nasm file", default=sys.stdout) help="output nasm file", default=sys.stdout)
...@@ -60,7 +64,7 @@ def main(): ...@@ -60,7 +64,7 @@ def main():
symbols.setdefault(library, []) symbols.setdefault(library, [])
symbols[library].append((symbol, reloc)) symbols[library].append((symbol, reloc))
output(arch, symbols, args.output) output(arch, symbols, args.libsep, args.output)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
......
...@@ -3,25 +3,34 @@ import sys ...@@ -3,25 +3,34 @@ import sys
from smolshared import * from smolshared import *
def output_x86(libraries, outf): def output_x86(libraries, libsep, outf):
outf.write('; vim: set ft=nasm:\n') # be friendly outf.write('; vim: set ft=nasm:\n') # be friendly
outf.write('bits 32\n') outf.write('bits 32\n')
if libsep:
outf.write('%define LIBSEP\n')
shorts = { l: l.split('.', 1)[0].lower().replace('-', '_') for l in libraries } shorts = { l: l.split('.', 1)[0].lower().replace('-', '_') for l in libraries }
outf.write('%include "header32.asm"\n') outf.write('%include "header32.asm"\n')
outf.write('dynamic.needed:\n') outf.write('dynamic.needed:\n')
for library in libraries: for library in libraries:
outf.write('dd 1;DT_NEEDED\n') outf.write('dd 1;DT_NEEDED\n')
outf.write('dd (_symbols.{} - _symbols)\n'.format(shorts[library])) outf.write('dd (_symbols.{} - _strtab)\n'.format(shorts[library]))
outf.write('dynamic.end:\n') outf.write('dynamic.end:\n')
# if needgot: # if needgot:
# outf.write('global _GLOBAL_OFFSET_TABLE_\n') # outf.write('global _GLOBAL_OFFSET_TABLE_\n')
# outf.write('_GLOBAL_OFFSET_TABLE_:\n') # outf.write('_GLOBAL_OFFSET_TABLE_:\n')
# outf.write('dd dynamic\n') # outf.write('dd dynamic\n')
outf.write('_strtab:\n')
if not libsep:
for library, symrels in libraries.items():
outf.write('\t_symbols.{}: db "{}",0\n'.format(shorts[library], library))
outf.write('_symbols:\n') outf.write('_symbols:\n')
for library, symrels in libraries.items(): for library, symrels in libraries.items():
outf.write('\t_symbols.{}: db "{}",0\n'.format(shorts[library], library)) if libsep:
outf.write('\t_symbols.{}: db "{}",0\n'.format(shorts[library], library))
for sym, reloc in symrels: for sym, reloc in symrels:
# meh # meh
...@@ -41,27 +50,45 @@ def output_x86(libraries, outf): ...@@ -41,27 +50,45 @@ def output_x86(libraries, outf):
outf.write('_symbols.end:\n') outf.write('_symbols.end:\n')
outf.write('%include "loader32.asm"\n') outf.write('%include "loader32.asm"\n')
# end output_x86
def output_amd64(libraries, outf): def output_amd64(libraries, libsep, outf):
outf.write('; vim: set ft=nasm:\n') outf.write('; vim: set ft=nasm:\n')
outf.write('bits 64\n') outf.write('bits 64\n')
if libsep:
outf.write('%define LIBSEP\n')
shorts = { l: l.split('.', 1)[0].lower().replace('-', '_') for l in libraries } shorts = { l: l.split('.', 1)[0].lower().replace('-', '_') for l in libraries }
outf.write('%include "header64.asm"\n') outf.write('%include "header64.asm"\n')
outf.write('dynamic.needed:\n') outf.write('dynamic.needed:\n')
for library in libraries: for library in libraries:
outf.write('dq 1;DT_NEEDED\n') outf.write('dq 1;DT_NEEDED\n')
outf.write('dq (_symbols.{} - _symbols)\n'.format(shorts[library])) outf.write('dq (_symbols.{} - _strtab)\n'.format(shorts[library]))
outf.write('dynamic.end:\n') outf.write('dynamic.end:\n')
outf.write('[section .data.smolgot]\n') if libsep:
outf.write('[section .data.smolgot]\n')
else:
outf.write('[section .rodata.neededlibs]\n')
# if needgot: # if needgot:
# outf.write('global _GLOBAL_OFFSET_TABLE_\n') # outf.write('global _GLOBAL_OFFSET_TABLE_\n')
# outf.write('_GLOBAL_OFFSET_TABLE_:\n') # outf.write('_GLOBAL_OFFSET_TABLE_:\n')
# outf.write('dq dynamic\n') # outf.write('dq dynamic\n')
outf.write('_strtab:\n')
if not libsep:
for library, symrels in libraries.items():
outf.write('\t_symbols.{}: db "{}",0\n'.format(shorts[library], library))
if not libsep:
outf.write('[section .data.smolgot]\n')
outf.write('_symbols:\n') outf.write('_symbols:\n')
for library, symrels in libraries.items(): for library, symrels in libraries.items():
outf.write('\t_symbols.{}: db "{}",0\n'.format(shorts[library], library)) if libsep:
outf.write('\t_symbols.{}: db "{}",0\n'.format(shorts[library], library))
for sym, reloc in symrels: for sym, reloc in symrels:
if reloc != 'R_X86_64_PLT32' and reloc != 'R_X86_64_GOTPCRELX': if reloc != 'R_X86_64_PLT32' and reloc != 'R_X86_64_GOTPCRELX':
...@@ -78,7 +105,9 @@ global {name} ...@@ -78,7 +105,9 @@ global {name}
outf.write('\t\t_symbols.{lib}.{name}: dq 0x{hash:x}\n'\ outf.write('\t\t_symbols.{lib}.{name}: dq 0x{hash:x}\n'\
.format(lib=shorts[library],name=sym,hash=hash)) .format(lib=shorts[library],name=sym,hash=hash))
outf.write('\tdq 0\n') if libsep:
outf.write('\tdq 0\n')
outf.write('db 0\n') outf.write('db 0\n')
outf.write('_symbols.end:\n') outf.write('_symbols.end:\n')
...@@ -95,9 +124,12 @@ global {name} ...@@ -95,9 +124,12 @@ global {name}
outf.write('_smolplt.end:\n') outf.write('_smolplt.end:\n')
outf.write('%include "loader64.asm"\n') outf.write('%include "loader64.asm"\n')
def output(arch, libraries, outf): # end output_amd64
if arch == 'i386': output_x86(libraries, outf)
elif arch == 'x86_64': output_amd64(libraries, outf)
def output(arch, libraries, libsep, outf):
if arch == 'i386': output_x86(libraries, libsep, outf)
elif arch == 'x86_64': output_amd64(libraries, libsep, outf)
else: else:
eprintf("E: cannot emit for arch '" + str(arch) + "'") eprintf("E: cannot emit for arch '" + str(arch) + "'")
sys.exit(1) sys.exit(1)
......
...@@ -87,7 +87,7 @@ _DYNAMIC: ...@@ -87,7 +87,7 @@ _DYNAMIC:
dynamic: dynamic:
dynamic.strtab: dynamic.strtab:
dd DT_STRTAB ; d_tag dd DT_STRTAB ; d_tag
dd _symbols ; d_un.d_ptr dd _strtab ; d_un.d_ptr
dynamic.symtab: dynamic.symtab:
; this is required to be present or ld.so will crash, but it can be bogus ; this is required to be present or ld.so will crash, but it can be bogus
dd DT_SYMTAB ; d_tag: 6 = DT_SYMTAB dd DT_SYMTAB ; d_tag: 6 = DT_SYMTAB
......
...@@ -87,17 +87,19 @@ phdr.load2: ...@@ -87,17 +87,19 @@ phdr.load2:
phdr.end: phdr.end:
%ifdef USE_INTERP %ifdef USE_INTERP
[section .rodata.interp]
interp: interp:
db "/lib64/ld-linux-x86-64.so.2", 0 db "/lib64/ld-linux-x86-64.so.2", 0
interp.end: interp.end:
%endif %endif
[section .rodata.dynamic]
global _DYNAMIC global _DYNAMIC
_DYNAMIC: _DYNAMIC:
dynamic: dynamic:
dynamic.strtab: dynamic.strtab:
dq DT_STRTAB ; d_tag dq DT_STRTAB ; d_tag
dq _symbols ; d_un.d_ptr dq _strtab ; d_un.d_ptr
dynamic.symtab: dynamic.symtab:
dq DT_SYMTAB ; d_tag dq DT_SYMTAB ; d_tag
dq 0 ; d_un.d_ptr dq 0 ; d_un.d_ptr
......
; vim: set ft=nasm: ; vim: set ft=nasm:
; TODO: r13 -> something else
; TODO: smaller!!
%define R10_BIAS (0x2B8)
%include "rtld.inc" %include "rtld.inc"
%ifdef ELF_TYPE %ifdef ELF_TYPE
...@@ -9,9 +14,10 @@ ...@@ -9,9 +14,10 @@
[section .text] [section .text]
%endif %endif
; rbx: ptrdiff_t glibc_vercompat_extra_hi_field_off ; r9 : ptrdiff_t glibc_vercompat_extra_hi_field_off
; r10: struct link_map* entry + far correction factor ; r10: struct link_map* entry + far correction factor
; r12: struct link_map* entry ; r12: struct link_map* entry
; r14: struct link_map* root
; r13: _dl_fini address (reqd by the ABI) ; r13: _dl_fini address (reqd by the ABI)
%ifndef ELF_TYPE %ifndef ELF_TYPE
...@@ -26,48 +32,36 @@ _smol_start: ...@@ -26,48 +32,36 @@ _smol_start:
%ifdef USE_DT_DEBUG %ifdef USE_DT_DEBUG
mov r12, [rel _DEBUG] mov r12, [rel _DEBUG]
mov r12, [r14 + 8] mov r12, [r12 + 8]
%else %else
mov r12, [rsp - 8] ; return address of _dl_init mov r12, [rsp - 8] ; return address of _dl_init
mov r11d, dword [r12 - 20] ; decode part of 'mov rdi, [rel _rtld_global]' mov ebx, dword [r12 - 20] ; decode part of 'mov rdi, [rel _rtld_global]'
mov r12, [r12 + r11 - 16] ; ??? mov r12, [r12 + rbx - 16] ; ???
%endif %endif
; struct link_map* root = r12 ; struct link_map* root = r12
%ifdef SKIP_ENTRIES %ifdef SKIP_ENTRIES
mov r12, [r12 + L_NEXT_OFF] ; skip this binary mov r12, [r12 + L_NEXT_OFF] ; skip this binary
mov r12, [r12 + L_NEXT_OFF] ; skip the vdso ; mov r12, [r12 + L_NEXT_OFF] ; skip the vdso
; the second one isn't needed anymore, see code below (.next_link)
%endif %endif
; mov rsi, r12
; size_t* field = (size_t*)root;
; for (; *field != _smol_start; ++field) ;
; .next_off:
; lodsq
; cmp rax, _smol_start
; jne short .next_off
; // rbx = offsetof(struct link_map* rsi, l_entry) - DEFAULT_OFFSET
; rbx = field - root - offsetof(struct link_map, l_entry)
; sub rsi, r12
; sub rsi, LF_ENTRY_OFF+8
; xchg rbx, rsi
mov rdi, r12 mov rdi, r12
push -1 push -1
pop rcx pop rcx
;mov rax, _smol_start ;mov rax, _smol_start
lea rax, [rel _smol_start] lea rax, [rel _smol_start] ; TODO: make offset positive!
repne scasq repne scasq
sub rdi, r12 sub rdi, r12
sub rdi, LF_ENTRY_OFF+8 sub rdi, LF_ENTRY_OFF+8
xchg rbx, rdi xchg r9 , rdi
;mov esi, _symbols ;mov edi, _symbols
lea esi, [rel _symbols] lea edi, [rel _symbols]
; for (rsi = (uint8_t*)_symbols; *rsi; ++rsi) { %ifdef LIBSEP
; for (rdi = (uint8_t*)_symbols; *rdi; ++rdi) {
.next_needed: .next_needed:
cmp byte [rsi], 0 cmp byte [rdi], 0
je .needed_end je .needed_end
; do { // iter over the link_map ; do { // iter over the link_map
...@@ -75,14 +69,17 @@ repne scasq ...@@ -75,14 +69,17 @@ repne scasq
; entry = entry->l_next; ; entry = entry->l_next;
mov r12, [r12 + L_NEXT_OFF] ; skip the first one (this is our main mov r12, [r12 + L_NEXT_OFF] ; skip the first one (this is our main
; binary, it has no symbols) ; binary, it has no symbols)
lea r10, [r12 + r9 + R10_BIAS]
; keep the current symbol in a backup reg ; keep the current symbol in a backup reg
mov rdx, rsi push rdi
pop rdx
; r11 = basename(rsi = entry->l_name) ; r11 = basename(rsi = entry->l_name)
mov rsi, [r12 + L_NAME_OFF] mov rsi, [r12 + L_NAME_OFF]
.basename: .basename:
mov r11, rsi push rsi
pop r11
.basename.next: .basename.next:
lodsb lodsb
cmp al, '/' cmp al, '/'
...@@ -92,7 +89,10 @@ repne scasq ...@@ -92,7 +89,10 @@ repne scasq
.basename.done: .basename.done:
; and place it back ; and place it back
mov rsi, rdx ; rsi == _symbol push rdx
push rdx
pop rdi ; rdi == _symbol
pop rsi
; strcmp(rsi, r11) -> flags; rsi == first hash if matches ; strcmp(rsi, r11) -> flags; rsi == first hash if matches
.strcmp: .strcmp:
...@@ -105,8 +105,7 @@ repne scasq ...@@ -105,8 +105,7 @@ repne scasq
inc r11 inc r11
jmp short .strcmp jmp short .strcmp
.strcmp.done: .strcmp.done:
xchg rsi, rdi
;mov rsi, rdx
; if (strcmp(...)) goto next_link; ; if (strcmp(...)) goto next_link;
;cmovnz r12, [r12 + L_NEXT_OFF] ; this is guaranteed to be nonzero ;cmovnz r12, [r12 + L_NEXT_OFF] ; this is guaranteed to be nonzero
...@@ -119,61 +118,121 @@ repne scasq ...@@ -119,61 +118,121 @@ repne scasq
; do { ; do {
.next_hash: .next_hash:
; if (!*phash) break; ; if (!*phash) break;
lodsq ;lodsq
mov eax, dword [rdi]
or eax, eax or eax, eax
jz short .next_needed ; done the last hash, so move to the next lib jz short .next_needed ; done the last hash, so move to the next lib
;link_symbol(struct link_map* entry = r12, size_t* phash = rsi, uint32_t hash = eax) ;link_symbol(struct link_map* entry = r12, size_t* phash = rsi, uint32_t hash = eax)
lea r10, [r12 + rbx]
mov r11, rax push rax
pop r11
; uint32_t bkt_ind(edx) = hash % entry->l_nbuckets ; uint32_t bkt_ind(edx) = hash % entry->l_nbuckets
xor edx, edx xor edx, edx
mov ecx, dword [r10 + LF_NBUCKETS_OFF] mov ecx, dword [r10 + LF_NBUCKETS_OFF - R10_BIAS]
div ecx div ecx
; shift left because we don't want to compare the lowest bit ; shift left because we don't want to compare the lowest bit
shr r11, 1 shr r11, 1
; uint32_t bucket(edx) = entry->l_gnu_buckets[bkt_ind] ; uint32_t bucket(edx) = entry->l_gnu_buckets[bkt_ind]
mov r8, [r10 + LF_GNU_BUCKETS_OFF] mov r8, [r10 + LF_GNU_BUCKETS_OFF - R10_BIAS]
mov edx, dword [r8 + rdx * 4] mov edx, dword [r8 + rdx * 4]
; do { ; do {
.next_chain: .next_chain:
; uint32_t luhash(ecx) = entry->l_gnu_chain_zero[bucket] >> 1 ; uint32_t luhash(ecx) = entry->l_gnu_chain_zero[bucket] >> 1
mov rcx, [r10 + LF_GNU_CHAIN_ZERO_OFF] mov rcx, [r10 + LF_GNU_CHAIN_ZERO_OFF - R10_BIAS]
mov ecx, dword [rcx + rdx * 4] mov ecx, dword [rcx + rdx * 4]
shr ecx, 1 shr ecx, 1
; if (luhash == hash) break; ; if (luhash == hash) break;
cmp ecx, r11d cmp ecx, r11d
je short .chain_break je short .chain_break
; ++bucket; } while (1); ; ++bucket; } while (LIBSEP || (luhash & 1))
inc edx inc edx
jne short .next_chain jne short .next_chain
%else
; !LIBSEP
push r12
pop r11 ; back up link_map root
.next_hash:
;lodsq
mov eax, dword [rdi]
or al, al
jz short .needed_end
push r11
push rax
push rax
pop rbx
pop r14
pop r12
; shift left because we don't want to compare the lowest bit
shr ebx, 1
.next_link:
mov r12, [r12 + L_NEXT_OFF]
lea r10, [r12 + r9 + R10_BIAS]
; uint32_t bkt_ind(edx) = hash % entry->l_nbuckets
xor edx, edx
push r14
pop rax
mov ecx, dword [r10 + LF_NBUCKETS_OFF - R10_BIAS]
div ecx
; uint32_t bucket(edx) = entry->l_gnu_buckets[bkt_ind]
mov r8 , [r10 + LF_GNU_BUCKETS_OFF - R10_BIAS]
mov edx, dword [r8 + rdx * 4]
or edx, edx
jz short .next_link
.next_chain:
; uint32_t luhash(ecx) = entry->l_gnu_chain_zero[bucket] >> 1
mov rcx, [r10 + LF_GNU_CHAIN_ZERO_OFF - R10_BIAS]
mov ecx, dword [rcx + rdx * 4]
; if (!(luhash & 1)) goto next_link; // nothing to be found in this lib.
mov al, cl
shr ecx, 1
; if (luhash == hash) break;
cmp ecx, ebx
je short .chain_break
; ++bucket; } while (luhash & 1);
and al, 1
jnz short .next_link
inc edx
jmp short .next_chain
%endif
.chain_break: .chain_break:
; ElfW(Sym)* symtab = entry->l_info[DT_SYMTAB]->d_un.d_ptr ; ElfW(Sym)* symtab = entry->l_info[DT_SYMTAB]->d_un.d_ptr
; ElfW(Sym)* sym = &symtab[bucket] ; ElfW(Sym)* sym = &symtab[bucket]
; *phash = sym->st_value + entry->l_addr ; *phash = sym->st_value + entry->l_addr
; ElfW(Dyn)* dyn(rax) = entry->l_info[DT_SYMTAB] ; ElfW(Dyn)* dyn(rax) = entry->l_info[DT_SYMTAB]
mov rax, [r12 + L_INFO_DT_SYMTAB_OFF] mov rax, [r12 + L_INFO_DT_SYMTAB_OFF]
; ElfW(Sym)* symtab(rax) = dyn->d_un.d_ptr ; ElfW(Sym)* symtab(rax) = dyn->d_un.d_ptr
mov rax, [rax + D_UN_PTR_OFF] mov rax, [rax + D_UN_PTR_OFF]
; ElfW(Addr) symoff(rax) = symtab[bucket].st_value ; ElfW(Addr) symoff(rax) = symtab[bucket].st_value
lea rdx, [rdx + rdx * 2] lea rdx, [rdx + rdx * 2]
mov rax, [rax + rdx * 8 + ST_VALUE_OFF] mov rax, [rax + rdx * 8 + ST_VALUE_OFF]
; void* finaladdr(rax) = symoff + entry->l_addr ; void* finaladdr(rax) = symoff + entry->l_addr
mov rcx, [r12 + L_ADDR_OFF] mov rcx, [r12 + L_ADDR_OFF]
add rax, rcx add rax, rcx
; *phash = finaladdr ; *phash = finaladdr
mov [rsi-8], rax ;mov [rsi-8], rax
; mov [rsi], rax
; } while (1) ; lodsq
jmp short .next_hash stosq
; } while (1)