C++: Mapování hodnot z runtime na integrální konstanty

5. 6. 2023 8:28 (aktualizováno) Ondřej Novák

Dneska to bude relativně krátké. S každou další verzí C++ lze víc a víc algoritmů přesunout do constexpr „domény“, ve které se výpočty provádí během překladu a v runtime se již používají výsledky toho výpočtu. Často ale vstupem do takového výpočtu je hodnota získaná až v runtime.

Proč to potřebujeme?

Doména constexpr je velice silná a s každou novou verzí C++ se rozšiřují možnosti toho co všechno si lze dovolit v contexpr provádět. K tomu se pak vážou další možnosti optimalizace kódu, a výsledkem je samozřejmě efektivnější běh kódu.

Pokud se nám „nějak“ podaří namapovat hodnotu z runtime na integrální konstantu (integral constant), pak se nám otevírá obrovský svět, ve kterém můžeme tuhle konstantu použít nejen k výpočtu během překladu, ale tato konstanta může použita jako parametr šablony. Představte si, že bych mohl udělat následující:

template<int i>
struct NejakaStruktura;

auto funkce_ktera_nejde_prelozit(int v) {
    return NejakaStruktura<v>();
}

Výše zmíněný příklad nepřeložíme, protože proměnná v není integralni konstantou, je dodaná až v runtime. Překladač ale musí znát podobu typu NejakaStruktura<int> už během překladu.

Jiné použití je například u typu std::variant<>. Tento typ umožňuje definovat varianty pro jednu proměnnou (jako union). Například

std::variant<int, std::string, bool, double> my_variant(42);

S takovou proměnnou mohu pracovat následovně

  • Dotázat se, zda je v proměnné uložena nějaká alternativa (std::holds_alternative<std::strin­g>(my_variant) )- kde výsledkem volání je true nebo false podle aktuálního stavu
  • Získat hodnotu, pokud jí tam očekávám v určitém typu (std::get<int>(my_variant) )
  • Zavolat funktor (visitor) a pracovat s obsahem (std::visit([](auto val){…}, my_variant) )
  • Zjistit index aktuální varianty. Jednotlivé varianty jsou očíslované čísly 0,1,2…, v tomto případě 0 = int, 1 = std::string, 2 = bool, 3 = double.

Ten index lze použít například při serializaci, kdy do výsledného streamu uložíme index varianty a pak provedeme serializaci vlastní varianty. Při deserializaci stačí jen vyzvednout index, inicializovat variant podle indexu a provést deserializace a hotovo. 

Až na jeden problém, neexistuje standardní způsob, jak inicializovat variantu na základě indexu z runtime.

Standardní cestou je v zásadě switch-case s variantami

using MyVariant = std::variant<int, std::string, bool, double>
MyVariant deserialize(Stream stream,int index) {
    switch (index) {
        case 0: return deserialize<int>(stream);
        case 1: return deserialize<std::string>(stream);
        case 2: return deserialize<bool>(stream);
        case 3: return deserialize<double>(stream);
    }
}

Nutnost použít switch-case není vůbec generické, a nedokážu si představit podobný nástroj napsaný pro obecný std::variant<Types…>. Ze části STL, která nabízí šablony pro metaprograming můžeme použít typ std::variant_alternati­ve_t<index,std::variant<…> >. Jako by se nabízelo to napsat takto:

using MyVariant = std::variant<int, std::string, bool, double>
MyVariant deserialize(Stream stream,int index) { using Type = std::variant_alternative_t<index,MyVariant>; return deserialize<Type>(stream); } //toto se nepřeloží

Problém je, že index musí být integralní konstanta. 

Řešení

Zdá se naprosto nelogické chtít v constexpr doméně používat hodnotu z runtime. Přece v době běhu programu je tahle doména pevně „zapečena“ do kódu a nemůže dynamicky reagovat na stav runtime. Proto to je koneckonců const-expression. 

Leda, že by constexpr doména připravila všechny varianty a pak na základě dodané hodnoty vybrala příslušnou variant. To by šlo, ne?

Vraťme se k řešení switchem. Co to vlastně představuje? Switch je často řešen jako jump-table. Překladač vygeneruje tabulku adres na kterých začínají jednotlivé varianty. Při běhu pak mapuje proměnnou, podle které se vybírá varianta na index do této tabulky a skáče na adresu uloženou právě na daném indexu této tabulce. Skok realizuje přímo instrukce jmp.

Stejnou myšlenku by šlo použít i v tomto případě. Stačilo by si programově připravit tabulku adres na funkce, které představují jednotlivé varianty. Tyto varianty budou volat nějakou naši lambda funkci podobně jako funguje std::visit. Představuju si následující rozhraní 

template<int min, int max, typename Fn>
auto number_to_constant(int number, Fn &&fn);

A použití třeba takto:

MyVariant create_variant_by_index(int index, Stream stream) {
    constexpr int size = std::variant_size_v<MyVariant>;
    return number_to_constant<0,size-1 >(index, [&](auto v){
            using Type = std::variant_alternative_t<v.value,MyVariant>;
            return deserialize<Type>(stream);
}); }

