Таки да!
Вот реально, если бы этот фрагмент прилетел бы ко мне на сопровождение, то у меня бы нехило пригорело бы. Уже хотя бы от того, что даже в его оформлении используются элементы, которые меня лично сильно подбешивают. Но там и парочки объективных проблем хватило бы, чтобы я сильно приуныл.
Во-первых, количество аргументов метода search. Тут как в старой программерской мудрости: если ваша функция получает 10 аргументов, то вы наверняка забыли передать туда еще что-то :) Кроме того, есть подряд идущие аргументы одного типа (например, три аргумента типа dim_t, затем три аргумента типа bool). Мой опыт показывает, что это прямой путь к тому, что рано или поздно какое-то из значений будет перепутано.
Во-вторых, объем метода. Ну он просто слишком большой. Вот честно завидую людям, которые в состоянии все происходящее в методе такого объема удержать в памяти. И не просто удержать, а еще и убедится, что там все корректно.
Вот, скажем, мой взгляд зацепился за объявление переменной biased_decoder типа std::unique_ptr<BiasedDecoder>. Она инициализируется реальным значением только при выполнении некоторых условий, в остальных случаях там будет nullptr. Используется же потом этот biased_decoder всего лишь один раз, довольно далеко от того места, в котором он был создан. Где гарантии, что никто к biased_decoder не обратится еще раз, не проверив предварительно bias_towards_prefix?
Или вот этот вот keep_batches, который так же unique_ptr, но который объявляется вне ветки if-а, в которой он затем создается и используется. Нафига, спрашивается? Там же вообще, насколько я понимаю, можно было обойтись без unique_ptr и объявить keep_batches как обычную переменную на стеке прямо в том месте, где она потребовалась...
Конечно, очень смело с моей стороны судить о коде, в котором не разбираешься (а ведь я не в зуб ногой в этой предметной области). Но тут просто профессиональный нюх подсказывает, что этот код хрупок. И нипривидихоспади столкнуться с сопровождением такого кода. Хотя, конечно, доводилось видеть и пострашнее. Но там сразу было видно, что говно-говном. А здесь вроде как все чистенько и аккуратненько, но с запашком, однако :(
Для тех, кому лень переходить по ссылке на github, под катом весь фрагмент, к которому я докопался.
ЗЫ. На этот проект вышел через здесь, захотелось посмотреть, что там в других местах с качеством C++ного кода.
std::vector<DecodingResult> BeamSearch::search(layers::Decoder& decoder, layers::DecoderState& state, const Sampler& sampler, const std::vector<size_t>& start_ids, const std::vector<size_t>& end_ids, const dim_t start_step, const dim_t max_length, const dim_t min_length, const bool return_scores, const bool return_attention, const bool return_prefix, const size_t num_hypotheses, const bool include_eos_in_hypotheses, const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors, const std::vector<std::vector<size_t>>* prefix_ids) const { PROFILE("beam_search"); const Device device = decoder.device(); const DataType dtype = decoder.output_type(); const dim_t vocabulary_size = decoder.output_size(); const dim_t batch_size = start_ids.size(); // We get more candidates than the beam size so that if half the candidates are EOS, // we can replace finished hypotheses with active beams. const dim_t num_candidates = _beam_size * 2; // Only the first beam is considered in the first step. As an additional optimization // we try to run the first step without expanding the batch size. const bool expand_after_first_step = (device == Device::CPU && num_candidates <= vocabulary_size); // We can exit early when the first beam finishes and no penalties are used. const bool allow_early_exit = (_length_penalty == 0 && _coverage_penalty == 0); StorageView topk_ids({batch_size}, DataType::INT32); StorageView topk_scores(dtype); std::vector<bool> top_beam_finished(batch_size, false); std::vector<dim_t> batch_offset(batch_size); std::vector<DecodingResult> results(batch_size); for (dim_t i = 0; i < batch_size; ++i) { batch_offset[i] = i; topk_ids.at<int32_t>(i) = start_ids[i]; } if (!expand_after_first_step) { decoder.replicate_state(state, _beam_size); repeat_batch(topk_ids, _beam_size); TYPE_DISPATCH(dtype, initialize_beam_scores<T>(topk_scores, batch_size, _beam_size)); } std::unique_ptr<BiasedDecoder> biased_decoder; std::vector<std::vector<bool>> beams_diverged_from_prefix; bool bias_towards_prefix = prefix_ids && _prefix_bias_beta > 0; if (bias_towards_prefix) { biased_decoder = std::make_unique<BiasedDecoder>(_prefix_bias_beta, *prefix_ids); beams_diverged_from_prefix.resize(batch_size, std::vector<bool>(_beam_size, false)); } const bool use_hard_prefix = prefix_ids && !bias_towards_prefix; StorageView logits(dtype, device); StorageView alive_seq(topk_ids.dtype()); StorageView alive_attention; const dim_t max_step = get_max_step(max_length, return_prefix, use_hard_prefix ? prefix_ids : nullptr); for (dim_t step = 0; step < max_step; ++step) { const bool is_expanded = (!expand_after_first_step || step > 0); // Compute log probs for the current step. StorageView attention_step(dtype, device); convert_to_original_word_ids(decoder, topk_ids); decoder(start_step + step, topk_ids.to(device), state, &logits, // output shape: (cur_batch_size*beam_size x vocab_size), if not expanded beam_size is 1 (return_attention || _coverage_penalty != 0) ? &attention_step : nullptr); const dim_t cur_batch_size = is_expanded ? logits.dim(0) / _beam_size : logits.dim(0); DisableTokens disable_tokens(logits); // Prevent the generation of end_ids until the minimum length is reached. apply_min_length(step, min_length, end_ids, disable_tokens, batch_offset, return_prefix, prefix_ids); if (!logits_processors.empty()) { if (alive_seq) merge_batch_beam(alive_seq); for (const auto& logits_processor : logits_processors) logits_processor->apply(step, logits, disable_tokens, alive_seq, batch_offset, prefix_ids); if (alive_seq) split_batch_beam(alive_seq, _beam_size); } disable_tokens.apply(); StorageView log_probs(dtype, device); if (bias_towards_prefix) { biased_decoder->decode(cur_batch_size, step, batch_offset, beams_diverged_from_prefix, logits, log_probs); } else { ops::LogSoftMax()(logits); log_probs.shallow_copy(logits); } // Multiply by the current beam log probs. if (topk_scores) { DEVICE_AND_TYPE_DISPATCH(log_probs.device(), log_probs.dtype(), primitives<D>::add_depth_broadcast(topk_scores.to(device).data<T>(), log_probs.data<T>(), topk_scores.size(), log_probs.size())); } // Flatten the probs into a list of candidates. log_probs.reshape({cur_batch_size, -1}); // TopK candidates. sampler(log_probs, topk_ids, topk_scores, num_candidates); // Unflatten the ids. StorageView gather_indices = unflatten_ids(topk_ids, _beam_size, vocabulary_size, is_expanded); if (prefix_ids) { if (use_hard_prefix) { update_sample_with_prefix(step, topk_ids, topk_scores, *prefix_ids, end_ids, batch_offset, _beam_size, &gather_indices, is_expanded); } else if (bias_towards_prefix) { beams_diverged_from_prefix = get_beams_divergence_from_prefix(beams_diverged_from_prefix, step, topk_ids, *prefix_ids, batch_offset); } } // Append last prediction. append_step_output(alive_seq, topk_ids, &gather_indices); if (attention_step) { if (!is_expanded) repeat_batch(attention_step, _beam_size); split_batch_beam(attention_step, _beam_size); append_step_output(alive_attention, attention_step.to_float32().to(Device::CPU)); gather_beam_flat(alive_attention, gather_indices, num_candidates); } // Check if some hypotheses are finished. std::vector<int32_t> non_finished_index; non_finished_index.reserve(cur_batch_size); // Only keep the first beam_size candidates. StorageView active_beams({cur_batch_size * _beam_size}, DataType::INT32); for (dim_t i = 0; i < cur_batch_size; ++i) { const dim_t batch_id = batch_offset[i]; const dim_t prefix_length = use_hard_prefix ? prefix_ids->at(batch_id).size() : 0; const bool is_last_step_for_batch = is_last_step(step, max_length, prefix_length, return_prefix); auto& result = results[batch_id]; dim_t secondary_candidates_offset = _beam_size; for (dim_t k = 0; k < _beam_size; ++k) { const size_t last_id = topk_ids.at<int32_t>({i, k}); dim_t next_beam_id = k; if ((is_eos(last_id, end_ids) && step >= prefix_length) || is_last_step_for_batch) { if (k == 0) top_beam_finished[i] = true; const bool ignore_last_token = is_eos(last_id, end_ids) && !include_eos_in_hypotheses; const dim_t start = return_prefix ? 0 : prefix_length; const dim_t end = ignore_last_token ? step : step + 1; // Register this hypothesis. result.scores.emplace_back(topk_scores.scalar_at<float>({i, k})); result.hypotheses.emplace_back(build_hypothesis(alive_seq, i, k, start, end)); if (alive_attention) result.attention.emplace_back(build_attention(alive_attention, i, k, start, end)); // Move another active beam to this position. for (dim_t j = secondary_candidates_offset; j < num_candidates; ++j) { const auto candidate = topk_ids.at<int32_t>({i, j}); if (!is_eos(candidate, end_ids)) { next_beam_id = j; secondary_candidates_offset = j + 1; break; } } } active_beams.at<int32_t>(i * _beam_size + k) = i * num_candidates + next_beam_id; } bool is_finished = false; if (is_last_step_for_batch) is_finished = true; else if (allow_early_exit) is_finished = top_beam_finished[i] && result.hypotheses.size() >= num_hypotheses; else is_finished = result.hypotheses.size() >= _max_candidates; if (is_finished) { finalize_result(result, num_hypotheses, _length_penalty, _coverage_penalty, return_scores, return_attention); } else { non_finished_index.emplace_back(i); } } const dim_t next_batch_size = non_finished_index.size(); // If all remaining sentences are finished, no need to go further. if (next_batch_size == 0) { if (!is_expanded) { // We should ensure that states are replicated before exiting this function. decoder.replicate_state(state, _beam_size); } break; } gather(gather_indices, active_beams); gather_beam_flat(topk_ids, active_beams, _beam_size); gather_beam_flat(topk_scores, active_beams, _beam_size); gather_beam_flat(alive_seq, active_beams, _beam_size); if (alive_attention) gather_beam_flat(alive_attention, active_beams, _beam_size); // If some sentences finished on this step, ignore them for the next step. std::unique_ptr<StorageView> keep_batches; if (next_batch_size != cur_batch_size) { batch_offset = index_vector(batch_offset, non_finished_index); top_beam_finished = index_vector(top_beam_finished, non_finished_index); if (bias_towards_prefix) beams_diverged_from_prefix = index_vector(beams_diverged_from_prefix, non_finished_index); keep_batches = std::make_unique<StorageView>(Shape{next_batch_size}, non_finished_index); gather(topk_ids, *keep_batches); gather(topk_scores, *keep_batches); gather(alive_seq, *keep_batches); if (alive_attention) gather(alive_attention, *keep_batches); if (keep_batches->device() != device) *keep_batches = keep_batches->to(device); } if (gather_indices.device() != device) gather_indices = gather_indices.to(device); decoder.update_state(state, gather_indices, _beam_size, keep_batches.get()); topk_ids.reshape({next_batch_size * _beam_size}); topk_scores.reshape({next_batch_size * _beam_size}); if (bias_towards_prefix) bias_towards_prefix = !all_beams_diverged_from_prefix(beams_diverged_from_prefix); } return results; } |
Похоже, что вот здесь как раз тот случай, когда автор кода просто не думал о всех тех вещах, которые Вы описали. Причём видно, что он в курсе дела за модный C++, но видимо скучных книжек написанных серьёзными дядями не читал, ну или не дочитал.
ОтветитьУдалитьМеня бы уже лично взорвало вот от этого:
for (dim_t step = 0; step < max_step; ++step) {
const bool is_expanded = (!expand_after_first_step || step > 0);
Вот серьёзно? Каждый раз заново вычислять значение переменной, которое не меняется после первого витка цикла! Я догадываюсь, что компилятор оптимизирует, но дело не в этом. Просто мне кажется, что это же элементарная вещь: не нужно вычислять заново то, что уже известно. Но автор этого не видит, для него это не важно.
И что-то мне подсказывает, что я одинок в своём мнении ;)
@Stanislav Mischenko
ОтветитьУдалитьТам в цикле столько всего происходит, что вычисление одного bool-а вряд ли как-то повлияет на производительность.
Меня больше смущает то, что конкретно это условие записано так, что лично я сходу не могу понять какое значение оно должно принять :(
Ну и как от его вычисления можно избавиться? Я вижу два пути:
1. Сделать первую итерацию вне цикла и как раз для нее вычислять is_expanded как !expand_after_first_step. А последующие итерации делать в цикле, в котором начальным значением для step будет 1. Но это требует вообще тотального рефакторинга всего метода.
2. До цикла вычислить is_expanded как !expand_after_first_step, а в конце цикла тупо присваивать is_expanded true, без всяких условий.
@eao197
ОтветитьУдалить> вычисление одного bool-а вряд ли как-то повлияет на производительность.
Безусловно. Речь о другом - вычислять в цикле уже известное значение это как-то "непрофессионально" что-ли, как-то другого слова на ум не приходит.
> Ну и как от его вычисления можно избавиться? Я вижу два пути:
Всё так. Я вообще больше сторонник того, чтобы иметь несколько маленьких циклов без "бранчинга", чем один большой но с "бранчингами".