пятница, 8 сентября 2023 г.

[prog.flame] Может ли аккуратно написанный код выглядеть и пахнуть как говнокод?

Таки да!

Вот реально, если бы этот фрагмент прилетел бы ко мне на сопровождение, то у меня бы нехило пригорело бы. Уже хотя бы от того, что даже в его оформлении используются элементы, которые меня лично сильно подбешивают. Но там и парочки объективных проблем хватило бы, чтобы я сильно приуныл.

Во-первых, количество аргументов метода search. Тут как в старой программерской мудрости: если ваша функция получает 10 аргументов, то вы наверняка забыли передать туда еще что-то :) Кроме того, есть подряд идущие аргументы одного типа (например, три аргумента типа dim_t, затем три аргумента типа bool). Мой опыт показывает, что это прямой путь к тому, что рано или поздно какое-то из значений будет перепутано.

Во-вторых, объем метода. Ну он просто слишком большой. Вот честно завидую людям, которые в состоянии все происходящее в методе такого объема удержать в памяти. И не просто удержать, а еще и убедится, что там все корректно.

Вот, скажем, мой взгляд зацепился за объявление переменной biased_decoder типа std::unique_ptr<BiasedDecoder>. Она инициализируется реальным значением только при выполнении некоторых условий, в остальных случаях там будет nullptr. Используется же потом этот biased_decoder всего лишь один раз, довольно далеко от того места, в котором он был создан. Где гарантии, что никто к biased_decoder не обратится еще раз, не проверив предварительно bias_towards_prefix?

Или вот этот вот keep_batches, который так же unique_ptr, но который объявляется вне ветки if-а, в которой он затем создается и используется. Нафига, спрашивается? Там же вообще, насколько я понимаю, можно было обойтись без unique_ptr и объявить keep_batches как обычную переменную на стеке прямо в том месте, где она потребовалась...