Tento kód by se přeložil za předpokladu, že proměnná auto v byla nějakého typu, který nese integralni konstantu v proměnné value (takže v.value je integralní konstanta, kterou lze použít ve std::variant_alternative_t). Takový typ může vypadat následovně

template<int i>
struct ConstInteger {
    static constexpr int value = i;
};

Struktura deklaruje proměnnou value která je konstantou. A tuto hodnotu lze skutečně použít jako parametr šablony. Pak by mělo být možné kód přeložit. Jak naprogramovat funkci number_to_constant.

Programově vytvořená jump-table

Nutnost definovat číselný rozsah min a max ve funkci number_to_constant má samozřejmě význam. Z rozsahu si totiž odvodíme velikost naší skákací tabulky. Není možné mapovat libovolnou hodnotu na konstantu, protože celkové množství intů je … hodně moc… a kód připravený na tolik variant by byl obrovský. To není potřeba, množství variant bude vždycky výrazně menší.

Naši jump-table nám pomůže vytvořit constexpr objekt. To je objekt, který se instanciuje během překladu a do finální binárky vstupuje již inicializovaný.

 1 template<typename Fn, int min, int max>
 2 class JumpTable {
 3 public:
 4   static_assert(min <= max);
 5
 6   using Ret = decltype(std::declval<Fn>()(std::declval<ConstInteger<min> >()));
 7   static constexpr unsigned int size = max - min + 1;
 8
 9   constexpr JumpTable() {
10        init_jump_table<min>(_jumpTable);
11    }
12
13    constexpr Ret visit(int value, Fn &&fn) const  {
14        auto index = static_cast<unsigned int>(value - min);
15        if (index >= size) fn(ConstError{});
16        return _jumpTable[index](std::forward<Fn>(fn));
17    }
18
19 protected:
20
21    template<int i>
22    static Ret call_fn(Fn &&fn) {return fn(ConstInteger<i>());}
23
24    using FnPtr = Ret (*)(Fn &&fn);
25    FnPtr _jumpTable[size] = {};
26
27    template<int i>
28    constexpr void init_jump_table(FnPtr *ptr) {
29        if constexpr(i <= max) {
30            *ptr = &call_fn<i>;
31            init_jump_table<i+1>(ptr+1);
32        }
33    }
34 };
  •  Řádek 1–2 - deklarace šablony JumpTable, vstupními parametry je min, max a typ funkce, který budeme volat pro každou alternativu. Tady by se mělo jednat o funktor, protože se předpokládá, že funkce je šablonou. Předpokládá se tedy následující tvar lambda funkce: [](auto val){…}
  • Řádek 4 - ověří, že min <= max. Obě hodnoty z rozsahu patří do rozsahu
  • Řádek 6 – zjistí typ návratové hodnoty naší funkce. Předpokládá se, že všechny alternativy pro naší funkci vrací stejný typ, proto stačí zjistit typ z funkce pro variantu „min“
  • Řádek 7 - spočítá velikost tabulky. Výsledkem je konstanta
  • Řádek 9 - deklaruje constexpr konstruktor. Tím dáváme překladači najevo, že tato objekt lze konstruovat v constexpr doméně. 
  • Řádek 10 - v konstruktoru zavoláme funkci pro inicializaci jump-table. Tento kód se vykonává během překladu
  • Řádky 13–16 - Tahle funkce provádí mapování hodnoty na index do jump-table a na řádku 16 dochází k zavolání připravené funkce. To je konkrétní varianta pro zadaný index. Na řádku 15 kontrolujeme, že dodaná hodnota je v rozsahu a pokud není, zavoláme funkor a předáme mu speciální chybový objekt. Tato varianta je podobná jako default větev u switch-case
  • Řádky 21–22 - Ukazatel na tuhle funkci se bude ukládat do naší tabulky, přičemž je definovaná jako šablona. Šablona přijímá konstantu. Tato funkce pak instanciuje třídu ConstInteger<int> s konstantou předanou parametrem šablony a volá funktor s touto instancí jako parametr. Překladač si odvodí typ a přeloží funktor ve variantě pro konkrétní konstantu.
  • Řádek 24 - definuje typ ukazatel na funkci z řádku 22
  • Řádek 25 – deklaruje vlastní tabulku adres
  • Řádky 27–33 se provádí inicializace tabulky adres. Ta se musí provést rekurzivně, kdy parametrem šablony je nejprve minimální hodnotu a s každou rekurzí se index zvyšuje až překročí maximální hodnotu, kde se rekurze ukončí. V každém kroku se instanciuje jedna varianta funkce call_fn, která instanciuje variantu dodaného funktoru. Adresa této varianty se pak uloží do tabulky na patřičné místo a pokračuje se v rekurzi pro další položku v tabulce. Je třeba si uvědomit, že tahle funkce opět běží během překladu. Ve výsledné binárce je již tabulka adres připravená od spuštění a pouze se používá. Pro velký počet variant se může stát, že narazíme na velikost zásobníku překladače, pak by se kód musel trochu upravit

