USE16

stage1:
    ; Initialize segment registers
    xor ax, ax ; Set ax to 0
    mov ds, ax
    mov es, ax
    mov ss, ax

    ; Initialize stack pointer
    ; (stack grows up)
    mov sp, 0x7C00

    ; Initialize CS
    ;
    ; `retf` sets both CS and IP to a known-good state.
    ; This is necessary because we don't know where the BIOS put us at startup.
    ; (could be 0x00:0x7C00, could be 0x7C00:0x00. Not everybody follows spec.)
    push ax ; `ax` is still 0
    push word .set_cs
    retf

.set_cs:
    ; Save disk number.
    ; BIOS sets `dl` to the number of
    ; the disk we're booting from.
    mov [disk], dl

    ; Print "Stage 1"
    mov si, stage_msg
    call print
    mov al, '1'
    call print_char
    call print_line

    ; read CHS gemotry, save into [chs]
    ;  CL (bits 0-5) = maximum sector number
    ;  CL (bits 6-7) = high bits of max cylinder number
    ;  CH = low bits of maximum cylinder number
    ;  DH = maximum head number
    mov ah, 0x08
    mov dl, [disk]
    xor di, di
    int 0x13
    jc error ; carry flag set on error
    mov bl, ch
    mov bh, cl
    shr bh, 6
    mov [chs.c], bx
    shr dx, 8
    inc dx ; returns heads - 1
    mov [chs.h], dx
    and cl, 0x3f
    mov [chs.s], cl

    ; First sector of stage 2
    mov eax, STAGE2_SECTOR

    ; Where to load stage 2
    mov bx, stage2

    ; length of stage2 + stage3
    ; (on disk, in sectors)
    mov cx, (stage3.end - stage2) / 512
    mov dx, 0

    ; Consume eax, bx, cx, dx
    ; and load code from disk.
    call load

    jmp stage2.entry

; Load sectors from disk to memory.
; Cannot load more than 1MiB.
;
; Input:
;   ax: start sector
;   bx: offset of buffer
;   cx: number of sectors (512 Bytes each)
;   dx: segment of buffer
;
; Clobbers ax, bx, cx, dx, si
load:
    ; Every "replace 1" comment means that the `1`
    ; on that line could be bigger.
    ;
    ; See https://stackoverflow.com/questions/58564895/problem-with-bios-int-13h-read-sectors-from-drive
    ; We have to load one sector at a time to avoid the 1K boundary error.
    ; Would be nice to read more sectors at a time, though, that's faster.

    cmp cx, 1 ; replace 1
    jbe .good_size

    pusha
    mov cx, 1 ; replace 1
    call load
    popa
    add eax, 1 ; replace 1
    add dx, 1 * 512 / 16 ; replace 1
    sub cx, 1 ; replace 1

    jmp load
.good_size:
    mov [DAPACK.addr], eax
    mov [DAPACK.buf], bx
    mov [DAPACK.count], cx
    mov [DAPACK.seg], dx

    ; Print the data we're reading
    ; Prints AAAAAAAA#BBBB CCCC:DDDD, where:
    ; - A..A is the lba we're reading (printed in two parts)
    ; - BBBB is the number of sectors we're reading
    ; - CCCC is the index we're writing to
    ; - DDDD is the buffer we're writing to
    mov bx, [DAPACK.addr + 2] ; last two bytes
    call print_hex
    mov bx, [DAPACK.addr] ; first two bytes
    call print_hex
    mov al, '#'
    call print_char
    mov bx, [DAPACK.count]
    call print_hex
    mov al, ' '
    call print_char
    mov bx, [DAPACK.seg]
    call print_hex
    mov al, ':'
    call print_char
    mov bx, [DAPACK.buf]
    call print_hex
    call print_line


    ; Read from disk.
    ; int13h, ah=0x42 does not work on some disks.
    ; use int13h, ah=0x02 in this case.
    cmp byte [chs.s], 0
    jne .chs
    
    mov dl, [disk]
    mov si, DAPACK
    mov ah, 0x42
    int 0x13
    jc error ; carry flag set on error
    ret

.chs:
    ; calculate CHS
    xor edx, edx
    mov eax, [DAPACK.addr]
    div dword [chs.s] ; divide by sectors
    mov ecx, edx ; move sector remainder to ecx
    xor edx, edx
    div dword [chs.h] ; divide by heads
    ; eax has cylinders, edx has heads, ecx has sectors

    ; Sector cannot be greater than 63
    inc ecx ; Sector is base 1
    cmp ecx, 63
    ja error_chs

    ; Head cannot be greater than 255
    cmp edx, 255
    ja error_chs

    ; Cylinder cannot be greater than 1023
    cmp eax, 1023
    ja error_chs

    ; Move CHS values to parameters
    mov ch, al
    shl ah, 6
    and cl, 0x3f
    or cl, ah
    shl dx, 8

    ; read from disk using CHS
    mov al, [DAPACK.count]
    mov ah, 0x02 ; disk read (CHS)
    mov bx, [DAPACK.buf]
    mov dl, [disk]
    push es ; save ES
    mov es, [DAPACK.seg]
    int 0x13
    pop es ; restore EC
    jc error ; carry flag set on error
    ret

;
; MARK: errors
;

error_chs:
    mov ah, 0

error:
    call print_line

    mov bh, 0
    mov bl, ah
    call print_hex

    mov si, stage1_error_msg
    call print
    call print_line 

; halt after printing error details
.halt:
    cli
    hlt
    jmp .halt

;
; MARK: data
;

%include "print.asm"

stage_msg: db "Stage ",0
stage1_error_msg: db " ERROR",0

disk: db 0

chs:
    .c: dd 0
    .h: dd 0
    .s: dd 0

DAPACK:
        db 0x10
        db 0
.count: dw 0 ; int 13 resets this to # of blocks actually read/written
.buf:   dw 0 ; memory buffer destination address (0:7c00)
.seg:   dw 0 ; in memory page zero
.addr:  dq 0 ; put the lba to read in this spot