Конечно, очень смело с моей стороны судить о коде, в котором не разбираешься (а ведь я не в зуб ногой в этой предметной области). Но тут просто профессиональный нюх подсказывает, что этот код хрупок. И нипривидихоспади столкнуться с сопровождением такого кода. Хотя, конечно, доводилось видеть и пострашнее. Но там сразу было видно, что говно-говном. А здесь вроде как все чистенько и аккуратненько, но с запашком, однако :(

Для тех, кому лень переходить по ссылке на github, под катом весь фрагмент, к которому я докопался.

ЗЫ. На этот проект вышел через здесь, захотелось посмотреть, что там в других местах с качеством C++ного кода.

  std::vector<DecodingResult>
  BeamSearch::search(layers::Decoder& decoder,
                     layers::DecoderState& state,
                     const Sampler& sampler,
                     const std::vector<size_t>& start_ids,
                     const std::vector<size_t>& end_ids,
                     const dim_t start_step,
                     const dim_t max_length,
                     const dim_t min_length,
                     const bool return_scores,
                     const bool return_attention,
                     const bool return_prefix,
                     const size_t num_hypotheses,
                     const bool include_eos_in_hypotheses,
                     const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors,
                     const std::vector<std::vector<size_t>>* prefix_ids) const {
    PROFILE("beam_search");
    const Device device = decoder.device();
    const DataType dtype = decoder.output_type();
    const dim_t vocabulary_size = decoder.output_size();
    const dim_t batch_size = start_ids.size();

    // We get more candidates than the beam size so that if half the candidates are EOS,
    // we can replace finished hypotheses with active beams.
    const dim_t num_candidates = _beam_size * 2;

    // Only the first beam is considered in the first step. As an additional optimization
    // we try to run the first step without expanding the batch size.
    const bool expand_after_first_step = (device == Device::CPU
                                          && num_candidates <= vocabulary_size);

    // We can exit early when the first beam finishes and no penalties are used.
    const bool allow_early_exit = (_length_penalty == 0 && _coverage_penalty == 0);

    StorageView topk_ids({batch_size}, DataType::INT32);
    StorageView topk_scores(dtype);

    std::vector<bool> top_beam_finished(batch_size, false);
    std::vector<dim_t> batch_offset(batch_size);
    std::vector<DecodingResult> results(batch_size);
    for (dim_t i = 0; i < batch_size; ++i) {
      batch_offset[i] = i;
      topk_ids.at<int32_t>(i) = start_ids[i];
    }

    if (!expand_after_first_step) {
      decoder.replicate_state(state, _beam_size);
      repeat_batch(topk_ids, _beam_size);
      TYPE_DISPATCH(dtype, initialize_beam_scores<T>(topk_scores, batch_size, _beam_size));
    }

    std::unique_ptr<BiasedDecoder> biased_decoder;
    std::vector<std::vector<bool>> beams_diverged_from_prefix;
    bool bias_towards_prefix = prefix_ids && _prefix_bias_beta > 0;
    if (bias_towards_prefix) {
      biased_decoder = std::make_unique<BiasedDecoder>(_prefix_bias_beta, *prefix_ids);
      beams_diverged_from_prefix.resize(batch_size, std::vector<bool>(_beam_size, false));
    }
    const bool use_hard_prefix = prefix_ids && !bias_towards_prefix;

    StorageView logits(dtype, device);
    StorageView alive_seq(topk_ids.dtype());
    StorageView alive_attention;

    const dim_t max_step = get_max_step(max_length,
                                        return_prefix,
                                        use_hard_prefix ? prefix_ids : nullptr);

    for (dim_t step = 0; step < max_step; ++step) {
      const bool is_expanded = (!expand_after_first_step || step > 0);

      // Compute log probs for the current step.
      StorageView attention_step(dtype, device);
      convert_to_original_word_ids(decoder, topk_ids);
      decoder(start_step + step,
              topk_ids.to(device),
              state,
              &logits,  // output shape: (cur_batch_size*beam_size x vocab_size), if not expanded beam_size is 1
              (return_attention || _coverage_penalty != 0) ? &attention_step : nullptr);

      const dim_t cur_batch_size = is_expanded ? logits.dim(0) / _beam_size : logits.dim(0);

      DisableTokens disable_tokens(logits);

      // Prevent the generation of end_ids until the minimum length is reached.
      apply_min_length(step,
                       min_length,
                       end_ids,
                       disable_tokens,
                       batch_offset,
                       return_prefix,
                       prefix_ids);

      if (!logits_processors.empty()) {
        if (alive_seq)
          merge_batch_beam(alive_seq);
        for (const auto& logits_processor : logits_processors)
          logits_processor->apply(step, logits, disable_tokens, alive_seq, batch_offset, prefix_ids);
        if (alive_seq)
          split_batch_beam(alive_seq, _beam_size);
      }

      disable_tokens.apply();

      StorageView log_probs(dtype, device);
      if (bias_towards_prefix) {
        biased_decoder->decode(cur_batch_size,
                               step,
                               batch_offset,
                               beams_diverged_from_prefix,
                               logits,
                               log_probs);
      } else {
        ops::LogSoftMax()(logits);
        log_probs.shallow_copy(logits);
      }

      // Multiply by the current beam log probs.
      if (topk_scores) {
        DEVICE_AND_TYPE_DISPATCH(log_probs.device(), log_probs.dtype(),
                                 primitives<D>::add_depth_broadcast(topk_scores.to(device).data<T>(),
                                                                    log_probs.data<T>(),
                                                                    topk_scores.size(),
                                                                    log_probs.size()));
      }

      // Flatten the probs into a list of candidates.
      log_probs.reshape({cur_batch_size, -1});

      // TopK candidates.
      sampler(log_probs, topk_ids, topk_scores, num_candidates);

      // Unflatten the ids.
      StorageView gather_indices = unflatten_ids(topk_ids, _beam_size, vocabulary_size, is_expanded);

      if (prefix_ids) {
        if (use_hard_prefix) {
          update_sample_with_prefix(step,
                                    topk_ids,
                                    topk_scores,
                                    *prefix_ids,
                                    end_ids,
                                    batch_offset,
                                    _beam_size,
                                    &gather_indices,
                                    is_expanded);
        } else if (bias_towards_prefix) {
          beams_diverged_from_prefix = get_beams_divergence_from_prefix(beams_diverged_from_prefix,
                                                                        step,
                                                                        topk_ids,
                                                                        *prefix_ids,
                                                                        batch_offset);
        }
      }

      // Append last prediction.
      append_step_output(alive_seq, topk_ids, &gather_indices);

      if (attention_step) {
        if (!is_expanded)
          repeat_batch(attention_step, _beam_size);
        split_batch_beam(attention_step, _beam_size);
        append_step_output(alive_attention, attention_step.to_float32().to(Device::CPU));
        gather_beam_flat(alive_attention, gather_indices, num_candidates);
      }

      // Check if some hypotheses are finished.
      std::vector<int32_t> non_finished_index;
      non_finished_index.reserve(cur_batch_size);

      // Only keep the first beam_size candidates.
      StorageView active_beams({cur_batch_size * _beam_size}, DataType::INT32);

      for (dim_t i = 0; i < cur_batch_size; ++i) {
        const dim_t batch_id = batch_offset[i];
        const dim_t prefix_length = use_hard_prefix ? prefix_ids->at(batch_id).size() : 0;
        const bool is_last_step_for_batch = is_last_step(step,
                                                         max_length,
                                                         prefix_length,
                                                         return_prefix);

        auto& result = results[batch_id];
        dim_t secondary_candidates_offset = _beam_size;

        for (dim_t k = 0; k < _beam_size; ++k) {
          const size_t last_id = topk_ids.at<int32_t>({i, k});
          dim_t next_beam_id = k;

          if ((is_eos(last_id, end_ids) && step >= prefix_length) || is_last_step_for_batch) {
            if (k == 0)
              top_beam_finished[i] = true;

            const bool ignore_last_token = is_eos(last_id, end_ids) && !include_eos_in_hypotheses;
            const dim_t start = return_prefix ? 0 : prefix_length;
            const dim_t end = ignore_last_token ? step : step + 1;

            // Register this hypothesis.
            result.scores.emplace_back(topk_scores.scalar_at<float>({i, k}));
            result.hypotheses.emplace_back(build_hypothesis(alive_seq, i, k, start, end));
            if (alive_attention)
              result.attention.emplace_back(build_attention(alive_attention, i, k, start, end));

            // Move another active beam to this position.
            for (dim_t j = secondary_candidates_offset; j < num_candidates; ++j) {
              const auto candidate = topk_ids.at<int32_t>({i, j});
              if (!is_eos(candidate, end_ids)) {
                next_beam_id = j;
                secondary_candidates_offset = j + 1;
                break;
              }
            }
          }

          active_beams.at<int32_t>(i * _beam_size + k) = i * num_candidates + next_beam_id;
        }

        bool is_finished = false;
        if (is_last_step_for_batch)
          is_finished = true;
        else if (allow_early_exit)
          is_finished = top_beam_finished[i] && result.hypotheses.size() >= num_hypotheses;
        else
          is_finished = result.hypotheses.size() >= _max_candidates;

        if (is_finished) {
          finalize_result(result,
                          num_hypotheses,
                          _length_penalty,
                          _coverage_penalty,
                          return_scores,
                          return_attention);
        } else {
          non_finished_index.emplace_back(i);
        }
      }

      const dim_t next_batch_size = non_finished_index.size();

      // If all remaining sentences are finished, no need to go further.
      if (next_batch_size == 0) {
        if (!is_expanded) {
          // We should ensure that states are replicated before exiting this function.
          decoder.replicate_state(state, _beam_size);
        }
        break;
      }

      gather(gather_indices, active_beams);
      gather_beam_flat(topk_ids, active_beams, _beam_size);
      gather_beam_flat(topk_scores, active_beams, _beam_size);
      gather_beam_flat(alive_seq, active_beams, _beam_size);
      if (alive_attention)
        gather_beam_flat(alive_attention, active_beams, _beam_size);

      // If some sentences finished on this step, ignore them for the next step.
      std::unique_ptr<StorageView> keep_batches;
      if (next_batch_size != cur_batch_size) {
        batch_offset = index_vector(batch_offset, non_finished_index);
        top_beam_finished = index_vector(top_beam_finished, non_finished_index);
        if (bias_towards_prefix)
          beams_diverged_from_prefix = index_vector(beams_diverged_from_prefix, non_finished_index);

        keep_batches = std::make_unique<StorageView>(Shape{next_batch_size}, non_finished_index);
        gather(topk_ids, *keep_batches);
        gather(topk_scores, *keep_batches);
        gather(alive_seq, *keep_batches);
        if (alive_attention)
          gather(alive_attention, *keep_batches);
        if (keep_batches->device() != device)
          *keep_batches = keep_batches->to(device);
      }

      if (gather_indices.device() != device)
        gather_indices = gather_indices.to(device);
      decoder.update_state(state, gather_indices, _beam_size, keep_batches.get());

      topk_ids.reshape({next_batch_size * _beam_size});
      topk_scores.reshape({next_batch_size * _beam_size});

      if (bias_towards_prefix)
        bias_towards_prefix = !all_beams_diverged_from_prefix(beams_diverged_from_prefix);
    }

    return results;
  }

3 комментария:

Stanislav Mischenko комментирует...

Похоже, что вот здесь как раз тот случай, когда автор кода просто не думал о всех тех вещах, которые Вы описали. Причём видно, что он в курсе дела за модный C++, но видимо скучных книжек написанных серьёзными дядями не читал, ну или не дочитал.

Меня бы уже лично взорвало вот от этого:

for (dim_t step = 0; step < max_step; ++step) {
const bool is_expanded = (!expand_after_first_step || step > 0);

Вот серьёзно? Каждый раз заново вычислять значение переменной, которое не меняется после первого витка цикла! Я догадываюсь, что компилятор оптимизирует, но дело не в этом. Просто мне кажется, что это же элементарная вещь: не нужно вычислять заново то, что уже известно. Но автор этого не видит, для него это не важно.

И что-то мне подсказывает, что я одинок в своём мнении ;)

eao197 комментирует...

@Stanislav Mischenko

Там в цикле столько всего происходит, что вычисление одного bool-а вряд ли как-то повлияет на производительность.

Меня больше смущает то, что конкретно это условие записано так, что лично я сходу не могу понять какое значение оно должно принять :(

Ну и как от его вычисления можно избавиться? Я вижу два пути:

1. Сделать первую итерацию вне цикла и как раз для нее вычислять is_expanded как !expand_after_first_step. А последующие итерации делать в цикле, в котором начальным значением для step будет 1. Но это требует вообще тотального рефакторинга всего метода.

2. До цикла вычислить is_expanded как !expand_after_first_step, а в конце цикла тупо присваивать is_expanded true, без всяких условий.

Stanislav Mischenko комментирует...

@eao197
> вычисление одного bool-а вряд ли как-то повлияет на производительность.
Безусловно. Речь о другом - вычислять в цикле уже известное значение это как-то "непрофессионально" что-ли, как-то другого слова на ум не приходит.
> Ну и как от его вычисления можно избавиться? Я вижу два пути:
Всё так. Я вообще больше сторонник того, чтобы иметь несколько маленьких циклов без "бранчинга", чем один большой но с "бранчингами".