Funkce number_to_constant pak vypadá takto:

template<int min, int max, typename Fn>
auto number_to_constant(int number, Fn &&fn) {
    static constexpr JumpTable<Fn, min, max> jptable;
    return jptable.visit(number, std::forward<Fn>(fn));
}

Když je hodnota mimo rozsah?

Na tuhle situaci je třeba myslet a pro větší pohodlí upravíme třídu nesoucí integrální konstantu a třídu oznamující chybu

template<int i>
struct ConstInteger {
    static constexpr int value = i;
    static constexpr bool valid = true;
};

struct ConstError {
    static constexpr bool valid = false;
};

Proměnnou valid pak použijeme v deserializační funkci pro kontrolu, zda vstupem byl index v daném rozsahu.

MyVariant deserialize(int index, Stream stream) {
    constexpr int size = std::variant_size_v<MyVariant>;
    return number_to_constant<0, size-1 >(index, [](auto v){
        if constexpr(v.valid) {
            using Type = std::variant_alternative_t<v.value,MyVariant>;
            return MyVariant(Type{});
        } else {
            throw std::invalid_argument("Index out of range");
        }
    });
}

Použití if constexpr lze vyřídit situaci kdy v.valid == false, v takovém případě se část, kde se pracuje s v.value vůbec nebude překládat a překladač si nebude stěžovat na neexistenci v.value. Protože jde o integrální konstantu, lze použít if constexpr i v jiných situacích a na základě hodnoty kód větvit, přičemž překladač vždy přeloží jen tu variantu, která se týká dané hodnoty a ostatní tam vůbec nebudou. Lze se takhle úplně vyhnout podmínkám v kódu a konstruovat tak branchless kód.

Jiný příklad – převod utf-8 na wchar skoro branchless

template<typename InputIterator>
wchar_t utf8Towchar(InputIterator &at, InputIterator end) {
    unsigned char c = *at;
    ++at;
    int bytes = (c >= 0xC0) + (c >= 0xE0) + (c > 0xF0);
    return number_to_constant<0,3>(bytes, [&](auto b) ->wchar_t {
        if constexpr(!b.valid) {
            return static_cast<wchar_t>(c);
        } else if constexpr(b.value == 0) {
            return static_cast<wchar_t>(c);
        } else {
            wchar_t ret = static_cast<wchar_t>(c) & (0x3F >> b.value);
            for (int i = 0; i < b.value; i++) {
                if (at == end) return -1;
                unsigned char d = *at;
                ++at;
                ret = (ret << 6) | (d ^ 0x80);
            }
            return ret;
        }
    });
}
  • Kód spočítá počet dodatečných bajtů pomocí výrazu a uloží do bytes.
  • Na základě této proměnné, která může nabývat hodnot 0 až 3 se vykoná následující větev kódu
  • Je třeba řešit b.valid i když zde je jasné, že jiných hodnot proměnná nabývat nemůže, přesto, je to hodnota přicházející z runtime, takže se musí ošetřit
  • Pro nulu (c < 0×80) speciálně se char mapuje na wchar
  • Pro hodnoty 1,2,3 se spočítá maska (constexpr) a tou maskou se maskuje dolní čast bajtu, kde je začátek unicode znaku.
  • Pak následuje cyklus – a tady pozor: Překladač při -O2 rozvine cyklus do sekvence opakujícího se kódu.  V tomto cyklu se doplní zbyvající bajty z UTF-8 kódování
  • Výše uvedený kód se exekuuje branchless – bez podmíněných skoků – tedy přesněji, podmíněné skoky tam jsou, například při kontrole konce iterátoru, ale tyto podmínky jsou dobře předvídatelné a CPU si s nimi poradí.

Jak to vypadá v strojovém kódu?

; int bytes = (c >= 0xC0) + (c >= 0xE0) + (c > 0xF0);
xor     %eax,%eax
cmp     $0xbf,%dl       ; v dl je hodnota c
seta    %al
xor     %ecx,%ecx
cmp     $0xdf,%dl
seta    %cl
add     %ecx,%eax
cmp     $0xf0,%dl
lea     0x7313(%rip),%rcx
seta    %dl
movzbl  %dl,%edx
add     %edx,%eax       ; v eax máme index
call    *(%rcx,%rax,8)  ; skok na adresu která je obsahem [%rcx+%rax*8]
;gcc řeší jako call patřičné varianty

K čemu se to hodí?

Udělejme si shrnutí, k čemu se takový nástroj hodí

  • Konstrukci variantu dle zadaného indexu
  • Náhrada switch-case pro generické  algoritmy
  • Vynucení si optimalizace kódu pro jednotlivé varianty, například rozvinutí cyklů 
  • Konstrukce branchless kódu a tím vyšší výkon algoritmu

Demonstrační příklad

Nebojte se zeptat v komentářích. V přípravě mám další vychytávky v C++.

Sdílet